helpdesk_env / models.py
Freakdivi's picture
Upload folder using huggingface_hub
026df2c verified
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field, model_validator
class Observation(BaseModel):
case_id: str
track: str
customer_message: str
conversation_history: List[Dict[str, str]]
known_facts: Dict[str, Any]
required_slots: List[str]
available_actions: List[str]
turn_number: int
@property
def ticket_id(self) -> str:
return self.case_id
@property
def task_id(self) -> str:
return str(self.known_facts.get("difficulty", ""))
@property
def ticket_text(self) -> str:
return self.customer_message
@property
def knowledge_base(self) -> List[Dict[str, Any]]:
kb = self.known_facts.get("knowledge_base", [])
return kb if isinstance(kb, list) else []
@property
def available_categories(self) -> List[str]:
categories = self.known_facts.get("available_categories", [])
return categories if isinstance(categories, list) else []
class Action(BaseModel):
action_type: Literal[
"ask_for_details",
"take_action",
"respond_to_user",
"escalate_case",
"close_case",
]
message: Optional[str] = None
fields_requested: List[str] = Field(default_factory=list)
operation: Optional[str] = None
target: Optional[str] = None
# Legacy compatibility with the original helpdesk action schema.
category: Optional[str] = None
faq_id: Optional[str] = None
@model_validator(mode="after")
def _validate_canonical_shape(self) -> "Action":
if self.action_type == "take_action" and not self.operation:
raise ValueError("take_action requires operation")
return self
LegacyActionType = Literal[
"classify",
"lookup_faq",
"ask_clarification",
"reply",
"escalate",
"resolve_ticket",
]
def normalize_action(raw: Dict[str, Any]) -> Action:
action_type = str(raw.get("action_type", "")).strip()
if action_type == "classify":
return Action(
action_type="take_action",
operation="classify",
category=raw.get("category"),
message=raw.get("message"),
faq_id=raw.get("faq_id"),
)
if action_type == "lookup_faq":
return Action(
action_type="take_action",
operation="lookup_faq",
faq_id=raw.get("faq_id"),
message=raw.get("message"),
category=raw.get("category"),
)
if action_type == "ask_clarification":
return Action(
action_type="ask_for_details",
fields_requested=list(raw.get("fields_requested") or ["issue_details"]),
message=raw.get("message"),
)
if action_type == "reply":
return Action(
action_type="respond_to_user",
message=raw.get("message"),
)
if action_type == "escalate":
return Action(
action_type="escalate_case",
target=raw.get("target") or "human_agent",
message=raw.get("message"),
)
if action_type == "resolve_ticket":
return Action(
action_type="close_case",
operation=raw.get("operation") or "resolve_with_guidance",
message=raw.get("message"),
)
return Action(**raw)
class Reward(BaseModel):
value: float = Field(ge=0.0, le=1.0)
correctness: float
safety: float
resolution: float
efficiency: float
penalties: float
done: bool
info: Dict[str, Any]
@property
def escalation_accuracy(self) -> float:
return float(self.info.get("escalation_accuracy", self.correctness))
@dataclass
class TicketState:
ticket_id: str
track: str
required_slots: List[str] = field(default_factory=list)
collected_slots: Dict[str, Any] = field(default_factory=dict)
issue_resolved: bool = False
clarification_received: bool = False
escalated: bool = False
turns_used: int = 0
correct_faq_retrieved: bool = False