Spaces:
Running
Running
File size: 5,406 Bytes
ae36543 8241eb5 6920aae ae36543 043d9e1 ae36543 043d9e1 ae36543 6920aae ae36543 6c5051f 8241eb5 ae36543 6920aae 8241eb5 454cef3 6920aae 8241eb5 454cef3 ae36543 8241eb5 ae36543 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | from __future__ import annotations
import io
import unittest
from unittest import mock
import openenv_test_stubs # noqa: F401
from models import HelpdeskTicketRecord
from server import tasks as task_module
from server.tasks import CURATED_EXPANSION_RECORDS, TASKS, get_task_definition, load_dataset
from vocabulary import (
ASSIGNMENT_GROUPS,
ISSUE_TYPES,
PRIORITIES,
RESOLUTION_ACTIONS,
TASK_IDS,
)
class TasksAndDatasetUnitTests(unittest.TestCase):
def test_task_ids_match_frozen_contract(self) -> None:
self.assertEqual(tuple(TASKS.keys()), TASK_IDS)
def test_task_allowed_fields_match_expected_ladder(self) -> None:
expected_fields = [
"issue_type",
"priority",
"assignment_group",
"resolution_action",
]
self.assertEqual(get_task_definition(1)["allowed_fields"], expected_fields)
self.assertEqual(get_task_definition(2)["allowed_fields"], expected_fields)
self.assertEqual(
get_task_definition(3)["allowed_fields"],
expected_fields,
)
def test_task_difficulty_ladder_is_frozen(self) -> None:
self.assertEqual(
[TASKS[task_id]["difficulty"] for task_id in TASK_IDS],
["easy", "medium", "hard"],
)
def test_invalid_task_id_raises(self) -> None:
with self.assertRaisesRegex(ValueError, "Unsupported task_id"):
get_task_definition(0)
def test_load_dataset_returns_valid_records(self) -> None:
dataset = load_dataset()
self.assertGreaterEqual(len(dataset), 45)
self.assertTrue(
all(
isinstance(record, HelpdeskTicketRecord)
or (
record.__class__.__name__ == "HelpdeskTicketRecord"
and hasattr(record, "model_dump")
and hasattr(record, "ticket_id")
)
for record in dataset
)
)
def test_dataset_ticket_ids_are_unique(self) -> None:
dataset = load_dataset()
ticket_ids = [record.ticket_id for record in dataset]
self.assertEqual(len(ticket_ids), len(set(ticket_ids)))
def test_related_ticket_ids_reference_existing_records(self) -> None:
dataset = load_dataset()
ticket_ids = {record.ticket_id for record in dataset}
missing_links = [
record.related_ticket_id
for record in dataset
if record.related_ticket_id is not None
and record.related_ticket_id not in ticket_ids
]
self.assertEqual(missing_links, [])
def test_dataset_covers_all_defined_issue_types(self) -> None:
dataset = load_dataset()
issue_types = {record.issue_type for record in dataset}
self.assertEqual(issue_types, set(ISSUE_TYPES))
def test_dataset_covers_all_defined_priorities(self) -> None:
dataset = load_dataset()
priorities = {record.priority for record in dataset}
self.assertEqual(priorities, set(PRIORITIES))
def test_dataset_covers_all_assignment_groups(self) -> None:
dataset = load_dataset()
assignment_groups = {record.assignment_group for record in dataset}
self.assertEqual(assignment_groups, set(ASSIGNMENT_GROUPS))
def test_dataset_covers_all_resolution_actions(self) -> None:
dataset = load_dataset()
resolution_actions = {record.resolution_action for record in dataset}
self.assertEqual(resolution_actions, set(RESOLUTION_ACTIONS))
def test_dataset_preserves_ambiguous_and_follow_up_cases(self) -> None:
dataset = load_dataset()
ambiguity_count = sum(1 for record in dataset if record.ambiguity_note)
follow_up_count = sum(1 for record in dataset if record.related_ticket_id)
alternate_route_count = sum(
1 for record in dataset if record.alternate_route_score_multiplier > 0.0
)
clustered_case_count = sum(1 for record in dataset if record.service_cluster_id)
self.assertGreaterEqual(ambiguity_count, 4)
self.assertGreaterEqual(follow_up_count, 3)
self.assertGreaterEqual(alternate_route_count, 10)
self.assertGreaterEqual(clustered_case_count, 10)
def test_load_dataset_accepts_utf8_bom(self) -> None:
sample = (
b"\xef\xbb\xbf"
b"["
b"{"
b'"ticket_id":"ticket-bom",'
b'"title":"BOM test",'
b'"requester":"user@example.com",'
b'"description":"Dataset loader should tolerate UTF-8 BOM.",'
b'"issue_type":"general_inquiry",'
b'"priority":"low",'
b'"assignment_group":"service_desk",'
b'"resolution_action":"acknowledge",'
b'"ambiguity_note":null,'
b'"related_ticket_id":null'
b"}"
b"]"
)
def fake_open(self, mode="r", encoding=None): # type: ignore[no-untyped-def]
return io.TextIOWrapper(io.BytesIO(sample), encoding=encoding)
with mock.patch.object(task_module.Path, "open", fake_open):
dataset = load_dataset()
self.assertIn("ticket-bom", [record.ticket_id for record in dataset])
self.assertEqual(len(dataset), 1 + len(CURATED_EXPANSION_RECORDS))
if __name__ == "__main__":
unittest.main()
|