# Copyright Sierra import argparse import logging from dotenv import load_dotenv from litellm import provider_list from linalg_zero.grpo.envs.user import UserStrategy from linalg_zero.grpo.run import run from linalg_zero.grpo.types import RunConfig load_dotenv(override=True) # Suppress LiteLLM logging spam logging.getLogger("LiteLLM").setLevel(logging.WARNING) def parse_args() -> RunConfig: parser = argparse.ArgumentParser() parser.add_argument("--project-id", type=int, default=0) parser.add_argument("--project-name", type=str, default="linalgzero-grpo-eval") parser.add_argument("--num-trials", type=int, default=1) parser.add_argument("--env", type=str, choices=["linear_algebra"], default="linear_algebra") parser.add_argument( "--model", type=str, help="The model to use for the agent", ) parser.add_argument( "--model-provider", type=str, choices=provider_list, help="The model provider for the agent", ) parser.add_argument( "--user-model", type=str, default="gpt-4o", help="The model to use for the user simulator", ) parser.add_argument( "--user-model-provider", type=str, choices=provider_list, help="The model provider for the user simulator", ) parser.add_argument( "--agent-strategy", type=str, default="tool-calling", choices=["tool-calling", "act", "react", "few-shot"], ) parser.add_argument( "--temperature", type=float, default=0.0, help="The sampling temperature for the action model", ) parser.add_argument( "--task-split", type=str, default="test", choices=["train", "test", "dev"], help="The split of tasks to run (only applies to the retail domain for now", ) parser.add_argument("--start-index", type=int, default=0) parser.add_argument("--end-index", type=int, default=-1, help="Run all tasks if -1") parser.add_argument( "--task-ids", type=int, nargs="+", help="(Optional) run only the tasks with the given IDs", ) parser.add_argument("--log-dir", type=str, default="results") parser.add_argument( "--max-concurrency", type=int, default=1, help="Number of tasks to run in parallel", ) parser.add_argument("--seed", type=int, default=10) parser.add_argument("--shuffle", type=int, default=0) parser.add_argument( "--user-strategy", type=str, default="llm", choices=[item.value for item in UserStrategy], ) parser.add_argument( "--few-shot-displays-path", type=str, help="Path to a jsonlines file containing few shot displays", ) parser.add_argument( "--dataset-path", default="atomwalk12/linalgzero-grpo", type=str, help="Path to the huggingface dataset" ) args = parser.parse_args() print(args) return RunConfig( project_id=args.project_id, dataset_path=args.dataset_path, project=args.project_name, model_provider=args.model_provider, model=args.model, num_trials=args.num_trials, env=args.env, agent_strategy=args.agent_strategy, temperature=args.temperature, task_split=args.task_split, start_index=args.start_index, end_index=args.end_index, task_ids=args.task_ids, log_dir=args.log_dir, max_concurrency=args.max_concurrency, seed=args.seed, shuffle=args.shuffle, user_strategy=args.user_strategy, few_shot_displays_path=args.few_shot_displays_path, ) def main(): config = parse_args() run(config) if __name__ == "__main__": main()