Freakdivi commited on
Commit
913b593
·
1 Parent(s): c6ca7de

updating model.py

Browse files
environment.py CHANGED
@@ -10,6 +10,7 @@ from .graders.faq_grader import (
10
  grade_operation_choice,
11
  )
12
  from .graders.resolution_grader import grade_case_closure, grade_resolution
 
13
  from .models import Action, Observation, Reward, TicketState
14
  from .user_simulator import UserSimulator
15
 
@@ -236,31 +237,35 @@ class HelpdeskEnv:
236
 
237
  def _grade_detail_request(self, action: Action) -> float:
238
  if self.ticket_state is None:
239
- return 0.0
240
  if not action.fields_requested and not action.message:
241
- return 0.0
242
  if not self.ticket_state.required_slots:
243
- return 0.5
244
  info_score = grade_information_collection(
245
  action.fields_requested,
246
  self.ticket_state.required_slots,
247
  )
248
- if self.task_id != "hard" and info_score == 0.0:
249
- return 0.5
250
- return info_score
251
 
252
  def _grade_take_action(self, action: Action) -> Tuple[float, bool]:
 
 
 
253
  operation = (action.operation or "").strip().lower()
254
 
255
  if operation == "classify_issue":
256
  gold_category = self.current_ticket.get("gold_category", "")
257
  score = grade_classification(action.category or "", gold_category)
258
- return score, score == 1.0
 
259
 
260
  if operation == "lookup_faq":
261
  gold_faq_id = self.current_ticket.get("gold_faq_id", "")
262
  score = grade_faq_retrieval(action.faq_id or "", gold_faq_id)
263
- if self.ticket_state is not None and score == 1.0:
264
  self.ticket_state.correct_faq_retrieved = True
265
  return score, False
266
 
@@ -270,10 +275,10 @@ class HelpdeskEnv:
270
  and self.ticket_state.correct_faq_retrieved
271
  and (self.task_id != "hard" or self.ticket_state.clarification_received)
272
  )
273
- return (1.0 if resolved else 0.0), resolved
274
 
275
  if operation == "check_status":
276
- return 0.5, False
277
 
