Spaces:
Paused
Paused
File size: 3,846 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | # Copyright Sierra
import argparse
import logging
from dotenv import load_dotenv
from litellm import provider_list
from linalg_zero.grpo.envs.user import UserStrategy
from linalg_zero.grpo.run import run
from linalg_zero.grpo.types import RunConfig
load_dotenv(override=True)
# Suppress LiteLLM logging spam
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
def parse_args() -> RunConfig:
parser = argparse.ArgumentParser()
parser.add_argument("--project-id", type=int, default=0)
parser.add_argument("--project-name", type=str, default="linalgzero-grpo-eval")
parser.add_argument("--num-trials", type=int, default=1)
parser.add_argument("--env", type=str, choices=["linear_algebra"], default="linear_algebra")
parser.add_argument(
"--model",
type=str,
help="The model to use for the agent",
)
parser.add_argument(
"--model-provider",
type=str,
choices=provider_list,
help="The model provider for the agent",
)
parser.add_argument(
"--user-model",
type=str,
default="gpt-4o",
help="The model to use for the user simulator",
)
parser.add_argument(
"--user-model-provider",
type=str,
choices=provider_list,
help="The model provider for the user simulator",
)
parser.add_argument(
"--agent-strategy",
type=str,
default="tool-calling",
choices=["tool-calling", "act", "react", "few-shot"],
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="The sampling temperature for the action model",
)
parser.add_argument(
"--task-split",
type=str,
default="test",
choices=["train", "test", "dev"],
help="The split of tasks to run (only applies to the retail domain for now",
)
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--end-index", type=int, default=-1, help="Run all tasks if -1")
parser.add_argument(
"--task-ids",
type=int,
nargs="+",
help="(Optional) run only the tasks with the given IDs",
)
parser.add_argument("--log-dir", type=str, default="results")
parser.add_argument(
"--max-concurrency",
type=int,
default=1,
help="Number of tasks to run in parallel",
)
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--shuffle", type=int, default=0)
parser.add_argument(
"--user-strategy",
type=str,
default="llm",
choices=[item.value for item in UserStrategy],
)
parser.add_argument(
"--few-shot-displays-path",
type=str,
help="Path to a jsonlines file containing few shot displays",
)
parser.add_argument(
"--dataset-path", default="atomwalk12/linalgzero-grpo", type=str, help="Path to the huggingface dataset"
)
args = parser.parse_args()
print(args)
return RunConfig(
project_id=args.project_id,
dataset_path=args.dataset_path,
project=args.project_name,
model_provider=args.model_provider,
model=args.model,
num_trials=args.num_trials,
env=args.env,
agent_strategy=args.agent_strategy,
temperature=args.temperature,
task_split=args.task_split,
start_index=args.start_index,
end_index=args.end_index,
task_ids=args.task_ids,
log_dir=args.log_dir,
max_concurrency=args.max_concurrency,
seed=args.seed,
shuffle=args.shuffle,
user_strategy=args.user_strategy,
few_shot_displays_path=args.few_shot_displays_path,
)
def main():
config = parse_args()
run(config)
if __name__ == "__main__":
main()
|