Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, List | |
| ROOT = Path(__file__).resolve().parent.parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from support_ops_env.env import SupportOpsEnv | |
| from support_ops_env.models import Action, BaselineResult, Observation, TicketObservation | |
| from support_ops_env.tasks import list_task_ids | |
| CONTEXT_PRIORITY = [ | |
| "account_security", | |
| "billing_activity", | |
| "tax_status", | |
| "payout_hold", | |
| "appeal_state", | |
| "campaign_deadline", | |
| "payment_status", | |
| ] | |
| def choose_next_action(observation: Observation) -> Action: | |
| if observation.queue_mode and not observation.current_queue_order: | |
| ranking = rank_tickets(observation.tickets) | |
| return Action(action_type="rank_queue", value=",".join(ranking)) | |
| for ticket in observation.tickets: | |
| next_context = missing_high_value_context(ticket) | |
| if next_context: | |
| return Action(action_type="request_context", target=ticket.ticket_id, value=next_context) | |
| for ticket in observation.tickets: | |
| priority = infer_priority(ticket) | |
| if ticket.selected_priority != priority: | |
| return Action(action_type="set_priority", target=ticket.ticket_id, value=priority) | |
| for ticket in observation.tickets: | |
| route = infer_route(ticket) | |
| if ticket.selected_route != route: | |
| return Action(action_type="set_route", target=ticket.ticket_id, value=route) | |
| for ticket in observation.tickets: | |
| resolution = infer_resolution(ticket) | |
| if ticket.selected_resolution != resolution: | |
| return Action(action_type="set_resolution", target=ticket.ticket_id, value=resolution) | |
| for ticket in observation.tickets: | |
| escalation = infer_escalation(ticket) | |
| if ticket.escalation_team != escalation: | |
| return Action(action_type="escalate", target=ticket.ticket_id, value=escalation) | |
| return Action(action_type="finalize") | |
| def missing_high_value_context(ticket: TicketObservation) -> str | None: | |
| discovered = set(ticket.discovered_context) | |
| haystack = flattened_text(ticket) | |
| candidates: List[str] = infer_required_context(ticket) | |
| for key in CONTEXT_PRIORITY: | |
| if key in candidates and key not in discovered: | |
| return key | |
| return None | |
| def infer_required_context(ticket: TicketObservation) -> List[str]: | |
| text = flattened_text(ticket) | |
| if "payout" in text or "w-9" in text or "bank details" in text or "funds released" in text: | |
| return ["tax_status", "payout_hold"] | |
| if "appeal" in text or "auto-removed" in text or "monetization is paused" in text: | |
| return ["appeal_state", "campaign_deadline"] | |
| if "duplicate charge" in text or "refund" in text: | |
| return ["payment_status"] | |
| if ( | |
| "login" in text | |
| or "ad spend" in text | |
| or "unfamiliar campaigns" in text | |
| or "taken over" in text | |
| or "recovery email was changed" in text | |
| ): | |
| return ["account_security", "billing_activity"] | |
| return [] | |
| def infer_priority(ticket: TicketObservation) -> str: | |
| text = flattened_text(ticket) | |
| if ( | |
| "critical" in text | |
| or "$1,900" in text | |
| or "unauthorized ad spend" in text | |
| or "impossible travel" in text | |
| or "recovery email was changed" in text | |
| ): | |
| return "urgent" | |
| if "campaign begins in 18 hours" in text or "monetization is paused" in text: | |
| return "high" | |
| if "w-9 expired" in text or "monthly payout" in text: | |
| return "high" | |
| return "normal" | |
| def infer_route(ticket: TicketObservation) -> str: | |
| text = flattened_text(ticket) | |
| if ( | |
| "account takeover" in text | |
| or "new devices" in text | |
| or "recovery email was changed" in text | |
| or "unfamiliar campaigns" in text | |
| or "unauthorized ad spend" in text | |
| or "losing access" in text | |
| ): | |
| return "account_security" | |
| if "w-9 expired" in text or "compliance hold" in text: | |
| return "monetization_compliance" | |
| if "auto-removed" in text or "human yet" in text: | |
| return "policy_appeals" | |
| if "duplicate charge" in text or "automatically refundable" in text: | |
| return "billing_refunds" | |
| return "general_support" | |
| def infer_resolution(ticket: TicketObservation) -> str: | |
| text = flattened_text(ticket) | |
| if ( | |
| "account takeover" in text | |
| or "new devices" in text | |
| or "impossible travel" in text | |
| or "unfamiliar campaigns" in text | |
| or "losing access" in text | |
| ): | |
| return "temporary_lock_and_manual_recovery" | |
| if "w-9 expired" in text or "compliance hold" in text: | |
| return "request_tax_renewal" | |
| if "auto-removed" in text or "sponsored campaign begins" in text: | |
| return "expedited_human_review" | |
| if "duplicate charge" in text or "automatically refundable" in text: | |
| return "approve_refund" | |
| return "request_more_info" | |
| def infer_escalation(ticket: TicketObservation) -> str | None: | |
| text = flattened_text(ticket) | |
| if ( | |
| "account takeover" in text | |
| or "critical" in text | |
| or "impossible travel" in text | |
| or "unfamiliar campaigns" in text | |
| or "losing access" in text | |
| ): | |
| return "security_specialist" | |
| return None | |
| def rank_tickets(tickets: List[TicketObservation]) -> List[str]: | |
| scored = [] | |
| for ticket in tickets: | |
| text = flattened_text(ticket) | |
| score = 0 | |
| if "critical" in text or "account takeover" in text or "$1,900" in text or "unfamiliar campaigns" in text: | |
| score += 100 | |
| if "campaign begins in 18 hours" in text or "sponsored campaign" in text: | |
| score += 60 | |
| if "duplicate charge" in text: | |
| score += 20 | |
| if ticket.visible_context.get("sla_hours_remaining") == "1": | |
| score += 30 | |
| if ticket.visible_context.get("sla_hours_remaining") == "4": | |
| score += 10 | |
| scored.append((score, ticket.ticket_id)) | |
| scored.sort(key=lambda item: (-item[0], item[1])) | |
| return [ticket_id for _, ticket_id in scored] | |
| def flattened_text(ticket: TicketObservation) -> str: | |
| parts = [ | |
| ticket.summary, | |
| json.dumps(ticket.visible_context, sort_keys=True), | |
| json.dumps(ticket.discovered_context, sort_keys=True), | |
| ] | |
| return " ".join(parts).lower() | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run a deterministic rule-based baseline over all tasks.") | |
| parser.add_argument("--output", default="rule_baseline_results.json", help="Path to write JSON results") | |
| args = parser.parse_args() | |
| results: List[BaselineResult] = [] | |
| for task_id in list_task_ids(): | |
| env = SupportOpsEnv(task_id=task_id) | |
| observation = env.reset() | |
| done = False | |
| transcript: List[Dict[str, object]] = [] | |
| last_info: Dict[str, object] = {} | |
| while not done: | |
| action = choose_next_action(observation) | |
| observation, reward, done, info = env.step(action) | |
| transcript.append( | |
| { | |
| "action": action.model_dump(), | |
| "reward": reward.model_dump(), | |
| "task_score": info["task_score"], | |
| "done": done, | |
| } | |
| ) | |
| last_info = info | |
| results.append( | |
| BaselineResult( | |
| task_id=task_id, | |
| difficulty=observation.difficulty, | |
| score=float(last_info.get("task_score", 0.0)), | |
| steps=int(last_info.get("step_count", 0)), | |
| transcript=transcript, | |
| ) | |
| ) | |
| payload = { | |
| "baseline": "rule_based", | |
| "average_score": round(sum(item.score for item in results) / len(results), 4), | |
| "results": [item.model_dump() for item in results], | |
| } | |
| output_path = Path(args.output) | |
| output_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| print(json.dumps(payload, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |