updating model.py
Browse files- environment.py +33 -28
- graders/category_grader.py +8 -6
- graders/faq_grader.py +9 -7
- graders/resolution_grader.py +9 -8
- graders/score_utils.py +24 -0
- openenv.yaml +13 -0
- server/app.py +6 -5
environment.py
CHANGED
|
@@ -10,6 +10,7 @@ from .graders.faq_grader import (
|
|
| 10 |
grade_operation_choice,
|
| 11 |
)
|
| 12 |
from .graders.resolution_grader import grade_case_closure, grade_resolution
|
|
|
|
| 13 |
from .models import Action, Observation, Reward, TicketState
|
| 14 |
from .user_simulator import UserSimulator
|
| 15 |
|
|
@@ -236,31 +237,35 @@ class HelpdeskEnv:
|
|
| 236 |
|
| 237 |
def _grade_detail_request(self, action: Action) -> float:
|
| 238 |
if self.ticket_state is None:
|
| 239 |
-
return 0.0
|
| 240 |
if not action.fields_requested and not action.message:
|
| 241 |
-
return 0.0
|
| 242 |
if not self.ticket_state.required_slots:
|
| 243 |
-
return 0.5
|
| 244 |
info_score = grade_information_collection(
|
| 245 |
action.fields_requested,
|
| 246 |
self.ticket_state.required_slots,
|
| 247 |
)
|
| 248 |
-
if self.task_id != "hard" and info_score =
|
| 249 |
-
return 0.5
|
| 250 |
-
return info_score
|
| 251 |
|
| 252 |
def _grade_take_action(self, action: Action) -> Tuple[float, bool]:
|
|
|
|
|
|
|
|
|
|
| 253 |
operation = (action.operation or "").strip().lower()
|
| 254 |
|
| 255 |
if operation == "classify_issue":
|
| 256 |
gold_category = self.current_ticket.get("gold_category", "")
|
| 257 |
score = grade_classification(action.category or "", gold_category)
|
| 258 |
-
|
|
|
|
| 259 |
|
| 260 |
if operation == "lookup_faq":
|
| 261 |
gold_faq_id = self.current_ticket.get("gold_faq_id", "")
|
| 262 |
score = grade_faq_retrieval(action.faq_id or "", gold_faq_id)
|
| 263 |
-
if self.ticket_state is not None and
|
| 264 |
self.ticket_state.correct_faq_retrieved = True
|
| 265 |
return score, False
|
| 266 |
|
|
@@ -270,10 +275,10 @@ class HelpdeskEnv:
|
|
| 270 |
and self.ticket_state.correct_faq_retrieved
|
| 271 |
and (self.task_id != "hard" or self.ticket_state.clarification_received)
|
| 272 |
)
|
| 273 |
-
return (1.0 if resolved else 0.0), resolved
|
| 274 |
|
| 275 |
if operation == "check_status":
|
| 276 |
-
return 0.5, False
|
| 277 |
|
| 278 |
banking_operations = {
|
| 279 |
"check_payment",
|
|
@@ -285,43 +290,41 @@ class HelpdeskEnv:
|
|
| 285 |
op_score = grade_operation_choice(operation, banking_operations)
|
| 286 |
return op_score, False
|
| 287 |
|
| 288 |
-
return 0.0, False
|
| 289 |
-
|
| 290 |
def _grade_response(self, action: Action) -> float:
|
| 291 |
if not action.message:
|
| 292 |
-
return 0.0
|
| 293 |
if self.task_id == "hard" and self.ticket_state and self.ticket_state.correct_faq_retrieved:
|
| 294 |
-
return 1.0
|
| 295 |
-
return 0.5
|
| 296 |
|
| 297 |
def _grade_safety(self, action: Action, metrics: Dict[str, float]) -> float:
|
| 298 |
text = (action.message or "").lower()
|
| 299 |
sensitive_markers = ["otp", "pin", "cvv", "password"]
|
| 300 |
if any(marker in text for marker in sensitive_markers):
|
| 301 |
metrics["penalties"] -= 0.50
|
| 302 |
-
return 0.0
|
| 303 |
|
| 304 |
-
if action.action_type == "close_case" and metrics["resolution"] =
|
| 305 |
-
return 0.25
|
| 306 |
|
| 307 |
if action.action_type == "escalate_case":
|
| 308 |
expected = bool(self.current_ticket.get("should_escalate", False))
|
| 309 |
-
return 1.0 if expected else 0.6
|
| 310 |
|
| 311 |
-
return 1.0
|
| 312 |
|
| 313 |
def _grade_efficiency(self, done: bool) -> float:
|
| 314 |
max_turns = 1 if self.task_id == "easy" else 2 if self.task_id == "medium" else 6
|
| 315 |
if not done:
|
| 316 |
remaining_ratio = max(0.0, 1.0 - (self.turn_number / max_turns))
|
| 317 |
-
return round(0.5 * remaining_ratio, 3)
|
| 318 |
-
return
|
| 319 |
|
| 320 |
def _calculate_reward(self, metrics: Dict[str, float], done: bool) -> Reward:
|
| 321 |
-
correctness = metrics.get("correctness", 0.0)
|
| 322 |
-
safety = metrics.get("safety", 0.0)
|
| 323 |
-
resolution = metrics.get("resolution", 0.0)
|
| 324 |
-
efficiency = metrics.get("efficiency", 0.0)
|
| 325 |
penalties = metrics.get("penalties", 0.0)
|
| 326 |
|
| 327 |
weighted = (
|
|
@@ -335,7 +338,7 @@ class HelpdeskEnv:
|
|
| 335 |
if len(recent_actions) >= 2 and len(set(recent_actions)) < len(recent_actions):
|
| 336 |
penalties -= 0.05
|
| 337 |
|
| 338 |
-
final_value =
|
| 339 |
return Reward(
|
| 340 |
value=final_value,
|
| 341 |
correctness=correctness,
|
|
@@ -347,7 +350,9 @@ class HelpdeskEnv:
|
|
| 347 |
info={
|
| 348 |
"turn_number": self.turn_number,
|
| 349 |
"task_id": self.task_id,
|
| 350 |
-
"escalation_accuracy":
|
|
|
|
|
|
|
| 351 |
},
|
| 352 |
)
|
| 353 |
|
|
|
|
| 10 |
grade_operation_choice,
|
| 11 |
)
|
| 12 |
from .graders.resolution_grader import grade_case_closure, grade_resolution
|
| 13 |
+
from .graders.score_utils import ensure_open_unit_interval
|
| 14 |
from .models import Action, Observation, Reward, TicketState
|
| 15 |
from .user_simulator import UserSimulator
|
| 16 |
|
|
|
|
| 237 |
|
| 238 |
def _grade_detail_request(self, action: Action) -> float:
|
| 239 |
if self.ticket_state is None:
|
| 240 |
+
return ensure_open_unit_interval(0.0)
|
| 241 |
if not action.fields_requested and not action.message:
|
| 242 |
+
return ensure_open_unit_interval(0.0)
|
| 243 |
if not self.ticket_state.required_slots:
|
| 244 |
+
return ensure_open_unit_interval(0.5)
|
| 245 |
info_score = grade_information_collection(
|
| 246 |
action.fields_requested,
|
| 247 |
self.ticket_state.required_slots,
|
| 248 |
)
|
| 249 |
+
if self.task_id != "hard" and info_score <= 0.001:
|
| 250 |
+
return ensure_open_unit_interval(0.5)
|
| 251 |
+
return ensure_open_unit_interval(info_score)
|
| 252 |
|
| 253 |
def _grade_take_action(self, action: Action) -> Tuple[float, bool]:
|
| 254 |
+
if self.current_ticket is None:
|
| 255 |
+
return ensure_open_unit_interval(0.0), False
|
| 256 |
+
|
| 257 |
operation = (action.operation or "").strip().lower()
|
| 258 |
|
| 259 |
if operation == "classify_issue":
|
| 260 |
gold_category = self.current_ticket.get("gold_category", "")
|
| 261 |
score = grade_classification(action.category or "", gold_category)
|
| 262 |
+
resolved = (action.category or "").strip().lower() == str(gold_category).strip().lower()
|
| 263 |
+
return score, resolved
|
| 264 |
|
| 265 |
if operation == "lookup_faq":
|
| 266 |
gold_faq_id = self.current_ticket.get("gold_faq_id", "")
|
| 267 |
score = grade_faq_retrieval(action.faq_id or "", gold_faq_id)
|
| 268 |
+
if self.ticket_state is not None and (action.faq_id or "").strip() == str(gold_faq_id).strip():
|
| 269 |
self.ticket_state.correct_faq_retrieved = True
|
| 270 |
return score, False
|
| 271 |
|
|
|
|
| 275 |
and self.ticket_state.correct_faq_retrieved
|
| 276 |
and (self.task_id != "hard" or self.ticket_state.clarification_received)
|
| 277 |
)
|
| 278 |
+
return ensure_open_unit_interval(1.0 if resolved else 0.0), resolved
|
| 279 |
|
| 280 |
if operation == "check_status":
|
| 281 |
+
return ensure_open_unit_interval(0.5), False
|
| 282 |
|
| 283 |
banking_operations = {
|
| 284 |
"check_payment",
|
|
|
|
| 290 |
op_score = grade_operation_choice(operation, banking_operations)
|
| 291 |
return op_score, False
|
| 292 |
|
|
|
|
|
|
|
| 293 |
def _grade_response(self, action: Action) -> float:
|
| 294 |
if not action.message:
|
| 295 |
+
return ensure_open_unit_interval(0.0)
|
| 296 |
if self.task_id == "hard" and self.ticket_state and self.ticket_state.correct_faq_retrieved:
|
| 297 |
+
return ensure_open_unit_interval(1.0)
|
| 298 |
+
return ensure_open_unit_interval(0.5)
|
| 299 |
|
| 300 |
def _grade_safety(self, action: Action, metrics: Dict[str, float]) -> float:
|
| 301 |
text = (action.message or "").lower()
|
| 302 |
sensitive_markers = ["otp", "pin", "cvv", "password"]
|
| 303 |
if any(marker in text for marker in sensitive_markers):
|
| 304 |
metrics["penalties"] -= 0.50
|
| 305 |
+
return ensure_open_unit_interval(0.0)
|
| 306 |
|
| 307 |
+
if action.action_type == "close_case" and metrics["resolution"] <= 0.001:
|
| 308 |
+
return ensure_open_unit_interval(0.25)
|
| 309 |
|
| 310 |
if action.action_type == "escalate_case":
|
| 311 |
expected = bool(self.current_ticket.get("should_escalate", False))
|
| 312 |
+
return ensure_open_unit_interval(1.0 if expected else 0.6)
|
| 313 |
|
| 314 |
+
return ensure_open_unit_interval(1.0)
|
| 315 |
|
| 316 |
def _grade_efficiency(self, done: bool) -> float:
|
| 317 |
max_turns = 1 if self.task_id == "easy" else 2 if self.task_id == "medium" else 6
|
| 318 |
if not done:
|
| 319 |
remaining_ratio = max(0.0, 1.0 - (self.turn_number / max_turns))
|
| 320 |
+
return ensure_open_unit_interval(round(0.5 * remaining_ratio, 3))
|
| 321 |
+
return ensure_open_unit_interval(1.0 - (0.1 * max(0, self.turn_number - 1)))
|
| 322 |
|
| 323 |
def _calculate_reward(self, metrics: Dict[str, float], done: bool) -> Reward:
|
| 324 |
+
correctness = ensure_open_unit_interval(metrics.get("correctness", 0.0))
|
| 325 |
+
safety = ensure_open_unit_interval(metrics.get("safety", 0.0))
|
| 326 |
+
resolution = ensure_open_unit_interval(metrics.get("resolution", 0.0))
|
| 327 |
+
efficiency = ensure_open_unit_interval(metrics.get("efficiency", 0.0))
|
| 328 |
penalties = metrics.get("penalties", 0.0)
|
| 329 |
|
| 330 |
weighted = (
|
|
|
|
| 338 |
if len(recent_actions) >= 2 and len(set(recent_actions)) < len(recent_actions):
|
| 339 |
penalties -= 0.05
|
| 340 |
|
| 341 |
+
final_value = ensure_open_unit_interval(weighted + penalties)
|
| 342 |
return Reward(
|
| 343 |
value=final_value,
|
| 344 |
correctness=correctness,
|
|
|
|
| 350 |
info={
|
| 351 |
"turn_number": self.turn_number,
|
| 352 |
"task_id": self.task_id,
|
| 353 |
+
"escalation_accuracy": ensure_open_unit_interval(
|
| 354 |
+
metrics.get("escalation_accuracy", correctness)
|
| 355 |
+
),
|
| 356 |
},
|
| 357 |
)
|
| 358 |
|
graders/category_grader.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
from typing import Iterable, List
|
| 2 |
|
|
|
|
|
|
|
| 3 |
|
| 4 |
def grade_track_classification(predicted_track: str, gold_track: str) -> float:
|
| 5 |
if predicted_track.strip().lower() == gold_track.strip().lower():
|
| 6 |
-
return 1.0
|
| 7 |
-
return 0.0
|
| 8 |
|
| 9 |
|
| 10 |
def grade_information_collection(
|
|
@@ -14,23 +16,23 @@ def grade_information_collection(
|
|
| 14 |
requested = {field.strip().lower() for field in requested_fields if field.strip()}
|
| 15 |
required = {field.strip().lower() for field in required_fields if field.strip()}
|
| 16 |
if not requested or not required:
|
| 17 |
-
return 0.0
|
| 18 |
|
| 19 |
overlap = requested & required
|
| 20 |
-
return len(overlap) / len(required)
|
| 21 |
|
| 22 |
|
| 23 |
def grade_batch_classification(predictions: List[str], gold_labels: List[str]) -> float:
|
| 24 |
if len(predictions) != len(gold_labels):
|
| 25 |
raise ValueError("predictions and gold_labels must have the same length")
|
| 26 |
if not predictions:
|
| 27 |
-
return 0.0
|
| 28 |
|
| 29 |
total = sum(
|
| 30 |
grade_track_classification(predicted, gold)
|
| 31 |
for predicted, gold in zip(predictions, gold_labels)
|
| 32 |
)
|
| 33 |
-
return total / len(predictions)
|
| 34 |
|
| 35 |
|
| 36 |
# Backward-compatible alias while the environment transitions from category to track naming.
|
|
|
|
| 1 |
from typing import Iterable, List
|
| 2 |
|
| 3 |
+
from .score_utils import ensure_open_unit_interval
|
| 4 |
+
|
| 5 |
|
| 6 |
def grade_track_classification(predicted_track: str, gold_track: str) -> float:
|
| 7 |
if predicted_track.strip().lower() == gold_track.strip().lower():
|
| 8 |
+
return ensure_open_unit_interval(1.0)
|
| 9 |
+
return ensure_open_unit_interval(0.0)
|
| 10 |
|
| 11 |
|
| 12 |
def grade_information_collection(
|
|
|
|
| 16 |
requested = {field.strip().lower() for field in requested_fields if field.strip()}
|
| 17 |
required = {field.strip().lower() for field in required_fields if field.strip()}
|
| 18 |
if not requested or not required:
|
| 19 |
+
return ensure_open_unit_interval(0.0)
|
| 20 |
|
| 21 |
overlap = requested & required
|
| 22 |
+
return ensure_open_unit_interval(len(overlap) / len(required))
|
| 23 |
|
| 24 |
|
| 25 |
def grade_batch_classification(predictions: List[str], gold_labels: List[str]) -> float:
|
| 26 |
if len(predictions) != len(gold_labels):
|
| 27 |
raise ValueError("predictions and gold_labels must have the same length")
|
| 28 |
if not predictions:
|
| 29 |
+
return ensure_open_unit_interval(0.0)
|
| 30 |
|
| 31 |
total = sum(
|
| 32 |
grade_track_classification(predicted, gold)
|
| 33 |
for predicted, gold in zip(predictions, gold_labels)
|
| 34 |
)
|
| 35 |
+
return ensure_open_unit_interval(total / len(predictions))
|
| 36 |
|
| 37 |
|
| 38 |
# Backward-compatible alias while the environment transitions from category to track naming.
|
graders/faq_grader.py
CHANGED
|
@@ -1,26 +1,28 @@
|
|
| 1 |
from typing import Iterable
|
| 2 |
|
|
|
|
|
|
|
| 3 |
|
| 4 |
def grade_operation_choice(selected_operation: str, valid_operations: Iterable[str]) -> float:
|
| 5 |
operation = selected_operation.strip().lower()
|
| 6 |
valid = {candidate.strip().lower() for candidate in valid_operations if candidate.strip()}
|
| 7 |
if not operation or not valid:
|
| 8 |
-
return 0.0
|
| 9 |
-
return 1.0 if operation in valid else 0.0
|
| 10 |
|
| 11 |
|
| 12 |
def grade_retrieval_or_action_match(selected_reference: str, gold_reference: str) -> float:
|
| 13 |
if selected_reference.strip() and selected_reference.strip() == gold_reference.strip():
|
| 14 |
-
return 1.0
|
| 15 |
-
return 0.0
|
| 16 |
|
| 17 |
|
| 18 |
def grade_escalation(agent_escalated: bool, should_escalate: bool, correct_target: bool = True) -> float:
|
| 19 |
if agent_escalated != should_escalate:
|
| 20 |
-
return 0.0
|
| 21 |
if agent_escalated and not correct_target:
|
| 22 |
-
return 0.5
|
| 23 |
-
return 1.0
|
| 24 |
|
| 25 |
|
| 26 |
# Backward-compatible alias from the old FAQ-focused environment.
|
|
|
|
| 1 |
from typing import Iterable
|
| 2 |
|
| 3 |
+
from .score_utils import ensure_open_unit_interval
|
| 4 |
+
|
| 5 |
|
| 6 |
def grade_operation_choice(selected_operation: str, valid_operations: Iterable[str]) -> float:
|
| 7 |
operation = selected_operation.strip().lower()
|
| 8 |
valid = {candidate.strip().lower() for candidate in valid_operations if candidate.strip()}
|
| 9 |
if not operation or not valid:
|
| 10 |
+
return ensure_open_unit_interval(0.0)
|
| 11 |
+
return ensure_open_unit_interval(1.0 if operation in valid else 0.0)
|
| 12 |
|
| 13 |
|
| 14 |
def grade_retrieval_or_action_match(selected_reference: str, gold_reference: str) -> float:
|
| 15 |
if selected_reference.strip() and selected_reference.strip() == gold_reference.strip():
|
| 16 |
+
return ensure_open_unit_interval(1.0)
|
| 17 |
+
return ensure_open_unit_interval(0.0)
|
| 18 |
|
| 19 |
|
| 20 |
def grade_escalation(agent_escalated: bool, should_escalate: bool, correct_target: bool = True) -> float:
|
| 21 |
if agent_escalated != should_escalate:
|
| 22 |
+
return ensure_open_unit_interval(0.0)
|
| 23 |
if agent_escalated and not correct_target:
|
| 24 |
+
return ensure_open_unit_interval(0.5)
|
| 25 |
+
return ensure_open_unit_interval(1.0)
|
| 26 |
|
| 27 |
|
| 28 |
# Backward-compatible alias from the old FAQ-focused environment.
|
graders/resolution_grader.py
CHANGED
|
@@ -1,29 +1,30 @@
|
|
| 1 |
from ..models import TicketState
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def grade_resolution(ticket_state: TicketState, max_turns: int = 6) -> float:
|
| 5 |
if ticket_state.escalated:
|
| 6 |
-
return 1.0
|
| 7 |
|
| 8 |
if not ticket_state.issue_resolved:
|
| 9 |
-
return 0.0
|
| 10 |
|
| 11 |
if ticket_state.turns_used > max_turns:
|
| 12 |
-
return 0.0
|
| 13 |
|
| 14 |
slot_bonus = 0.1 if ticket_state.required_slots and ticket_state.collected_slots else 0.0
|
| 15 |
penalty_turns = max(0, ticket_state.turns_used - 3)
|
| 16 |
score = 0.9 + slot_bonus - (0.05 * penalty_turns)
|
| 17 |
-
return
|
| 18 |
|
| 19 |
|
| 20 |
def grade_case_closure(ticket_state: TicketState) -> float:
|
| 21 |
if ticket_state.issue_resolved or ticket_state.escalated:
|
| 22 |
-
return 1.0
|
| 23 |
-
return 0.0
|
| 24 |
|
| 25 |
|
| 26 |
def grade_clarification(asked_clarification: bool, ticket_needed_clarification: bool) -> float:
|
| 27 |
if asked_clarification == ticket_needed_clarification:
|
| 28 |
-
return 0.25
|
| 29 |
-
return 0.0
|
|
|
|
| 1 |
from ..models import TicketState
|
| 2 |
+
from .score_utils import ensure_open_unit_interval
|
| 3 |
|
| 4 |
|
| 5 |
def grade_resolution(ticket_state: TicketState, max_turns: int = 6) -> float:
|
| 6 |
if ticket_state.escalated:
|
| 7 |
+
return ensure_open_unit_interval(1.0)
|
| 8 |
|
| 9 |
if not ticket_state.issue_resolved:
|
| 10 |
+
return ensure_open_unit_interval(0.0)
|
| 11 |
|
| 12 |
if ticket_state.turns_used > max_turns:
|
| 13 |
+
return ensure_open_unit_interval(0.0)
|
| 14 |
|
| 15 |
slot_bonus = 0.1 if ticket_state.required_slots and ticket_state.collected_slots else 0.0
|
| 16 |
penalty_turns = max(0, ticket_state.turns_used - 3)
|
| 17 |
score = 0.9 + slot_bonus - (0.05 * penalty_turns)
|
| 18 |
+
return ensure_open_unit_interval(score)
|
| 19 |
|
| 20 |
|
| 21 |
def grade_case_closure(ticket_state: TicketState) -> float:
|
| 22 |
if ticket_state.issue_resolved or ticket_state.escalated:
|
| 23 |
+
return ensure_open_unit_interval(1.0)
|
| 24 |
+
return ensure_open_unit_interval(0.0)
|
| 25 |
|
| 26 |
|
| 27 |
def grade_clarification(asked_clarification: bool, ticket_needed_clarification: bool) -> float:
|
| 28 |
if asked_clarification == ticket_needed_clarification:
|
| 29 |
+
return ensure_open_unit_interval(0.25)
|
| 30 |
+
return ensure_open_unit_interval(0.0)
|
graders/score_utils.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
MIN_SCORE = 0.001
|
| 6 |
+
MAX_SCORE = 0.999
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def ensure_open_unit_interval(value: Any) -> float:
|
| 10 |
+
"""Return a native Python float strictly inside the open unit interval."""
|
| 11 |
+
try:
|
| 12 |
+
score = float(value)
|
| 13 |
+
except (TypeError, ValueError):
|
| 14 |
+
return MIN_SCORE
|
| 15 |
+
|
| 16 |
+
if not math.isfinite(score):
|
| 17 |
+
return MIN_SCORE
|
| 18 |
+
|
| 19 |
+
score = max(0.0, min(1.0, score))
|
| 20 |
+
if score <= 0.0:
|
| 21 |
+
return MIN_SCORE
|
| 22 |
+
if score >= 1.0:
|
| 23 |
+
return MAX_SCORE
|
| 24 |
+
return float(score)
|
openenv.yaml
CHANGED
|
@@ -4,3 +4,16 @@ type: space
|
|
| 4 |
runtime: fastapi
|
| 5 |
app: server.app:app
|
| 6 |
port: 8000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
runtime: fastapi
|
| 5 |
app: server.app:app
|
| 6 |
port: 8000
|
| 7 |
+
tasks:
|
| 8 |
+
- id: easy
|
| 9 |
+
description: "Classify a customer issue into the correct UPI banking support track."
|
| 10 |
+
difficulty: easy
|
| 11 |
+
max_turns: 1
|
| 12 |
+
- id: medium
|
| 13 |
+
description: "Choose the correct FAQ or escalate when manual review is required."
|
| 14 |
+
difficulty: medium
|
| 15 |
+
max_turns: 3
|
| 16 |
+
- id: hard
|
| 17 |
+
description: "Handle a multi-turn banking support conversation with clarification, safe guidance, and closure."
|
| 18 |
+
difficulty: hard
|
| 19 |
+
max_turns: 8
|
server/app.py
CHANGED
|
@@ -8,6 +8,7 @@ from pydantic import BaseModel
|
|
| 8 |
import uvicorn
|
| 9 |
|
| 10 |
from ..environment import HelpdeskEnv
|
|
|
|
| 11 |
from ..models import Action, Reward
|
| 12 |
|
| 13 |
app = FastAPI(title="Helpdesk OpenEnv")
|
|
@@ -874,11 +875,11 @@ class ResetBody(BaseModel):
|
|
| 874 |
|
| 875 |
def _zero_reward() -> Dict[str, Any]:
|
| 876 |
return Reward(
|
| 877 |
-
value=0.0,
|
| 878 |
-
correctness=0.0,
|
| 879 |
-
safety=1.0,
|
| 880 |
-
resolution=0.0,
|
| 881 |
-
efficiency=0.0,
|
| 882 |
penalties=0.0,
|
| 883 |
done=False,
|
| 884 |
info={},
|
|
|
|
| 8 |
import uvicorn
|
| 9 |
|
| 10 |
from ..environment import HelpdeskEnv
|
| 11 |
+
from ..graders.score_utils import ensure_open_unit_interval
|
| 12 |
from ..models import Action, Reward
|
| 13 |
|
| 14 |
app = FastAPI(title="Helpdesk OpenEnv")
|
|
|
|
| 875 |
|
| 876 |
def _zero_reward() -> Dict[str, Any]:
|
| 877 |
return Reward(
|
| 878 |
+
value=ensure_open_unit_interval(0.0),
|
| 879 |
+
correctness=ensure_open_unit_interval(0.0),
|
| 880 |
+
safety=ensure_open_unit_interval(1.0),
|
| 881 |
+
resolution=ensure_open_unit_interval(0.0),
|
| 882 |
+
efficiency=ensure_open_unit_interval(0.0),
|
| 883 |
penalties=0.0,
|
| 884 |
done=False,
|
| 885 |
info={},
|