from __future__ import annotations import contextlib import importlib import io import os import sys import types import unittest from types import SimpleNamespace from unittest import mock import openenv_test_stubs # noqa: F401 from models import HelpdeskTicketObservation def _load_inference_module(env: dict[str, str] | None = None): env = env or {} client_stub = types.ModuleType("client") class PlaceholderEnvClient: def __init__(self, *args, **kwargs) -> None: pass def sync(self): raise NotImplementedError client_stub.HelpdeskTicketEnvClient = PlaceholderEnvClient with mock.patch.dict(os.environ, env, clear=True): with mock.patch.dict(sys.modules, {"client": client_stub}): sys.modules.pop("inference", None) return importlib.import_module("inference") class FakeResponse: def __init__(self, payload): self._payload = payload def raise_for_status(self) -> None: return None def json(self): return self._payload class FakeHttpClient: def __init__(self, *args, **kwargs) -> None: pass def get(self, path: str) -> FakeResponse: if path == "/health": return FakeResponse({"status": "ok"}) if path == "/tasks": return FakeResponse( { "tasks": [ { "id": 1, "name": "Issue Type Classification", "difficulty": "easy", "instructions": "Classify the issue type.", "allowed_fields": ["issue_type"], } ] } ) raise AssertionError(f"Unexpected path: {path}") def close(self) -> None: return None class FakeSyncClient: def __enter__(self): return self def __exit__(self, exc_type, exc, tb) -> None: return None def reset(self, seed: int, task_id: int): observation = HelpdeskTicketObservation( task_id=task_id, task_name="Issue Type Classification", instructions="Classify the issue type.", allowed_fields=["issue_type"], current_ticket={ "ticket_id": "ticket-001", "title": "Invoice issue", "requester": "user@example.com", "description": "Customer was charged twice.", }, queue_size=1, tickets_remaining=1, tickets_processed=0, history=[], done=False, reward=None, metadata={}, ) return SimpleNamespace(observation=observation, done=False, reward=None) def step(self, action): observation = HelpdeskTicketObservation( task_id=1, task_name="Issue Type Classification", instructions="Classify the issue type.", allowed_fields=["issue_type"], current_ticket=None, queue_size=1, tickets_remaining=0, tickets_processed=1, history=[], done=True, reward=1.0, metadata={}, ) return SimpleNamespace(observation=observation, done=True, reward=1.0) class FakeEnvClient: def __init__(self, *args, **kwargs) -> None: pass def sync(self) -> FakeSyncClient: return FakeSyncClient() class InferenceUnitTests(unittest.TestCase): def test_api_credentials_have_no_defaults_and_model_name_keeps_allowed_default(self) -> None: inference = _load_inference_module() self.assertEqual( inference.API_BASE_URL, "https://router.huggingface.co/v1", ) self.assertEqual(inference.MODEL_NAME, "gpt-4o-mini") self.assertIsNone(inference.API_KEY) self.assertIsNone(inference.HF_TOKEN) self.assertFalse(inference.llm_mode_enabled()) def test_api_key_enables_llm_mode_without_hf_token(self) -> None: inference = _load_inference_module({"API_KEY": "validator-proxy-key"}) self.assertEqual(inference.API_KEY, "validator-proxy-key") self.assertIsNone(inference.HF_TOKEN) self.assertEqual(inference.MODEL_NAME, "gpt-4o-mini") self.assertTrue(inference.llm_mode_enabled()) def test_seed_env_override_is_respected(self) -> None: inference = _load_inference_module({"SEED": "7"}) self.assertEqual(inference.SEED, 7) def test_invalid_seed_env_falls_back_to_default(self) -> None: inference = _load_inference_module({"SEED": "not-an-int"}) self.assertEqual(inference.SEED, 42) def test_run_uses_only_structured_start_step_end_logs(self) -> None: inference = _load_inference_module() with mock.patch.object(inference.httpx, "Client", FakeHttpClient): with mock.patch.object(inference, "HelpdeskTicketEnvClient", FakeEnvClient): stdout = io.StringIO() with contextlib.redirect_stdout(stdout): inference.run() lines = [line for line in stdout.getvalue().splitlines() if line.strip()] self.assertGreaterEqual(len(lines), 3) self.assertTrue(lines[0].startswith("[START] ")) self.assertTrue(any(line.startswith("[STEP] ") for line in lines)) self.assertTrue(lines[-1].startswith("[END] ")) self.assertTrue( all( line.startswith("[START] ") or line.startswith("[STEP] ") or line.startswith("[END] ") for line in lines ) ) def test_default_task_selection_runs_single_first_task(self) -> None: inference = _load_inference_module() self.assertEqual( inference.get_tasks_to_run({1: {}, 2: {}, 3: {}}), [1, 2, 3], ) def test_run_all_tasks_override_keeps_local_batch_mode_available(self) -> None: inference = _load_inference_module({"RUN_ALL_TASKS": "1"}) self.assertEqual( inference.get_tasks_to_run({1: {}, 2: {}, 3: {}}), [1, 2, 3], ) def test_build_llm_user_message_includes_recent_history_feedback(self) -> None: inference = _load_inference_module() ticket = { "ticket_id": "ticket-xyz", "title": "Contractor onboarding blocked by access issue", "requester": "pm@contractorco.com", "description": "Access permissions are blocking contractor setup.", "context_status": { "investigation_required": True, "hidden_context_remaining": True, "context_gap_count": 1, "revealed_context_count": 0, "context_completeness": 0.0, "investigations_used_for_ticket": 0, }, "last_tool_result": {"tool_name": "lookup_requester_history", "found": False}, "feedback_summary": "Ticket score=0.40; field_scores[issue_type=0.40]; reward=0.40", "last_reward_components": {"ticket_score": 0.4, "final_reward": 0.4}, "investigation_budget_remaining": 2, "average_score_so_far": 0.7, "progress_fraction": 0.5, "recent_history": [ { "ticket_id": "ticket-prev", "predicted": {"issue_type": "identity_access"}, "score": 0.4, "breakdown": {"issue_type": 0.4}, "penalty_reason": "extra_fields: ['assignment_group']", "feedback_summary": "Penalty applied: extra_fields: ['assignment_group']; reward=0.00", "reward_components": {"reward_kind": "step_penalty", "final_reward": 0.0}, } ], "queue_position": 2, "tickets_remaining": 4, } message = inference.build_llm_user_message( ticket, ["issue_type"], "Read the ticket and select the single best IT issue type.", ) self.assertIn("Recent evaluation feedback", message) self.assertIn("score=0.4", message) self.assertIn("penalty_reason=extra_fields", message) self.assertIn("Latest environment feedback", message) self.assertIn("Context status", message) self.assertIn("Latest reward components", message) self.assertIn("Average score so far: 0.7", message) self.assertIn("Episode progress: 0.5", message) self.assertIn("Investigation budget remaining: 2", message) self.assertIn("Investigation result", message) self.assertIn("queue_position=2", message) def test_build_action_backfills_missing_fields_from_heuristic(self) -> None: inference = _load_inference_module() inference.llm_client = object() ticket = { "ticket_id": "ticket-018", "title": "Question about enterprise tier pricing", "requester": "finance@urbanstack.io", "description": ( "We're comparing your enterprise plan against two competitors. " "Can you send over a detailed pricing breakdown?" ), } with mock.patch.object( inference, "call_llm", return_value={"issue_type": "service_request"}, ): action, action_source, fallback_reason = inference.build_action( ticket, ["issue_type", "priority", "assignment_group", "resolution_action"], "Perform full helpdesk routing.", ) self.assertEqual(action.issue_type, "service_request") self.assertEqual(action.priority, "medium") self.assertEqual(action.assignment_group, "procurement") self.assertEqual(action.resolution_action, "assign") self.assertEqual(action_source, "llm_backfilled") self.assertIn("heuristic_backfill", fallback_reason or "") def test_build_action_ignores_invalid_llm_fields_and_keeps_valid_ones(self) -> None: inference = _load_inference_module() inference.llm_client = object() ticket = { "ticket_id": "ticket-018", "title": "Question about enterprise tier pricing", "requester": "finance@urbanstack.io", "description": ( "We're comparing your enterprise plan against two competitors. " "Can you send over a detailed pricing breakdown?" ), } with mock.patch.object( inference, "call_llm", return_value={ "issue_type": "service_request", "priority": "urgent", }, ): action, action_source, fallback_reason = inference.build_action( ticket, ["issue_type", "priority"], "Read the ticket, select the best IT issue type, and estimate the priority.", ) self.assertEqual(action.issue_type, "service_request") self.assertEqual(action.priority, "medium") self.assertEqual(action_source, "llm_backfilled") self.assertIn("invalid_llm_fields=['priority']", fallback_reason or "") def test_build_action_backfills_dependent_fields_from_llm_issue_type(self) -> None: inference = _load_inference_module() inference.llm_client = object() ticket = { "ticket_id": "ticket-002", "title": "Can not sign in after 2FA reset", "requester": "ops@laneeight.io", "description": ( "I was forced to reset 2FA and now the account stays locked even " "with the backup code." ), } with mock.patch.object( inference, "call_llm", return_value={"issue_type": "identity_access"}, ): action, action_source, fallback_reason = inference.build_action( ticket, ["issue_type", "assignment_group", "resolution_action"], "Perform full helpdesk routing.", ) self.assertEqual(action.issue_type, "identity_access") self.assertEqual(action.assignment_group, "service_desk") self.assertEqual(action.resolution_action, "fulfill") self.assertEqual(action_source, "llm_backfilled") self.assertIn("heuristic_backfill", fallback_reason or "") def test_build_action_normalizes_pricing_request_issue_type(self) -> None: inference = _load_inference_module() inference.llm_client = object() ticket = { "ticket_id": "ticket-018", "title": "Question about enterprise tier pricing", "requester": "finance@urbanstack.io", "description": ( "We're comparing your enterprise plan against two competitors. " "Can you send over a detailed pricing breakdown?" ), } with mock.patch.object( inference, "call_llm", return_value={ "issue_type": "billing_license", "priority": "medium", }, ): action, action_source, fallback_reason = inference.build_action( ticket, ["issue_type", "priority", "assignment_group", "resolution_action"], "Perform full helpdesk routing.", ) self.assertEqual(action.issue_type, "service_request") self.assertEqual(action.assignment_group, "procurement") self.assertEqual(action.resolution_action, "assign") self.assertEqual(action.priority, "medium") self.assertEqual(action_source, "llm_backfilled") self.assertIn("domain_overrides", fallback_reason or "") def test_build_action_normalizes_onboarding_access_blocker(self) -> None: inference = _load_inference_module() inference.llm_client = object() ticket = { "ticket_id": "TKT-NONDEFAULT-003", "title": "Contractor onboarding blocked by access issue", "requester": "pm@contractorco.com", "description": ( "A new contractor cannot complete onboarding because their account " "access is blocked by a permissions error. The onboarding team " "cannot resolve access issues; routing to service desk." ), "ambiguity_note": "Contractor onboarding blocked by access issue, routed to service desk", } with mock.patch.object( inference, "call_llm", return_value={ "issue_type": "identity_access", "priority": "high", }, ): action, action_source, fallback_reason = inference.build_action( ticket, ["issue_type", "priority", "assignment_group", "resolution_action"], "Perform full helpdesk routing.", ) self.assertEqual(action.issue_type, "onboarding") self.assertEqual(action.priority, "medium") self.assertEqual(action.assignment_group, "service_desk") self.assertEqual(action.resolution_action, "fulfill") self.assertEqual(action_source, "llm_backfilled") self.assertIn("domain_overrides", fallback_reason or "") def test_build_action_deescalates_nonurgent_onboarding_priority(self) -> None: inference = _load_inference_module() inference.llm_client = object() ticket = { "ticket_id": "ticket-008", "title": "Kickoff onboarding session for newly activated account", "requester": "admin@brightpath.io", "description": ( "We activated our account this week and need an onboarding call plus " "admin setup guidance for six internal users." ), } with mock.patch.object( inference, "call_llm", return_value={ "issue_type": "onboarding", "priority": "high", }, ): action, action_source, fallback_reason = inference.build_action( ticket, ["issue_type", "priority"], "Read the ticket, select the best IT issue type, and estimate the priority.", ) self.assertEqual(action.issue_type, "onboarding") self.assertEqual(action.priority, "medium") self.assertEqual(action_source, "llm_backfilled") self.assertIn("domain_overrides", fallback_reason or "") def test_merge_ticket_context_carries_feedback_summary_from_observation(self) -> None: inference = _load_inference_module() observation = SimpleNamespace( last_tool_result={"tool_name": "lookup_requester_history", "found": True}, history=[{"ticket_id": "ticket-prev", "score": 0.4}], queue_position=2, tickets_remaining=4, investigation_budget_remaining=1, average_score_so_far=0.55, progress_fraction=0.4, last_reward_components={"ticket_score": 0.4, "final_reward": 0.4}, metadata={"last_feedback_summary": "Ticket score=0.40; reward=0.40"}, ) merged = inference.merge_ticket_context( { "ticket_id": "ticket-xyz", "title": "Contractor onboarding blocked by access issue", }, observation, ) self.assertEqual(merged["feedback_summary"], "Ticket score=0.40; reward=0.40") self.assertEqual(merged["investigation_budget_remaining"], 1) self.assertEqual(merged["average_score_so_far"], 0.55) self.assertEqual(merged["progress_fraction"], 0.4) self.assertEqual(merged["last_reward_components"]["final_reward"], 0.4) self.assertEqual(merged["queue_position"], 2) self.assertEqual(merged["tickets_remaining"], 4) self.assertEqual(merged["last_tool_result"]["tool_name"], "lookup_requester_history") def test_should_investigate_uses_hidden_context_and_ticket_cues(self) -> None: inference = _load_inference_module() investigate, tool_name = inference.should_investigate( { "ticket_id": "TKT-NONDEFAULT-003", "title": "Contractor onboarding blocked by access issue", "description": "Additional routing context is available via investigation.", "context_status": { "hidden_context_remaining": True, "context_gap_count": 1, } }, [], ) self.assertTrue(investigate) self.assertEqual(tool_name, "lookup_internal_routing_note") if __name__ == "__main__": unittest.main()