hierarchical-rag-eval / tests /test_classification.py
hh786's picture
Deployment of Hierarchical RAG system
c54dcef
"""Tests for improved hierarchical classification."""
import pytest
import os
from core.classification import ImprovedHierarchicalClassifier
def test_classifier_initialization():
"""Test classifier initialization."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
assert classifier.hierarchy_name == "hospital"
assert classifier.hierarchy is not None
assert len(classifier.level1_keywords) > 0
def test_keyword_classification():
"""Test keyword-based classification."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
text = "Patient admission procedures require proper documentation and identification verification."
result = classifier.classify_text(text)
assert "level1" in result
assert "level2" in result
assert "level3" in result
assert "doc_type" in result
assert result["method"] == "keyword"
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="No OpenAI API key")
def test_llm_classification():
"""Test LLM-based classification (requires API key)."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=True)
text = """
Patient Admission Protocol
All patients must present valid identification upon admission.
Emergency cases follow expedited procedures.
"""
result = classifier.classify_text(text)
assert "level1" in result
assert "level2" in result
assert "level3" in result
assert "doc_type" in result
assert "confidence" in result
assert 0 <= result["confidence"] <= 1
def test_doc_type_inference():
"""Test document type inference."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
policy_text = "This policy outlines the requirements for patient care."
manual_text = "This manual provides step-by-step instructions for procedures."
report_text = "This report summarizes the findings of our quality analysis."
policy_result = classifier._infer_doc_type(policy_text.lower())
manual_result = classifier._infer_doc_type(manual_text.lower())
report_result = classifier._infer_doc_type(report_text.lower())
assert policy_result == "policy"
assert manual_result == "manual"
assert report_result == "report"
def test_classification_fallback():
"""Test fallback when LLM fails."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
# Even with empty text, should return valid structure
result = classifier._fallback_classification("")
assert "level1" in result
assert "level2" in result
assert "level3" in result
assert "doc_type" in result
assert result["confidence"] == 0.3
assert result["method"] == "keyword"
def test_classification_with_override():
"""Test classification with doc_type override."""
classifier = ImprovedHierarchicalClassifier("hospital", use_llm=False)
text = "Some document about patient care."
result = classifier.classify_text(text, doc_type="manual")
assert result["doc_type"] == "manual"