File size: 10,488 Bytes
2612bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# src/train_router.py
# Fine-tune DistilBERT for 8-class ticket routing
# SupportMind v1.0 β€” Asmitha
#
# Memory-optimized for machines with limited RAM:
#   - max_length=64 (tickets are short, saves ~4x memory vs 256)
#   - batch_size=2 (minimal footprint)
#   - gradient_accumulation_steps=8 (effective batch=16)
#   - fp16=True if CUDA available
#   - Datasets cleared before model loading

import os
import sys
import gc

# Disable TensorFlow to prevent DLL loading errors under Application Control policies
os.environ['USE_TF'] = '0'
os.environ['USE_JAX'] = '0'
# Limit torch threads to reduce memory pressure
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import pandas as pd
import torch
import logging
from transformers import (
    DistilBertTokenizer, 
    DistilBertForSequenceClassification, 
    Trainer, 
    TrainingArguments,
    TrainerCallback
)
from transformers.trainer_utils import get_last_checkpoint
import psutil
from datasets import Dataset
import numpy as np

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_DIR = os.path.join(BASE_DIR, 'data', 'processed')
MODEL_DIR = os.path.join(BASE_DIR, 'models', 'ticket_classifier')

# Shorter max_length β€” support tickets are typically short
# 64 tokens is enough to capture intent from these tickets
MAX_LENGTH = 64


class MemoryProfilerCallback(TrainerCallback):
    """Logs memory usage + progress summary every N steps."""
    def __init__(self, total_steps: int):
        import os
        self.process = psutil.Process(os.getpid())
        self.total_steps = total_steps
        
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.logging_steps == 0:
            mem_mb = self.process.memory_info().rss / (1024 * 1024)
            pct = (state.global_step / self.total_steps) * 100 if self.total_steps else 0
            logger.info(
                f"[{pct:5.1f}%] Step {state.global_step}/{self.total_steps} "
                f"| Epoch {state.epoch:.2f} | RAM: {mem_mb:.0f} MB"
            )


def compute_metrics(eval_pred):
    """Compute accuracy metric for evaluation."""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = (predictions == labels).astype(np.float32).mean()
    return {"accuracy": float(accuracy)}


