Spaces:
Sleeping
Sleeping
| """ | |
| Inference module for Healthcare Reason Classification | |
| This module provides inference for the reason classification system, | |
| separate from the medical/insurance classifier. | |
| """ | |
| from ..head import ClassifierHead | |
| from datetime import datetime | |
| import os | |
| import pprint | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| # Reason-specific configuration | |
| REASON_CATEGORIES = { | |
| 0: "ROUTINE_CARE", | |
| 1: "PAIN_CONDITIONS", | |
| 2: "INJURIES", | |
| 3: "SKIN_CONDITIONS", | |
| 4: "STRUCTURAL_ISSUES", | |
| 5: "PROCEDURES" | |
| } | |
| REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints" | |
| DATETIME_FORMAT = "%Y%m%d_%H%M%S" | |
| MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical" | |
| def get_device(): | |
| """Get the best available device for inference.""" | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| else: | |
| return torch.device("cpu") | |
| DEVICE = get_device() | |
| def get_reason_models(): | |
| """Get the embedding model and classifier head for reason inference.""" | |
| # Load embedding model | |
| embedding_model = SentenceTransformer( | |
| MODEL_NAME, | |
| prompts={ | |
| 'classification': 'task: healthcare reason classification | query: ', | |
| 'retrieval (query)': 'task: search result | query: ', | |
| 'retrieval (document)': 'title: {title | "none"} | text: ', | |
| }, | |
| default_prompt_name='classification', | |
| ) | |
| # Load classifier head (for 6 reason categories) | |
| classifier_head = ClassifierHead(len(REASON_CATEGORIES)) | |
| return embedding_model.to(DEVICE), classifier_head.to(DEVICE) | |
| def predict_reason_query( | |
| text: list[str], | |
| embedding_model: SentenceTransformer, | |
| classifier_head: ClassifierHead, | |
| ) -> dict: | |
| """ | |
| Runs the full inference pipeline for reason classification: Text -> Embedding -> Classification. | |
| """ | |
| # Set models to evaluation mode | |
| embedding_model.eval() | |
| classifier_head.eval() | |
| with torch.no_grad(): | |
| # Embed the text | |
| embeddings = embedding_model.encode( | |
| text, | |
| convert_to_tensor=True, | |
| device=DEVICE | |
| ).to(DEVICE) | |
| # Calculate probabilities and prediction | |
| probabilities = classifier_head.predict_proba(embeddings) | |
| # Get the predicted index and confidence | |
| predicted_indices = torch.argmax(probabilities, dim=1) | |
| # Convert tensors to Python types safely | |
| if predicted_indices.dim() == 0: # Single prediction | |
| predicted_indices = [predicted_indices.item()] | |
| else: | |
| predicted_indices = predicted_indices.cpu().tolist() | |
| # Get confidences | |
| confidences = [] | |
| for i, idx in enumerate(predicted_indices): | |
| conf = probabilities[i][idx].item() if probabilities.dim() > 1 else probabilities[idx].item() | |
| confidences.append(conf) | |
| # Get the predicted label names | |
| predicted_labels = [REASON_CATEGORIES[i] for i in predicted_indices] | |
| return { | |
| 'prediction': predicted_labels, | |
| 'confidence': confidences, | |
| 'probabilities': probabilities.cpu().tolist() | |
| } | |
| def predict_single_reason(query: str) -> dict: | |
| """Convenience function to predict a single reason query.""" | |
| try: | |
| embedding_model, classifier_head = get_reason_models() | |
| # Try to load the most recent trained checkpoint | |
| if os.path.exists(REASON_CHECKPOINT_PATH): | |
| for d in os.listdir(REASON_CHECKPOINT_PATH): | |
| if d.endswith('.pt'): | |
| checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}" | |
| try: | |
| state_dict = torch.load(checkpoint_path, weights_only=True, map_location=DEVICE) | |
| classifier_head.load_state_dict(state_dict) | |
| print(f"Loaded trained weights from {checkpoint_path}") | |
| break | |
| except Exception as e: | |
| print(f"Could not load weights from {checkpoint_path}: {e}") | |
| result = predict_reason_query([query], embedding_model, classifier_head) | |
| # Extract values safely | |
| prediction = result['prediction'][0] if isinstance(result['prediction'], list) else str(result['prediction']) | |
| confidence = result['confidence'] if isinstance(result['confidence'], float) else (result['confidence'][0] if isinstance(result['confidence'], list) else float(result['confidence'])) | |
| # Handle probabilities - ensure it's a list | |
| probabilities = result['probabilities'] | |
| if isinstance(probabilities, list) and len(probabilities) > 0: | |
| if isinstance(probabilities[0], list): | |
| probabilities = probabilities[0] | |
| # Create probability dictionary | |
| prob_dict = {} | |
| for i, category in REASON_CATEGORIES.items(): | |
| if i < len(probabilities): | |
| prob_dict[category] = float(probabilities[i]) | |
| else: | |
| prob_dict[category] = 0.0 | |
| return { | |
| 'query': query, | |
| 'category': prediction, | |
| 'confidence': confidence, | |
| 'probabilities': prob_dict | |
| } | |
| except Exception as e: | |
| # Return a default classification if the model fails | |
| return { | |
| 'query': query, | |
| 'category': 'GENERAL_MEDICAL', | |
| 'confidence': 0.5, | |
| 'probabilities': {category: 1.0/len(REASON_CATEGORIES) for category in REASON_CATEGORIES.values()}, | |
| 'error': str(e) | |
| } | |
| def test_reason_classifier(): | |
| """Test the reason classifier with sample queries.""" | |
| latest = None | |
| path = "" | |
| # Try to load the most recent checkpoint | |
| if os.path.exists(REASON_CHECKPOINT_PATH): | |
| for d in os.listdir(REASON_CHECKPOINT_PATH): | |
| if d.endswith('.pt'): | |
| checkpoint_path = f"{REASON_CHECKPOINT_PATH}/{d}" | |
| print(f"Found checkpoint: {checkpoint_path}") | |
| path = checkpoint_path | |
| break | |
| if not path: | |
| print("No trained checkpoints found. Using untrained model.") | |
| else: | |
| print("No checkpoint directory found. Using untrained model.") | |
| embedding_model, classifier = get_reason_models() | |
| # Load trained weights if available | |
| if path and os.path.exists(path): | |
| try: | |
| state_dict = torch.load(path, weights_only=True, map_location=DEVICE) | |
| classifier.load_state_dict(state_dict) | |
| print(f"Loaded trained weights from {path}") | |
| except Exception as e: | |
| print(f"Could not load weights: {e}. Using untrained model.") | |
| # Test queries for reason classification | |
| queries = [ | |
| "I have heel pain when I walk", | |
| "My toenail is ingrown and painful", | |
| "I need routine foot care", | |
| "I sprained my ankle playing sports", | |
| "I have plantar fasciitis", | |
| "I need a cortisone injection" | |
| ] | |
| print("\nTesting reason classification:") | |
| pred = predict_reason_query( | |
| text=queries, | |
| embedding_model=embedding_model, | |
| classifier_head=classifier, | |
| ) | |
| pprint.pprint(pred, indent=4) | |
| if __name__ == "__main__": | |
| test_reason_classifier() |