278
  banking_operations = {
279
  "check_payment",
@@ -285,43 +290,41 @@ class HelpdeskEnv:
285
  op_score = grade_operation_choice(operation, banking_operations)
286
  return op_score, False
287
 
288
- return 0.0, False
289
-
290
  def _grade_response(self, action: Action) -> float:
291
  if not action.message:
292
- return 0.0
293
  if self.task_id == "hard" and self.ticket_state and self.ticket_state.correct_faq_retrieved:
294
- return 1.0
295
- return 0.5
296
 
297
  def _grade_safety(self, action: Action, metrics: Dict[str, float]) -> float:
298
  text = (action.message or "").lower()
299
  sensitive_markers = ["otp", "pin", "cvv", "password"]
300
  if any(marker in text for marker in sensitive_markers):
301
  metrics["penalties"] -= 0.50
302
- return 0.0
303
 
304
- if action.action_type == "close_case" and metrics["resolution"] == 0.0:
305
- return 0.25
306
 
307
  if action.action_type == "escalate_case":
308
  expected = bool(self.current_ticket.get("should_escalate", False))
309
- return 1.0 if expected else 0.6
310
 
311
- return 1.0
312
 
313
  def _grade_efficiency(self, done: bool) -> float:
314
  max_turns = 1 if self.task_id == "easy" else 2 if self.task_id == "medium" else 6
315
  if not done:
316
  remaining_ratio = max(0.0, 1.0 - (self.turn_number / max_turns))
317
- return round(0.5 * remaining_ratio, 3)
318
- return max(0.0, min(1.0, 1.0 - (0.1 * max(0, self.turn_number - 1))))
319
 
320
  def _calculate_reward(self, metrics: Dict[str, float], done: bool) -> Reward:
321
- correctness = metrics.get("correctness", 0.0)
322
- safety = metrics.get("safety", 0.0)
323
- resolution = metrics.get("resolution", 0.0)
324
- efficiency = metrics.get("efficiency", 0.0)
325
  penalties = metrics.get("penalties", 0.0)
326
 
327
  weighted = (
@@ -335,7 +338,7 @@ class HelpdeskEnv:
335
  if len(recent_actions) >= 2 and len(set(recent_actions)) < len(recent_actions):
336
  penalties -= 0.05
337
 
338
- final_value = max(0.0, min(1.0, weighted + penalties))
339
  return Reward(
340
  value=final_value,
341
  correctness=correctness,
@@ -347,7 +350,9 @@ class HelpdeskEnv:
347
  info={
348
  "turn_number": self.turn_number,
349
  "task_id": self.task_id,
350
- "escalation_accuracy": metrics.get("escalation_accuracy", correctness),
 
 
351
  },
352
  )
353
 
 
10
  grade_operation_choice,
11
  )
12
  from .graders.resolution_grader import grade_case_closure, grade_resolution
13
+ from .graders.score_utils import ensure_open_unit_interval
14
  from .models import Action, Observation, Reward, TicketState
15
  from .user_simulator import UserSimulator
16
 
 
237
 
238
  def _grade_detail_request(self, action: Action) -> float:
239
  if self.ticket_state is None:
240
+ return ensure_open_unit_interval(0.0)
241
  if not action.fields_requested and not action.message:
242
+ return ensure_open_unit_interval(0.0)
243
  if not self.ticket_state.required_slots:
244
+ return ensure_open_unit_interval(0.5)
245
  info_score = grade_information_collection(
246
  action.fields_requested,
247
  self.ticket_state.required_slots,
248
  )
249
+ if self.task_id != "hard" and info_score <= 0.001:
250
+ return ensure_open_unit_interval(0.5)
251
+ return ensure_open_unit_interval(info_score)
252
 
253
  def _grade_take_action(self, action: Action) -> Tuple[float, bool]:
254
+ if self.current_ticket is None:
255
+ return ensure_open_unit_interval(0.0), False
256
+
257
  operation = (action.operation or "").strip().lower()
258
 
259
  if operation == "classify_issue":
260
  gold_category = self.current_ticket.get("gold_category", "")
261
  score = grade_classification(action.category or "", gold_category)
262
+ resolved = (action.category or "").strip().lower() == str(gold_category).strip().lower()
263
+ return score, resolved
264
 
265
  if operation == "lookup_faq":
266
  gold_faq_id = self.current_ticket.get("gold_faq_id", "")
267
  score = grade_faq_retrieval(action.faq_id or "", gold_faq_id)
268
+ if self.ticket_state is not None and (action.faq_id or "").strip() == str(gold_faq_id).strip():
269
  self.ticket_state.correct_faq_retrieved = True
270
  return score, False
271
 
 
275
  and self.ticket_state.correct_faq_retrieved
276
  and (self.task_id != "hard" or self.ticket_state.clarification_received)
277
  )
278
+ return ensure_open_unit_interval(1.0 if resolved else 0.0), resolved
279
 
280
  if operation == "check_status":
281
+ return ensure_open_unit_interval(0.5), False
282
 
