Spaces:
Sleeping
Sleeping
| """ | |
| Healthcare Reason for Visit Classifier | |
| This module implements a classifier for healthcare clinic queries | |
| using real healthcare data from clinic appointment records. | |
| Categories based on the actual data: | |
| - ROUTINE_CARE: Routine care, maintenance visits | |
| - PAIN_CONDITIONS: Various pain-related conditions | |
| - INJURIES: Sprains, wounds, trauma-related visits | |
| - SKIN_CONDITIONS: Skin-related conditions and issues | |
| - STRUCTURAL_ISSUES: Structural problems and conditions | |
| - PROCEDURES: Injections, surgical consults, postop care | |
| """ | |
| import os | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| from typing import List, Dict, Tuple, Optional | |
| from sentence_transformers import SentenceTransformer | |
| from setfit import SetFitModel | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| from datasets import Dataset | |
| import json | |
| from ..head import ClassifierHead | |
| # Healthcare reason categories based on real data analysis | |
| REASON_CATEGORIES = { | |
| 0: "ROUTINE_CARE", | |
| 1: "PAIN_CONDITIONS", | |
| 2: "INJURIES", | |
| 3: "SKIN_CONDITIONS", | |
| 4: "STRUCTURAL_ISSUES", | |
| 5: "PROCEDURES" | |
| } | |
| CATEGORY_DESCRIPTIONS = { | |
| "ROUTINE_CARE": "Routine healthcare, maintenance visits, general care", | |
| "PAIN_CONDITIONS": "Various pain-related conditions and discomfort", | |
| "INJURIES": "Sprains, wounds, trauma-related conditions", | |
| "SKIN_CONDITIONS": "Skin-related issues and conditions", | |
| "STRUCTURAL_ISSUES": "Structural problems and related conditions", | |
| "PROCEDURES": "Injections, surgical consultations, post-operative care" | |
| } | |
| class ReasonClassifier: | |
| """ | |
| Healthcare Reason Classifier that uses real clinic data to classify | |
| patient queries into specific healthcare reason categories. | |
| """ | |
| def __init__(self, data_file: str = "data/reason_for_visit_data.xlsx"): | |
| self.model_name = "sentence-transformers/embeddinggemma-300m-medical" | |
| self.num_classes = len(REASON_CATEGORIES) | |
| self.categories = REASON_CATEGORIES | |
| self.data_file = data_file | |
| self.model = None | |
| self.device = self._get_device() | |
| # Load and process real data | |
| self.healthcare_df = self._load_data() | |
| self._initialize_model() | |
| def _get_device(self): | |
| """Get the best available device for training/inference.""" | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| else: | |
| return torch.device("cpu") | |
| def _load_data(self) -> pd.DataFrame: | |
| """Load the real healthcare dataset.""" | |
| try: | |
| df = pd.read_excel(self.data_file) | |
| print(f"Loaded {len(df)} healthcare records from {self.data_file}") | |
| print(f"Unique reasons: {df['Reason For Visit'].nunique()}") | |
| return df | |
| except Exception as e: | |
| print(f"Error loading data: {e}") | |
| raise RuntimeError(f"Failed to load healthcare data from {self.data_file}") | |
| def _initialize_model(self): | |
| """Initialize the model with the existing infrastructure.""" | |
| try: | |
| model_body = SentenceTransformer( | |
| self.model_name, | |
| prompts={ | |
| 'classification': 'task: healthcare reason classification | query: ', | |
| 'retrieval (query)': 'task: search result | query: ', | |
| 'retrieval (document)': 'title: {title | "none"} | text: ', | |
| }, | |
| default_prompt_name='classification', | |
| ) | |
| model_head = ClassifierHead(self.num_classes, embedding_dim=768) | |
| self.model = SetFitModel(model_body, model_head) | |
| self.model.freeze("body") # Freeze embedding weights | |
| self.model = self.model.to(self.device) | |
| print(f"Initialized ReasonClassifier on {self.device}") | |
| except Exception as e: | |
| print(f"Error initializing model: {e}") | |
| raise RuntimeError("Failed to initialize reason classifier") | |
| def _map_reason_to_category(self, reason: str) -> int: | |
| """ | |
| Map real healthcare reasons to categories using keyword matching. | |
| Based on the actual data distribution. | |
| """ | |
| reason_lower = reason.lower() | |
| # ROUTINE_CARE (routine foot care, nail care, calluses) | |
| if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']): | |
| return 0 | |
| # PAIN_CONDITIONS (heel pain, ankle pain, foot pain, etc.) | |
| if any(word in reason_lower for word in ['pain', 'ache', 'sore']): | |
| return 1 | |
| # INJURIES (ankle sprain, wounds, trauma) | |
| if any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma']): | |
| return 2 | |
| # SKIN_CONDITIONS (ingrown toenail, calluses, skin issues) | |
| if any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'skin']): | |
| return 3 | |
| # STRUCTURAL_ISSUES (flat feet, plantar fasciitis, achilles) | |
| if any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon']): | |
| return 4 | |
| # PROCEDURES (injection, surgical consult, postop) | |
| if any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'procedure']): | |
| return 5 | |
| # Default to pain conditions (most common category) | |
| return 1 | |
| def create_real_dataset(self) -> pd.DataFrame: | |
| """ | |
| Create training dataset from real healthcare data. | |
| """ | |
| training_data = [] | |
| for _, row in self.healthcare_df.iterrows(): | |
| reason = row['Reason For Visit'] | |
| appointment_type = row['Appointment Type'] | |
| # Map reason to category | |
| category_id = self._map_reason_to_category(reason) | |
| # Create enhanced text with context | |
| enhanced_text = reason | |
| if pd.notna(appointment_type): | |
| enhanced_text += f" | {appointment_type}" | |
| training_data.append({ | |
| 'text': enhanced_text, | |
| 'label': category_id, | |
| 'category': self.categories[category_id], | |
| 'original_reason': reason, | |
| 'appointment_type': appointment_type | |
| }) | |
| df = pd.DataFrame(training_data) | |
| # Show category distribution | |
| print("\nCategory distribution in training data:") | |
| for cat_id, cat_name in self.categories.items(): | |
| count = len(df[df['label'] == cat_id]) | |
| percentage = (count / len(df)) * 100 | |
| print(f" {cat_name}: {count} samples ({percentage:.1f}%)") | |
| return df.sample(frac=1).reset_index(drop=True) # Shuffle | |
| def train(self, train_data: pd.DataFrame = None, eval_data: Optional[pd.DataFrame] = None, | |
| epochs: int = 16, output_dir: str = "classifier/reason_checkpoints"): | |
| """Train the healthcare reason classifier.""" | |
| if train_data is None: | |
| train_data = self.create_real_dataset() | |
| if eval_data is None: | |
| train_data, eval_data = train_test_split(train_data, test_size=0.2, | |
| stratify=train_data['label'], | |
| random_state=42) | |
| train_dataset = Dataset.from_pandas(train_data) | |
| eval_dataset = Dataset.from_pandas(eval_data) | |
| from setfit import Trainer, TrainingArguments | |
| args = TrainingArguments( | |
| output_dir=output_dir, | |
| num_epochs=(0, epochs), # Skip contrastive learning, only train head | |
| eval_strategy='epoch', | |
| eval_steps=100, | |
| save_strategy='epoch', | |
| logging_steps=50, | |
| ) | |
| trainer = Trainer( | |
| model=self.model, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| metric='accuracy', | |
| column_mapping={"text": "text", "label": "label"}, | |
| args=args, | |
| ) | |
| print("Starting training...") | |
| trainer.train() | |
| # Evaluate | |
| metrics = trainer.evaluate(eval_dataset) | |
| print(f"Training completed. Final metrics: {metrics}") | |
| return metrics | |
| def predict(self, queries: List[str]) -> List[Dict]: | |
| """ | |
| Predict healthcare reason categories for a list of queries. | |
| Returns: | |
| List of dictionaries with 'query', 'category', 'confidence', 'probabilities' | |
| """ | |
| if not self.model: | |
| raise RuntimeError("Model not initialized. Train or load a model first.") | |
| predictions = [] | |
| for query in queries: | |
| # Get prediction using SetFit's built-in methods | |
| pred_label = self.model.predict([query])[0] | |
| pred_proba = self.model.predict_proba([query])[0] | |
| category = self.categories[int(pred_label)] | |
| confidence = float(pred_proba[int(pred_label)]) | |
| predictions.append({ | |
| 'query': query, | |
| 'category': category, | |
| 'confidence': confidence, | |
| 'probabilities': {self.categories[i]: float(prob) | |
| for i, prob in enumerate(pred_proba)} | |
| }) | |
| return predictions | |
| def save_model(self, path: str): | |
| """Save the trained model.""" | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| self.model.save_pretrained(path) | |
| # Save category mapping | |
| with open(os.path.join(path, 'categories.json'), 'w') as f: | |
| json.dump(self.categories, f) | |
| print(f"Model saved to {path}") | |
| def load_model(self, path: str): | |
| """Load a trained model.""" | |
| self.model = SetFitModel.from_pretrained(path) | |
| self.model = self.model.to(self.device) | |
| # Load category mapping | |
| with open(os.path.join(path, 'categories.json'), 'r') as f: | |
| self.categories = {int(k): v for k, v in json.load(f).items()} | |
| print(f"Model loaded from {path}") | |
| def evaluate_on_test_set(self, test_data: pd.DataFrame) -> Dict: | |
| """Evaluate the model on a test dataset.""" | |
| predictions = self.predict(test_data['text'].tolist()) | |
| y_true = test_data['label'].tolist() | |
| y_pred = [list(self.categories.keys())[list(self.categories.values()).index(p['category'])] | |
| for p in predictions] | |
| # Classification report | |
| report = classification_report(y_true, y_pred, | |
| target_names=list(self.categories.values()), | |
| output_dict=True) | |
| # Confusion matrix | |
| cm = confusion_matrix(y_true, y_pred) | |
| return { | |
| 'classification_report': report, | |
| 'confusion_matrix': cm.tolist(), | |
| 'accuracy': report['accuracy'] | |
| } | |
| def analyze_real_data(self): | |
| """Analyze the real healthcare data to understand patterns.""" | |
| print("Real Data Analysis:") | |
| print("=" * 50) | |
| print(f"Total records: {len(self.healthcare_df)}") | |
| print(f"Unique reasons: {self.healthcare_df['Reason For Visit'].nunique()}") | |
| print("\nTop 15 reasons for visit:") | |
| top_reasons = self.healthcare_df['Reason For Visit'].value_counts().head(15) | |
| for reason, count in top_reasons.items(): | |
| category_id = self._map_reason_to_category(reason) | |
| category_name = self.categories[category_id] | |
| print(f" {reason}: {count} ({category_name})") | |
| print(f"\nAppointment types:") | |
| print(self.healthcare_df['Appointment Type'].value_counts()) | |
| def main(): | |
| """Example usage and training script for healthcare reason data.""" | |
| print("Initializing Healthcare Reason Classifier...") | |
| # Initialize classifier with real data | |
| classifier = ReasonClassifier() | |
| # Analyze the real data | |
| classifier.analyze_real_data() | |
| # Create training dataset from real data | |
| print("\nCreating training dataset from real healthcare data...") | |
| dataset = classifier.create_real_dataset() | |
| print(f"Dataset created with {len(dataset)} real examples") | |
| # Train the model | |
| print("\nTraining classifier...") | |
| metrics = classifier.train(dataset, epochs=20) | |
| # Save the model | |
| model_path = "classifier/reason_model" | |
| classifier.save_model(model_path) | |
| # Test predictions on healthcare reason queries | |
| test_queries = [ | |
| "I have heel pain when I walk", | |
| "My toenail is ingrown and painful", | |
| "I need routine foot care", | |
| "I sprained my ankle playing sports", | |
| "I have flat feet and need evaluation", | |
| "I need a cortisone injection for my foot", | |
| "I have plantar fasciitis", | |
| "My foot wound is not healing" | |
| ] | |
| print("\nTesting predictions on healthcare reason queries:") | |
| predictions = classifier.predict(test_queries) | |
| for pred in predictions: | |
| print(f"Query: {pred['query']}") | |
| print(f"Category: {pred['category']} (confidence: {pred['confidence']:.3f})") | |
| print("---") | |
| if __name__ == "__main__": | |
| main() |