Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, List, Literal, Optional | |
| from pydantic import BaseModel, Field, model_validator | |
| class Observation(BaseModel): | |
| case_id: str | |
| track: str | |
| customer_message: str | |
| conversation_history: List[Dict[str, str]] | |
| known_facts: Dict[str, Any] | |
| required_slots: List[str] | |
| available_actions: List[str] | |
| turn_number: int | |
| def ticket_id(self) -> str: | |
| return self.case_id | |
| def task_id(self) -> str: | |
| return str(self.known_facts.get("difficulty", "")) | |
| def ticket_text(self) -> str: | |
| return self.customer_message | |
| def knowledge_base(self) -> List[Dict[str, Any]]: | |
| kb = self.known_facts.get("knowledge_base", []) | |
| return kb if isinstance(kb, list) else [] | |
| def available_categories(self) -> List[str]: | |
| categories = self.known_facts.get("available_categories", []) | |
| return categories if isinstance(categories, list) else [] | |
| class Action(BaseModel): | |
| action_type: Literal[ | |
| "ask_for_details", | |
| "take_action", | |
| "respond_to_user", | |
| "escalate_case", | |
| "close_case", | |
| ] | |
| message: Optional[str] = None | |
| fields_requested: List[str] = Field(default_factory=list) | |
| operation: Optional[str] = None | |
| target: Optional[str] = None | |
| # Legacy compatibility with the original helpdesk action schema. | |
| category: Optional[str] = None | |
| faq_id: Optional[str] = None | |
| def _validate_canonical_shape(self) -> "Action": | |
| if self.action_type == "take_action" and not self.operation: | |
| raise ValueError("take_action requires operation") | |
| return self | |
| LegacyActionType = Literal[ | |
| "classify", | |
| "lookup_faq", | |
| "ask_clarification", | |
| "reply", | |
| "escalate", | |
| "resolve_ticket", | |
| ] | |
| def normalize_action(raw: Dict[str, Any]) -> Action: | |
| action_type = str(raw.get("action_type", "")).strip() | |
| if action_type == "classify": | |
| return Action( | |
| action_type="take_action", | |
| operation="classify", | |
| category=raw.get("category"), | |
| message=raw.get("message"), | |
| faq_id=raw.get("faq_id"), | |
| ) | |
| if action_type == "lookup_faq": | |
| return Action( | |
| action_type="take_action", | |
| operation="lookup_faq", | |
| faq_id=raw.get("faq_id"), | |
| message=raw.get("message"), | |
| category=raw.get("category"), | |
| ) | |
| if action_type == "ask_clarification": | |
| return Action( | |
| action_type="ask_for_details", | |
| fields_requested=list(raw.get("fields_requested") or ["issue_details"]), | |
| message=raw.get("message"), | |
| ) | |
| if action_type == "reply": | |
| return Action( | |
| action_type="respond_to_user", | |
| message=raw.get("message"), | |
| ) | |
| if action_type == "escalate": | |
| return Action( | |
| action_type="escalate_case", | |
| target=raw.get("target") or "human_agent", | |
| message=raw.get("message"), | |
| ) | |
| if action_type == "resolve_ticket": | |
| return Action( | |
| action_type="close_case", | |
| operation=raw.get("operation") or "resolve_with_guidance", | |
| message=raw.get("message"), | |
| ) | |
| return Action(**raw) | |
| class Reward(BaseModel): | |
| value: float = Field(ge=0.0, le=1.0) | |
| correctness: float | |
| safety: float | |
| resolution: float | |
| efficiency: float | |
| penalties: float | |
| done: bool | |
| info: Dict[str, Any] | |
| def escalation_accuracy(self) -> float: | |
| return float(self.info.get("escalation_accuracy", self.correctness)) | |
| class TicketState: | |
| ticket_id: str | |
| track: str | |
| required_slots: List[str] = field(default_factory=list) | |
| collected_slots: Dict[str, Any] = field(default_factory=dict) | |
| issue_resolved: bool = False | |
| clarification_received: bool = False | |
| escalated: bool = False | |
| turns_used: int = 0 | |
| correct_faq_retrieved: bool = False | |