|
|
"""
|
|
|
Healthcare Reason for Visit Classifier
|
|
|
|
|
|
This module implements a classifier for healthcare clinic queries
|
|
|
using real healthcare data from clinic appointment records.
|
|
|
|
|
|
Categories based on the actual data:
|
|
|
- ROUTINE_CARE: Routine care, maintenance visits
|
|
|
- PAIN_CONDITIONS: Various pain-related conditions
|
|
|
- INJURIES: Sprains, wounds, trauma-related visits
|
|
|
- SKIN_CONDITIONS: Skin-related conditions and issues
|
|
|
- STRUCTURAL_ISSUES: Structural problems and conditions
|
|
|
- PROCEDURES: Injections, surgical consults, postop care
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
from typing import List, Dict, Tuple, Optional
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
from setfit import SetFitModel
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
from sklearn.metrics import classification_report, confusion_matrix
|
|
|
from datasets import Dataset
|
|
|
import json
|
|
|
|
|
|
from ..head import ClassifierHead
|
|
|
|
|
|
|
|
|
REASON_CATEGORIES = {
|
|
|
0: "ROUTINE_CARE",
|
|
|
1: "PAIN_CONDITIONS",
|
|
|
2: "INJURIES",
|
|
|
3: "SKIN_CONDITIONS",
|
|
|
4: "STRUCTURAL_ISSUES",
|
|
|
5: "PROCEDURES"
|
|
|
}
|
|
|
|
|
|
CATEGORY_DESCRIPTIONS = {
|
|
|
"ROUTINE_CARE": "Routine healthcare, maintenance visits, general care",
|
|
|
"PAIN_CONDITIONS": "Various pain-related conditions and discomfort",
|
|
|
"INJURIES": "Sprains, wounds, trauma-related conditions",
|
|
|
"SKIN_CONDITIONS": "Skin-related issues and conditions",
|
|
|
"STRUCTURAL_ISSUES": "Structural problems and related conditions",
|
|
|
"PROCEDURES": "Injections, surgical consultations, post-operative care"
|
|
|
}
|
|
|
|
|
|
class ReasonClassifier:
|
|
|
"""
|
|
|
Healthcare Reason Classifier that uses real clinic data to classify
|
|
|
patient queries into specific healthcare reason categories.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, data_file: str = "data/reason_for_visit_data.xlsx"):
|
|
|
self.model_name = "sentence-transformers/embeddinggemma-300m-medical"
|
|
|
self.num_classes = len(REASON_CATEGORIES)
|
|
|
self.categories = REASON_CATEGORIES
|
|
|
self.data_file = data_file
|
|
|
self.model = None
|
|
|
self.device = self._get_device()
|
|
|
|
|
|
|
|
|
self.healthcare_df = self._load_data()
|
|
|
self._initialize_model()
|
|
|
|
|
|
def _get_device(self):
|
|
|
"""Get the best available device for training/inference."""
|
|
|
if torch.backends.mps.is_available():
|
|
|
return torch.device("mps")
|
|
|
elif torch.cuda.is_available():
|
|
|
return torch.device("cuda")
|
|
|
else:
|
|
|
return torch.device("cpu")
|
|
|
|
|
|
def _load_data(self) -> pd.DataFrame:
|
|
|
"""Load the real healthcare dataset."""
|
|
|
try:
|
|
|
df = pd.read_excel(self.data_file)
|
|
|
print(f"Loaded {len(df)} healthcare records from {self.data_file}")
|
|
|
print(f"Unique reasons: {df['Reason For Visit'].nunique()}")
|
|
|
return df
|
|
|
except Exception as e:
|
|
|
print(f"Error loading data: {e}")
|
|
|
raise RuntimeError(f"Failed to load healthcare data from {self.data_file}")
|
|
|
|
|
|
def _initialize_model(self):
|
|
|
"""Initialize the model with the existing infrastructure."""
|
|
|
try:
|
|
|
model_body = SentenceTransformer(
|
|
|
self.model_name,
|
|
|
prompts={
|
|
|
'classification': 'task: healthcare reason classification | query: ',
|
|
|
'retrieval (query)': 'task: search result | query: ',
|
|
|
'retrieval (document)': 'title: {title | "none"} | text: ',
|
|
|
},
|
|
|
default_prompt_name='classification',
|
|
|
)
|
|
|
|
|
|
model_head = ClassifierHead(self.num_classes, embedding_dim=768)
|
|
|
self.model = SetFitModel(model_body, model_head)
|
|
|
self.model.freeze("body")
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
|
|
print(f"Initialized ReasonClassifier on {self.device}")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error initializing model: {e}")
|
|
|
raise RuntimeError("Failed to initialize reason classifier")
|
|
|
|
|
|
def _map_reason_to_category(self, reason: str) -> int:
|
|
|
"""
|
|
|
Map real healthcare reasons to categories using keyword matching.
|
|
|
Based on the actual data distribution.
|
|
|
"""
|
|
|
reason_lower = reason.lower()
|
|
|
|
|
|
|
|
|
if any(word in reason_lower for word in ['routine', 'nail care', 'calluses']):
|
|
|
return 0
|
|
|
|
|
|
|
|
|
if any(word in reason_lower for word in ['pain', 'ache', 'sore']):
|
|
|
return 1
|
|
|
|
|
|
|
|
|
if any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma']):
|
|
|
return 2
|
|
|
|
|
|
|
|
|
if any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'skin']):
|
|
|
return 3
|
|
|
|
|
|
|
|
|
if any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon']):
|
|
|
return 4
|
|
|
|
|
|
|
|
|
if any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'procedure']):
|
|
|
return 5
|
|
|
|
|
|
|
|
|
return 1
|
|
|
|
|
|
def create_real_dataset(self) -> pd.DataFrame:
|
|
|
"""
|
|
|
Create training dataset from real healthcare data.
|
|
|
"""
|
|
|
|
|
|
training_data = []
|
|
|
|
|
|
for _, row in self.healthcare_df.iterrows():
|
|
|
reason = row['Reason For Visit']
|
|
|
appointment_type = row['Appointment Type']
|
|
|
|
|
|
|
|
|
category_id = self._map_reason_to_category(reason)
|
|
|
|
|
|
|
|
|
enhanced_text = reason
|
|
|
if pd.notna(appointment_type):
|
|
|
enhanced_text += f" | {appointment_type}"
|
|
|
|
|
|
training_data.append({
|
|
|
'text': enhanced_text,
|
|
|
'label': category_id,
|
|
|
'category': self.categories[category_id],
|
|
|
'original_reason': reason,
|
|
|
'appointment_type': appointment_type
|
|
|
})
|
|
|
|
|
|
df = pd.DataFrame(training_data)
|
|
|
|
|
|
|
|
|
print("\nCategory distribution in training data:")
|
|
|
for cat_id, cat_name in self.categories.items():
|
|
|
count = len(df[df['label'] == cat_id])
|
|
|
percentage = (count / len(df)) * 100
|
|
|
print(f" {cat_name}: {count} samples ({percentage:.1f}%)")
|
|
|
|
|
|
return df.sample(frac=1).reset_index(drop=True)
|
|
|
|
|
|
def train(self, train_data: pd.DataFrame = None, eval_data: Optional[pd.DataFrame] = None,
|
|
|
epochs: int = 16, output_dir: str = "classifier/reason_checkpoints"):
|
|
|
"""Train the healthcare reason classifier."""
|
|
|
|
|
|
if train_data is None:
|
|
|
train_data = self.create_real_dataset()
|
|
|
|
|
|
if eval_data is None:
|
|
|
train_data, eval_data = train_test_split(train_data, test_size=0.2,
|
|
|
stratify=train_data['label'],
|
|
|
random_state=42)
|
|
|
|
|
|
train_dataset = Dataset.from_pandas(train_data)
|
|
|
eval_dataset = Dataset.from_pandas(eval_data)
|
|
|
|
|
|
from setfit import Trainer, TrainingArguments
|
|
|
|
|
|
args = TrainingArguments(
|
|
|
output_dir=output_dir,
|
|
|
num_epochs=(0, epochs),
|
|
|
eval_strategy='epoch',
|
|
|
eval_steps=100,
|
|
|
save_strategy='epoch',
|
|
|
logging_steps=50,
|
|
|
)
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model=self.model,
|
|
|
train_dataset=train_dataset,
|
|
|
eval_dataset=eval_dataset,
|
|
|
metric='accuracy',
|
|
|
column_mapping={"text": "text", "label": "label"},
|
|
|
args=args,
|
|
|
)
|
|
|
|
|
|
print("Starting training...")
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
|
metrics = trainer.evaluate(eval_dataset)
|
|
|
print(f"Training completed. Final metrics: {metrics}")
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
def predict(self, queries: List[str]) -> List[Dict]:
|
|
|
"""
|
|
|
Predict healthcare reason categories for a list of queries.
|
|
|
|
|
|
Returns:
|
|
|
List of dictionaries with 'query', 'category', 'confidence', 'probabilities'
|
|
|
"""
|
|
|
if not self.model:
|
|
|
raise RuntimeError("Model not initialized. Train or load a model first.")
|
|
|
|
|
|
predictions = []
|
|
|
|
|
|
for query in queries:
|
|
|
|
|
|
pred_label = self.model.predict([query])[0]
|
|
|
pred_proba = self.model.predict_proba([query])[0]
|
|
|
|
|
|
category = self.categories[int(pred_label)]
|
|
|
confidence = float(pred_proba[int(pred_label)])
|
|
|
|
|
|
predictions.append({
|
|
|
'query': query,
|
|
|
'category': category,
|
|
|
'confidence': confidence,
|
|
|
'probabilities': {self.categories[i]: float(prob)
|
|
|
for i, prob in enumerate(pred_proba)}
|
|
|
})
|
|
|
|
|
|
return predictions
|
|
|
|
|
|
def save_model(self, path: str):
|
|
|
"""Save the trained model."""
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
self.model.save_pretrained(path)
|
|
|
|
|
|
|
|
|
with open(os.path.join(path, 'categories.json'), 'w') as f:
|
|
|
json.dump(self.categories, f)
|
|
|
|
|
|
print(f"Model saved to {path}")
|
|
|
|
|
|
def load_model(self, path: str):
|
|
|
"""Load a trained model."""
|
|
|
self.model = SetFitModel.from_pretrained(path)
|
|
|
self.model = self.model.to(self.device)
|
|
|
|
|
|
|
|
|
with open(os.path.join(path, 'categories.json'), 'r') as f:
|
|
|
self.categories = {int(k): v for k, v in json.load(f).items()}
|
|
|
|
|
|
print(f"Model loaded from {path}")
|
|
|
|
|
|
def evaluate_on_test_set(self, test_data: pd.DataFrame) -> Dict:
|
|
|
"""Evaluate the model on a test dataset."""
|
|
|
predictions = self.predict(test_data['text'].tolist())
|
|
|
|
|
|
y_true = test_data['label'].tolist()
|
|
|
y_pred = [list(self.categories.keys())[list(self.categories.values()).index(p['category'])]
|
|
|
for p in predictions]
|
|
|
|
|
|
|
|
|
report = classification_report(y_true, y_pred,
|
|
|
target_names=list(self.categories.values()),
|
|
|
output_dict=True)
|
|
|
|
|
|
|
|
|
cm = confusion_matrix(y_true, y_pred)
|
|
|
|
|
|
return {
|
|
|
'classification_report': report,
|
|
|
'confusion_matrix': cm.tolist(),
|
|
|
'accuracy': report['accuracy']
|
|
|
}
|
|
|
|
|
|
def analyze_real_data(self):
|
|
|
"""Analyze the real healthcare data to understand patterns."""
|
|
|
print("Real Data Analysis:")
|
|
|
print("=" * 50)
|
|
|
|
|
|
print(f"Total records: {len(self.healthcare_df)}")
|
|
|
print(f"Unique reasons: {self.healthcare_df['Reason For Visit'].nunique()}")
|
|
|
|
|
|
print("\nTop 15 reasons for visit:")
|
|
|
top_reasons = self.healthcare_df['Reason For Visit'].value_counts().head(15)
|
|
|
for reason, count in top_reasons.items():
|
|
|
category_id = self._map_reason_to_category(reason)
|
|
|
category_name = self.categories[category_id]
|
|
|
print(f" {reason}: {count} ({category_name})")
|
|
|
|
|
|
print(f"\nAppointment types:")
|
|
|
print(self.healthcare_df['Appointment Type'].value_counts())
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Example usage and training script for healthcare reason data."""
|
|
|
print("Initializing Healthcare Reason Classifier...")
|
|
|
|
|
|
|
|
|
classifier = ReasonClassifier()
|
|
|
|
|
|
|
|
|
classifier.analyze_real_data()
|
|
|
|
|
|
|
|
|
print("\nCreating training dataset from real healthcare data...")
|
|
|
dataset = classifier.create_real_dataset()
|
|
|
|
|
|
print(f"Dataset created with {len(dataset)} real examples")
|
|
|
|
|
|
|
|
|
print("\nTraining classifier...")
|
|
|
metrics = classifier.train(dataset, epochs=20)
|
|
|
|
|
|
|
|
|
model_path = "classifier/reason_model"
|
|
|
classifier.save_model(model_path)
|
|
|
|
|
|
|
|
|
test_queries = [
|
|
|
"I have heel pain when I walk",
|
|
|
"My toenail is ingrown and painful",
|
|
|
"I need routine foot care",
|
|
|
"I sprained my ankle playing sports",
|
|
|
"I have flat feet and need evaluation",
|
|
|
"I need a cortisone injection for my foot",
|
|
|
"I have plantar fasciitis",
|
|
|
"My foot wound is not healing"
|
|
|
]
|
|
|
|
|
|
print("\nTesting predictions on healthcare reason queries:")
|
|
|
predictions = classifier.predict(test_queries)
|
|
|
|
|
|
for pred in predictions:
|
|
|
print(f"Query: {pred['query']}")
|
|
|
print(f"Category: {pred['category']} (confidence: {pred['confidence']:.3f})")
|
|
|
print("---")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |