Spaces:
Sleeping
Sleeping
File size: 1,502 Bytes
38c9982 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.executive_assistant.agent import BaselineAgent
from src.executive_assistant.config import TrainingRuntimeConfig, load_env_file
from src.executive_assistant.training import evaluate_q_policy, train_q_learning
def main() -> None:
load_env_file(TrainingRuntimeConfig().env_file)
parser = argparse.ArgumentParser(description="Train a tabular RL policy for seeded tasks.")
parser.add_argument("--episodes", type=int, default=300)
parser.add_argument("--epsilon", type=float, default=0.15)
parser.add_argument("--checkpoint", default="artifacts/checkpoints/q_policy.json")
parser.add_argument("--no-teacher", action="store_true")
args = parser.parse_args()
teacher = None if args.no_teacher else BaselineAgent()
policy, training_scores = train_q_learning(
episodes=args.episodes,
epsilon=args.epsilon,
teacher=teacher,
)
checkpoint_path = policy.save(args.checkpoint)
evaluation = evaluate_q_policy(policy)
print(
json.dumps(
{
"checkpoint": str(checkpoint_path),
"training_scores": training_scores,
"evaluation": evaluation,
},
indent=2,
)
)
if __name__ == "__main__":
main()
|