Spaces:
Sleeping
Sleeping
File size: 16,304 Bytes
b7f3196 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 |
"""
Query Router System
This module integrates the medical/insurance classifier with the reason
classification system to provide intelligent routing of healthcare portal queries.
The router first determines if a query is medical or insurance-related, then
routes accordingly:
- Insurance queries -> Direct to insurance department
- Medical queries -> Reason classification -> Appropriate medical department routing
"""
import os
import sys
from typing import Dict, List, Optional, Tuple
from pathlib import Path
# Add project root to path for imports
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from classifier.infer import predict_query
from classifier.utils import get_models, CATEGORIES
from classifier.reason import predict_single_reason
from retriever.search import Retriever
from team.candidates import get_candidates
class HealthcareQueryRouter:
"""
Intelligent routing system for healthcare portal queries.
Routes queries through a two-stage process:
1. Medical vs Insurance classification
2. For medical queries: Reason classification for department routing
"""
def __init__(self,
medical_model_path: Optional[str] = None,
use_retrieval: bool = True):
"""
Initialize the query router.
Args:
medical_model_path: Path to trained medical/insurance classifier
use_retrieval: Whether to use retrieval system for medical queries
"""
# Initialize medical/insurance classifier
try:
self.embedding_model, self.classifier_head = get_models()
# Load trained model if available
if medical_model_path and os.path.exists(medical_model_path):
import torch
state_dict = torch.load(medical_model_path, weights_only=True)
self.classifier_head.load_state_dict(state_dict)
print(f"Loaded medical/insurance classifier from {medical_model_path}")
else:
print("Using untrained medical/insurance classifier")
except Exception as e:
print(f"Error initializing medical/insurance classifier: {e}")
raise
# Initialize retrieval system if requested
self.retriever = None
if use_retrieval:
try:
# Use default corpora configuration
corpora_config = {
"medical_qa": {
"path": "data/corpora/medical_qa.jsonl",
"text_fields": ["question", "answer", "title"],
},
"miriad": {
"path": "data/corpora/miriad_text.jsonl",
"text_fields": ["text", "title"],
}
}
# Only use available corpora
available_config = {k: v for k, v in corpora_config.items()
if Path(v["path"]).exists()}
if available_config:
self.retriever = Retriever(available_config)
print(f"Retrieval system initialized with {len(available_config)} corpora")
else:
print("No corpora files found. Retrieval disabled.")
except Exception as e:
print(f"Could not initialize retrieval system: {e}")
# Routing rules for insurance queries
self.insurance_routing = {
"department": "Insurance Department",
"priority": "normal",
"estimated_response": "1-2 business days",
"contact_method": "phone_or_email",
"description": "Insurance coverage, claims, and benefits inquiries"
}
# Medical department routing based on reason categories
self.medical_department_routing = {
"ROUTINE_CARE": {
"department": "Primary Care",
"priority": "normal",
"estimated_response": "1-7 days",
"contact_method": "standard_scheduling",
"description": "Routine healthcare and maintenance visits"
},
"PAIN_CONDITIONS": {
"department": "Pain Management",
"priority": "high",
"estimated_response": "same day to 3 days",
"contact_method": "phone_preferred",
"description": "Pain-related conditions and discomfort"
},
"INJURIES": {
"department": "Urgent Care",
"priority": "high",
"estimated_response": "same day",
"contact_method": "phone_immediate",
"description": "Injuries, sprains, and trauma-related conditions"
},
"SKIN_CONDITIONS": {
"department": "Dermatology",
"priority": "normal",
"estimated_response": "3-7 days",
"contact_method": "standard_scheduling",
"description": "Skin-related issues and conditions"
},
"STRUCTURAL_ISSUES": {
"department": "Orthopedics",
"priority": "normal",
"estimated_response": "1-14 days",
"contact_method": "standard_scheduling",
"description": "Structural problems and musculoskeletal conditions"
},
"PROCEDURES": {
"department": "Surgical Services",
"priority": "normal",
"estimated_response": "3-14 days",
"contact_method": "scheduling_coordinator",
"description": "Surgical consultations and procedures"
}
}
def route_query(self, query: str, include_retrieval: bool = True) -> Dict:
"""
Route a healthcare query through the classification and routing system.
Args:
query: The user's query text
include_retrieval: Whether to include retrieval results for medical queries
Returns:
Dictionary with routing decision, confidence, and additional context
"""
# Step 1: Medical vs Insurance classification
medical_prediction = predict_query([query], self.embedding_model, self.classifier_head)
# Extract prediction details
primary_category = medical_prediction['prediction'][0]
confidence = medical_prediction['confidence'] if isinstance(medical_prediction['confidence'], float) else medical_prediction['confidence'][0]
probabilities = medical_prediction['probabilities']
routing_result = {
"query": query,
"primary_classification": primary_category,
"confidence": confidence,
"all_probabilities": {
CATEGORIES[i]: float(probabilities[i]) if isinstance(probabilities[0], list) else float(probabilities[i])
for i in range(len(CATEGORIES))
},
"routing_decision": None,
"reason_classification": None,
"retrieval_results": None,
"recommendations": []
}
# Step 2: Route based on classification
if primary_category.lower() == "medical":
routing_result["routing_decision"], routing_result["reason_classification"] = self._route_medical_query(query, include_retrieval)
else:
routing_result["routing_decision"] = self._route_insurance_query()
# Step 3: Add contextual recommendations
routing_result["recommendations"] = self._generate_recommendations(
primary_category, confidence, routing_result.get("reason_classification")
)
return routing_result
def _route_medical_query(self, query: str, include_retrieval: bool = True) -> Tuple[Dict, Dict]:
"""Route medical queries through reason classification."""
# Get reason classification
try:
reason_result = predict_single_reason(query)
reason_category = reason_result['category']
reason_confidence = reason_result['confidence']
reason_probabilities = reason_result['probabilities']
except Exception as e:
print(f"Reason classification failed: {e}")
# Fallback to general medical routing
reason_category = "ROUTINE_CARE"
reason_confidence = 0.5
reason_probabilities = {}
# Get department routing based on reason
routing = self.medical_department_routing.get(
reason_category,
self.medical_department_routing["ROUTINE_CARE"]
).copy()
# Add reason classification details
reason_classification = {
"category": reason_category,
"confidence": reason_confidence,
"probabilities": reason_probabilities
}
# Add retrieval results if available and requested
if include_retrieval and self.retriever:
try:
retrieval_results = self.retriever.retrieve(query, k=5, for_ui=True)
routing["retrieval_results"] = retrieval_results
except Exception as e:
print(f"Retrieval failed: {e}")
routing["retrieval_results"] = []
return routing, reason_classification
def _route_insurance_query(self) -> Dict:
"""Route insurance queries to insurance department."""
return self.insurance_routing.copy()
def _generate_recommendations(self, primary_category: str, confidence: float, reason_classification: Dict = None) -> List[str]:
"""Generate contextual recommendations based on classification."""
recommendations = []
# Low confidence warning
if confidence < 0.7:
recommendations.append(
"Classification confidence is low. Consider manual review or "
"asking the user to clarify their request."
)
# Category-specific recommendations
if primary_category.lower() == "medical":
recommendations.extend([
"Consider asking follow-up questions about symptoms",
"Verify if this requires immediate attention",
"Check if patient has existing appointments or conditions"
])
# Reason-specific recommendations
if reason_classification:
reason_category = reason_classification.get('category')
if reason_category == "PAIN_CONDITIONS":
recommendations.append("Assess pain level and duration for urgency determination")
elif reason_category == "INJURIES":
recommendations.append("Determine if immediate medical attention is required")
elif reason_category == "PROCEDURES":
recommendations.append("Verify insurance pre-authorization requirements")
elif primary_category.lower() == "insurance":
recommendations.extend([
"Have patient account information ready",
"Verify current insurance information and benefits",
"Prepare to explain coverage details and requirements"
])
return recommendations
def batch_route_queries(self, queries: List[str]) -> List[Dict]:
"""Route multiple queries efficiently."""
return [self.route_query(query) for query in queries]
def get_routing_statistics(self, queries: List[str]) -> Dict:
"""Analyze routing patterns for a batch of queries."""
results = self.batch_route_queries(queries)
# Count categories
primary_counts = {}
reason_counts = {}
confidence_scores = []
for result in results:
# Primary classification counts
primary_category = result["primary_classification"]
primary_counts[primary_category] = primary_counts.get(primary_category, 0) + 1
confidence_scores.append(result["confidence"])
# Reason classification counts (for medical queries)
if result["reason_classification"]:
reason_category = result["reason_classification"]["category"]
reason_counts[reason_category] = reason_counts.get(reason_category, 0) + 1
return {
"total_queries": len(queries),
"primary_distribution": primary_counts,
"reason_distribution": reason_counts,
"average_confidence": sum(confidence_scores) / len(confidence_scores),
"low_confidence_queries": len([c for c in confidence_scores if c < 0.7]),
"primary_percentages": {
cat: (count / len(queries)) * 100
for cat, count in primary_counts.items()
},
"reason_percentages": {
cat: (count / len(queries)) * 100
for cat, count in reason_counts.items()
}
}
def demo_router():
"""Demonstrate the query router functionality."""
print("Initializing Healthcare Query Router...")
router = HealthcareQueryRouter()
# Test queries covering different categories
test_queries = [
# Insurance queries
"My insurance claim was denied, can you help?",
"What does my insurance cover for this procedure?",
"I need to verify my insurance benefits",
# Medical queries - different reasons
"I have heel pain when I walk", # PAIN_CONDITIONS
"I need routine foot care", # ROUTINE_CARE
"I sprained my ankle playing sports", # INJURIES
"My toenail is ingrown and infected", # SKIN_CONDITIONS
"I have flat feet and need evaluation", # STRUCTURAL_ISSUES
"I need a cortisone injection", # PROCEDURES
]
print(f"\nRouting {len(test_queries)} test queries...\n")
for i, query in enumerate(test_queries, 1):
print(f"Query {i}: {query}")
result = router.route_query(query)
print(f" Primary: {result['primary_classification']} "
f"(confidence: {result['confidence']:.3f})")
if result['reason_classification']:
print(f" Reason: {result['reason_classification']['category']} "
f"(confidence: {result['reason_classification']['confidence']:.3f})")
print(f" Department: {result['routing_decision']['department']}")
print(f" Priority: {result['routing_decision']['priority']}")
print(f" Response Time: {result['routing_decision']['estimated_response']}")
if result['recommendations']:
print(f" Recommendation: {result['recommendations'][0]}")
print()
# Show routing statistics
print("Routing Statistics:")
stats = router.get_routing_statistics(test_queries)
print("Primary Classification:")
for category, percentage in stats['primary_percentages'].items():
print(f" {category}: {percentage:.1f}%")
if stats['reason_percentages']:
print("Reason Classification:")
for category, percentage in stats['reason_percentages'].items():
print(f" {category}: {percentage:.1f}%")
print(f"Average Confidence: {stats['average_confidence']:.3f}")
print(f"Low Confidence Queries: {stats['low_confidence_queries']}")
if __name__ == "__main__":
demo_router() |