Coding Ninja commited on
Commit
eae2b1d
·
1 Parent(s): 3752981

March 30 - April 1st : sever/

Browse files
Files changed (2) hide show
  1. inference.py +104 -28
  2. server/environment.py +19 -4
inference.py CHANGED
@@ -163,6 +163,75 @@ KEYWORD_ISSUE_TYPES = {
163
  "export": "feature_request",
164
  }
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def heuristic_action(ticket: dict, allowed_fields: list[str]) -> dict:
167
  text = (ticket.get("title", "") + " " + ticket.get("description", "")).lower()
168
 
@@ -172,13 +241,8 @@ def heuristic_action(ticket: dict, allowed_fields: list[str]) -> dict:
172
  issue_type = mapped_issue_type
173
  break
174
 
175
- priority = "medium"
176
- if any(w in text for w in ["urgent", "critical", "blocking", "asap", "immediately"]):
177
- priority = "critical"
178
- elif any(w in text for w in ["important", "high priority", "revenue"]):
179
- priority = "high"
180
- elif any(w in text for w in ["low", "whenever", "no rush"]):
181
- priority = "low"
182
 
183
  result: dict = {}
184
  if "issue_type" in allowed_fields:
@@ -190,12 +254,26 @@ def heuristic_action(ticket: dict, allowed_fields: list[str]) -> dict:
190
  issue_type, "service_desk"
191
  )
192
  if "resolution_action" in allowed_fields:
193
- result["resolution_action"] = ISSUE_TYPE_TO_RESOLUTION_ACTION.get(
194
- issue_type, "acknowledge"
195
- )
196
  return result
197
 
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  # ---------------------------------------------------------------------------
200
  # Main loop using WebSocket client for multi-step episodes
201
  # ---------------------------------------------------------------------------
@@ -213,7 +291,7 @@ def run():
213
  print(f"Available tasks: {[t['name'] for t in available_tasks.values()]}")
214
  http.close()
215
 
216
- all_scores: dict[int, list[float]] = {}
217
 
218
  for task_id in TASKS:
219
  if task_id not in available_tasks:
@@ -229,7 +307,7 @@ def run():
229
  result = sync_client.reset(seed=SEED, task_id=task_id)
230
  obs = result.observation
231
 
232
- task_scores: list[float] = []
233
  step_num = 0
234
 
235
  while not result.done:
@@ -240,12 +318,7 @@ def run():
240
  allowed = obs.allowed_fields
241
  instructions = obs.instructions
242
 
243
- if llm_client is not None:
244
- action_dict = call_llm(ticket, allowed, instructions)
245
- else:
246
- action_dict = heuristic_action(ticket, allowed)
247
-
248
- action = HelpdeskTicketAction(**action_dict)
249
  result = sync_client.step(action)
250
  obs = result.observation
251
 
@@ -253,21 +326,24 @@ def run():
253
  print(f" Step {step_num}: reward={result.reward} done={result.done}")
254
 
255
  if result.reward is not None:
256
- task_scores.append(result.reward)
257
 
258
- all_scores[task_id] = task_scores
259
- final = task_scores[-1] if task_scores else 0.0
260
- print(f" Task {task_id} final reward: {final:.4f}")
 
 
 
261
 
262
  # Summary
263
  print("\n=== RESULTS ===")
264
- overall = []
265
  for tid in TASKS:
266
- if tid in all_scores:
267
- scores = all_scores[tid]
268
- avg = sum(scores) / len(scores) if scores else 0.0
269
- overall.append(avg)
270
- print(f"Task {tid}: avg_score={avg:.4f} ({len(scores)} steps)")
271
  if overall:
272
  print(f"Overall: {sum(overall) / len(overall):.4f}")
273
 
 
163
  "export": "feature_request",
164
  }
165
 
