File size: 9,550 Bytes
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4c15b1
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4c15b1
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4c15b1
96a5caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
inference.py -- Agent that solves all three Email Triage tasks.

Uses the OpenAI Python client pointed at a Groq-compatible endpoint.
All LLM config is controlled via environment variables:
    - API_BASE_URL : base URL for the OpenAI-compatible API  (has default)
    - MODEL_NAME   : model to use for inference               (has default)
    - HF_TOKEN     : API key (mandatory, injected by runner)

Usage:
    python inference.py --task 1   # run task 1
    python inference.py --task 2   # run task 2
    python inference.py --task 3   # run task 3 (full pipeline)
    python inference.py --all      # run all tasks and report scores
"""

import argparse
import json
import os
import time
from typing import Any

from openai import OpenAI

from environment import EmailTriageEnv, Action

# ---------------------------------------------------------------------------
# Configuration via environment variables (hackathon-compliant)
# ---------------------------------------------------------------------------

# API_BASE_URL: Groq's OpenAI-compatible endpoint (default provided)
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")

# MODEL_NAME: which model to use on the endpoint (default provided)
MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile")

# HF_TOKEN: the API key -- mandatory, injected by hackathon runner
HF_TOKEN = os.environ.get("HF_TOKEN", "")

# Initialize the OpenAI client pointing at Groq (or whatever API_BASE_URL is)
client = OpenAI(
    base_url=API_BASE_URL,
    api_key=HF_TOKEN,
)

# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = """You are an expert email triage agent. You manage an inbox
efficiently by reading emails, assigning priority labels, drafting professional
replies, archiving junk, and flagging ambiguous items for human review.

You interact with an email environment through a strict JSON action interface.
Each response you produce MUST be a single valid JSON object -- no markdown,
no extra text -- in exactly this format:

{
  "action": "<action_name>",
  "email_id": "<id or null>",
  "priority": "<urgent|normal|low or null>",
  "body": "<reply text or null>",
  "reason": "<flag reason or null>"
}

Available actions:
- list_inbox   -- see all emails (no email_id needed)
- read         -- read full body of an email (requires email_id)
- label        -- assign a priority label (requires email_id + priority)
- draft_reply  -- write a reply (requires email_id + body)
- archive      -- move to archive (requires email_id)
- flag         -- escalate for human review (requires email_id + reason)

Rules:
- NEVER archive an urgent email.
- ALWAYS read an email before labelling or replying.
- Draft replies ONLY for urgent emails (unless instructed otherwise).
- Archive obvious spam/junk.
- Flag emails that are ambiguous or need human judgment.
- When drafting replies: be professional, address all issues raised, do NOT
  invent facts (no fake tracking numbers, refund amounts, dates).
