from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Dict, List ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from support_ops_env.env import SupportOpsEnv from support_ops_env.models import Action, BaselineResult, Observation, TicketObservation from support_ops_env.tasks import list_task_ids CONTEXT_PRIORITY = [ "account_security", "billing_activity", "tax_status", "payout_hold", "appeal_state", "campaign_deadline", "payment_status", ] def choose_next_action(observation: Observation) -> Action: if observation.queue_mode and not observation.current_queue_order: ranking = rank_tickets(observation.tickets) return Action(action_type="rank_queue", value=",".join(ranking)) for ticket in observation.tickets: next_context = missing_high_value_context(ticket) if next_context: return Action(action_type="request_context", target=ticket.ticket_id, value=next_context) for ticket in observation.tickets: priority = infer_priority(ticket) if ticket.selected_priority != priority: return Action(action_type="set_priority", target=ticket.ticket_id, value=priority) for ticket in observation.tickets: route = infer_route(ticket) if ticket.selected_route != route: return Action(action_type="set_route", target=ticket.ticket_id, value=route) for ticket in observation.tickets: resolution = infer_resolution(ticket) if ticket.selected_resolution != resolution: return Action(action_type="set_resolution", target=ticket.ticket_id, value=resolution) for ticket in observation.tickets: escalation = infer_escalation(ticket) if ticket.escalation_team != escalation: return Action(action_type="escalate", target=ticket.ticket_id, value=escalation) return Action(action_type="finalize") def missing_high_value_context(ticket: TicketObservation) -> str | None: discovered = set(ticket.discovered_context) haystack = flattened_text(ticket) candidates: List[str] = infer_required_context(ticket) for key in CONTEXT_PRIORITY: if key in candidates and key not in discovered: return key return None def infer_required_context(ticket: TicketObservation) -> List[str]: text = flattened_text(ticket) if "payout" in text or "w-9" in text or "bank details" in text or "funds released" in text: return ["tax_status", "payout_hold"] if "appeal" in text or "auto-removed" in text or "monetization is paused" in text: return ["appeal_state", "campaign_deadline"] if "duplicate charge" in text or "refund" in text: return ["payment_status"] if ( "login" in text or "ad spend" in text or "unfamiliar campaigns" in text or "taken over" in text or "recovery email was changed" in text ): return ["account_security", "billing_activity"] return [] def infer_priority(ticket: TicketObservation) -> str: text = flattened_text(ticket) if ( "critical" in text or "$1,900" in text or "unauthorized ad spend" in text or "impossible travel" in text or "recovery email was changed" in text ): return "urgent" if "campaign begins in 18 hours" in text or "monetization is paused" in text: return "high" if "w-9 expired" in text or "monthly payout" in text: return "high" return "normal" def infer_route(ticket: TicketObservation) -> str: text = flattened_text(ticket) if ( "account takeover" in text or "new devices" in text or "recovery email was changed" in text or "unfamiliar campaigns" in text or "unauthorized ad spend" in text or "losing access" in text ): return "account_security" if "w-9 expired" in text or "compliance hold" in text: return "monetization_compliance" if "auto-removed" in text or "human yet" in text: return "policy_appeals" if "duplicate charge" in text or "automatically refundable" in text: return "billing_refunds" return "general_support" def infer_resolution(ticket: TicketObservation) -> str: text = flattened_text(ticket) if ( "account takeover" in text or "new devices" in text or "impossible travel" in text or "unfamiliar campaigns" in text or "losing access" in text ): return "temporary_lock_and_manual_recovery" if "w-9 expired" in text or "compliance hold" in text: return "request_tax_renewal" if "auto-removed" in text or "sponsored campaign begins" in text: return "expedited_human_review" if "duplicate charge" in text or "automatically refundable" in text: return "approve_refund" return "request_more_info" def infer_escalation(ticket: TicketObservation) -> str | None: text = flattened_text(ticket) if ( "account takeover" in text or "critical" in text or "impossible travel" in text or "unfamiliar campaigns" in text or "losing access" in text ): return "security_specialist" return None def rank_tickets(tickets: List[TicketObservation]) -> List[str]: scored = [] for ticket in tickets: text = flattened_text(ticket) score = 0 if "critical" in text or "account takeover" in text or "$1,900" in text or "unfamiliar campaigns" in text: score += 100 if "campaign begins in 18 hours" in text or "sponsored campaign" in text: score += 60 if "duplicate charge" in text: score += 20 if ticket.visible_context.get("sla_hours_remaining") == "1": score += 30 if ticket.visible_context.get("sla_hours_remaining") == "4": score += 10 scored.append((score, ticket.ticket_id)) scored.sort(key=lambda item: (-item[0], item[1])) return [ticket_id for _, ticket_id in scored] def flattened_text(ticket: TicketObservation) -> str: parts = [ ticket.summary, json.dumps(ticket.visible_context, sort_keys=True), json.dumps(ticket.discovered_context, sort_keys=True), ] return " ".join(parts).lower() def main() -> None: parser = argparse.ArgumentParser(description="Run a deterministic rule-based baseline over all tasks.") parser.add_argument("--output", default="rule_baseline_results.json", help="Path to write JSON results") args = parser.parse_args() results: List[BaselineResult] = [] for task_id in list_task_ids(): env = SupportOpsEnv(task_id=task_id) observation = env.reset() done = False transcript: List[Dict[str, object]] = [] last_info: Dict[str, object] = {} while not done: action = choose_next_action(observation) observation, reward, done, info = env.step(action) transcript.append( { "action": action.model_dump(), "reward": reward.model_dump(), "task_score": info["task_score"], "done": done, } ) last_info = info results.append( BaselineResult( task_id=task_id, difficulty=observation.difficulty, score=float(last_info.get("task_score", 0.0)), steps=int(last_info.get("step_count", 0)), transcript=transcript, ) ) payload = { "baseline": "rule_based", "average_score": round(sum(item.score for item in results) / len(results), 4), "results": [item.model_dump() for item in results], } output_path = Path(args.output) output_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") print(json.dumps(payload, indent=2)) if __name__ == "__main__": main()