""" Training script for Healthcare Reason Classification This script trains a classifier for healthcare visit reasons using real healthcare data. It creates a separate system from the medical/insurance classifier. """ from sentence_transformers import SentenceTransformer from setfit import SetFitModel, Trainer, TrainingArguments import sys from pathlib import Path # Add project root to path for imports REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from classifier.head import ClassifierHead import os import pandas as pd from sklearn.model_selection import train_test_split from datasets import Dataset import torch from pathlib import Path from datetime import datetime # Reason-specific configuration REASON_CATEGORIES = { 0: "ROUTINE_CARE", 1: "PAIN_CONDITIONS", 2: "INJURIES", 3: "SKIN_CONDITIONS", 4: "STRUCTURAL_ISSUES", 5: "PROCEDURES" } REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints" HEALTHCARE_DATA_PATH = "data/reason_for_visit_data.xlsx" MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical" def get_device(): """Get the best available device for training/inference.""" if torch.backends.mps.is_available(): return torch.device("mps") elif torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") def get_reason_model(num_classes: int): """Get model for reason classification.""" try: model_body = SentenceTransformer( MODEL_NAME, prompts={ 'classification': 'task: healthcare reason classification | query: ', 'retrieval (query)': 'task: search result | query: ', 'retrieval (document)': 'title: {title | "none"} | text: ', }, default_prompt_name='classification', ) # Freeze weights of embedding model model_head = ClassifierHead(num_classes) model = SetFitModel(model_body, model_head) model.freeze("body") except Exception as e: print(f"Error loading model {MODEL_NAME}: {e}") raise RuntimeError("Failed to load the embedding model.") device = get_device() print(f"Using device: {device}") return model.to(device) def get_reason_dataset() -> pd.DataFrame: """Load the healthcare reason dataset from Excel file.""" try: if not os.path.exists(HEALTHCARE_DATA_PATH): raise FileNotFoundError(f"Healthcare data file not found: {HEALTHCARE_DATA_PATH}") print(f"Loading healthcare data from {HEALTHCARE_DATA_PATH}...") df = pd.read_excel(HEALTHCARE_DATA_PATH) print(f"Loaded {len(df)} healthcare records") return df except Exception as e: print(f"Error loading healthcare dataset: {e}") raise Exception(f"Failed to load healthcare data: {e}") def map_reason_to_category(reason: str) -> int: """Map healthcare reasons to categories using keyword matching.""" reason_lower = reason.lower() # ROUTINE_CARE (routine care, maintenance visits) if any(word in reason_lower for word in ['routine', 'nail care', 'calluses', 'maintenance']): return 0 # PAIN_CONDITIONS (various pain-related conditions) elif any(word in reason_lower for word in ['pain', 'ache', 'sore', 'hurt']): return 1 # INJURIES (sprains, wounds, trauma) elif any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma', 'cut', 'bruise']): return 2 # SKIN_CONDITIONS (skin-related issues) elif any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'corn', 'skin']): return 3 # STRUCTURAL_ISSUES (structural problems) elif any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon', 'arch']): return 4 # PROCEDURES (injections, surgical consultations) elif any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'surgery', 'procedure']): return 5 # Default to pain conditions (most common category) else: return 1 def preprocess_reason_data(df: pd.DataFrame) -> pd.DataFrame: """Preprocess the healthcare reason dataset for training.""" training_data = [] for _, row in df.iterrows(): reason = row['Reason For Visit'] appointment_type = row.get('Appointment Type', '') # Map reason to category using keyword matching category_id = map_reason_to_category(reason) # Create enhanced text with context enhanced_text = reason if pd.notna(appointment_type) and appointment_type: enhanced_text += f" | {appointment_type}" training_data.append({ 'text': enhanced_text, 'label': category_id, 'category': REASON_CATEGORIES[category_id], 'original_reason': reason }) processed_df = pd.DataFrame(training_data) # Show category distribution print("\nReason category distribution in training data:") for cat_id, cat_name in REASON_CATEGORIES.items(): count = len(processed_df[processed_df['label'] == cat_id]) percentage = (count / len(processed_df)) * 100 print(f" {cat_name}: {count} samples ({percentage:.1f}%)") return processed_df def main(): print("Healthcare Reason Classification - Training Pipeline") print("=" * 60) # Load and preprocess data df = get_reason_dataset() df = preprocess_reason_data(df) # Get model model = get_reason_model(len(REASON_CATEGORIES)) # Split data train, test = train_test_split( df, test_size=0.2, stratify=df['label'], random_state=42 ) print(f"\nData split:") print(f" Training: {len(train)} samples") print(f" Testing: {len(test)} samples") train_dataset = Dataset.from_pandas(train) test_dataset = Dataset.from_pandas(test) # Ensure checkpoint directory exists Path(REASON_CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True) # Training arguments timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = f"{REASON_CHECKPOINT_PATH}/training_{timestamp}" args = TrainingArguments( output_dir=output_dir, # Skip contrastive fine-tuning (body is frozen) num_epochs=(0, 20), eval_strategy='epoch', eval_steps=100, save_strategy='epoch', logging_steps=50, load_best_model_at_end=True, metric_for_best_model='accuracy', ) trainer = Trainer( model=model, train_dataset=train_dataset, eval_dataset=test_dataset, metric='accuracy', column_mapping={"text": "text", "label": "label"}, args=args, ) print("\nStarting reason classification training...") trainer.train() # Evaluate print("\nEvaluating reason classification model...") metrics = trainer.evaluate(test_dataset) print(f"Final evaluation metrics: {metrics}") # Save the trained classifier head model_save_path = f"{REASON_CHECKPOINT_PATH}/reason_classifier_head_{timestamp}.pt" torch.save(model.model_head.state_dict(), model_save_path) print(f"Reason classifier head saved to: {model_save_path}") return metrics if __name__ == "__main__": main()