Spaces:
Sleeping
Sleeping
File size: 3,136 Bytes
c54dcef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """Tests for improved hierarchical classification."""
import pytest
import os
from core.classification import ImprovedHierarchicalClassifier
def test_classifier_initialization():
"""Test classifier initialization."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
assert classifier.hierarchy_name == "hospital"
assert classifier.hierarchy is not None
assert len(classifier.level1_keywords) > 0
def test_keyword_classification():
"""Test keyword-based classification."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
text = "Patient admission procedures require proper documentation and identification verification."
result = classifier.classify_text(text)
assert "level1" in result
assert "level2" in result
assert "level3" in result
assert "doc_type" in result
assert result["method"] == "keyword"
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OpenAI API key")
def test_llm_classification():
"""Test LLM-based classification (requires API key)."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=True)
text = """
Patient Admission Protocol
All patients must present valid identification upon admission.
Emergency cases follow expedited procedures.
"""
result = classifier.classify_text(text)
assert "level1" in result
assert "level2" in result
assert "level3" in result
assert "doc_type" in result
assert "confidence" in result
assert 0 <= result["confidence"] <= 1
def test_doc_type_inference():
"""Test document type inference."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
policy_text = "This policy outlines the requirements for patient care."
manual_text = "This manual provides step-by-step instructions for procedures."
report_text = "This report summarizes the findings of our quality analysis."
policy_result = classifier._infer_doc_type(policy_text.lower())
manual_result = classifier._infer_doc_type(manual_text.lower())
report_result = classifier._infer_doc_type(report_text.lower())
assert policy_result == "policy"
assert manual_result == "manual"
assert report_result == "report"
def test_classification_fallback():
"""Test fallback when LLM fails."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
# Even with empty text, should return valid structure
result = classifier._fallback_classification("")
assert "level1" in result
assert "level2" in result
assert "level3" in result
assert "doc_type" in result
assert result["confidence"] == 0.3
assert result["method"] == "keyword"
def test_classification_with_override():
"""Test classification with doc_type override."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
text = "Some document about patient care."
result = classifier.classify_text(text, doc_type="manual")
assert result["doc_type"] == "manual" |