taraky's picture
Upload folder using huggingface_hub
b7f3196 verified
"""
Training script for Healthcare Reason Classification
This script trains a classifier for healthcare visit reasons using real
healthcare data. It creates a separate system from the medical/insurance
classifier.
"""
from sentence_transformers import SentenceTransformer
from setfit import SetFitModel, Trainer, TrainingArguments
import sys
from pathlib import Path
# Add project root to path for imports
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from classifier.head import ClassifierHead
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset
import torch
from pathlib import Path
from datetime import datetime
# Reason-specific configuration
REASON_CATEGORIES = {
0: "ROUTINE_CARE",
1: "PAIN_CONDITIONS",
2: "INJURIES",
3: "SKIN_CONDITIONS",
4: "STRUCTURAL_ISSUES",
5: "PROCEDURES"
}
REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints"
HEALTHCARE_DATA_PATH = "data/reason_for_visit_data.xlsx"
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"
def get_device():
"""Get the best available device for training/inference."""
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
def get_reason_model(num_classes: int):
"""Get model for reason classification."""
try:
model_body = SentenceTransformer(
MODEL_NAME,
prompts={
'classification': 'task: healthcare reason classification | query: ',
'retrieval (query)': 'task: search result | query: ',
'retrieval (document)': 'title: {title | "none"} | text: ',
},
default_prompt_name='classification',
)
# Freeze weights of embedding model
model_head = ClassifierHead(num_classes)
model = SetFitModel(model_body, model_head)
model.freeze("body")
except Exception as e:
print(f"Error loading model {MODEL_NAME}: {e}")
raise RuntimeError("Failed to load the embedding model.")
device = get_device()
print(f"Using device: {device}")
return model.to(device)
def get_reason_dataset() -> pd.DataFrame:
"""Load the healthcare reason dataset from Excel file."""
try:
if not os.path.exists(HEALTHCARE_DATA_PATH):
raise FileNotFoundError(f"Healthcare data file not found: {HEALTHCARE_DATA_PATH}")
print(f"Loading healthcare data from {HEALTHCARE_DATA_PATH}...")
df = pd.read_excel(HEALTHCARE_DATA_PATH)
print(f"Loaded {len(df)} healthcare records")
return df
except Exception as e:
print(f"Error loading healthcare dataset: {e}")
raise Exception(f"Failed to load healthcare data: {e}")
def map_reason_to_category(reason: str) -> int:
"""Map healthcare reasons to categories using keyword matching."""
reason_lower = reason.lower()
# ROUTINE_CARE (routine care, maintenance visits)
if any(word in reason_lower for word in ['routine', 'nail care', 'calluses', 'maintenance']):
return 0
# PAIN_CONDITIONS (various pain-related conditions)
elif any(word in reason_lower for word in ['pain', 'ache', 'sore', 'hurt']):
return 1
# INJURIES (sprains, wounds, trauma)
elif any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma', 'cut', 'bruise']):
return 2
# SKIN_CONDITIONS (skin-related issues)
elif any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'corn', 'skin']):
return 3
# STRUCTURAL_ISSUES (structural problems)
elif any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon', 'arch']):
return 4
# PROCEDURES (injections, surgical consultations)
elif any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'surgery', 'procedure']):
return 5
# Default to pain conditions (most common category)
else:
return 1
def preprocess_reason_data(df: pd.DataFrame) -> pd.DataFrame:
"""Preprocess the healthcare reason dataset for training."""
training_data = []
for _, row in df.iterrows():
reason = row['Reason For Visit']
appointment_type = row.get('Appointment Type', '')
# Map reason to category using keyword matching
category_id = map_reason_to_category(reason)
# Create enhanced text with context
enhanced_text = reason
if pd.notna(appointment_type) and appointment_type:
enhanced_text += f" | {appointment_type}"
training_data.append({
'text': enhanced_text,
'label': category_id,
'category': REASON_CATEGORIES[category_id],
'original_reason': reason
})
processed_df = pd.DataFrame(training_data)
# Show category distribution
print("\nReason category distribution in training data:")
for cat_id, cat_name in REASON_CATEGORIES.items():
count = len(processed_df[processed_df['label'] == cat_id])
percentage = (count / len(processed_df)) * 100
print(f" {cat_name}: {count} samples ({percentage:.1f}%)")
return processed_df
def main():
print("Healthcare Reason Classification - Training Pipeline")
print("=" * 60)
# Load and preprocess data
df = get_reason_dataset()
df = preprocess_reason_data(df)
# Get model
model = get_reason_model(len(REASON_CATEGORIES))
# Split data
train, test = train_test_split(
df, test_size=0.2, stratify=df['label'], random_state=42
)
print(f"\nData split:")
print(f" Training: {len(train)} samples")
print(f" Testing: {len(test)} samples")
train_dataset = Dataset.from_pandas(train)
test_dataset = Dataset.from_pandas(test)
# Ensure checkpoint directory exists
Path(REASON_CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True)
# Training arguments
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"{REASON_CHECKPOINT_PATH}/training_{timestamp}"
args = TrainingArguments(
output_dir=output_dir,
# Skip contrastive fine-tuning (body is frozen)
num_epochs=(0, 20),
eval_strategy='epoch',
eval_steps=100,
save_strategy='epoch',
logging_steps=50,
load_best_model_at_end=True,
metric_for_best_model='accuracy',
)
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
metric='accuracy',
column_mapping={"text": "text", "label": "label"},
args=args,
)
print("\nStarting reason classification training...")
trainer.train()
# Evaluate
print("\nEvaluating reason classification model...")
metrics = trainer.evaluate(test_dataset)
print(f"Final evaluation metrics: {metrics}")
# Save the trained classifier head
model_save_path = f"{REASON_CHECKPOINT_PATH}/reason_classifier_head_{timestamp}.pt"
torch.save(model.model_head.state_dict(), model_save_path)
print(f"Reason classifier head saved to: {model_save_path}")
return metrics
if __name__ == "__main__":
main()