166
+
167
+ CRITICAL_PRIORITY_KEYWORDS = (
168
+ "urgent",
169
+ "critical",
170
+ "blocking",
171
+ "asap",
172
+ "immediately",
173
+ "locked out",
174
+ "outage",
175
+ )
176
+
177
+ HIGH_PRIORITY_KEYWORDS = (
178
+ "important",
179
+ "high priority",
180
+ "revenue",
181
+ "today",
182
+ "eod",
183
+ )
184
+
185
+ LOW_PRIORITY_KEYWORDS = ("low", "whenever", "no rush")
186
+
187
+ ESCALATE_KEYWORDS = (
188
+ "refund",
189
+ "charged twice",
190
+ "still haven't",
191
+ "following up",
192
+ "needs immediate resolution",
193
+ "locked out",
194
+ "suspended",
195
+ "legal",
196
+ )
197
+
198
+ FULFILL_KEYWORDS = (
199
+ "please provide",
200
+ "confirmation",
201
+ "data processing addendum",
202
+ "guidance",
203
+ "fix",
204
+ "reproducible",
205
+ "outage",
206
+ "policy",
207
+ "mfa enabled",
208
+ )
209
+
210
+
211
+ def heuristic_priority(text: str) -> str:
212
+ if any(word in text for word in CRITICAL_PRIORITY_KEYWORDS):
213
+ return "critical"
214
+ if any(word in text for word in HIGH_PRIORITY_KEYWORDS):
215
+ return "high"
216
+ if any(word in text for word in LOW_PRIORITY_KEYWORDS):
217
+ return "low"
218
+ return "medium"
219
+
220
+
221
+ def heuristic_resolution_action(text: str, issue_type: str) -> str:
222
+ if issue_type == "spam_phishing":
223
+ return "ignore"
224
+ if issue_type == "service_request":
225
+ return "assign"
226
+ if issue_type in {"general_inquiry", "feature_request"}:
227
+ return "acknowledge"
228
+ if any(keyword in text for keyword in ESCALATE_KEYWORDS):
229
+ return "escalate"
230
+ if any(keyword in text for keyword in FULFILL_KEYWORDS):
231
+ return "fulfill"
232
+ return ISSUE_TYPE_TO_RESOLUTION_ACTION.get(issue_type, "acknowledge")
233
+
234
+
235
  def heuristic_action(ticket: dict, allowed_fields: list[str]) -> dict:
236
  text = (ticket.get("title", "") + " " + ticket.get("description", "")).lower()
237
 
 
241
  issue_type = mapped_issue_type
242
  break
243
 
244
+ priority = heuristic_priority(text)
245
+ resolution_action = heuristic_resolution_action(text, issue_type)
 
 
 
 
 
246
 
247
  result: dict = {}
248
  if "issue_type" in allowed_fields:
 
254
  issue_type, "service_desk"
255
  )
256
  if "resolution_action" in allowed_fields:
257
+ result["resolution_action"] = resolution_action
 
 
258
  return result
259
 
260
 
261
+ def build_action(
262
+ ticket: dict, allowed_fields: list[str], instructions: str
263
+ ) -> HelpdeskTicketAction:
264
+ heuristic_dict = heuristic_action(ticket, allowed_fields)
265
+
266
+ if llm_client is None:
267
+ return HelpdeskTicketAction(**heuristic_dict)
268
+
269
+ llm_dict = call_llm(ticket, allowed_fields, instructions)
270
+ try:
271
+ return HelpdeskTicketAction(**llm_dict)
272
+ except Exception as exc:
273
+ print(f" Falling back to heuristic action due to invalid LLM output: {exc}")
274
+ return HelpdeskTicketAction(**heuristic_dict)
275
+
276
+
277
  # ---------------------------------------------------------------------------
278
  # Main loop using WebSocket client for multi-step episodes
279
  # ---------------------------------------------------------------------------
 
291
  print(f"Available tasks: {[t['name'] for t in available_tasks.values()]}")
292
  http.close()
293
 
294
+ all_results: dict[int, dict[str, float | int]] = {}
295
 
296
  for task_id in TASKS:
297
  if task_id not in available_tasks:
 
307
  result = sync_client.reset(seed=SEED, task_id=task_id)
308
  obs = result.observation
309
 
310
+ task_step_rewards: list[float] = []
311
  step_num = 0
312
 
313
  while not result.done:
 
318
  allowed = obs.allowed_fields
319
  instructions = obs.instructions
320
 
321
+ action = build_action(ticket, allowed, instructions)
 
 
 
 
 
