Spaces:
Running on Zero
Running on Zero
| # Copyright Sierra | |
| import argparse | |
| import logging | |
| from dotenv import load_dotenv | |
| from litellm import provider_list | |
| from linalg_zero.grpo.envs.user import UserStrategy | |
| from linalg_zero.grpo.run import run | |
| from linalg_zero.grpo.types import RunConfig | |
| load_dotenv(override=True) | |
| # Suppress LiteLLM logging spam | |
| logging.getLogger("LiteLLM").setLevel(logging.WARNING) | |
| def parse_args() -> RunConfig: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--project-id", type=int, default=0) | |
| parser.add_argument("--project-name", type=str, default="linalgzero-grpo-eval") | |
| parser.add_argument("--num-trials", type=int, default=1) | |
| parser.add_argument("--env", type=str, choices=["linear_algebra"], default="linear_algebra") | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| help="The model to use for the agent", | |
| ) | |
| parser.add_argument( | |
| "--model-provider", | |
| type=str, | |
| choices=provider_list, | |
| help="The model provider for the agent", | |
| ) | |
| parser.add_argument( | |
| "--user-model", | |
| type=str, | |
| default="gpt-4o", | |
| help="The model to use for the user simulator", | |
| ) | |
| parser.add_argument( | |
| "--user-model-provider", | |
| type=str, | |
| choices=provider_list, | |
| help="The model provider for the user simulator", | |
| ) | |
| parser.add_argument( | |
| "--agent-strategy", | |
| type=str, | |
| default="tool-calling", | |
| choices=["tool-calling", "act", "react", "few-shot"], | |
| ) | |
| parser.add_argument( | |
| "--temperature", | |
| type=float, | |
| default=0.0, | |
| help="The sampling temperature for the action model", | |
| ) | |
| parser.add_argument( | |
| "--task-split", | |
| type=str, | |
| default="test", | |
| choices=["train", "test", "dev"], | |
| help="The split of tasks to run (only applies to the retail domain for now", | |
| ) | |
| parser.add_argument("--start-index", type=int, default=0) | |
| parser.add_argument("--end-index", type=int, default=-1, help="Run all tasks if -1") | |
| parser.add_argument( | |
| "--task-ids", | |
| type=int, | |
| nargs="+", | |
| help="(Optional) run only the tasks with the given IDs", | |
| ) | |
| parser.add_argument("--log-dir", type=str, default="results") | |
| parser.add_argument( | |
| "--max-concurrency", | |
| type=int, | |
| default=1, | |
| help="Number of tasks to run in parallel", | |
| ) | |
| parser.add_argument("--seed", type=int, default=10) | |
| parser.add_argument("--shuffle", type=int, default=0) | |
| parser.add_argument( | |
| "--user-strategy", | |
| type=str, | |
| default="llm", | |
| choices=[item.value for item in UserStrategy], | |
| ) | |
| parser.add_argument( | |
| "--few-shot-displays-path", | |
| type=str, | |
| help="Path to a jsonlines file containing few shot displays", | |
| ) | |
| parser.add_argument( | |
| "--dataset-path", default="atomwalk12/linalgzero-grpo", type=str, help="Path to the huggingface dataset" | |
| ) | |
| args = parser.parse_args() | |
| print(args) | |
| return RunConfig( | |
| project_id=args.project_id, | |
| dataset_path=args.dataset_path, | |
| project=args.project_name, | |
| model_provider=args.model_provider, | |
| model=args.model, | |
| num_trials=args.num_trials, | |
| env=args.env, | |
| agent_strategy=args.agent_strategy, | |
| temperature=args.temperature, | |
| task_split=args.task_split, | |
| start_index=args.start_index, | |
| end_index=args.end_index, | |
| task_ids=args.task_ids, | |
| log_dir=args.log_dir, | |
| max_concurrency=args.max_concurrency, | |
| seed=args.seed, | |
| shuffle=args.shuffle, | |
| user_strategy=args.user_strategy, | |
| few_shot_displays_path=args.few_shot_displays_path, | |
| ) | |
| def main(): | |
| config = parse_args() | |
| run(config) | |
| if __name__ == "__main__": | |
| main() | |