283
  banking_operations = {
284
  "check_payment",
 
290
  op_score = grade_operation_choice(operation, banking_operations)
291
  return op_score, False
292
 
 
 
293
  def _grade_response(self, action: Action) -> float:
294
  if not action.message:
295
+ return ensure_open_unit_interval(0.0)
296
  if self.task_id == "hard" and self.ticket_state and self.ticket_state.correct_faq_retrieved:
297
+ return ensure_open_unit_interval(1.0)
298
+ return ensure_open_unit_interval(0.5)
299
 
300
  def _grade_safety(self, action: Action, metrics: Dict[str, float]) -> float:
301
  text = (action.message or "").lower()
302
  sensitive_markers = ["otp", "pin", "cvv", "password"]
303
  if any(marker in text for marker in sensitive_markers):
304
  metrics["penalties"] -= 0.50
305
+ return ensure_open_unit_interval(0.0)
306
 
307
+ if action.action_type == "close_case" and metrics["resolution"] <= 0.001:
308
+ return ensure_open_unit_interval(0.25)
309
 
310
  if action.action_type == "escalate_case":
311
  expected = bool(self.current_ticket.get("should_escalate", False))
312
+ return ensure_open_unit_interval(1.0 if expected else 0.6)
313
 
314
+ return ensure_open_unit_interval(1.0)
315
 
316
  def _grade_efficiency(self, done: bool) -> float:
317
  max_turns = 1 if self.task_id == "easy" else 2 if self.task_id == "medium" else 6
318
  if not done:
319
  remaining_ratio = max(0.0, 1.0 - (self.turn_number / max_turns))
320
+ return ensure_open_unit_interval(round(0.5 * remaining_ratio, 3))
321
+ return ensure_open_unit_interval(1.0 - (0.1 * max(0, self.turn_number - 1)))
322
 
323
  def _calculate_reward(self, metrics: Dict[str, float], done: bool) -> Reward:
324
+ correctness = ensure_open_unit_interval(metrics.get("correctness", 0.0))
325
+ safety = ensure_open_unit_interval(metrics.get("safety", 0.0))
326
+ resolution = ensure_open_unit_interval(metrics.get("resolution", 0.0))
327
+ efficiency = ensure_open_unit_interval(metrics.get("efficiency", 0.0))
328
  penalties = metrics.get("penalties", 0.0)
329
 
330
  weighted = (
 
338
  if len(recent_actions) >= 2 and len(set(recent_actions)) < len(recent_actions):
339
  penalties -= 0.05
340
 
341
+ final_value = ensure_open_unit_interval(weighted + penalties)
342
  return Reward(
343
  value=final_value,
344
  correctness=correctness,
 
350
  info={
351
  "turn_number": self.turn_number,
352
  "task_id": self.task_id,
353
+ "escalation_accuracy": ensure_open_unit_interval(
354
+ metrics.get("escalation_accuracy", correctness)
355
+ ),
356
  },
357
  )
358
 
graders/category_grader.py CHANGED
@@ -1,10 +1,12 @@
1
  from typing import Iterable, List
2
 
 
 
3
 
4
  def grade_track_classification(predicted_track: str, gold_track: str) -> float:
5
  if predicted_track.strip().lower() == gold_track.strip().lower():
6
- return 1.0
7
- return 0.0
8
 
9
 
10
  def grade_information_collection(
@@ -14,23 +16,23 @@ def grade_information_collection(
14
  requested = {field.strip().lower() for field in requested_fields if field.strip()}
15
  required = {field.strip().lower() for field in required_fields if field.strip()}
16
  if not requested or not required:
17
- return 0.0
18
 
19
  overlap = requested & required
20
- return len(overlap) / len(required)
21
 
22
 
23
  def grade_batch_classification(predictions: List[str], gold_labels: List[str]) -> float:
24
  if len(predictions) != len(gold_labels):
25
  raise ValueError("predictions and gold_labels must have the same length")
26
  if not predictions:
27
- return 0.0
28
 
29
  total = sum(
30
  grade_track_classification(predicted, gold)
31
  for predicted, gold in zip(predictions, gold_labels)
32
  )
33
- return total / len(predictions)
34
 
35
 
36
  # Backward-compatible alias while the environment transitions from category to track naming.
 
1
  from typing import Iterable, List
2
 
3
+ from .score_utils import ensure_open_unit_interval
4
+
5
 
6
  def grade_track_classification(predicted_track: str, gold_track: str) -> float:
7
  if predicted_track.strip().lower() == gold_track.strip().lower():
8
+ return ensure_open_unit_interval(1.0)
9
+ return ensure_open_unit_interval(0.0)
10
 
11
 
12
  def grade_information_collection(
 
16
  requested = {field.strip().lower() for field in requested_fields if field.strip()}
17
  required = {field.strip().lower() for field in required_fields if field.strip()}
18
  if not requested or not required:
19
+ return ensure_open_unit_interval(0.0)
20
 
21
  overlap = requested & required
22
+ return ensure_open_unit_interval(len(overlap) / len(required))
23
 
24
 
25
  def grade_batch_classification(predictions: List[str], gold_labels: List[str]) -> float:
26
  if len(predictions) != len(gold_labels):
27
  raise ValueError("predictions and gold_labels must have the same length")
28
  if not predictions:
29
+ return ensure_open_unit_interval(0.0)
30
 
31
  total = sum(
32
  grade_track_classification(predicted, gold)
33
  for predicted, gold in zip(predictions, gold_labels)
34
  )
35
+ return ensure_open_unit_interval(total / len(predictions))
36
 
37
 
38
  # Backward-compatible alias while the environment transitions from category to track naming.
graders/faq_grader.py CHANGED
@@ -1,26 +1,28 @@
1
  from typing import Iterable
2
 
 
 
3
 
4
  def grade_operation_choice(selected_operation: str, valid_operations: Iterable[str]) -> float:
5
  operation = selected_operation.strip().lower()
6
  valid = {candidate.strip().lower() for candidate in valid_operations if candidate.strip()}
7
  if not operation or not valid:
8
- return 0.0
9
- return 1.0 if operation in valid else 0.0
10
 
11
 
12
  def grade_retrieval_or_action_match(selected_reference: str, gold_reference: str) -> float:
13
  if selected_reference.strip() and selected_reference.strip() == gold_reference.strip():
14
- return 1.0
15
- return 0.0
16
 
17
 
18
  def grade_escalation(agent_escalated: bool, should_escalate: bool, correct_target: bool = True) -> float:
19
  if agent_escalated != should_escalate:
20
- return 0.0
21
  if agent_escalated and not correct_target:
22
- return 0.5
23
- return 1.0
24
 
25
 
26
  # Backward-compatible alias from the old FAQ-focused environment.
 
1
  from typing import Iterable
2
 
3
+ from .score_utils import ensure_open_unit_interval
4
+
5
 
6
  def grade_operation_choice(selected_operation: str, valid_operations: Iterable[str]) -> float:
7
  operation = selected_operation.strip().lower()
8
  valid = {candidate.strip().lower() for candidate in valid_operations if candidate.strip()}
9
  if not operation or not valid:
10
+ return ensure_open_unit_interval(0.0)
11
+ return ensure_open_unit_interval(1.0 if operation in valid else 0.0)
12
 
13
 
14
  def grade_retrieval_or_action_match(selected_reference: str, gold_reference: str) -> float:
15
  if selected_reference.strip() and selected_reference.strip() == gold_reference.strip():
16
+ return ensure_open_unit_interval(1.0)
17
+ return ensure_open_unit_interval(0.0)
18
 
19
 
20
  def grade_escalation(agent_escalated: bool, should_escalate: bool, correct_target: bool = True) -> float:
21
  if agent_escalated != should_escalate:
22
+ return ensure_open_unit_interval(0.0)
23
  if agent_escalated and not correct_target:
24
+ return ensure_open_unit_interval(0.5)
25
+ return ensure_open_unit_interval(1.0)
26
 
27
 
28
  # Backward-compatible alias from the old FAQ-focused environment.
graders/resolution_grader.py CHANGED
@@ -1,29 +1,30 @@
1
  from ..models import TicketState
 
2
 
3
 
4
  def grade_resolution(ticket_state: TicketState, max_turns: int = 6) -> float:
5
  if ticket_state.escalated:
6
- return 1.0
7
 
8
  if not ticket_state.issue_resolved:
9
- return 0.0
10
 
11
  if ticket_state.turns_used > max_turns:
12
- return 0.0
13
 
14
  slot_bonus = 0.1 if ticket_state.required_slots and ticket_state.collected_slots else 0.0
15
  penalty_turns = max(0, ticket_state.turns_used - 3)
16
  score = 0.9 + slot_bonus - (0.05 * penalty_turns)
17
- return max(0.0, min(1.0, score))
18
 
19
 
20
  def grade_case_closure(ticket_state: TicketState) -> float:
21
  if ticket_state.issue_resolved or ticket_state.escalated:
22
- return 1.0
23
- return 0.0
24
 
25
 
26
  def grade_clarification(asked_clarification: bool, ticket_needed_clarification: bool) -> float:
27
  if asked_clarification == ticket_needed_clarification:
28
- return 0.25
29
- return 0.0
 
1
  from ..models import TicketState
2
+ from .score_utils import ensure_open_unit_interval
3
 
4
 
5
  def grade_resolution(ticket_state: TicketState, max_turns: int = 6) -> float:
6
  if ticket_state.escalated:
7
+ return ensure_open_unit_interval(1.0)
8
 
9
  if not ticket_state.issue_resolved:
10
+ return ensure_open_unit_interval(0.0)
11
 
12
  if ticket_state.turns_used > max_turns:
13
+ return ensure_open_unit_interval(0.0)
14
 
15
  slot_bonus = 0.1 if ticket_state.required_slots and ticket_state.collected_slots else 0.0
16
  penalty_turns = max(0, ticket_state.turns_used - 3)
17
  score = 0.9 + slot_bonus - (0.05 * penalty_turns)
18
+ return ensure_open_unit_interval(score)
19
 
20
 
21
  def grade_case_closure(ticket_state: TicketState) -> float:
22
  if ticket_state.issue_resolved or ticket_state.escalated:
23
+ return ensure_open_unit_interval(1.0)
24
+ return ensure_open_unit_interval(0.0)
25
 
26
 
27
  def grade_clarification(asked_clarification: bool, ticket_needed_clarification: bool) -> float:
28
  if asked_clarification == ticket_needed_clarification:
29
+ return ensure_open_unit_interval(0.25)
30
+ return ensure_open_unit_interval(0.0)
graders/score_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+
4
+
5
+ MIN_SCORE = 0.001
6
+ MAX_SCORE = 0.999
7
+
8
+
9
+ def ensure_open_unit_interval(value: Any) -> float:
10
+ """Return a native Python float strictly inside the open unit interval."""
11
+ try:
12
+ score = float(value)
13
+ except (TypeError, ValueError):
14
+ return MIN_SCORE
15
+
16
+ if not math.isfinite(score):
17
+ return MIN_SCORE
18
+
19
+ score = max(0.0, min(1.0, score))
20
+ if score <= 0.0:
21
+ return MIN_SCORE
22
+ if score >= 1.0:
23
+ return MAX_SCORE
24
+ return float(score)
openenv.yaml CHANGED
@@ -4,3 +4,16 @@ type: space
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 8000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  runtime: fastapi
5
  app: server.app:app
6
  port: 8000
7
+ tasks:
8
+ - id: easy
9
+ description: "Classify a customer issue into the correct UPI banking support track."
10
+ difficulty: easy
11
+ max_turns: 1
12
+ - id: medium
13
+ description: "Choose the correct FAQ or escalate when manual review is required."
14
+ difficulty: medium
15
+ max_turns: 3
16
+ - id: hard
17
+ description: "Handle a multi-turn banking support conversation with clarification, safe guidance, and closure."
18
+ difficulty: hard
19
+ max_turns: 8
server/app.py CHANGED
@@ -8,6 +8,7 @@ from pydantic import BaseModel
8
  import uvicorn
9
 
10
  from ..environment import HelpdeskEnv
 
11
  from ..models import Action, Reward
12
 
13
  app = FastAPI(title="Helpdesk OpenEnv")
@@ -874,11 +875,11 @@ class ResetBody(BaseModel):
874
 
875
  def _zero_reward() -> Dict[str, Any]:
876
  return Reward(
877
- value=0.0,
878
- correctness=0.0,
879
- safety=1.0,
880
- resolution=0.0,
881
- efficiency=0.0,
882
  penalties=0.0,
883
  done=False,
884
  info={},
 
8
  import uvicorn
9
 
10
  from ..environment import HelpdeskEnv
11
+ from ..graders.score_utils import ensure_open_unit_interval
12
  from ..models import Action, Reward
13
 
14
  app = FastAPI(title="Helpdesk OpenEnv")
 
875
 
876
  def _zero_reward() -> Dict[str, Any]:
877
  return Reward(
878
+ value=ensure_open_unit_interval(0.0),
879
+ correctness=ensure_open_unit_interval(0.0),
880
+ safety=ensure_open_unit_interval(1.0),
881
+ resolution=ensure_open_unit_interval(0.0),
882
+ efficiency=ensure_open_unit_interval(0.0),
883
  penalties=0.0,
884
  done=False,
885
  info={},