Spaces:
Runtime error
Runtime error
| """Headless batch test — run the agent on multiple seeds and print results.""" | |
| import asyncio | |
| import json | |
| import os | |
| import random | |
| from datetime import date, timedelta | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| import websockets | |
| load_dotenv() | |
| WS_URL = "ws://localhost:8000/ws" | |
| MODEL = "llama-3.3-70b-versatile" | |
| EPISODE_BASE_DATE = date(2026, 1, 5) | |
| EPISODE_WEEKS = 3 | |
| MAX_STEPS = 60 | |
| SYSTEM_PROMPT = """You are a calendar personal assistant managing a team's calendar. | |
| Start by reading your inbox (read_inbox) and reviewing the calendar (list_events) to understand what needs to be done today. | |
| IMPORTANT workflow: | |
| 1. Read your inbox to see pending requests from your boss and team. | |
| 2. Review the calendar to understand the current schedule. | |
| 3. BEFORE scheduling any meeting, call get_contact_preferences(person) to learn their constraints. | |
| 4. Use check_availability before scheduling — don't guess times. | |
| 5. Respect HARD constraints (must obey) and SOFT constraints (preferences). | |
| 6. After making changes, call check_constraint_violations to verify. | |
| 7. Keep checking your inbox — new messages may arrive while you work. | |
| 8. When attendees decline a meeting, read their feedback and adjust your proposal. | |
| 9. Personal events on the calendar are immovable — schedule work around them. | |
| 10. Reply to messages that need responses. | |
| 11. Inbox-driven requests are tracked in the inbox; if any task-style view omits them, use read_inbox as the source of truth. | |
| 12. Think step by step about what tools to call and in what order. | |
| 13. Pay attention to policy changes announced via interrupts. Rules may change mid-session. | |
| 14. If someone's availability changes, re-validate any meetings you already scheduled with that person. | |
| 15. Family may text mid-session about personal event changes. Re-check for new conflicts.""" | |
| TOOLS = [ | |
| {"type": "function", "function": {"name": "list_events", "description": "List all calendar events for a given date.", "parameters": {"type": "object", "properties": {"date": {"type": "string", "description": "Date string: 'today', 'tomorrow', 'next monday', or YYYY-MM-DD"}}, "required": []}}}, | |
| {"type": "function", "function": {"name": "create_event", "description": "Create a calendar event.", "parameters": {"type": "object", "properties": {"title": {"type": "string"}, "date": {"type": "string", "description": "'today', 'tomorrow', 'next monday', or YYYY-MM-DD"}, "start_time": {"type": "string", "description": "HH:MM format"}, "duration_minutes": {"type": "integer", "default": 60}, "attendees": {"type": "string", "description": "Comma-separated names"}, "description": {"type": "string", "description": "Meeting agenda/description (required for meetings >30 min after policy update)"}}, "required": ["title", "date", "start_time"]}}}, | |
| {"type": "function", "function": {"name": "delete_event", "description": "Delete a calendar event by title.", "parameters": {"type": "object", "properties": {"title": {"type": "string"}}, "required": ["title"]}}}, | |
| {"type": "function", "function": {"name": "edit_event", "description": "Edit an existing event. Only provided fields are changed.", "parameters": {"type": "object", "properties": {"title": {"type": "string", "description": "Current title of the event to edit"}, "new_title": {"type": "string"}, "new_date": {"type": "string"}, "new_start_time": {"type": "string", "description": "HH:MM"}, "new_duration_minutes": {"type": "integer"}, "new_attendees": {"type": "string", "description": "Comma-separated, replaces all"}, "new_description": {"type": "string", "description": "New description/agenda for the event"}}, "required": ["title"]}}}, | |
| {"type": "function", "function": {"name": "find_free_slots", "description": "Find available time slots on a given date (8:00-18:00).", "parameters": {"type": "object", "properties": {"date": {"type": "string", "default": "today"}, "duration_minutes": {"type": "integer", "default": 60}}, "required": []}}}, | |
| {"type": "function", "function": {"name": "check_conflicts", "description": "Check for scheduling conflicts on a date.", "parameters": {"type": "object", "properties": {"date": {"type": "string", "default": "today"}}, "required": []}}}, | |
| {"type": "function", "function": {"name": "resolve_conflict", "description": "Resolve a conflict by moving an event to a new time.", "parameters": {"type": "object", "properties": {"event_title": {"type": "string"}, "new_start_time": {"type": "string", "description": "HH:MM format"}}, "required": ["event_title", "new_start_time"]}}}, | |
| {"type": "function", "function": {"name": "send_notification", "description": "Send a notification to a person.", "parameters": {"type": "object", "properties": {"to": {"type": "string"}, "message": {"type": "string"}}, "required": ["to", "message"]}}}, | |
| {"type": "function", "function": {"name": "check_availability", "description": "Check a person's availability on a given date.", "parameters": {"type": "object", "properties": {"person": {"type": "string"}, "date": {"type": "string", "default": "today"}}, "required": ["person"]}}}, | |
| {"type": "function", "function": {"name": "get_constraints", "description": "Get scheduling constraints (hard and soft) that apply to the calendar. Note: individual people may have additional private constraints — use get_contact_preferences to discover them.", "parameters": {"type": "object", "properties": {}, "required": []}}}, | |
| {"type": "function", "function": {"name": "get_contact_preferences", "description": "Get a person's scheduling preferences, private constraints, role, and preferred notification method. Some constraints are only visible through this tool.", "parameters": {"type": "object", "properties": {"person": {"type": "string", "description": "Person's name (e.g. Alice, Bob, Charlie, Dave, Eve)"}}, "required": ["person"]}}}, | |
| {"type": "function", "function": {"name": "check_constraint_violations", "description": "Check the current calendar for all constraint violations.", "parameters": {"type": "object", "properties": {}, "required": []}}}, | |
| {"type": "function", "function": {"name": "read_inbox", "description": "List inbox messages, filtered by all/unread/unreplied.", "parameters": {"type": "object", "properties": {"status": {"type": "string", "description": "Filter: 'all', 'unread', or 'unreplied'", "default": "all"}}, "required": []}}}, | |
| {"type": "function", "function": {"name": "reply_message", "description": "Reply to an inbox message. Reply must address the sender's concern.", "parameters": {"type": "object", "properties": {"message_id": {"type": "string", "description": "ID of the message to reply to"}, "reply": {"type": "string", "description": "Your reply text (min 20 chars, must address the sender's ask)"}}, "required": ["message_id", "reply"]}}}, | |
| {"type": "function", "function": {"name": "check_personal_calendar", "description": "Show personal (immovable) events that cannot be moved or deleted.", "parameters": {"type": "object", "properties": {}, "required": []}}}, | |
| ] | |
| def _seed_to_episode_today(seed: int) -> str: | |
| rng = random.Random(seed) | |
| weekdays = [] | |
| for week in range(EPISODE_WEEKS): | |
| for day in range(5): | |
| weekdays.append(EPISODE_BASE_DATE + timedelta(weeks=week, days=day)) | |
| return rng.choice(weekdays).isoformat() | |
| async def run_seed(client: OpenAI, seed: int) -> dict: | |
| episode_today = _seed_to_episode_today(seed) | |
| print(f"\n{'='*60}") | |
| print(f"SEED {seed} | Episode date: {episode_today}") | |
| print(f"{'='*60}") | |
| async with websockets.connect(WS_URL, ping_interval=20, ping_timeout=120) as ws: | |
| # Reset | |
| await ws.send(json.dumps({"type": "reset", "data": {"seed": seed}})) | |
| reset_resp = json.loads(await ws.recv()) | |
| obs = reset_resp["data"]["observation"] | |
| # Sliding window: state summary + last N exchanges | |
| SLIDING_WINDOW = 6 | |
| history = [] | |
| latest_state_summary = obs.get("state_summary", obs["output"]) | |
| total_tasks = len(obs.get("flags_found", [])) + obs.get("pending_tasks", 0) | |
| final_reward = 0 | |
| final_flags = [] | |
| steps_used = 0 | |
| for step in range(MAX_STEPS): | |
| steps_used = step + 1 | |
| # Build messages: system + state summary + recent history | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": latest_state_summary}, | |
| ] + history[-SLIDING_WINDOW:] | |
| try: | |
| response = await asyncio.to_thread( | |
| client.chat.completions.create, | |
| model=MODEL, messages=messages, tools=TOOLS, tool_choice="auto", | |
| ) | |
| except Exception as e: | |
| print(f" Step {step+1}: API error: {e}") | |
| break | |
| msg = response.choices[0].message | |
| if msg.tool_calls: | |
| history.append(msg) | |
| for tool_call in msg.tool_calls: | |
| tool_name = tool_call.function.name | |
| args = json.loads(tool_call.function.arguments) | |
| print(f" Step {step+1}: {tool_name}({json.dumps(args, separators=(',',':'))})") | |
| action = {"type": "step", "data": {"instruction": json.dumps({"tool": tool_name, "args": args})}} | |
| await ws.send(json.dumps(action)) | |
| step_resp = json.loads(await ws.recv()) | |
| data = step_resp["data"] | |
| obs = data["observation"] | |
| final_reward = data.get("reward", 0) | |
| final_flags = obs.get("flags_found", []) | |
| total_tasks = len(final_flags) + obs.get("pending_tasks", 0) | |
| # Update state summary from latest observation | |
| latest_state_summary = obs.get("state_summary", latest_state_summary) | |
| history.append({ | |
| "role": "tool", | |
| "tool_call_id": tool_call.id, | |
| "content": obs["output"], | |
| }) | |
| if data.get("done", False): | |
| print(f" >>> DONE at step {step+1}") | |
| break | |
| if data.get("done", False): | |
| break | |
| else: | |
| content = msg.content or "" | |
| print(f" Step {step+1}: [text] {content[:80]}...") | |
| history.append(msg) | |
| history.append({"role": "user", "content": "Continue completing the remaining tasks. Use the tools available to you."}) | |
| # Fetch final calendar | |
| calendar_lines = [] | |
| today_dt = date.fromisoformat(episode_today) | |
| days_since_mon = today_dt.weekday() | |
| week_start = today_dt - timedelta(days=days_since_mon) | |
| for week in range(2): | |
| for day in range(5): | |
| d = (week_start + timedelta(weeks=week, days=day)).isoformat() | |
| action = {"type": "step", "data": {"instruction": json.dumps({"tool": "list_events", "args": {"date": d}})}} | |
| await ws.send(json.dumps(action)) | |
| resp = json.loads(await ws.recv()) | |
| output = resp["data"]["observation"]["output"] | |
| if "No events" not in output: | |
| calendar_lines.append(output) | |
| calendar = "\n".join(calendar_lines) if calendar_lines else "No events." | |
| print(f"\n REWARD: {final_reward:.2f} ({len(final_flags)}/{total_tasks})") | |
| print(f" FLAGS: {sorted(final_flags)}") | |
| print(f" STEPS: {steps_used}") | |
| print(f" CALENDAR:\n{calendar}") | |
| return { | |
| "seed": seed, | |
| "date": episode_today, | |
| "reward": final_reward, | |
| "flags": sorted(final_flags), | |
| "steps": steps_used, | |
| "total_tasks": total_tasks, | |
| } | |
| async def main(): | |
| seeds = [1, 3, 5, 9, 12] | |
| client = OpenAI( | |
| base_url="https://api.groq.com/openai/v1", | |
| api_key=os.getenv("GROQ_API_KEY"), | |
| ) | |
| results = [] | |
| for seed in seeds: | |
| try: | |
| result = await run_seed(client, seed) | |
| results.append(result) | |
| except Exception as e: | |
| print(f"\nSEED {seed} FAILED: {e}") | |
| results.append({"seed": seed, "reward": 0, "flags": [], "error": str(e)}) | |
| print(f"\n{'='*60}") | |
| print("SUMMARY") | |
| print(f"{'='*60}") | |
| for r in results: | |
| flags_count = len(r.get("flags", [])) | |
| total_tasks = r.get("total_tasks", "?") | |
| print(f" Seed {r['seed']:3d} | {r.get('date','?'):>10s} | Reward: {r.get('reward',0):.2f} ({flags_count}/{total_tasks}) | Steps: {r.get('steps','?')}") | |
| rewards = [r.get("reward", 0) for r in results] | |
| print(f"\n Average reward: {sum(rewards)/len(rewards):.2f}") | |
| print(f" Min: {min(rewards):.2f} | Max: {max(rewards):.2f}") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |