File size: 6,285 Bytes
0387a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python3
"""
Inference loop for Email Assistant OpenEnv.

Features:
- Connect to env locally via OpenEnv API.
- Run multiple tasks (easy, medium, hard).
- Query an LLM for each action (OpenAI), with safe fallback policy.
- Log step and episode level metrics.
"""

from __future__ import annotations

import argparse
import json
import logging
import os
from typing import Any

from dotenv import load_dotenv
from openai import OpenAI

from app.openenv_env.models import Action, Observation
from client import EmailAssistantEnvClient


def configure_logging() -> None:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
    )


def build_prompt(observation: Observation) -> list[dict[str, str]]:
    current = observation.current_email
    recent = [r.action.model_dump() for r in observation.previous_actions[-5:]]
    system = (
        "You are an email assistant agent. Return a single JSON object for the next action. "
        "The action must have a 'type' and a 'payload' object.\n"
        "Available actions:\n"
        "- type='classify', payload={'intent': 'Support|Sales|Spam|General', 'confidence': 0..1, 'reasoning': '...'}\n"
        "- type='prioritize', payload={'message_id': '...'}\n"
        "- type='reply', payload={'tone': 'formal|neutral|casual'}\n"
        "- type='send', payload={'to_email': '...', 'subject': '...', 'body': '...'}\n"
        "- type='skip', payload={'escalate': bool, 'reason': '...'}"
    )
    user = {
        "current_email": {
            "message_id": current.message_id,
            "from_email": current.from_email,
            "subject": current.subject,
            "body": current.body,
        },
        "inbox_summary": [
            {
                "message_id": i.message_id,
                "subject": i.subject,
                "deadline": i.deadline_minutes,
                "urgency": i.urgency_score,
                "intent": i.predicted_intent,
                "handled": i.handled
            } for i in observation.inbox_summary
        ],
        "recent_actions": recent,
    }
    return [
        {"role": "system", "content": system},
        {"role": "user", "content": json.dumps(user)},
    ]


def fallback_policy(observation: Observation) -> Action:
    # Very simple deterministic logic for demonstration
    email = observation.current_email
    
    # Check if we have an intent for the current email
    curr_summary = next((i for i in observation.inbox_summary if i.message_id == email.message_id), None)
    
    if curr_summary and not curr_summary.predicted_intent:
        return Action(type="classify", payload={}) # Trigger tool-driven classification
    
    # If classified but not handled, try to reply
    return Action(type="reply", payload={"tone": "formal"})


def action_from_llm(observation: Observation, llm: OpenAI | None, model: str) -> Action:
    if llm is None:
        return fallback_policy(observation)

    try:
        resp = llm.chat.completions.create(
            model=model,
            messages=build_prompt(observation),
            temperature=0.2,
            response_format={"type": "json_object"},
        )
        content = (resp.choices[0].message.content or "").strip()
        payload = json.loads(content) if content else {}
        # Ensure payload is wrapped properly if LLM only returns the payload
        if "type" not in payload:
             return fallback_policy(observation)
        return Action(**payload)
    except Exception as exc:
        logging.warning("LLM action generation failed, using fallback policy: %s", exc)
        return fallback_policy(observation)


def run_episode(env: Any, task_id: str, llm: OpenAI | None, model: str, max_steps: int) -> dict[str, Any]:
    logging.info("Starting task: %s", task_id)
    obs = env.reset(task_id=task_id)
    total_reward = 0.0
    done = False
    steps = 0

    while not done and steps < max_steps:
        action = action_from_llm(obs, llm, model)
        obs, reward, done, info = env.step(action)
        
        total_reward += reward.value
        steps += 1
        logging.info(
            "step=%d action=%s reward=%.3f done=%s",
            steps,
            action.type,
            reward.value,
            done,
        )

    return {
        "task_id": task_id,
        "steps": steps,
        "done": done,
        "total_reward": total_reward,
        "info": info
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Run Email Assistant OpenEnv inference.")
    parser.add_argument("--base-url", default="http://127.0.0.1:7860/openenv")
    parser.add_argument("--episodes", type=int, default=1, help="Number of times to run each task")
    parser.add_argument("--max-steps", type=int, default=10)
    parser.add_argument("--model", default="gpt-4o-mini")
    parser.add_argument("--task", default=None, help="Specific task ID to run (easy_classification, medium_prioritization, hard_workflow)")
    args = parser.parse_args()

    configure_logging()
    load_dotenv()
    
    api_key = os.getenv("OPENAI_API_KEY")
    if api_key and "your_openai_api_key_here" in api_key:
        api_key = None
        
    llm = OpenAI(api_key=api_key) if api_key else None
    if llm is None:
        logging.warning("Valid OPENAI_API_KEY not found; running with fallback policy.")

    client = EmailAssistantEnvClient(base_url=args.base_url)
    
    tasks = [args.task] if args.task else ["easy_classification", "medium_prioritization", "hard_workflow"]
    
    all_results = []
    
    with client.sync() as env:
        for task_id in tasks:
            for ep in range(args.episodes):
                logging.info("=== Task %s | Episode %d/%d ===", task_id, ep + 1, args.episodes)
                result = run_episode(env, task_id, llm, args.model, args.max_steps)
                all_results.append(result)

    avg_reward = sum(r["total_reward"] for r in all_results) / max(1, len(all_results))
    logging.info("FINAL SUMMARY: episodes=%d average_reward=%.3f", len(all_results), avg_reward)
    
    for r in all_results:
        logging.info("Result: task=%s reward=%.3f done=%s", r["task_id"], r["total_reward"], r["done"])


if __name__ == "__main__":
    main()