File size: 7,727 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""

Training script for Healthcare Reason Classification



This script trains a classifier for healthcare visit reasons using real

healthcare data. It creates a separate system from the medical/insurance

classifier.

"""

from sentence_transformers import SentenceTransformer
from setfit import SetFitModel, Trainer, TrainingArguments
import sys
from pathlib import Path

# Add project root to path for imports
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from classifier.head import ClassifierHead
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset
import torch
from pathlib import Path
from datetime import datetime

# Reason-specific configuration
REASON_CATEGORIES = {
    0: "ROUTINE_CARE",
    1: "PAIN_CONDITIONS", 
    2: "INJURIES",
    3: "SKIN_CONDITIONS",
    4: "STRUCTURAL_ISSUES",
    5: "PROCEDURES"
}

REASON_CHECKPOINT_PATH = "classifier/reason_checkpoints"
HEALTHCARE_DATA_PATH = "data/reason_for_visit_data.xlsx"
MODEL_NAME = "sentence-transformers/embeddinggemma-300m-medical"

def get_device():
    """Get the best available device for training/inference."""
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

def get_reason_model(num_classes: int):
    """Get model for reason classification."""
    try:
        model_body = SentenceTransformer(
            MODEL_NAME,
            prompts={
                'classification': 'task: healthcare reason classification | query: ',
                'retrieval (query)': 'task: search result | query: ',
                'retrieval (document)': 'title: {title | "none"} | text: ',
            },
            default_prompt_name='classification',
        )
        # Freeze weights of embedding model
        model_head = ClassifierHead(num_classes)
        model = SetFitModel(model_body, model_head)
        model.freeze("body")

    except Exception as e:
        print(f"Error loading model {MODEL_NAME}: {e}")
        raise RuntimeError("Failed to load the embedding model.")
    
    device = get_device()
    print(f"Using device: {device}")
    return model.to(device)

def get_reason_dataset() -> pd.DataFrame:
    """Load the healthcare reason dataset from Excel file."""
    try:
        if not os.path.exists(HEALTHCARE_DATA_PATH):
            raise FileNotFoundError(f"Healthcare data file not found: {HEALTHCARE_DATA_PATH}")
        
        print(f"Loading healthcare data from {HEALTHCARE_DATA_PATH}...")
        df = pd.read_excel(HEALTHCARE_DATA_PATH)
        print(f"Loaded {len(df)} healthcare records")
        return df

    except Exception as e:
        print(f"Error loading healthcare dataset: {e}")
        raise Exception(f"Failed to load healthcare data: {e}")

def map_reason_to_category(reason: str) -> int:
    """Map healthcare reasons to categories using keyword matching."""
    reason_lower = reason.lower()
    
    # ROUTINE_CARE (routine care, maintenance visits)
    if any(word in reason_lower for word in ['routine', 'nail care', 'calluses', 'maintenance']):
        return 0
    
    # PAIN_CONDITIONS (various pain-related conditions)
    elif any(word in reason_lower for word in ['pain', 'ache', 'sore', 'hurt']):
        return 1
    
    # INJURIES (sprains, wounds, trauma)
    elif any(word in reason_lower for word in ['sprain', 'wound', 'injury', 'trauma', 'cut', 'bruise']):
        return 2
    
    # SKIN_CONDITIONS (skin-related issues)
    elif any(word in reason_lower for word in ['ingrown', 'toenail', 'callus', 'corn', 'skin']):
        return 3
    
    # STRUCTURAL_ISSUES (structural problems)
    elif any(word in reason_lower for word in ['flat feet', 'plantar', 'fasciitis', 'achilles', 'tendon', 'arch']):
        return 4
    
    # PROCEDURES (injections, surgical consultations)
    elif any(word in reason_lower for word in ['injection', 'surgical', 'consult', 'postop', 'surgery', 'procedure']):
        return 5
    
    # Default to pain conditions (most common category)
    else:
        return 1

def preprocess_reason_data(df: pd.DataFrame) -> pd.DataFrame:
    """Preprocess the healthcare reason dataset for training."""
    training_data = []
    
    for _, row in df.iterrows():
        reason = row['Reason For Visit']
        appointment_type = row.get('Appointment Type', '')
        
        # Map reason to category using keyword matching
        category_id = map_reason_to_category(reason)
        
        # Create enhanced text with context
        enhanced_text = reason
        if pd.notna(appointment_type) and appointment_type:
            enhanced_text += f" | {appointment_type}"
        
        training_data.append({
            'text': enhanced_text,
            'label': category_id,
            'category': REASON_CATEGORIES[category_id],
            'original_reason': reason
        })
    
    processed_df = pd.DataFrame(training_data)
    
    # Show category distribution
    print("\nReason category distribution in training data:")
    for cat_id, cat_name in REASON_CATEGORIES.items():
        count = len(processed_df[processed_df['label'] == cat_id])
        percentage = (count / len(processed_df)) * 100
        print(f"  {cat_name}: {count} samples ({percentage:.1f}%)")
    
    return processed_df

def main():
    print("Healthcare Reason Classification - Training Pipeline")
    print("=" * 60)
    
    # Load and preprocess data
    df = get_reason_dataset()
    df = preprocess_reason_data(df)
    
    # Get model
    model = get_reason_model(len(REASON_CATEGORIES))
    
    # Split data
    train, test = train_test_split(
        df, test_size=0.2, stratify=df['label'], random_state=42
    )
    
    print(f"\nData split:")
    print(f"  Training: {len(train)} samples")
    print(f"  Testing: {len(test)} samples")
    
    train_dataset = Dataset.from_pandas(train)
    test_dataset = Dataset.from_pandas(test)
    
    # Ensure checkpoint directory exists
    Path(REASON_CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True)
    
    # Training arguments
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = f"{REASON_CHECKPOINT_PATH}/training_{timestamp}"
    
    args = TrainingArguments(
        output_dir=output_dir,
        # Skip contrastive fine-tuning (body is frozen)
        num_epochs=(0, 20),
        eval_strategy='epoch',
        eval_steps=100,
        save_strategy='epoch',
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model='accuracy',
    )

    trainer = Trainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        metric='accuracy',
        column_mapping={"text": "text", "label": "label"},
        args=args,
    )

    print("\nStarting reason classification training...")
    trainer.train()

    # Evaluate
    print("\nEvaluating reason classification model...")
    metrics = trainer.evaluate(test_dataset)
    print(f"Final evaluation metrics: {metrics}")
    
    # Save the trained classifier head
    model_save_path = f"{REASON_CHECKPOINT_PATH}/reason_classifier_head_{timestamp}.pt"
    torch.save(model.model_head.state_dict(), model_save_path)
    print(f"Reason classifier head saved to: {model_save_path}")
    
    return metrics

if __name__ == "__main__":
    main()