AIHack-ITHelpDesk / tests /test_tasks_unit.py
Roopalgn's picture
Add cluster-aware queue dynamics to helpdesk env
454cef3
from __future__ import annotations
import io
import unittest
from unittest import mock
import openenv_test_stubs # noqa: F401
from models import HelpdeskTicketRecord
from server import tasks as task_module
from server.tasks import CURATED_EXPANSION_RECORDS, TASKS, get_task_definition, load_dataset
from vocabulary import (
ASSIGNMENT_GROUPS,
ISSUE_TYPES,
PRIORITIES,
RESOLUTION_ACTIONS,
TASK_IDS,
)
class TasksAndDatasetUnitTests(unittest.TestCase):
def test_task_ids_match_frozen_contract(self) -> None:
self.assertEqual(tuple(TASKS.keys()), TASK_IDS)
def test_task_allowed_fields_match_expected_ladder(self) -> None:
expected_fields = [
"issue_type",
"priority",
"assignment_group",
"resolution_action",
]
self.assertEqual(get_task_definition(1)["allowed_fields"], expected_fields)
self.assertEqual(get_task_definition(2)["allowed_fields"], expected_fields)
self.assertEqual(
get_task_definition(3)["allowed_fields"],
expected_fields,
)
def test_task_difficulty_ladder_is_frozen(self) -> None:
self.assertEqual(
[TASKS[task_id]["difficulty"] for task_id in TASK_IDS],
["easy", "medium", "hard"],
)
def test_invalid_task_id_raises(self) -> None:
with self.assertRaisesRegex(ValueError, "Unsupported task_id"):
get_task_definition(0)
def test_load_dataset_returns_valid_records(self) -> None:
dataset = load_dataset()
self.assertGreaterEqual(len(dataset), 45)
self.assertTrue(
all(
isinstance(record, HelpdeskTicketRecord)
or (
record.__class__.__name__ == "HelpdeskTicketRecord"
and hasattr(record, "model_dump")
and hasattr(record, "ticket_id")
)
for record in dataset
)
)
def test_dataset_ticket_ids_are_unique(self) -> None:
dataset = load_dataset()
ticket_ids = [record.ticket_id for record in dataset]
self.assertEqual(len(ticket_ids), len(set(ticket_ids)))
def test_related_ticket_ids_reference_existing_records(self) -> None:
dataset = load_dataset()
ticket_ids = {record.ticket_id for record in dataset}
missing_links = [
record.related_ticket_id
for record in dataset
if record.related_ticket_id is not None
and record.related_ticket_id not in ticket_ids
]
self.assertEqual(missing_links, [])
def test_dataset_covers_all_defined_issue_types(self) -> None:
dataset = load_dataset()
issue_types = {record.issue_type for record in dataset}
self.assertEqual(issue_types, set(ISSUE_TYPES))
def test_dataset_covers_all_defined_priorities(self) -> None:
dataset = load_dataset()
priorities = {record.priority for record in dataset}
self.assertEqual(priorities, set(PRIORITIES))
def test_dataset_covers_all_assignment_groups(self) -> None:
dataset = load_dataset()
assignment_groups = {record.assignment_group for record in dataset}
self.assertEqual(assignment_groups, set(ASSIGNMENT_GROUPS))
def test_dataset_covers_all_resolution_actions(self) -> None:
dataset = load_dataset()
resolution_actions = {record.resolution_action for record in dataset}
self.assertEqual(resolution_actions, set(RESOLUTION_ACTIONS))
def test_dataset_preserves_ambiguous_and_follow_up_cases(self) -> None:
dataset = load_dataset()
ambiguity_count = sum(1 for record in dataset if record.ambiguity_note)
follow_up_count = sum(1 for record in dataset if record.related_ticket_id)
alternate_route_count = sum(
1 for record in dataset if record.alternate_route_score_multiplier > 0.0
)
clustered_case_count = sum(1 for record in dataset if record.service_cluster_id)
self.assertGreaterEqual(ambiguity_count, 4)
self.assertGreaterEqual(follow_up_count, 3)
self.assertGreaterEqual(alternate_route_count, 10)
self.assertGreaterEqual(clustered_case_count, 10)
def test_load_dataset_accepts_utf8_bom(self) -> None:
sample = (
b"\xef\xbb\xbf"
b"["
b"{"
b'"ticket_id":"ticket-bom",'
b'"title":"BOM test",'
b'"requester":"user@example.com",'
b'"description":"Dataset loader should tolerate UTF-8 BOM.",'
b'"issue_type":"general_inquiry",'
b'"priority":"low",'
b'"assignment_group":"service_desk",'
b'"resolution_action":"acknowledge",'
b'"ambiguity_note":null,'
b'"related_ticket_id":null'
b"}"
b"]"
)
def fake_open(self, mode="r", encoding=None): # type: ignore[no-untyped-def]
return io.TextIOWrapper(io.BytesIO(sample), encoding=encoding)
with mock.patch.object(task_module.Path, "open", fake_open):
dataset = load_dataset()
self.assertIn("ticket-bom", [record.ticket_id for record in dataset])
self.assertEqual(len(dataset), 1 + len(CURATED_EXPANSION_RECORDS))
if __name__ == "__main__":
unittest.main()