atomwalk12's picture
initial commit
0dd6c2f
# Copyright Sierra
import argparse
import logging
from dotenv import load_dotenv
from litellm import provider_list
from linalg_zero.grpo.envs.user import UserStrategy
from linalg_zero.grpo.run import run
from linalg_zero.grpo.types import RunConfig
load_dotenv(override=True)
# Suppress LiteLLM logging spam
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
def parse_args() -> RunConfig:
parser = argparse.ArgumentParser()
parser.add_argument("--project-id", type=int, default=0)
parser.add_argument("--project-name", type=str, default="linalgzero-grpo-eval")
parser.add_argument("--num-trials", type=int, default=1)
parser.add_argument("--env", type=str, choices=["linear_algebra"], default="linear_algebra")
parser.add_argument(
"--model",
type=str,
help="The model to use for the agent",
)
parser.add_argument(
"--model-provider",
type=str,
choices=provider_list,
help="The model provider for the agent",
)
parser.add_argument(
"--user-model",
type=str,
default="gpt-4o",
help="The model to use for the user simulator",
)
parser.add_argument(
"--user-model-provider",
type=str,
choices=provider_list,
help="The model provider for the user simulator",
)
parser.add_argument(
"--agent-strategy",
type=str,
default="tool-calling",
choices=["tool-calling", "act", "react", "few-shot"],
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="The sampling temperature for the action model",
)
parser.add_argument(
"--task-split",
type=str,
default="test",
choices=["train", "test", "dev"],
help="The split of tasks to run (only applies to the retail domain for now",
)
parser.add_argument("--start-index", type=int, default=0)
parser.add_argument("--end-index", type=int, default=-1, help="Run all tasks if -1")
parser.add_argument(
"--task-ids",
type=int,
nargs="+",
help="(Optional) run only the tasks with the given IDs",
)
parser.add_argument("--log-dir", type=str, default="results")
parser.add_argument(
"--max-concurrency",
type=int,
default=1,
help="Number of tasks to run in parallel",
)
parser.add_argument("--seed", type=int, default=10)
parser.add_argument("--shuffle", type=int, default=0)
parser.add_argument(
"--user-strategy",
type=str,
default="llm",
choices=[item.value for item in UserStrategy],
)
parser.add_argument(
"--few-shot-displays-path",
type=str,
help="Path to a jsonlines file containing few shot displays",
)
parser.add_argument(
"--dataset-path", default="atomwalk12/linalgzero-grpo", type=str, help="Path to the huggingface dataset"
)
args = parser.parse_args()
print(args)
return RunConfig(
project_id=args.project_id,
dataset_path=args.dataset_path,
project=args.project_name,
model_provider=args.model_provider,
model=args.model,
num_trials=args.num_trials,
env=args.env,
agent_strategy=args.agent_strategy,
temperature=args.temperature,
task_split=args.task_split,
start_index=args.start_index,
end_index=args.end_index,
task_ids=args.task_ids,
log_dir=args.log_dir,
max_concurrency=args.max_concurrency,
seed=args.seed,
shuffle=args.shuffle,
user_strategy=args.user_strategy,
few_shot_displays_path=args.few_shot_displays_path,
)
def main():
config = parse_args()
run(config)
if __name__ == "__main__":
main()