- Signal completion by returning: {"action": "done", "email_id": null, "priority": null, "body": null, "reason": null}
"""

# ---------------------------------------------------------------------------
# Agent helpers
# ---------------------------------------------------------------------------

def parse_action(text: str) -> dict[str, Any]:
    """Extract JSON from model output (handles minor formatting noise)."""
    text = text.strip()
    # Strip markdown fences if present (some models wrap JSON in ```)
    if text.startswith("```"):
        lines = text.split("\n")
        text = "\n".join(l for l in lines if not l.startswith("```"))
    return json.loads(text)


def run_task(task: int, max_steps: int = 40, verbose: bool = True) -> float:
    """Run a single task with the LLM agent. Returns final score."""

    env = EmailTriageEnv(task=task)
    obs = env.reset()

    # --- Hackathon output marker ---
    print(f"[START] task={task}", flush=True)

    if verbose:
        print(f"Task {task} | Model: {MODEL_NAME} | Endpoint: {API_BASE_URL}")

    task_instruction = {
        1: (
            "Task: Read all 5 emails and label each as urgent, normal, or low priority. "
            "Start by listing the inbox, then read each email before labelling it. "
            "When done, output the done action."
        ),
        2: (
            "Task: Read the customer complaint email and draft a professional reply "
            "that addresses ALL issues the customer raised. Be empathetic, professional, "
            "and do not invent any facts. When done, output the done action."
        ),
        3: (
            "Task: Full triage pipeline on 10 emails.\n"
            "1. List the inbox.\n"
            "2. Read each email.\n"
            "3. Label all emails (urgent / normal / low).\n"
            "4. Draft replies for urgent emails.\n"
            "5. Archive obvious spam / junk.\n"
            "6. Flag ambiguous emails for human review.\n"
            "When everything is done, output the done action."
        ),
    }[task]

    # Build the message history (OpenAI format: system + user/assistant turns)
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": task_instruction},
    ]
    step = 0

    while step < max_steps:
        # Call the LLM via the OpenAI client (works with Groq, vLLM, etc.)
        # Retry with backoff on rate-limit errors (Groq free tier: 30 RPM)
        raw = None
        for attempt in range(3):
            try:
                response = client.chat.completions.create(
                    model=MODEL_NAME,
                    messages=messages,
                    max_tokens=1000,
                    temperature=0.2,
                )
                raw = response.choices[0].message.content
                break
            except Exception as e:
                wait = 2 ** attempt * 5  # 5s, 10s, 20s
                if verbose:
                    print(f"  [RETRY] {type(e).__name__} -- waiting {wait}s (attempt {attempt+1}/3)")
                time.sleep(wait)

        if raw is None:
            if verbose:
                print("  [ERROR] LLM call failed after 3 retries. Ending task.")
            break

        messages.append({"role": "assistant", "content": raw})

        if verbose:
            print(f"[Step {step+1}] Agent: {raw[:120]}{'...' if len(raw) > 120 else ''}")

        # Parse action from model output
        try:
            action_dict = parse_action(raw)
        except json.JSONDecodeError as e:
            if verbose:
                print(f"  [WARN] JSON parse error: {e} -- asking agent to retry")
            messages.append({
                "role": "user",
                "content": f"Your last response was not valid JSON. Error: {e}. Please try again with a valid JSON action."
            })
            continue

        # Done?
        if action_dict.get("action") == "done":
            if verbose:
                print("  Agent signalled completion.")
            break

        # Execute action in the environment
        try:
            action = Action(**action_dict)
        except Exception as e:
            messages.append({"role": "user", "content": f"Invalid action format: {e}. Try again."})
            continue

        result = env.step(action)

        # --- Hackathon output marker ---
        print(f"[STEP] step={step+1} reward={result.reward}", flush=True)

        if verbose:
            print(f"  Env: [{result.observation.status}] {result.observation.message}  reward={result.reward:+.2f}")

        # Feed observation back to the agent
        obs_summary = {
            "status": result.observation.status,
            "message": result.observation.message,
            "data": result.observation.data,
            "step": result.observation.step_count,
            "running_score": env.score(),
        }
        messages.append({"role": "user", "content": json.dumps(obs_summary)})
        step += 1

    final_score = env.score()

    # --- Hackathon output marker ---
    print(f"[END] task={task} score={final_score} steps={step}", flush=True)

    if verbose:
        print(f"Final score: {final_score:.2f} / 1.00  (steps used: {step})")

    return final_score


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Email Triage Agent")
    parser.add_argument("--task", type=int, choices=[1, 2, 3],
                        help="Run a specific task (1, 2, or 3)")
    parser.add_argument("--all", action="store_true",
                        help="Run all three tasks")
    parser.add_argument("--quiet", action="store_true",
                        help="Suppress verbose output")
    args = parser.parse_args()

    verbose = not args.quiet

    if args.all:
        scores = {}
        for t in [1, 2, 3]:
            scores[t] = run_task(t, verbose=verbose)
        print("=" * 40)
        print("  FINAL SCORES")
        print("=" * 40)
        for t, s in scores.items():
            print(f"  Task {t}: {s:.2f}")
        print(f"  Average: {sum(scores.values()) / 3:.2f}")
    elif args.task:
        run_task(args.task, verbose=verbose)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()