Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Inference loop for Email Assistant OpenEnv. | |
| Features: | |
| - Connect to env locally via OpenEnv API. | |
| - Run multiple tasks (easy, medium, hard). | |
| - Query an LLM for each action (OpenAI), with safe fallback policy. | |
| - Log step and episode level metrics. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| from typing import Any | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from app.openenv_env.models import Action, Observation | |
| from client import EmailAssistantEnvClient | |
| def configure_logging() -> None: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| ) | |
| def build_prompt(observation: Observation) -> list[dict[str, str]]: | |
| current = observation.current_email | |
| recent = [r.action.model_dump() for r in observation.previous_actions[-5:]] | |
| system = ( | |
| "You are an email assistant agent. Return a single JSON object for the next action. " | |
| "The action must have a 'type' and a 'payload' object.\n" | |
| "Available actions:\n" | |
| "- type='classify', payload={'intent': 'Support|Sales|Spam|General', 'confidence': 0..1, 'reasoning': '...'}\n" | |
| "- type='prioritize', payload={'message_id': '...'}\n" | |
| "- type='reply', payload={'tone': 'formal|neutral|casual'}\n" | |
| "- type='send', payload={'to_email': '...', 'subject': '...', 'body': '...'}\n" | |
| "- type='skip', payload={'escalate': bool, 'reason': '...'}" | |
| ) | |
| user = { | |
| "current_email": { | |
| "message_id": current.message_id, | |
| "from_email": current.from_email, | |
| "subject": current.subject, | |
| "body": current.body, | |
| }, | |
| "inbox_summary": [ | |
| { | |
| "message_id": i.message_id, | |
| "subject": i.subject, | |
| "deadline": i.deadline_minutes, | |
| "urgency": i.urgency_score, | |
| "intent": i.predicted_intent, | |
| "handled": i.handled | |
| } for i in observation.inbox_summary | |
| ], | |
| "recent_actions": recent, | |
| } | |
| return [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": json.dumps(user)}, | |
| ] | |
| def fallback_policy(observation: Observation) -> Action: | |
| # Very simple deterministic logic for demonstration | |
| email = observation.current_email | |
| # Check if we have an intent for the current email | |
| curr_summary = next((i for i in observation.inbox_summary if i.message_id == email.message_id), None) | |
| if curr_summary and not curr_summary.predicted_intent: | |
| return Action(type="classify", payload={}) # Trigger tool-driven classification | |
| # If classified but not handled, try to reply | |
| return Action(type="reply", payload={"tone": "formal"}) | |
| def action_from_llm(observation: Observation, llm: OpenAI | None, model: str) -> Action: | |
| if llm is None: | |
| return fallback_policy(observation) | |
| try: | |
| resp = llm.chat.completions.create( | |
| model=model, | |
| messages=build_prompt(observation), | |
| temperature=0.2, | |
| response_format={"type": "json_object"}, | |
| ) | |
| content = (resp.choices[0].message.content or "").strip() | |
| payload = json.loads(content) if content else {} | |
| # Ensure payload is wrapped properly if LLM only returns the payload | |
| if "type" not in payload: | |
| return fallback_policy(observation) | |
| return Action(**payload) | |
| except Exception as exc: | |
| logging.warning("LLM action generation failed, using fallback policy: %s", exc) | |
| return fallback_policy(observation) | |
| def run_episode(env: Any, task_id: str, llm: OpenAI | None, model: str, max_steps: int) -> dict[str, Any]: | |
| logging.info("Starting task: %s", task_id) | |
| obs = env.reset(task_id=task_id) | |
| total_reward = 0.0 | |
| done = False | |
| steps = 0 | |
| while not done and steps < max_steps: | |
| action = action_from_llm(obs, llm, model) | |
| obs, reward, done, info = env.step(action) | |
| total_reward += reward.value | |
| steps += 1 | |
| logging.info( | |
| "step=%d action=%s reward=%.3f done=%s", | |
| steps, | |
| action.type, | |
| reward.value, | |
| done, | |
| ) | |
| return { | |
| "task_id": task_id, | |
| "steps": steps, | |
| "done": done, | |
| "total_reward": total_reward, | |
| "info": info | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run Email Assistant OpenEnv inference.") | |
| parser.add_argument("--base-url", default="http://127.0.0.1:7860/openenv") | |
| parser.add_argument("--episodes", type=int, default=1, help="Number of times to run each task") | |
| parser.add_argument("--max-steps", type=int, default=10) | |
| parser.add_argument("--model", default="gpt-4o-mini") | |
| parser.add_argument("--task", default=None, help="Specific task ID to run (easy_classification, medium_prioritization, hard_workflow)") | |
| args = parser.parse_args() | |
| configure_logging() | |
| load_dotenv() | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if api_key and "your_openai_api_key_here" in api_key: | |
| api_key = None | |
| llm = OpenAI(api_key=api_key) if api_key else None | |
| if llm is None: | |
| logging.warning("Valid OPENAI_API_KEY not found; running with fallback policy.") | |
| client = EmailAssistantEnvClient(base_url=args.base_url) | |
| tasks = [args.task] if args.task else ["easy_classification", "medium_prioritization", "hard_workflow"] | |
| all_results = [] | |
| with client.sync() as env: | |
| for task_id in tasks: | |
| for ep in range(args.episodes): | |
| logging.info("=== Task %s | Episode %d/%d ===", task_id, ep + 1, args.episodes) | |
| result = run_episode(env, task_id, llm, args.model, args.max_steps) | |
| all_results.append(result) | |
| avg_reward = sum(r["total_reward"] for r in all_results) / max(1, len(all_results)) | |
| logging.info("FINAL SUMMARY: episodes=%d average_reward=%.3f", len(all_results), avg_reward) | |
| for r in all_results: | |
| logging.info("Result: task=%s reward=%.3f done=%s", r["task_id"], r["total_reward"], r["done"]) | |
| if __name__ == "__main__": | |
| main() | |