meta_hackathon / inference.py
afroimam's picture
Upload folder using huggingface_hub
1395b2e verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
from openai import OpenAI
from support_triage_openenv import Action, SupportTriageEnv
# Mandatory variables requested by organizers.
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
HF_TOKEN = os.getenv("HF_TOKEN")
BENCHMARK = os.getenv("SUPPORT_TRIAGE_BENCHMARK", "support-triage-openenv")
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.9"))
SYSTEM_PROMPT = (
"You are solving customer support ticket triage. "
"Return exactly one JSON object with keys: "
"action_type, ticket_id, priority, category, needs_escalation, message."
)
RULE_POLICY: dict[str, list[dict[str, Any]]] = {
"easy_password_reset": [
{"action_type": "read_ticket", "ticket_id": "T-1001"},
{
"action_type": "classify_ticket",
"ticket_id": "T-1001",
"priority": "medium",
"category": "account",
"needs_escalation": False,
},
{
"action_type": "draft_reply",
"message": (
"We will send a reset link to your email. For security, confirm the request "
"from your registered email before using the reset link."
),
},
{"action_type": "resolve_ticket", "ticket_id": "T-1001"},
],
"medium_billing_dispute": [
{"action_type": "read_ticket", "ticket_id": "T-2001"},
{"action_type": "read_ticket", "ticket_id": "T-2002"},
{
"action_type": "classify_ticket",
"ticket_id": "T-2001",
"priority": "high",
"category": "billing",
"needs_escalation": False,
},
{
"action_type": "draft_reply",
"message": (
"We confirmed a duplicate charge. We are issuing a refund and will share the invoice update. "
"Refund processing typically takes 3-5 business days."
),
},
{"action_type": "resolve_ticket", "ticket_id": "T-2001"},
],
"hard_outage_incident": [
{"action_type": "read_ticket", "ticket_id": "T-3001"},
{"action_type": "read_ticket", "ticket_id": "T-3002"},
{"action_type": "read_ticket", "ticket_id": "T-3003"},
{
"action_type": "classify_ticket",
"ticket_id": "T-3001",
"priority": "urgent",
"category": "technical",
"needs_escalation": True,
},
{
"action_type": "draft_reply",
"message": (
"We have escalated this incident and are investigating now. "
"The status page will carry updates while we continue incident response."
),
},
{"action_type": "resolve_ticket", "ticket_id": "T-3001"},
],
}
@dataclass
class EpisodeResult:
task_id: str
steps: int
score: float
success: bool
final_reward: float
rewards: list[float]
fallback_count: int
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
def _extract_json(text: str) -> str:
text = text.strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
raise ValueError("No JSON object found in model response")
return text[start : end + 1]
def heuristic_action(task_id: str, step_idx: int) -> Action:
plan = RULE_POLICY[task_id]
idx = min(step_idx, len(plan) - 1)
return Action.model_validate(plan[idx])
def llm_action(client: OpenAI, observation: dict[str, Any], state: dict[str, Any]) -> Action:
prompt = json.dumps(
{
"instruction": "Pick the best next single action to maximize final task score.",
"observation": observation,
"state": state,
},
ensure_ascii=True,
)
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0,
max_tokens=220,
stream=False,
)
text = (completion.choices[0].message.content or "").strip()
payload = json.loads(_extract_json(text))
return Action.model_validate(payload)
def action_to_str(action: Action) -> str:
if action.action_type == "read_ticket":
return f"read_ticket({action.ticket_id})"
if action.action_type == "classify_ticket":
return (
f"classify_ticket({action.ticket_id},{action.priority},{action.category},"
f"{str(bool(action.needs_escalation)).lower()})"
)
if action.action_type == "draft_reply":
length = len((action.message or "").strip())
return f"draft_reply(len={length})"
if action.action_type == "resolve_ticket":
return f"resolve_ticket({action.ticket_id})"
return action.action_type
def run_episode(
env: SupportTriageEnv,
task_id: str,
mode: str,
client: OpenAI | None,
started_at: float,
runtime_limit_seconds: int,
) -> EpisodeResult:
obs = env.reset(task_id)
done = False
info: dict[str, Any] = {}
rewards: list[float] = []
steps_taken = 0
fallback_count = 0
success = False
score = 0.0
final_reward = 0.0
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
while not done:
if time.monotonic() - started_at > runtime_limit_seconds:
raise TimeoutError(f"Runtime exceeded {runtime_limit_seconds}s")
step_idx = env.state()["step_count"]
if mode == "heuristic":
action = heuristic_action(task_id, step_idx)
else:
assert client is not None
try:
action = llm_action(client, obs.model_dump(), env.state())
except Exception:
fallback_count += 1
action = heuristic_action(task_id, step_idx)
step_error: str | None = None
try:
obs, reward, done, info = env.step(action)
reward_value = float(reward.value)
except Exception as exc:
step_error = str(exc)
reward_value = 0.0
done = True
steps_taken = step_idx + 1
rewards.append(reward_value)
final_reward = reward_value
log_step(
step=steps_taken,
action=action_to_str(action),
reward=reward_value,
done=done,
error=step_error,
)
if done:
break
score = max(0.0, min(1.0, float(info.get("grader_score", 0.0))))
success = score >= SUCCESS_SCORE_THRESHOLD
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return EpisodeResult(
task_id=task_id,
steps=steps_taken,
score=round(score, 4),
success=success,
final_reward=round(final_reward, 4),
rewards=[round(r, 4) for r in rewards],
fallback_count=fallback_count,
)
def main() -> None:
parser = argparse.ArgumentParser(description="Submission inference script.")
parser.add_argument("--mode", choices=["openai", "heuristic"], default="openai")
parser.add_argument("--output", default="scores/inference_scores.json")
parser.add_argument("--runtime-limit-seconds", type=int, default=1200)
parser.add_argument("--task-id", default="", help="Optional single task id; default runs all tasks")
args = parser.parse_args()
if args.mode == "openai" and not HF_TOKEN:
raise RuntimeError("HF_TOKEN is required for --mode openai")
env = SupportTriageEnv()
task_ids = [args.task_id] if args.task_id else env.task_ids
client = None
if args.mode == "openai":
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
started_at = time.monotonic()
episodes: list[EpisodeResult] = []
for task_id in task_ids:
if task_id not in env.task_ids:
raise ValueError(f"Unknown task_id '{task_id}'")
episodes.append(
run_episode(
env=env,
task_id=task_id,
mode=args.mode,
client=client,
started_at=started_at,
runtime_limit_seconds=args.runtime_limit_seconds,
)
)
summary = {
"mode": args.mode,
"api_base_url": API_BASE_URL,
"model_name": MODEL_NAME,
"avg_score": round(sum(e.score for e in episodes) / len(episodes), 4),
"avg_final_reward": round(sum(e.final_reward for e in episodes) / len(episodes), 4),
"total_steps": int(sum(e.steps for e in episodes)),
"episodes": [asdict(e) for e in episodes],
}
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
if __name__ == "__main__":
main()