supermail / inference.py
vicky1428's picture
Upload folder using huggingface_hub
b9ad6f9 verified
"""Async baseline inference runner for Supermail."""
from __future__ import annotations
import asyncio
import json
import os
from dataclasses import dataclass
from typing import Any, List, Optional
from openai import OpenAI
try:
from dotenv import load_dotenv
except ImportError: # pragma: no cover
def load_dotenv() -> bool:
return False
from client import SupermailEnv
from models import SupportAction, SupportObservation
from server.environment import SupermailEnvironment
from sys_prompt import SYSTEM_PROMPT
from tasks import ALL_TASKS, TASKS_BY_ID
load_dotenv()
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
BASE_URL = os.getenv("SUPERMAIL_BASE_URL") or os.getenv("SUPPORT_SIM_BASE_URL")
TASK_NAME = os.getenv("SUPERMAIL_TASK") or os.getenv("SUPPORT_SIM_TASK", "all")
BENCHMARK = os.getenv("SUPERMAIL_BENCHMARK") or os.getenv("SUPPORT_SIM_BENCHMARK", "supermail")
MAX_STEPS = 12
TEMPERATURE = 0.4
MAX_TOKENS = 25000
SUCCESS_SCORE_THRESHOLD = 0.95
MIN_SCORE = 0.01
MAX_SCORE = 0.99
@dataclass
class LocalStepResult:
"""Minimal local stand-in for OpenEnv StepResult."""
observation: SupportObservation
reward: float
done: bool
class LocalSupermailSession:
"""Async adapter for direct local environment usage."""
def __init__(self, task_id: str):
self._env = SupermailEnvironment(task_id=task_id)
async def reset(self) -> LocalStepResult:
observation = self._env.reset()
return LocalStepResult(
observation=observation,
reward=observation.reward or 0.0,
done=observation.done,
)
async def step(self, action: SupportAction) -> LocalStepResult:
observation = self._env.step(action)
return LocalStepResult(
observation=observation,
reward=observation.reward or 0.0,
done=observation.done,
)
async def close(self) -> None:
self._env.close()
def sanitize(value: Any) -> str:
"""Keep log output on a single line."""
text = str(value)
return " ".join(text.replace("\r", " ").replace("\n", " ").split())
def clamp_score(score: float) -> float:
"""Clamp score into the open interval (0, 1)."""
return min(max(score, MIN_SCORE), MAX_SCORE)
def compact_action(action: Optional[SupportAction]) -> str:
"""Serialize an action for the required log format."""
if action is None:
return "null"
payload = {
field_name: getattr(action, field_name)
for field_name in ("priority", "category", "action", "notes")
if getattr(action, field_name, None)
}
return json.dumps(payload, separators=(",", ":"), sort_keys=True)
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
*,
step: int,
action: Optional[SupportAction],
reward: float,
done: bool,
error: Optional[str],
) -> None:
error_text = error if error else "null"
print(
"[STEP] "
f"step={step} "
f"action={sanitize(compact_action(action))} "
f"reward={reward:.2f} "
f"done={'true' if done else 'false'} "
f"error={sanitize(error_text)}",
flush=True,
)
def log_end(*, success: bool, steps: int, score: float, rewards: List[float]) -> None:
reward_text = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={'true' if success else 'false'} "
f"steps={steps} score={score:.2f} rewards={reward_text}",
flush=True,
)
def build_client() -> Optional[OpenAI]:
"""Create an OpenAI client when credentials are available."""
if not HF_TOKEN:
return None
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
def heuristic_action(observation: SupportObservation) -> SupportAction:
"""Deterministic fallback policy for the bundled tasks."""
text = f"{observation.email} {json.dumps(observation.context, sort_keys=True)}".lower()
if any(
token in text
for token in (
"click here",
"gift card",
"crypto",
"lottery",
"unsubscribe",
"bypass all metrics",
"encrypted emergency",
"decrypt tool",
"emergency slot",
"override the normal queue",
"sender_verified\": \"false",
"spoofed sender",
)
):
priority = "spam"
elif any(
token in text
for token in (
"today",
"payroll closes",
"500 error",
"blocked",
"backing up",
"immediately",
"double",
"charged again",
)
):
priority = "urgent"
else:
priority = "normal"
if any(token in text for token in ("charge", "charged", "invoice", "refund", "billing", "subscription")):
category = "billing"
elif any(token in text for token in ("tracking", "shipment", "delivery", "delivered", "ship")):
category = "delivery"
elif any(token in text for token in ("error", "login", "outage", "crash", "bug", "sign in")):
category = "technical"
else:
category = "general"
if priority == "spam":
next_action = "ignore"
elif category == "technical":
next_action = "assign_to_team"
elif priority == "urgent":
next_action = "respond_immediately"
elif category == "delivery":
next_action = "assign_to_team"
else:
next_action = "respond_immediately"
payload: dict[str, str] = {}
if "priority" in observation.required_fields:
payload["priority"] = priority
if "category" in observation.required_fields:
payload["category"] = category
if "action" in observation.required_fields:
payload["action"] = next_action
return SupportAction(**payload)
def get_model_action(
client: OpenAI,
observation: SupportObservation,
history: List[str],
) -> SupportAction:
"""Use the OpenAI client for the next action."""
prompt = {
"task_id": observation.task_id,
"benchmark": observation.benchmark,
"objective": observation.objective,
"required_fields": observation.required_fields,
"allowed_values": observation.allowed_values,
"email": observation.email,
"context": observation.context,
"history": history,
"feedback": observation.feedback,
}
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(prompt, ensure_ascii=True)},
],
)
content = (response.choices[0].message.content or "").strip()
payload = json.loads(content)
filtered_payload = {
key: value
for key, value in payload.items()
if key in {"priority", "category", "action", "notes"}
}
return SupportAction(**filtered_payload)
def choose_action(
client: Optional[OpenAI],
observation: SupportObservation,
history: List[str],
) -> SupportAction:
"""Use the model when available, otherwise fall back to heuristics."""
if client is None:
return heuristic_action(observation)
try:
return get_model_action(client, observation, history)
except Exception:
return heuristic_action(observation)
async def create_env(task_id: str):
"""Create the environment session using docker, base URL, or local fallback."""
if LOCAL_IMAGE_NAME:
return await SupermailEnv.from_docker_image(
LOCAL_IMAGE_NAME,
env_vars={"SUPERMAIL_TASK": task_id},
)
if BASE_URL:
env = SupermailEnv(base_url=BASE_URL)
await env.connect()
return env
return LocalSupermailSession(task_id=task_id)
async def run_episode(task_id: str, client: Optional[OpenAI]) -> None:
"""Run a single task episode and emit the required logs."""
if task_id not in TASKS_BY_ID:
raise ValueError(f"Unknown task: {task_id}")
env = None
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = MIN_SCORE
success = False
action_for_log: Optional[SupportAction] = None
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
env = await create_env(task_id)
result = await env.reset()
observation = result.observation
for step in range(1, MAX_STEPS + 1):
if result.done:
break
action_for_log = choose_action(client, observation, history)
result = await env.step(action_for_log)
observation = result.observation
reward = result.reward or 0.0
done = result.done
error = observation.metadata.get("last_action_error")
rewards.append(reward)
steps_taken = step
score = clamp_score(float(getattr(observation, "score", 0.0)))
log_step(
step=step,
action=action_for_log,
reward=reward,
done=done,
error=error,
)
history.append(
f"step={step} action={compact_action(action_for_log)} "
f"reward={reward:.2f} score={score:.2f}"
)
if done:
break
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as exc:
log_step(
step=steps_taken,
action=action_for_log,
reward=0.0,
done=True,
error=str(exc),
)
finally:
if env is not None:
try:
await env.close()
except Exception:
pass
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
def task_sequence() -> List[str]:
"""Resolve the requested task selection."""
if TASK_NAME == "all":
return [task.task_id for task in ALL_TASKS]
return [TASK_NAME]
async def main() -> None:
"""Run one or more task episodes."""
client = build_client()
for task_id in task_sequence():
await run_episode(task_id, client)
if __name__ == "__main__":
asyncio.run(main())