"""Tests for improved hierarchical classification.""" import pytest import os from core.classification import ImprovedHierarchicalClassifier def test_classifier_initialization(): """Test classifier initialization.""" classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False) assert classifier.hierarchy_name == "hospital" assert classifier.hierarchy is not None assert len(classifier.level1_keywords) > 0 def test_keyword_classification(): """Test keyword-based classification.""" classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False) text = "Patient admission procedures require proper documentation and identification verification." result = classifier.classify_text(text) assert "level1" in result assert "level2" in result assert "level3" in result assert "doc_type" in result assert result["method"] == "keyword" @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OpenAI API key") def test_llm_classification(): """Test LLM-based classification (requires API key).""" classifier = ImprovedHierarchicalClassifier("hospital", use_llm=True) text = """ Patient Admission Protocol All patients must present valid identification upon admission. Emergency cases follow expedited procedures. """ result = classifier.classify_text(text) assert "level1" in result assert "level2" in result assert "level3" in result assert "doc_type" in result assert "confidence" in result assert 0 <= result["confidence"] <= 1 def test_doc_type_inference(): """Test document type inference.""" classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False) policy_text = "This policy outlines the requirements for patient care." manual_text = "This manual provides step-by-step instructions for procedures." report_text = "This report summarizes the findings of our quality analysis." policy_result = classifier._infer_doc_type(policy_text.lower()) manual_result = classifier._infer_doc_type(manual_text.lower()) report_result = classifier._infer_doc_type(report_text.lower()) assert policy_result == "policy" assert manual_result == "manual" assert report_result == "report" def test_classification_fallback(): """Test fallback when LLM fails.""" classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False) # Even with empty text, should return valid structure result = classifier._fallback_classification("") assert "level1" in result assert "level2" in result assert "level3" in result assert "doc_type" in result assert result["confidence"] == 0.3 assert result["method"] == "keyword" def test_classification_with_override(): """Test classification with doc_type override.""" classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False) text = "Some document about patient care." result = classifier.classify_text(text, doc_type="manual") assert result["doc_type"] == "manual"