Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from src.executive_assistant.agent import BaselineAgent, OpenRouterPolicy | |
| from src.executive_assistant.config import OpenRouterConfig, TrainingRuntimeConfig, load_env_file | |
| from src.executive_assistant.runner import export_traces_jsonl, run_policy_suite | |
| TASKS = [ | |
| "easy_deadline_extraction", | |
| "medium_triage_and_negotiation", | |
| "hard_rag_reply", | |
| ] | |
| def build_policy(provider: str, model_name: str) -> object: | |
| if provider == "baseline": | |
| return BaselineAgent() | |
| if provider == "openrouter": | |
| load_env_file(TrainingRuntimeConfig().env_file) | |
| config = OpenRouterConfig.from_env() | |
| config = OpenRouterConfig( | |
| api_key=config.api_key, | |
| model_name=model_name, | |
| base_url=config.base_url, | |
| site_url=config.site_url, | |
| app_name=config.app_name, | |
| temperature=config.temperature, | |
| max_tokens=config.max_tokens, | |
| ) | |
| return OpenRouterPolicy(config=config) | |
| raise ValueError(f"Unsupported provider: {provider}") | |
| def main() -> None: | |
| load_env_file(TrainingRuntimeConfig().env_file) | |
| parser = argparse.ArgumentParser(description="Evaluate a policy over all seeded tasks.") | |
| parser.add_argument("--provider", choices=["baseline", "openrouter"], default="baseline") | |
| parser.add_argument("--model", default="google/gemma-4-31b-it") | |
| parser.add_argument("--max-steps", type=int, default=12) | |
| parser.add_argument("--output", default="") | |
| args = parser.parse_args() | |
| traces = run_policy_suite( | |
| policy=build_policy(args.provider, args.model), | |
| task_names=TASKS, | |
| max_steps=args.max_steps, | |
| ) | |
| summary = { | |
| task_name: { | |
| "completed": trace.completed, | |
| "final_score": trace.final_score, | |
| "steps": len(trace.steps), | |
| "termination_reason": trace.termination_reason, | |
| } | |
| for task_name, trace in traces.items() | |
| } | |
| print(json.dumps(summary, indent=2)) | |
| if args.output: | |
| export_traces_jsonl(list(traces.values()), args.output) | |
| print(f"Saved traces to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |