openenv-email-triage / inference.py
Aneesha Das
Updated
fba0197
#!/usr/bin/env python3
"""
inference.py — OpenEnv Hackathon Submission
Email Triage Environment
STDOUT FORMAT (mandatory):
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""
from __future__ import annotations
import json
import os
import sys
from typing import Any, Dict, List, Optional
from openai import OpenAI
# ── Add src to path so we can import the environment ──────────────────────────
from environment import EmailTriageEnv
from models import Priority, Category, RouteTo, Action
# ── Required env vars ─────────────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
TASKS = ["easy", "medium", "hard"]
BENCHMARK = "email-triage-v1"
SUCCESS_THRESHOLD = 0.5 # score >= 0.5 counts as success
# ── Logging helpers (exact format required by hackathon) ──────────────────────
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={str(done).lower()} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.2f} rewards={rewards_str}",
flush=True,
)
# ── System prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are an expert email triage assistant for a B2B SaaS company.
You must classify each email and return ONLY a valid JSON object.
Return exactly this structure (no markdown, no explanation, just JSON):
{
"email_id": "<copy exactly from the email>",
"priority": "<urgent|high|medium|low|spam>",
"category": "<customer_complaint|billing_inquiry|technical_support|sales_lead|internal_hr|legal_compliance|spam_phishing|general_inquiry>",
"route_to": "<support_tier1|support_tier2|billing|sales|legal|hr|management|trash|archive>",
"summary": "<max 280 chars concise summary>",
"flag_review": <true|false>,
"reasoning": "<brief reasoning>"
}
Priority rules:
- urgent: legal deadline <72h, security incident, production outage, regulatory action
- high: legal threat, important sales, billing dispute, confidential HR, board matters
- medium: routine billing question, general inquiry, standard support
- low: internal social, scheduling, low-priority updates
- spam: phishing, scams, unsolicited commercial email
Routing rules:
- support_tier1: simple password resets, basic how-to
- support_tier2: production outages, security incidents, complex technical issues
- billing: invoices, subscriptions, payment failures
- sales: new business, enterprise leads, partnerships
- legal: regulatory notices, legal threats, compliance, contracts
- hr: employee relations, PIP, hiring, misconduct
- management: exec decisions, acquisition, crisis, major SLA breach
- trash: spam, phishing — delete immediately
- archive: low-priority non-actionable
IMPORTANT sequential constraints visible in the observation:
- escalation_budget_remaining: only flag_review=true if budget > 0 and email truly needs escalation
- team_queue_remaining: avoid routing to teams with 0 remaining capacity
- active_sla_warnings: process emails with steps_left=0 or 1 FIRST"""
def build_user_prompt(obs: dict) -> str:
"""Build the prompt for the current step from the observation dict."""
current = obs.get("current_email")
if not current:
return "No email to process."
header = current["header"]
body = current["body"]
budget = obs.get("escalation_budget_remaining", 0)
queues = obs.get("team_queue_remaining", {})
warnings = obs.get("active_sla_warnings", [])
step = obs.get("step_number", 0)
remaining = obs.get("remaining", 0)
sla_note = ""
if warnings:
sla_note = f"\n⚠️ SLA WARNINGS: {json.dumps(warnings)}"
full_queues = [k for k, v in queues.items() if v == 0]
queue_note = f"\n🚫 FULL QUEUES (do not route here): {full_queues}" if full_queues else ""
return f"""Step {step}{remaining} emails remaining.
Escalation budget left: {budget}{sla_note}{queue_note}
EMAIL TO TRIAGE:
ID: {header['email_id']}
From: {header['sender']}
Subject: {header['subject']}
Time: {header['timestamp']}
{body}
Return ONLY a JSON action object."""
def parse_action(raw: str, fallback_email_id: str) -> Dict[str, Any]:
"""Parse the model's JSON response into an action dict."""
# Strip markdown fences if present
text = raw.strip()
if text.startswith("```"):
text = text.split("```")[1]
if text.startswith("json"):
text = text[4:]
text = text.strip()
try:
data = json.loads(text)
# Ensure required fields exist with safe defaults
return {
"email_id": data.get("email_id", fallback_email_id),
"priority": data.get("priority", "medium"),
"category": data.get("category", "general_inquiry"),
"route_to": data.get("route_to", "support_tier1"),
"summary": str(data.get("summary", "Email processed."))[:280],
"flag_review": bool(data.get("flag_review", False)),
"reasoning": str(data.get("reasoning", "")),
}
except Exception:
# Fallback: safe default action
return {
"email_id": fallback_email_id,
"priority": "medium",
"category": "general_inquiry",
"route_to": "support_tier1",
"summary": "Unable to parse model response — defaulting to general inquiry.",
"flag_review": False,
"reasoning": f"Parse error on: {raw[:100]}",
}
def rule_based_action(obs: dict) -> Dict[str, Any]:
"""Fallback rule-based agent when no API key is set."""
current = obs.get("current_email")
if not current:
return {}
header = current["header"]
body = current["body"]
email_id = header["email_id"]
subject = (header["subject"] + " " + body).lower()
budget = obs.get("escalation_budget_remaining", 0)
queues = obs.get("team_queue_remaining", {})
priority = "medium"
category = "general_inquiry"
route_to = "support_tier1"
flag_review = False
if any(w in subject for w in ["spam", "phishing", "congratulations", "won $", "lottery", "verify your"]):
priority = "spam"; category = "spam_phishing"; route_to = "trash"
elif any(w in subject for w in ["legal", "lawsuit", "compliance", "regulation", "gdpr", "breach notice"]):
priority = "urgent"; category = "legal_compliance"
route_to = "legal" if queues.get("legal", 0) > 0 else "management"
flag_review = budget > 0 and priority in ["urgent", "high"]
elif any(w in subject for w in ["hacked", "outage", "security", "ransomware", "incident"]):
priority = "urgent"; category = "technical_support"; route_to = "support_tier2"
flag_review = budget > 0 and priority in ["urgent", "high"]
elif any(w in subject for w in ["invoice", "billing", "payment", "subscription", "overdue"]):
priority = "high"; category = "billing_inquiry"; route_to = "billing"
elif any(w in subject for w in ["enterprise", "sales", "pricing", "license", "acquisition"]):
priority = "high"; category = "sales_lead"; route_to = "sales"
elif any(w in subject for w in ["hr", "pip", "performance", "misconduct", "termination", "wages"]):
priority = "high"; category = "internal_hr"
route_to = "hr" if queues.get("hr", 0) > 0 else "management"
flag_review = budget > 0 and priority in ["urgent", "high"]
return {
"email_id": email_id,
"priority": priority,
"category": category,
"route_to": route_to,
"summary": header["subject"][:280],
"flag_review": flag_review,
"reasoning": "Rule-based heuristic",
}
def run_task(client: Optional[OpenAI], task_id: str) -> float:
"""Run one task episode and return the final score."""
env = EmailTriageEnv(task_id=task_id, seed=42)
obs_obj = env.reset()
obs = obs_obj.model_dump()
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
error_msg = None
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME if client else "rule-based-demo")
try:
step = 0
while not env.is_done:
step += 1
current = obs.get("current_email")
if not current:
break
fallback_id = current["header"]["email_id"]
error_msg = None
# Get action from model or rule-based fallback
if client:
try:
user_prompt = build_user_prompt(obs)
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=0.2,
max_tokens=400,
)
raw = (completion.choices[0].message.content or "").strip()
action_dict = parse_action(raw, fallback_id)
except Exception as exc:
error_msg = str(exc)[:80]
action_dict = rule_based_action(obs)
else:
action_dict = rule_based_action(obs)
# Build and validate Action
try:
action = Action(
email_id = action_dict["email_id"],
priority = Priority(action_dict["priority"]),
category = Category(action_dict["category"]),
route_to = RouteTo(action_dict["route_to"]),
summary = action_dict["summary"],
flag_review = action_dict["flag_review"],
reasoning = action_dict.get("reasoning", ""),
)
except Exception as exc:
error_msg = f"invalid_action:{exc}"
# Safe fallback action
action = Action(
email_id = fallback_id,
priority = Priority.MEDIUM,
category = Category.GENERAL_INQUIRY,
route_to = RouteTo.SUPPORT_TIER1,
summary = "Fallback — action validation failed.",
flag_review = False,
)
obs_obj, reward_obj, done, info = env.step(action)
obs = obs_obj.model_dump()
reward = reward_obj.total
rewards.append(reward)
steps_taken = step
action_str = (
f"{{id={action.email_id},pri={action.priority.value},"
f"cat={action.category.value},route={action.route_to.value},"
f"flag={action.flag_review}}}"
)
env_error = None
if isinstance(info, dict):
env_error = info.get("last_action_error") or info.get("error")
log_step(step=step, action=action_str, reward=reward,
done=done, error=env_error)
# Compute final score from grader
from grader import grade_episode
grader_result = grade_episode(env._actions_log)
score = grader_result.get("label_score", 0.0)
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_THRESHOLD
except Exception as exc:
error_msg = str(exc)
print(f"[DEBUG] Episode error: {exc}", file=sys.stderr, flush=True)
finally:
try:
close_fn = getattr(env, "close", None)
if callable(close_fn):
close_fn()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", file=sys.stderr, flush=True)
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
def run():
"""
OpenEnv entrypoint (called by server)
"""
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
score = run_task(client, "easy") # single task only
return {
"status": "success",
"score": score
}
def main() -> None:
if not HF_TOKEN:
raise ValueError("HF_TOKEN is not set")
client: Optional[OpenAI] = None
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
all_scores: Dict[str, float] = {}
for task_id in TASKS:
score = run_task(client, task_id)
all_scores[task_id] = score
# Summary to stderr so it doesn't pollute the required stdout format
print("\n=== FINAL SCORES ===", file=sys.stderr)
for task_id, score in all_scores.items():
print(f" {task_id:<8} {score:.4f}", file=sys.stderr)
overall = sum(all_scores.values()) / len(all_scores)
print(f" OVERALL {overall:.4f}", file=sys.stderr)
# Write results JSON
results_path = os.path.join(os.path.dirname(__file__), "baseline_results.json")
with open(results_path, "w") as f:
json.dump({
"model": MODEL_NAME if client else "rule-based-demo",
"benchmark": BENCHMARK,
"tasks": all_scores,
"overall": round(overall, 4),
}, f, indent=2)
print(f"\nResults written to {results_path}", file=sys.stderr)
if __name__ == "__main__":
main()