taraky's picture
Upload folder using huggingface_hub
b7f3196 verified
"""
Healthcare Reason for Visit Classifier
This module implements a classifier for healthcare clinic queries
using real healthcare data from clinic appointment records.
Categories based on the actual data:
- ROUTINE_CARE: Routine care, maintenance visits
- PAIN_CONDITIONS: Various pain-related conditions
- INJURIES: Sprains, wounds, trauma-related visits
- SKIN_CONDITIONS: Skin-related conditions and issues
- STRUCTURAL_ISSUES: Structural problems and conditions
- PROCEDURES: Injections, surgical consults, postop care
"""
import os
import torch
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional
from sentence_transformers import SentenceTransformer
from setfit import SetFitModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from datasets import Dataset
import json
from ..head import ClassifierHead
# Healthcare reason categories based on real data analysis
REASON_CATEGORIES = {
0: "ROUTINE_CARE",
1: "PAIN_CONDITIONS",
2: "INJURIES",
3: "SKIN_CONDITIONS",
4: "STRUCTURAL_ISSUES",
5: "PROCEDURES"
}
CATEGORY_DESCRIPTIONS = {
"ROUTINE_CARE": "Routine healthcare, maintenance visits, general care",
"PAIN_CONDITIONS": "Various pain-related conditions and discomfort",
"INJURIES": "Sprains, wounds, trauma-related conditions",
"SKIN_CONDITIONS": "Skin-related issues and conditions",
"STRUCTURAL_ISSUES": "Structural problems and related conditions",
"PROCEDURES": "Injections, surgical consultations, post-operative care"
}
class ReasonClassifier:
"""
Healthcare Reason Classifier that uses real clinic data to classify
patient queries into specific healthcare reason categories.
"""
def __init__(self, data_file: str = "data/reason_for_visit_data.xlsx"):
self.model_name = "sentence-transformers/embeddinggemma-300m-medical"
self.num_classes = len(REASON_CATEGORIES)
self.categories = REASON_CATEGORIES
self.data_file = data_file
self.model = None
self.device = self._get_device()
# Load and process real data
self.healthcare_df = self._load_data()
self._initialize_model()
def _get_device(self):
"""Get the best available device for training/inference."""
if torch.backends.mps.is_available():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
def _load_data(self) -> pd.DataFrame:
"""Load the real healthcare dataset."""
try:
df = pd.read_excel(self.data_file)
print(f"Loaded {len(df)} healthcare records from {self.data_file}")
print(f"Unique reasons: {df['Reason For Visit'].nunique()}")
return df
except Exception as e:
print(f"Error loading data: {e}")
raise RuntimeError(f"Failed to load healthcare data from {self.data_file}")
def _initialize_model(self):
"""Initialize the model with the existing infrastructure."""
try:
model_body = SentenceTransformer(
self.model_name,
prompts={
'classification': 'task: healthcare reason classification | query: ',
'retrieval (query)': 'task: search result | query: ',
'retrieval (document)': 'title: {title | "none"} | text: ',
},
default_prompt_name='classification',
)
model_head = ClassifierHead(self.num_classes, embedding_dim=768)
self.model = SetFitModel(model_body, model_head)
self.model.freeze("body") # Freeze embedding weights
self.model = self.model.to(self.device)
print(f"Initialized ReasonClassifier on {self.device}")
except Exception as e:
print(f"Error initializing model: {e}")
raise RuntimeError("Failed to initialize reason classifier")
def _map_reason_to_category(self, reason: str) -> int:
"""
Map real healthcare reasons to categories using keyword matching.
Based on the actual data distribution.
"""
reason_lower = reason.lower()
# ROUTINE_CARE (routine foot care, nail care, calluses)
if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
return 0
# PAIN_CONDITIONS (heel pain, ankle pain, foot pain, etc.)
if any(word in reason_lower for word in ['pain', 'ache', 'sore']):
return 1
# INJURIES (ankle sprain, wounds, trauma)
if any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma']):
return 2
# SKIN_CONDITIONS (ingrown toenail, calluses, skin issues)
if any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'skin']):
return 3
# STRUCTURAL_ISSUES (flat feet, plantar fasciitis, achilles)
if any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon']):
return 4
# PROCEDURES (injection, surgical consult, postop)
if any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'procedure']):
return 5
# Default to pain conditions (most common category)
return 1
def create_real_dataset(self) -> pd.DataFrame:
"""
Create training dataset from real healthcare data.
"""
training_data = []
for _, row in self.healthcare_df.iterrows():
reason = row['Reason For Visit']
appointment_type = row['Appointment Type']
# Map reason to category
category_id = self._map_reason_to_category(reason)
# Create enhanced text with context
enhanced_text = reason
if pd.notna(appointment_type):
enhanced_text += f" | {appointment_type}"
training_data.append({
'text': enhanced_text,
'label': category_id,
'category': self.categories[category_id],
'original_reason': reason,
'appointment_type': appointment_type
})
df = pd.DataFrame(training_data)
# Show category distribution
print("\nCategory distribution in training data:")
for cat_id, cat_name in self.categories.items():
count = len(df[df['label'] == cat_id])
percentage = (count / len(df)) * 100
print(f" {cat_name}: {count} samples ({percentage:.1f}%)")
return df.sample(frac=1).reset_index(drop=True) # Shuffle
def train(self, train_data: pd.DataFrame = None, eval_data: Optional[pd.DataFrame] = None,
epochs: int = 16, output_dir: str = "classifier/reason_checkpoints"):
"""Train the healthcare reason classifier."""
if train_data is None:
train_data = self.create_real_dataset()
if eval_data is None:
train_data, eval_data = train_test_split(train_data, test_size=0.2,
stratify=train_data['label'],
random_state=42)
train_dataset = Dataset.from_pandas(train_data)
eval_dataset = Dataset.from_pandas(eval_data)
from setfit import Trainer, TrainingArguments
args = TrainingArguments(
output_dir=output_dir,
num_epochs=(0, epochs), # Skip contrastive learning, only train head
eval_strategy='epoch',
eval_steps=100,
save_strategy='epoch',
logging_steps=50,
)
trainer = Trainer(
model=self.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
metric='accuracy',
column_mapping={"text": "text", "label": "label"},
args=args,
)
print("Starting training...")
trainer.train()
# Evaluate
metrics = trainer.evaluate(eval_dataset)
print(f"Training completed. Final metrics: {metrics}")
return metrics
def predict(self, queries: List[str]) -> List[Dict]:
"""
Predict healthcare reason categories for a list of queries.
Returns:
List of dictionaries with 'query', 'category', 'confidence', 'probabilities'
"""
if not self.model:
raise RuntimeError("Model not initialized. Train or load a model first.")
predictions = []
for query in queries:
# Get prediction using SetFit's built-in methods
pred_label = self.model.predict([query])[0]
pred_proba = self.model.predict_proba([query])[0]
category = self.categories[int(pred_label)]
confidence = float(pred_proba[int(pred_label)])
predictions.append({
'query': query,
'category': category,
'confidence': confidence,
'probabilities': {self.categories[i]: float(prob)
for i, prob in enumerate(pred_proba)}
})
return predictions
def save_model(self, path: str):
"""Save the trained model."""
os.makedirs(os.path.dirname(path), exist_ok=True)
self.model.save_pretrained(path)
# Save category mapping
with open(os.path.join(path, 'categories.json'), 'w') as f:
json.dump(self.categories, f)
print(f"Model saved to {path}")
def load_model(self, path: str):
"""Load a trained model."""
self.model = SetFitModel.from_pretrained(path)
self.model = self.model.to(self.device)
# Load category mapping
with open(os.path.join(path, 'categories.json'), 'r') as f:
self.categories = {int(k): v for k, v in json.load(f).items()}
print(f"Model loaded from {path}")
def evaluate_on_test_set(self, test_data: pd.DataFrame) -> Dict:
"""Evaluate the model on a test dataset."""
predictions = self.predict(test_data['text'].tolist())
y_true = test_data['label'].tolist()
y_pred = [list(self.categories.keys())[list(self.categories.values()).index(p['category'])]
for p in predictions]
# Classification report
report = classification_report(y_true, y_pred,
target_names=list(self.categories.values()),
output_dict=True)
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
return {
'classification_report': report,
'confusion_matrix': cm.tolist(),
'accuracy': report['accuracy']
}
def analyze_real_data(self):
"""Analyze the real healthcare data to understand patterns."""
print("Real Data Analysis:")
print("=" * 50)
print(f"Total records: {len(self.healthcare_df)}")
print(f"Unique reasons: {self.healthcare_df['Reason For Visit'].nunique()}")
print("\nTop 15 reasons for visit:")
top_reasons = self.healthcare_df['Reason For Visit'].value_counts().head(15)
for reason, count in top_reasons.items():
category_id = self._map_reason_to_category(reason)
category_name = self.categories[category_id]
print(f" {reason}: {count} ({category_name})")
print(f"\nAppointment types:")
print(self.healthcare_df['Appointment Type'].value_counts())
def main():
"""Example usage and training script for healthcare reason data."""
print("Initializing Healthcare Reason Classifier...")
# Initialize classifier with real data
classifier = ReasonClassifier()
# Analyze the real data
classifier.analyze_real_data()
# Create training dataset from real data
print("\nCreating training dataset from real healthcare data...")
dataset = classifier.create_real_dataset()
print(f"Dataset created with {len(dataset)} real examples")
# Train the model
print("\nTraining classifier...")
metrics = classifier.train(dataset, epochs=20)
# Save the model
model_path = "classifier/reason_model"
classifier.save_model(model_path)
# Test predictions on healthcare reason queries
test_queries = [
"I have heel pain when I walk",
"My toenail is ingrown and painful",
"I need routine foot care",
"I sprained my ankle playing sports",
"I have flat feet and need evaluation",
"I need a cortisone injection for my foot",
"I have plantar fasciitis",
"My foot wound is not healing"
]
print("\nTesting predictions on healthcare reason queries:")
predictions = classifier.predict(test_queries)
for pred in predictions:
print(f"Query: {pred['query']}")
print(f"Category: {pred['category']} (confidence: {pred['confidence']:.3f})")
print("---")
if __name__ == "__main__":
main()