afroimam's picture
Upload folder using huggingface_hub
1395b2e verified
from __future__ import annotations
from dataclasses import dataclass
from .tasks import TaskSpec
@dataclass
class GradeBreakdown:
read_score: float
classify_score: float
reply_score: float
resolve_score: float
total: float
def _keyword_coverage(message: str, required: tuple[str, ...]) -> float:
if not required:
return 1.0
lowered = message.lower()
found = sum(1 for k in required if k.lower() in lowered)
return found / len(required)
def _forbidden_penalty(message: str, forbidden: tuple[str, ...]) -> float:
lowered = message.lower()
count = sum(1 for k in forbidden if k.lower() in lowered)
return min(1.0, 0.5 * count)
def grade_task(task: TaskSpec, env_state: dict) -> GradeBreakdown:
read_target = 1.0 if task.target_ticket_id in env_state["read_ticket_ids"] else 0.0
context_hits = sum(1 for tid in task.required_context_ticket_ids if tid in env_state["read_ticket_ids"])
context_total = len(task.required_context_ticket_ids)
context_score = context_hits / context_total if context_total else 1.0
read_score = 0.6 * read_target + 0.4 * context_score
classification = env_state.get("classification") or {}
fields_correct = 0
fields_total = 3
fields_correct += int(classification.get("priority") == task.expected_priority)
fields_correct += int(classification.get("category") == task.expected_category)
fields_correct += int(classification.get("needs_escalation") == task.expected_escalation)
classify_score = fields_correct / fields_total
draft = env_state.get("draft_reply") or ""
keyword_score = _keyword_coverage(draft, task.required_reply_keywords)
forbidden_penalty = _forbidden_penalty(draft, task.forbidden_reply_keywords)
reply_score = max(0.0, keyword_score - forbidden_penalty)
resolved = bool(env_state.get("resolved"))
resolved_target = env_state.get("resolved_ticket_id") == task.target_ticket_id
resolve_score = 1.0 if resolved and resolved_target else 0.0
total = (0.2 * read_score) + (0.35 * classify_score) + (0.3 * reply_score) + (0.15 * resolve_score)
total = max(0.0, min(1.0, total))
return GradeBreakdown(
read_score=round(read_score, 4),
classify_score=round(classify_score, 4),
reply_score=round(reply_score, 4),
resolve_score=round(resolve_score, 4),
total=round(total, 4),
)