""" Query Router System This module integrates the medical/insurance classifier with the reason classification system to provide intelligent routing of healthcare portal queries. The router first determines if a query is medical or insurance-related, then routes accordingly: - Insurance queries -> Direct to insurance department - Medical queries -> Reason classification -> Appropriate medical department routing """ import os import sys from typing import Dict, List, Optional, Tuple from pathlib import Path # Add project root to path for imports REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from classifier.infer import predict_query from classifier.utils import get_models, CATEGORIES from classifier.reason import predict_single_reason from retriever.search import Retriever from team.candidates import get_candidates class HealthcareQueryRouter: """ Intelligent routing system for healthcare portal queries. Routes queries through a two-stage process: 1. Medical vs Insurance classification 2. For medical queries: Reason classification for department routing """ def __init__(self, medical_model_path: Optional[str] = None, use_retrieval: bool = True): """ Initialize the query router. Args: medical_model_path: Path to trained medical/insurance classifier use_retrieval: Whether to use retrieval system for medical queries """ # Initialize medical/insurance classifier try: self.embedding_model, self.classifier_head = get_models() # Load trained model if available if medical_model_path and os.path.exists(medical_model_path): import torch state_dict = torch.load(medical_model_path, weights_only=True) self.classifier_head.load_state_dict(state_dict) print(f"Loaded medical/insurance classifier from {medical_model_path}") else: print("Using untrained medical/insurance classifier") except Exception as e: print(f"Error initializing medical/insurance classifier: {e}") raise # Initialize retrieval system if requested self.retriever = None if use_retrieval: try: # Use default corpora configuration corpora_config = { "medical_qa": { "path": "data/corpora/medical_qa.jsonl", "text_fields": ["question", "answer", "title"], }, "miriad": { "path": "data/corpora/miriad_text.jsonl", "text_fields": ["text", "title"], } } # Only use available corpora available_config = {k: v for k, v in corpora_config.items() if Path(v["path"]).exists()} if available_config: self.retriever = Retriever(available_config) print(f"Retrieval system initialized with {len(available_config)} corpora") else: print("No corpora files found. Retrieval disabled.") except Exception as e: print(f"Could not initialize retrieval system: {e}") # Routing rules for insurance queries self.insurance_routing = { "department": "Insurance Department", "priority": "normal", "estimated_response": "1-2 business days", "contact_method": "phone_or_email", "description": "Insurance coverage, claims, and benefits inquiries" } # Medical department routing based on reason categories self.medical_department_routing = { "ROUTINE_CARE": { "department": "Primary Care", "priority": "normal", "estimated_response": "1-7 days", "contact_method": "standard_scheduling", "description": "Routine healthcare and maintenance visits" }, "PAIN_CONDITIONS": { "department": "Pain Management", "priority": "high", "estimated_response": "same day to 3 days", "contact_method": "phone_preferred", "description": "Pain-related conditions and discomfort" }, "INJURIES": { "department": "Urgent Care", "priority": "high", "estimated_response": "same day", "contact_method": "phone_immediate", "description": "Injuries, sprains, and trauma-related conditions" }, "SKIN_CONDITIONS": { "department": "Dermatology", "priority": "normal", "estimated_response": "3-7 days", "contact_method": "standard_scheduling", "description": "Skin-related issues and conditions" }, "STRUCTURAL_ISSUES": { "department": "Orthopedics", "priority": "normal", "estimated_response": "1-14 days", "contact_method": "standard_scheduling", "description": "Structural problems and musculoskeletal conditions" }, "PROCEDURES": { "department": "Surgical Services", "priority": "normal", "estimated_response": "3-14 days", "contact_method": "scheduling_coordinator", "description": "Surgical consultations and procedures" } } def route_query(self, query: str, include_retrieval: bool = True) -> Dict: """ Route a healthcare query through the classification and routing system. Args: query: The user's query text include_retrieval: Whether to include retrieval results for medical queries Returns: Dictionary with routing decision, confidence, and additional context """ # Step 1: Medical vs Insurance classification medical_prediction = predict_query([query], self.embedding_model, self.classifier_head) # Extract prediction details primary_category = medical_prediction['prediction'][0] confidence = medical_prediction['confidence'] if isinstance(medical_prediction['confidence'], float) else medical_prediction['confidence'][0] probabilities = medical_prediction['probabilities'] routing_result = { "query": query, "primary_classification": primary_category, "confidence": confidence, "all_probabilities": { CATEGORIES[i]: float(probabilities[i]) if isinstance(probabilities[0], list) else float(probabilities[i]) for i in range(len(CATEGORIES)) }, "routing_decision": None, "reason_classification": None, "retrieval_results": None, "recommendations": [] } # Step 2: Route based on classification if primary_category.lower() == "medical": routing_result["routing_decision"], routing_result["reason_classification"] = self._route_medical_query(query, include_retrieval) else: routing_result["routing_decision"] = self._route_insurance_query() # Step 3: Add contextual recommendations routing_result["recommendations"] = self._generate_recommendations( primary_category, confidence, routing_result.get("reason_classification") ) return routing_result def _route_medical_query(self, query: str, include_retrieval: bool = True) -> Tuple[Dict, Dict]: """Route medical queries through reason classification.""" # Get reason classification try: reason_result = predict_single_reason(query) reason_category = reason_result['category'] reason_confidence = reason_result['confidence'] reason_probabilities = reason_result['probabilities'] except Exception as e: print(f"Reason classification failed: {e}") # Fallback to general medical routing reason_category = "ROUTINE_CARE" reason_confidence = 0.5 reason_probabilities = {} # Get department routing based on reason routing = self.medical_department_routing.get( reason_category, self.medical_department_routing["ROUTINE_CARE"] ).copy() # Add reason classification details reason_classification = { "category": reason_category, "confidence": reason_confidence, "probabilities": reason_probabilities } # Add retrieval results if available and requested if include_retrieval and self.retriever: try: retrieval_results = self.retriever.retrieve(query, k=5, for_ui=True) routing["retrieval_results"] = retrieval_results except Exception as e: print(f"Retrieval failed: {e}") routing["retrieval_results"] = [] return routing, reason_classification def _route_insurance_query(self) -> Dict: """Route insurance queries to insurance department.""" return self.insurance_routing.copy() def _generate_recommendations(self, primary_category: str, confidence: float, reason_classification: Dict = None) -> List[str]: """Generate contextual recommendations based on classification.""" recommendations = [] # Low confidence warning if confidence < 0.7: recommendations.append( "Classification confidence is low. Consider manual review or " "asking the user to clarify their request." ) # Category-specific recommendations if primary_category.lower() == "medical": recommendations.extend([ "Consider asking follow-up questions about symptoms", "Verify if this requires immediate attention", "Check if patient has existing appointments or conditions" ]) # Reason-specific recommendations if reason_classification: reason_category = reason_classification.get('category') if reason_category == "PAIN_CONDITIONS": recommendations.append("Assess pain level and duration for urgency determination") elif reason_category == "INJURIES": recommendations.append("Determine if immediate medical attention is required") elif reason_category == "PROCEDURES": recommendations.append("Verify insurance pre-authorization requirements") elif primary_category.lower() == "insurance": recommendations.extend([ "Have patient account information ready", "Verify current insurance information and benefits", "Prepare to explain coverage details and requirements" ]) return recommendations def batch_route_queries(self, queries: List[str]) -> List[Dict]: """Route multiple queries efficiently.""" return [self.route_query(query) for query in queries] def get_routing_statistics(self, queries: List[str]) -> Dict: """Analyze routing patterns for a batch of queries.""" results = self.batch_route_queries(queries) # Count categories primary_counts = {} reason_counts = {} confidence_scores = [] for result in results: # Primary classification counts primary_category = result["primary_classification"] primary_counts[primary_category] = primary_counts.get(primary_category, 0) + 1 confidence_scores.append(result["confidence"]) # Reason classification counts (for medical queries) if result["reason_classification"]: reason_category = result["reason_classification"]["category"] reason_counts[reason_category] = reason_counts.get(reason_category, 0) + 1 return { "total_queries": len(queries), "primary_distribution": primary_counts, "reason_distribution": reason_counts, "average_confidence": sum(confidence_scores) / len(confidence_scores), "low_confidence_queries": len([c for c in confidence_scores if c < 0.7]), "primary_percentages": { cat: (count / len(queries)) * 100 for cat, count in primary_counts.items() }, "reason_percentages": { cat: (count / len(queries)) * 100 for cat, count in reason_counts.items() } } def demo_router(): """Demonstrate the query router functionality.""" print("Initializing Healthcare Query Router...") router = HealthcareQueryRouter() # Test queries covering different categories test_queries = [ # Insurance queries "My insurance claim was denied, can you help?", "What does my insurance cover for this procedure?", "I need to verify my insurance benefits", # Medical queries - different reasons "I have heel pain when I walk", # PAIN_CONDITIONS "I need routine foot care", # ROUTINE_CARE "I sprained my ankle playing sports", # INJURIES "My toenail is ingrown and infected", # SKIN_CONDITIONS "I have flat feet and need evaluation", # STRUCTURAL_ISSUES "I need a cortisone injection", # PROCEDURES ] print(f"\nRouting {len(test_queries)} test queries...\n") for i, query in enumerate(test_queries, 1): print(f"Query {i}: {query}") result = router.route_query(query) print(f" Primary: {result['primary_classification']} " f"(confidence: {result['confidence']:.3f})") if result['reason_classification']: print(f" Reason: {result['reason_classification']['category']} " f"(confidence: {result['reason_classification']['confidence']:.3f})") print(f" Department: {result['routing_decision']['department']}") print(f" Priority: {result['routing_decision']['priority']}") print(f" Response Time: {result['routing_decision']['estimated_response']}") if result['recommendations']: print(f" Recommendation: {result['recommendations'][0]}") print() # Show routing statistics print("Routing Statistics:") stats = router.get_routing_statistics(test_queries) print("Primary Classification:") for category, percentage in stats['primary_percentages'].items(): print(f" {category}: {percentage:.1f}%") if stats['reason_percentages']: print("Reason Classification:") for category, percentage in stats['reason_percentages'].items(): print(f" {category}: {percentage:.1f}%") print(f"Average Confidence: {stats['average_confidence']:.3f}") print(f"Low Confidence Queries: {stats['low_confidence_queries']}") if __name__ == "__main__": demo_router()