File size: 1,502 Bytes
38c9982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.executive_assistant.agent import BaselineAgent
from src.executive_assistant.config import TrainingRuntimeConfig, load_env_file
from src.executive_assistant.training import evaluate_q_policy, train_q_learning


def main() -> None:
    load_env_file(TrainingRuntimeConfig().env_file)
    parser = argparse.ArgumentParser(description="Train a tabular RL policy for seeded tasks.")
    parser.add_argument("--episodes", type=int, default=300)
    parser.add_argument("--epsilon", type=float, default=0.15)
    parser.add_argument("--checkpoint", default="artifacts/checkpoints/q_policy.json")
    parser.add_argument("--no-teacher", action="store_true")
    args = parser.parse_args()

    teacher = None if args.no_teacher else BaselineAgent()
    policy, training_scores = train_q_learning(
        episodes=args.episodes,
        epsilon=args.epsilon,
        teacher=teacher,
    )
    checkpoint_path = policy.save(args.checkpoint)
    evaluation = evaluate_q_policy(policy)
    print(
        json.dumps(
            {
                "checkpoint": str(checkpoint_path),
                "training_scores": training_scores,
                "evaluation": evaluation,
            },
            indent=2,
        )
    )


if __name__ == "__main__":
    main()