Addy897 commited on
Commit
841976b
·
1 Parent(s): ebc798b

inference update

Browse files
inference.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import json
2
  import os
3
  import textwrap
@@ -8,15 +10,14 @@ from openai import OpenAI
8
  from support_ops_env.env import SupportOpsEnv
9
  from support_ops_env.models import Action, Observation
10
  from support_ops_env.tasks import list_task_ids
11
- from dotenv import load_dotenv
12
- load_dotenv()
13
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
14
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
15
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
16
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
17
  TASK_NAME = os.getenv("SUPPORT_OPS_TASK", "easy_account_takeover")
18
  BENCHMARK = os.getenv("SUPPORT_OPS_BENCHMARK", "support_ops_env")
19
- MAX_STEPS = int(os.getenv("MAX_STEPS", "16"))
20
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.1"))
21
  MAX_TOKENS = int(os.getenv("MAX_TOKENS", "220"))
22
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.8"))
@@ -28,6 +29,7 @@ SYSTEM_PROMPT = textwrap.dedent(
28
  """
29
  You are operating a customer support triage environment.
30
  Return exactly one JSON object with keys: action_type, target, value.
 
31
  Allowed action_type values:
32
  - inspect_ticket
33
  - request_context
@@ -37,9 +39,49 @@ SYSTEM_PROMPT = textwrap.dedent(
37
  - escalate
38
  - rank_queue
39
  - finalize
40
- Choose only valid ticket ids from the observation.
41
- Use concise string values.
42
- Finalize only after enough evidence is gathered.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
  ).strip()
45
 
@@ -65,23 +107,28 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
65
  )
66
 
67
 
68
- def build_user_prompt(observation: Observation, step: int, rewards: List[float]) -> str:
69
  reward_history = ",".join(f"{reward:.2f}" for reward in rewards[-5:]) if rewards else "none"
 
70
  return textwrap.dedent(
71
  f"""
72
  Step: {step}
73
  Task: {observation.task_id}
74
  Difficulty: {observation.difficulty}
75
  Reward history: {reward_history}
 
 
 
 
76
  Observation JSON:
77
  {json.dumps(observation.model_dump(), indent=2, sort_keys=True)}
78
- Return one JSON action.
79
  """
80
  ).strip()
81
 
82
 
83
- def get_model_action(client: OpenAI, observation: Observation, step: int, rewards: List[float]) -> tuple[Action, Optional[str]]:
84
- user_prompt = build_user_prompt(observation, step, rewards)
85
  try:
86
  completion = client.chat.completions.create(
87
  model=MODEL_NAME,
@@ -130,6 +177,7 @@ def run_task(client: OpenAI, task_name: str) -> dict:
130
  """Run a single task and return a result dict."""
131
  env = SupportOpsEnv(task_id=task_name)
132
  rewards: List[float] = []
 
133
  steps_taken = 0
134
  score = 0.0
135
  success = False
@@ -140,8 +188,9 @@ def run_task(client: OpenAI, task_name: str) -> dict:
140
  observation = env.reset(task_id=task_name)
141
 
142
  for step in range(1, MAX_STEPS + 1):
143
- action, action_error = get_model_action(client, observation, step, rewards)
144
  action_str = json.dumps(action.model_dump(), separators=(",", ":"))
 
145
 
146
  observation, reward, done, info = env.step(action)
147
  reward_value = reward.value
@@ -173,7 +222,9 @@ def main() -> None:
173
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
174
 
175
  # Fix 2: run at least MIN_TASKS tasks so the grader has enough scored entries
176
- tasks = select_tasks(TASK_NAME)
 
 
177
 
178
  all_results = []
179
  for task_name in tasks:
 
1
+ from dotenv import load_dotenv
2
+ load_dotenv()
3
  import json
4
  import os
5
  import textwrap
 
10
  from support_ops_env.env import SupportOpsEnv
11
  from support_ops_env.models import Action, Observation
12
  from support_ops_env.tasks import list_task_ids
13
+
 
14
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
15
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
16
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
17
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
18
  TASK_NAME = os.getenv("SUPPORT_OPS_TASK", "easy_account_takeover")
19
  BENCHMARK = os.getenv("SUPPORT_OPS_BENCHMARK", "support_ops_env")
20
+ MAX_STEPS = int(os.getenv("MAX_STEPS", "24"))
21
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.1"))
22
  MAX_TOKENS = int(os.getenv("MAX_TOKENS", "220"))
23
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.8"))
 
29
  """
30
  You are operating a customer support triage environment.
31
  Return exactly one JSON object with keys: action_type, target, value.
32
+
33
  Allowed action_type values:
34
  - inspect_ticket
35
  - request_context
 
39
  - escalate
40
  - rank_queue
41
  - finalize
42
+
43
+ VALID VALUES you MUST use these exact strings:
44
+
45
+ priority values: urgent, high, normal, low
46
+ route values: account_security, monetization_compliance, billing_refunds, policy_appeals
47
+ resolution values: temporary_lock_and_manual_recovery, request_tax_renewal, approve_refund, expedited_human_review
48
+ escalation teams: security_specialist (only when account compromise is confirmed; omit otherwise)
49
+
50
+ ACTION FORMAT EXAMPLES — copy these exactly:
51
+ {"action_type": "inspect_ticket", "target": "T1", "value": ""}
52
+ {"action_type": "request_context", "target": "T1", "value": "tax_status"}
53
+ {"action_type": "set_priority", "target": "T1", "value": "urgent"}
54
+ {"action_type": "set_route", "target": "T1", "value": "account_security"}
55
+ {"action_type": "set_resolution", "target": "T1", "value": "temporary_lock_and_manual_recovery"}
56
+ {"action_type": "escalate", "target": "T1", "value": "security_specialist"}
57
+ {"action_type": "rank_queue", "target": "T1", "value": "T2,T1,T3"}
58
+ {"action_type": "finalize", "target": "T1", "value": ""}
59
+
60
+ CRITICAL: For request_context, target = ticket ID (e.g. "T1"), value = context key name.
61
+ NEVER put the context key name in target. target is ALWAYS a ticket ID.
62
+
63
+ WORKFLOW PER TICKET:
64
+ 1. inspect_ticket once (target=ticket_id, value="").
65
+ 2. request_context ONLY for keys in required_context_keys first (these affect your score).
66
+ Use target=ticket_id, value=key_name. Request each key at most once.
67
+ Do NOT request optional keys from available_context_keys — they give tiny reward
68
+ but waste steps you need for set_resolution, escalate, rank_queue, and finalize.
69
+ 3. set_priority, set_route, set_resolution using the VALID VALUES listed above.
70
+ Use the context you discovered to choose correctly.
71
+ 4. escalate only when account takeover / security compromise is confirmed.
72
+ 5. For queue tasks: rank_queue after processing all tickets (most urgent first).
73
+ 6. finalize (target=ticket_id, value="") when all tickets are done.
74
+
75
+ PRIORITY HINTS:
76
+ - Account takeover / fraud / SLA <= 2h → urgent
77
+ - Tax/compliance holds, payment issues / SLA <= 12h → high
78
+ - Routine appeals, refunds / SLA >= 24h → normal
79
+
80
+ STRICT RULES:
81
+ - NEVER repeat an action you have already taken (check your history).
82
+ - inspect_ticket AT MOST ONCE per ticket.
83
+ - target is ALWAYS a ticket ID like "T1". NEVER put a context key in target.
84
+ - Each request_context must use a different value (key name).
85
  """
86
  ).strip()
87
 
 
107
  )
108
 
109
 
110
+ def build_user_prompt(observation: Observation, step: int, rewards: List[float], action_history: List[str]) -> str:
111
  reward_history = ",".join(f"{reward:.2f}" for reward in rewards[-5:]) if rewards else "none"
112
+ history_str = "\n".join(f" {a}" for a in action_history) if action_history else " none"
113
  return textwrap.dedent(
114
  f"""
115
  Step: {step}
116
  Task: {observation.task_id}
117
  Difficulty: {observation.difficulty}
118
  Reward history: {reward_history}
119
+
120
+ Actions you have ALREADY taken this episode (do NOT repeat these):
121
+ {history_str}
122
+
123
  Observation JSON:
124
  {json.dumps(observation.model_dump(), indent=2, sort_keys=True)}
125
+ Return one JSON action that you have NOT already taken.
126
  """
127
  ).strip()
128
 
129
 
130
+ def get_model_action(client: OpenAI, observation: Observation, step: int, rewards: List[float], action_history: List[str]) -> tuple[Action, Optional[str]]:
131
+ user_prompt = build_user_prompt(observation, step, rewards, action_history)
132
  try:
133
  completion = client.chat.completions.create(
134
  model=MODEL_NAME,
 
177
  """Run a single task and return a result dict."""
178
  env = SupportOpsEnv(task_id=task_name)
179
  rewards: List[float] = []
180
+ action_history: List[str] = []
181
  steps_taken = 0
182
  score = 0.0
183
  success = False
 
188
  observation = env.reset(task_id=task_name)
189
 
190
  for step in range(1, MAX_STEPS + 1):
191
+ action, action_error = get_model_action(client, observation, step, rewards, action_history)
192
  action_str = json.dumps(action.model_dump(), separators=(",", ":"))
193
+ action_history.append(action_str)
194
 
195
  observation, reward, done, info = env.step(action)
196
  reward_value = reward.value
 
222
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
223
 
224
  # Fix 2: run at least MIN_TASKS tasks so the grader has enough scored entries
225
+ # Run in reverse difficulty order (hard first) so expensive tasks get credits
226
+ # while the budget is fresh, rather than always dying on the last task.
227
+ tasks = list(reversed(select_tasks(TASK_NAME)))
228
 
229
  all_results = []
230
  for task_name in tasks:
support_ops_env/__pycache__/env.cpython-314.pyc CHANGED
Binary files a/support_ops_env/__pycache__/env.cpython-314.pyc and b/support_ops_env/__pycache__/env.cpython-314.pyc differ
 
support_ops_env/__pycache__/models.cpython-314.pyc CHANGED
Binary files a/support_ops_env/__pycache__/models.cpython-314.pyc and b/support_ops_env/__pycache__/models.cpython-314.pyc differ
 
support_ops_env/data/hard_cases.json CHANGED
@@ -5,7 +5,7 @@
5
  "title": "Mixed Support Queue Triage",
6
  "description": "Prioritize a small queue of heterogeneous support tickets under SLA pressure and route each one correctly.",
7
  "instruction": "Inspect the queue, gather missing context where useful, assign the right priority and route for each ticket, set a valid resolution, rank the queue from most urgent to least urgent, and finalize.",
8
- "max_steps": 16,
9
  "queue_mode": true,
10
  "gold_queue_order": [
11
  "T2",
@@ -81,4 +81,4 @@
81
  }
82
  ]
83
  }
84
- ]
 
5
  "title": "Mixed Support Queue Triage",
6
  "description": "Prioritize a small queue of heterogeneous support tickets under SLA pressure and route each one correctly.",
7
  "instruction": "Inspect the queue, gather missing context where useful, assign the right priority and route for each ticket, set a valid resolution, rank the queue from most urgent to least urgent, and finalize.",
8
+ "max_steps": 24,
9
  "queue_mode": true,
10
  "gold_queue_order": [
11
  "T2",
 
81
  }
82
  ]
83
  }
84
+ ]
support_ops_env/env.py CHANGED
@@ -109,12 +109,15 @@ class SupportOpsEnv:
109
  for ticket in self._task.tickets:
110
  keys = self._state.discovered_keys.get(ticket.ticket_id, [])
111
  discovered_context = {key: ticket.hidden_context[key] for key in keys}
 
112
  tickets.append(
113
  TicketObservation(
114
  ticket_id=ticket.ticket_id,
115
  summary=ticket.summary,
116
  visible_context=ticket.visible_context,
117
  discovered_context=discovered_context,
 
 
118
  selected_priority=self._state.priorities.get(ticket.ticket_id),
119
  selected_route=self._state.routes.get(ticket.ticket_id),
120
  selected_resolution=self._state.resolutions.get(ticket.ticket_id),
 
109
  for ticket in self._task.tickets:
110
  keys = self._state.discovered_keys.get(ticket.ticket_id, [])
111
  discovered_context = {key: ticket.hidden_context[key] for key in keys}
112
+ available_keys = [k for k in ticket.hidden_context if k not in keys]
113
  tickets.append(
114
  TicketObservation(
115
  ticket_id=ticket.ticket_id,
116
  summary=ticket.summary,
117
  visible_context=ticket.visible_context,
118
  discovered_context=discovered_context,
119
+ available_context_keys=available_keys,
120
+ required_context_keys=[k for k in ticket.required_context if k not in keys],
121
  selected_priority=self._state.priorities.get(ticket.ticket_id),
122
  selected_route=self._state.routes.get(ticket.ticket_id),
123
  selected_resolution=self._state.resolutions.get(ticket.ticket_id),
support_ops_env/graders/__pycache__/common.cpython-314.pyc CHANGED
Binary files a/support_ops_env/graders/__pycache__/common.cpython-314.pyc and b/support_ops_env/graders/__pycache__/common.cpython-314.pyc differ
 
support_ops_env/models.py CHANGED
@@ -34,6 +34,8 @@ class TicketObservation(BaseModel):
34
  summary: str
35
  visible_context: Dict[str, str]
36
  discovered_context: Dict[str, str] = Field(default_factory=dict)
 
 
37
  selected_priority: Optional[str] = None
38
  selected_route: Optional[str] = None
39
  selected_resolution: Optional[str] = None
 
34
  summary: str
35
  visible_context: Dict[str, str]
36
  discovered_context: Dict[str, str] = Field(default_factory=dict)
37
+ available_context_keys: List[str] = Field(default_factory=list)
38
+ required_context_keys: List[str] = Field(default_factory=list)
39
  selected_priority: Optional[str] = None
40
  selected_route: Optional[str] = None
41
  selected_resolution: Optional[str] = None