taraky's picture
Upload folder using huggingface_hub
b7f3196 verified
raw
history blame
16.3 kB
"""
Query Router System
This module integrates the medical/insurance classifier with the reason
classification system to provide intelligent routing of healthcare portal queries.
The router first determines if a query is medical or insurance-related, then
routes accordingly:
- Insurance queries -> Direct to insurance department
- Medical queries -> Reason classification -> Appropriate medical department routing
"""
import os
import sys
from typing import Dict, List, Optional, Tuple
from pathlib import Path
# Add project root to path for imports
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from classifier.infer import predict_query
from classifier.utils import get_models, CATEGORIES
from classifier.reason import predict_single_reason
from retriever.search import Retriever
from team.candidates import get_candidates
class HealthcareQueryRouter:
"""
Intelligent routing system for healthcare portal queries.
Routes queries through a two-stage process:
1. Medical vs Insurance classification
2. For medical queries: Reason classification for department routing
"""
def __init__(self,
medical_model_path: Optional[str] = None,
use_retrieval: bool = True):
"""
Initialize the query router.
Args:
medical_model_path: Path to trained medical/insurance classifier
use_retrieval: Whether to use retrieval system for medical queries
"""
# Initialize medical/insurance classifier
try:
self.embedding_model, self.classifier_head = get_models()
# Load trained model if available
if medical_model_path and os.path.exists(medical_model_path):
import torch
state_dict = torch.load(medical_model_path, weights_only=True)
self.classifier_head.load_state_dict(state_dict)
print(f"Loaded medical/insurance classifier from {medical_model_path}")
else:
print("Using untrained medical/insurance classifier")
except Exception as e:
print(f"Error initializing medical/insurance classifier: {e}")
raise
# Initialize retrieval system if requested
self.retriever = None
if use_retrieval:
try:
# Use default corpora configuration
corpora_config = {
"medical_qa": {
"path": "data/corpora/medical_qa.jsonl",
"text_fields": ["question", "answer", "title"],
},
"miriad": {
"path": "data/corpora/miriad_text.jsonl",
"text_fields": ["text", "title"],
}
}
# Only use available corpora
available_config = {k: v for k, v in corpora_config.items()
if Path(v["path"]).exists()}
if available_config:
self.retriever = Retriever(available_config)
print(f"Retrieval system initialized with {len(available_config)} corpora")
else:
print("No corpora files found. Retrieval disabled.")
except Exception as e:
print(f"Could not initialize retrieval system: {e}")
# Routing rules for insurance queries
self.insurance_routing = {
"department": "Insurance Department",
"priority": "normal",
"estimated_response": "1-2 business days",
"contact_method": "phone_or_email",
"description": "Insurance coverage, claims, and benefits inquiries"
}
# Medical department routing based on reason categories
self.medical_department_routing = {
"ROUTINE_CARE": {
"department": "Primary Care",
"priority": "normal",
"estimated_response": "1-7 days",
"contact_method": "standard_scheduling",
"description": "Routine healthcare and maintenance visits"
},
"PAIN_CONDITIONS": {
"department": "Pain Management",
"priority": "high",
"estimated_response": "same day to 3 days",
"contact_method": "phone_preferred",
"description": "Pain-related conditions and discomfort"
},
"INJURIES": {
"department": "Urgent Care",
"priority": "high",
"estimated_response": "same day",
"contact_method": "phone_immediate",
"description": "Injuries, sprains, and trauma-related conditions"
},
"SKIN_CONDITIONS": {
"department": "Dermatology",
"priority": "normal",
"estimated_response": "3-7 days",
"contact_method": "standard_scheduling",
"description": "Skin-related issues and conditions"
},
"STRUCTURAL_ISSUES": {
"department": "Orthopedics",
"priority": "normal",
"estimated_response": "1-14 days",
"contact_method": "standard_scheduling",
"description": "Structural problems and musculoskeletal conditions"
},
"PROCEDURES": {
"department": "Surgical Services",
"priority": "normal",
"estimated_response": "3-14 days",
"contact_method": "scheduling_coordinator",
"description": "Surgical consultations and procedures"
}
}
def route_query(self, query: str, include_retrieval: bool = True) -> Dict:
"""
Route a healthcare query through the classification and routing system.
Args:
query: The user's query text
include_retrieval: Whether to include retrieval results for medical queries
Returns:
Dictionary with routing decision, confidence, and additional context
"""
# Step 1: Medical vs Insurance classification
medical_prediction = predict_query([query], self.embedding_model, self.classifier_head)
# Extract prediction details
primary_category = medical_prediction['prediction'][0]
confidence = medical_prediction['confidence'] if isinstance(medical_prediction['confidence'], float) else medical_prediction['confidence'][0]
probabilities = medical_prediction['probabilities']
routing_result = {
"query": query,
"primary_classification": primary_category,
"confidence": confidence,
"all_probabilities": {
CATEGORIES[i]: float(probabilities[i]) if isinstance(probabilities[0], list) else float(probabilities[i])
for i in range(len(CATEGORIES))
},
"routing_decision": None,
"reason_classification": None,
"retrieval_results": None,
"recommendations": []
}
# Step 2: Route based on classification
if primary_category.lower() == "medical":
routing_result["routing_decision"], routing_result["reason_classification"] = self._route_medical_query(query, include_retrieval)
else:
routing_result["routing_decision"] = self._route_insurance_query()
# Step 3: Add contextual recommendations
routing_result["recommendations"] = self._generate_recommendations(
primary_category, confidence, routing_result.get("reason_classification")
)
return routing_result
def _route_medical_query(self, query: str, include_retrieval: bool = True) -> Tuple[Dict, Dict]:
"""Route medical queries through reason classification."""
# Get reason classification
try:
reason_result = predict_single_reason(query)
reason_category = reason_result['category']
reason_confidence = reason_result['confidence']
reason_probabilities = reason_result['probabilities']
except Exception as e:
print(f"Reason classification failed: {e}")
# Fallback to general medical routing
reason_category = "ROUTINE_CARE"
reason_confidence = 0.5
reason_probabilities = {}
# Get department routing based on reason
routing = self.medical_department_routing.get(
reason_category,
self.medical_department_routing["ROUTINE_CARE"]
).copy()
# Add reason classification details
reason_classification = {
"category": reason_category,
"confidence": reason_confidence,
"probabilities": reason_probabilities
}
# Add retrieval results if available and requested
if include_retrieval and self.retriever:
try:
retrieval_results = self.retriever.retrieve(query, k=5, for_ui=True)
routing["retrieval_results"] = retrieval_results
except Exception as e:
print(f"Retrieval failed: {e}")
routing["retrieval_results"] = []
return routing, reason_classification
def _route_insurance_query(self) -> Dict:
"""Route insurance queries to insurance department."""
return self.insurance_routing.copy()
def _generate_recommendations(self, primary_category: str, confidence: float, reason_classification: Dict = None) -> List[str]:
"""Generate contextual recommendations based on classification."""
recommendations = []
# Low confidence warning
if confidence < 0.7:
recommendations.append(
"Classification confidence is low. Consider manual review or "
"asking the user to clarify their request."
)
# Category-specific recommendations
if primary_category.lower() == "medical":
recommendations.extend([
"Consider asking follow-up questions about symptoms",
"Verify if this requires immediate attention",
"Check if patient has existing appointments or conditions"
])
# Reason-specific recommendations
if reason_classification:
reason_category = reason_classification.get('category')
if reason_category == "PAIN_CONDITIONS":
recommendations.append("Assess pain level and duration for urgency determination")
elif reason_category == "INJURIES":
recommendations.append("Determine if immediate medical attention is required")
elif reason_category == "PROCEDURES":
recommendations.append("Verify insurance pre-authorization requirements")
elif primary_category.lower() == "insurance":
recommendations.extend([
"Have patient account information ready",
"Verify current insurance information and benefits",
"Prepare to explain coverage details and requirements"
])
return recommendations
def batch_route_queries(self, queries: List[str]) -> List[Dict]:
"""Route multiple queries efficiently."""
return [self.route_query(query) for query in queries]
def get_routing_statistics(self, queries: List[str]) -> Dict:
"""Analyze routing patterns for a batch of queries."""
results = self.batch_route_queries(queries)
# Count categories
primary_counts = {}
reason_counts = {}
confidence_scores = []
for result in results:
# Primary classification counts
primary_category = result["primary_classification"]
primary_counts[primary_category] = primary_counts.get(primary_category, 0) + 1
confidence_scores.append(result["confidence"])
# Reason classification counts (for medical queries)
if result["reason_classification"]:
reason_category = result["reason_classification"]["category"]
reason_counts[reason_category] = reason_counts.get(reason_category, 0) + 1
return {
"total_queries": len(queries),
"primary_distribution": primary_counts,
"reason_distribution": reason_counts,
"average_confidence": sum(confidence_scores) / len(confidence_scores),
"low_confidence_queries": len([c for c in confidence_scores if c < 0.7]),
"primary_percentages": {
cat: (count / len(queries)) * 100
for cat, count in primary_counts.items()
},
"reason_percentages": {
cat: (count / len(queries)) * 100
for cat, count in reason_counts.items()
}
}
def demo_router():
"""Demonstrate the query router functionality."""
print("Initializing Healthcare Query Router...")
router = HealthcareQueryRouter()
# Test queries covering different categories
test_queries = [
# Insurance queries
"My insurance claim was denied, can you help?",
"What does my insurance cover for this procedure?",
"I need to verify my insurance benefits",
# Medical queries - different reasons
"I have heel pain when I walk", # PAIN_CONDITIONS
"I need routine foot care", # ROUTINE_CARE
"I sprained my ankle playing sports", # INJURIES
"My toenail is ingrown and infected", # SKIN_CONDITIONS
"I have flat feet and need evaluation", # STRUCTURAL_ISSUES
"I need a cortisone injection", # PROCEDURES
]
print(f"\nRouting {len(test_queries)} test queries...\n")
for i, query in enumerate(test_queries, 1):
print(f"Query {i}: {query}")
result = router.route_query(query)
print(f" Primary: {result['primary_classification']} "
f"(confidence: {result['confidence']:.3f})")
if result['reason_classification']:
print(f" Reason: {result['reason_classification']['category']} "
f"(confidence: {result['reason_classification']['confidence']:.3f})")
print(f" Department: {result['routing_decision']['department']}")
print(f" Priority: {result['routing_decision']['priority']}")
print(f" Response Time: {result['routing_decision']['estimated_response']}")
if result['recommendations']:
print(f" Recommendation: {result['recommendations'][0]}")
print()
# Show routing statistics
print("Routing Statistics:")
stats = router.get_routing_statistics(test_queries)
print("Primary Classification:")
for category, percentage in stats['primary_percentages'].items():
print(f" {category}: {percentage:.1f}%")
if stats['reason_percentages']:
print("Reason Classification:")
for category, percentage in stats['reason_percentages'].items():
print(f" {category}: {percentage:.1f}%")
print(f"Average Confidence: {stats['average_confidence']:.3f}")
print(f"Low Confidence Queries: {stats['low_confidence_queries']}")
if __name__ == "__main__":
demo_router()