File size: 3,933 Bytes
bd67155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations

from typing import Dict, List

from ..models import StateModel, TaskGrade, TaskSpec, TicketSpec


def _ticket_component(
    ticket: TicketSpec,
    state: StateModel,
    weights: Dict[str, float],
) -> Dict[str, float]:
    discovered = set(state.discovered_keys.get(ticket.ticket_id, []))
    required = set(ticket.required_context)
    context_score = 1.0 if not required else len(discovered & required) / len(required)
    escalation_value = state.escalations.get(ticket.ticket_id)
    gold_escalation = ticket.gold_escalation_team
    escalation_score = 1.0 if escalation_value == gold_escalation else 0.0
    if gold_escalation is None and escalation_value is None:
        escalation_score = 1.0

    raw = {
        "context": context_score,
        "priority": 1.0 if state.priorities.get(ticket.ticket_id) == ticket.gold_priority else 0.0,
        "route": 1.0 if state.routes.get(ticket.ticket_id) == ticket.gold_route else 0.0,
        "resolution": 1.0 if state.resolutions.get(ticket.ticket_id) == ticket.gold_resolution else 0.0,
        "escalation": escalation_score,
    }
    return {name: raw[name] * weights.get(name, 0.0) for name in raw}


def grade_single_ticket(
    task: TaskSpec,
    state: StateModel,
    weights: Dict[str, float],
) -> TaskGrade:
    ticket = task.tickets[0]
    weighted = _ticket_component(ticket, state, weights)
    score = round(sum(weighted.values()), 4)
    notes = _notes_for_ticket(ticket, state)
    return TaskGrade(
        task_id=task.task_id,
        score=score,
        passed=score >= 0.8,
        component_scores=weighted,
        notes=notes,
    )


def grade_queue_task(
    task: TaskSpec,
    state: StateModel,
    weights: Dict[str, float],
) -> TaskGrade:
    ticket_scores: List[float] = []
    component_sums = {
        "context": 0.0,
        "priority": 0.0,
        "route": 0.0,
        "resolution": 0.0,
        "escalation": 0.0,
    }
    notes: List[str] = []
    for ticket in task.tickets:
        weighted = _ticket_component(ticket, state, weights)
        for name, value in weighted.items():
            component_sums[name] += value
        ticket_scores.append(sum(weighted.values()))
        notes.extend(_notes_for_ticket(ticket, state))

    divisor = max(len(task.tickets), 1)
    averaged = {name: round(value / divisor, 4) for name, value in component_sums.items()}

    ranking_score = 0.0
    if task.gold_queue_order:
        matches = sum(
            1 for observed, expected in zip(state.queue_order, task.gold_queue_order) if observed == expected
        )
        ranking_score = round((matches / len(task.gold_queue_order)) * weights.get("ranking", 0.0), 4)

    averaged["ranking"] = ranking_score
    score = round(sum(averaged.values()), 4)
    return TaskGrade(
        task_id=task.task_id,
        score=score,
        passed=score >= 0.8,
        component_scores=averaged,
        notes=notes,
    )


def _notes_for_ticket(ticket: TicketSpec, state: StateModel) -> List[str]:
    notes: List[str] = []
    if state.priorities.get(ticket.ticket_id) != ticket.gold_priority:
        notes.append(f"{ticket.ticket_id}: incorrect priority")
    if state.routes.get(ticket.ticket_id) != ticket.gold_route:
        notes.append(f"{ticket.ticket_id}: incorrect route")
    if state.resolutions.get(ticket.ticket_id) != ticket.gold_resolution:
        notes.append(f"{ticket.ticket_id}: incorrect resolution")
    if state.escalations.get(ticket.ticket_id) != ticket.gold_escalation_team:
        if not (ticket.gold_escalation_team is None and state.escalations.get(ticket.ticket_id) is None):
            notes.append(f"{ticket.ticket_id}: incorrect escalation")
    missing = set(ticket.required_context) - set(state.discovered_keys.get(ticket.ticket_id, []))
    if missing:
        notes.append(f"{ticket.ticket_id}: missing required context {sorted(missing)}")
    return notes