def main():
    train_path = os.path.join(DATA_DIR, 'train.csv')
    val_path = os.path.join(DATA_DIR, 'val.csv')
    
    if not os.path.exists(train_path):
        logger.error(f"Training data not found at {train_path}. Run data/preprocess.py first.")
        sys.exit(1)

    # ── Step 1: Load & tokenize data ──────────────────────
    logger.info("Loading processed datasets...")
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    logger.info(f"Train: {len(train_df)} samples, Val: {len(val_df)} samples")
    logger.info(f"Label distribution:\n{train_df['label'].value_counts().to_string()}")

    # Check device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    use_fp16 = device == "cuda"
    logger.info(f"Device: {device} | FP16: {use_fp16}")

    # Convert to HF Datasets  
    train_dataset = Dataset.from_pandas(train_df[['text', 'label']])
    val_dataset = Dataset.from_pandas(val_df[['text', 'label']])

    # Free DataFrame memory before tokenization
    del train_df, val_df
    gc.collect()

    logger.info("Initializing Tokenizer...")
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_LENGTH)

    logger.info("Tokenizing datasets...")
    tokenized_train = train_dataset.map(tokenize_function, batched=True, batch_size=64)
    tokenized_val = val_dataset.map(tokenize_function, batched=True, batch_size=64)

    # Free raw datasets
    del train_dataset, val_dataset
    gc.collect()

    # ── Step 2: Compute class weights for imbalanced data ─
    from sklearn.utils.class_weight import compute_class_weight
    
    labels_array = tokenized_train['label']
    unique_labels = sorted(set(labels_array))
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=np.array(unique_labels),
        y=np.array(labels_array)
    )
    # Map to all 8 classes (some might be missing)
    weight_dict = {c: w for c, w in zip(unique_labels, class_weights)}
    weights_tensor = torch.tensor(
        [weight_dict.get(i, 1.0) for i in range(8)], dtype=torch.float32
    )
    logger.info(f"Class weights: {weights_tensor.tolist()}")

    # ── Step 3: Load model ────────────────────────────────
    logger.info("Loading DistilBERT model...")
    model = DistilBertForSequenceClassification.from_pretrained(
        'distilbert-base-uncased', 
        num_labels=8
    )
    param_count = sum(p.numel() for p in model.parameters())
    logger.info(f"Model loaded. Parameters: {param_count:,}")

    # ── Freeze base layers β€” only fine-tune last 2 transformer layers + head ─
    # Freezing layers[0-3] cuts trainable params from 67M to ~7M,
    # reducing peak RAM from ~3.5 GB to ~800 MB. Quality impact is minimal
    # because the ticket vocabulary is similar to DistilBERT pretraining data.
    for name, param in model.named_parameters():
        param.requires_grad = False  # freeze everything first

    # Unfreeze: last 2 transformer layers (layer 4 and 5 of 6)
    for name, param in model.named_parameters():
        if any(key in name for key in [
            'transformer.layer.4',
            'transformer.layer.5',
            'pre_classifier',
            'classifier',
        ]):
            param.requires_grad = True

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total     = sum(p.numel() for p in model.parameters())
    logger.info(f"Trainable params: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")

    # Force garbage collection after model load
    gc.collect()

    # ── Step 4: Custom Trainer with weighted loss ─────────
    from torch.nn import CrossEntropyLoss

    class WeightedTrainer(Trainer):
        """Trainer with class-weighted cross-entropy for imbalanced datasets."""
        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
            labels = inputs.pop("labels")
            outputs = model(**inputs)
            logits = outputs.logits
            loss_fn = CrossEntropyLoss(weight=weights_tensor.to(logits.device))
            loss = loss_fn(logits, labels)
            return (loss, outputs) if return_outputs else loss

    # ── Step 5: Training ──────────────────────────────────
    # batch=1 with gradient_accumulation=16 gives effective batch=16
    # gradient_checkpointing trades compute for memory (critical on 5GB RAM)
    # Total steps = (train_samples / effective_batch) * epochs
    # 2800 / 16 * 5 = 875 steps
    total_steps = (len(tokenized_train) // 16) * 5

    training_args = TrainingArguments(
        output_dir=os.path.join(BASE_DIR, 'results'),
        num_train_epochs=5,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,
        gradient_checkpointing=True,
        warmup_steps=50,
        weight_decay=0.01,
        learning_rate=3e-5,
        logging_dir=os.path.join(BASE_DIR, 'logs'),
        logging_steps=25,          # Log every 25 steps (~2 min on CPU)
        evaluation_strategy="steps",
        eval_steps=50,             # Evaluate every 50 steps (~4 min)
        save_strategy="steps",
        save_steps=50,             # Must equal eval_steps when load_best_model_at_end=True
        save_total_limit=3,        # Keep 3 checkpoints (~75 steps of safety)
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        fp16=False,
        dataloader_num_workers=0,
        report_to="none",
        use_cpu=True,
        optim="adafactor",         # Much less memory than AdamW
    )

    trainer = WeightedTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        compute_metrics=compute_metrics,
        callbacks=[MemoryProfilerCallback(total_steps=total_steps)],
    )

    logger.info("=" * 60)
    logger.info("Starting DistilBERT fine-tuning (5 epochs, weighted loss)...")
    logger.info(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
    logger.info(f"  Max sequence length: {MAX_LENGTH}")
    logger.info(f"  Training samples: {len(tokenized_train)}")
    logger.info("=" * 60)
    
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is not None:
        logger.info(f"Resuming training from checkpoint: {last_checkpoint}")
    else:
        logger.info("No checkpoint found. Starting from scratch.")
        
    trainer.train(resume_from_checkpoint=last_checkpoint)

    # ── Step 6: Evaluate ──────────────────────────────────
    logger.info("Running final evaluation...")
    eval_results = trainer.evaluate()
    logger.info(f"Eval results: {eval_results}")

    # ── Step 7: Save ──────────────────────────────────────
    logger.info(f"Saving fine-tuned model to {MODEL_DIR}")
    os.makedirs(MODEL_DIR, exist_ok=True)
    model.save_pretrained(MODEL_DIR)
    tokenizer.save_pretrained(MODEL_DIR)
    
    # Save eval results
    import json
    results_path = os.path.join(BASE_DIR, 'results', 'training_results.json')
    os.makedirs(os.path.dirname(results_path), exist_ok=True)
    with open(results_path, 'w') as f:
        json.dump(eval_results, f, indent=2, default=str)
    logger.info(f"Results saved to {results_path}")

    logger.info("=" * 60)
    logger.info("Training complete! Model is ready for inference.")
    logger.info("=" * 60)


if __name__ == '__main__':
    main()