322
  result = sync_client.step(action)
323
  obs = result.observation
324
 
 
326
  print(f" Step {step_num}: reward={result.reward} done={result.done}")
327
 
328
  if result.reward is not None:
329
+ task_step_rewards.append(float(result.reward))
330
 
331
+ final_reward = task_step_rewards[-1] if task_step_rewards else 0.0
332
+ all_results[task_id] = {
333
+ "final_reward": final_reward,
334
+ "step_count": step_num,
335
+ }
336
+ print(f" Task {task_id} final reward: {final_reward:.4f}")
337
 
338
  # Summary
339
  print("\n=== RESULTS ===")
340
+ overall: list[float] = []
341
  for tid in TASKS:
342
+ if tid in all_results:
343
+ final_reward = float(all_results[tid]["final_reward"])
344
+ step_count = int(all_results[tid]["step_count"])
345
+ overall.append(final_reward)
346
+ print(f"Task {tid}: final_reward={final_reward:.4f} ({step_count} steps)")
347
  if overall:
348
  print(f"Overall: {sum(overall) / len(overall):.4f}")
349
 
server/environment.py CHANGED
@@ -20,6 +20,19 @@ from server.tasks import get_task_definition, load_dataset
20
  QUEUE_SIZE_RANGE = (3, 5)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class HelpdeskTicketRoutingEnvironment(
24
  Environment[HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState]
25
  ):
@@ -40,11 +53,13 @@ class HelpdeskTicketRoutingEnvironment(
40
  episode_id: Optional[str] = None,
41
  **kwargs: Any,
42
  ) -> HelpdeskTicketObservation:
43
- task_id: int = kwargs.get("task_id", 1)
 
 
44
  task = get_task_definition(task_id)
45
 
46
- if seed is not None:
47
- self._rng.seed(seed)
48
 
49
  queue_size = self._rng.randint(*QUEUE_SIZE_RANGE)
50
  self._queue = self._rng.sample(self._dataset, min(queue_size, len(self._dataset)))
@@ -53,7 +68,7 @@ class HelpdeskTicketRoutingEnvironment(
53
  episode_id=episode_id or str(uuid.uuid4()),
54
  step_count=0,
55
  current_task_id=task_id,
56
- seed=seed,
57
  queue_ticket_ids=[t.ticket_id for t in self._queue],
58
  current_ticket_index=0,
59
  per_ticket_scores=[],
 
20
  QUEUE_SIZE_RANGE = (3, 5)
21
 
22
 
23
+ def _coerce_optional_int(value: Any, field_name: str) -> Optional[int]:
24
+ if value is None or value == "":
25
+ return None
26
+ if isinstance(value, bool):
27
+ raise ValueError(f"{field_name} must be an integer")
28
+ if isinstance(value, int):
29
+ return value
30
+ try:
31
+ return int(value)
32
+ except (TypeError, ValueError) as exc:
33
+ raise ValueError(f"{field_name} must be an integer") from exc
34
+
35
+
36
  class HelpdeskTicketRoutingEnvironment(
37
  Environment[HelpdeskTicketAction, HelpdeskTicketObservation, HelpdeskTicketState]
38
  ):
 
53
  episode_id: Optional[str] = None,
54
  **kwargs: Any,
55
  ) -> HelpdeskTicketObservation:
56
+ normalized_seed = _coerce_optional_int(seed, "seed")
57
+ task_id_value = _coerce_optional_int(kwargs.get("task_id", 1), "task_id")
58
+ task_id = 1 if task_id_value is None else task_id_value
59
  task = get_task_definition(task_id)
60
 
61
+ if normalized_seed is not None:
62
+ self._rng.seed(normalized_seed)
63
 
64
  queue_size = self._rng.randint(*QUEUE_SIZE_RANGE)
65
  self._queue = self._rng.sample(self._dataset, min(queue_size, len(self._dataset)))
 
68
  episode_id=episode_id or str(uuid.uuid4()),
69
  step_count=0,
70
  current_task_id=task_id,
71
+ seed=normalized_seed,
72
  queue_ticket_ids=[t.ticket_id for t in self._queue],
73
  current_ticket_index=0,
74
  per_ticket_scores=[],