SupportMind / src /train_router.py
Asmitha-28's picture
Upload src/train_router.py with huggingface_hub
5cd9136 verified
# src/train_router.py
# Fine-tune DistilBERT for 8-class ticket routing
# SupportMind v1.0 β€” Asmitha
#
# Memory-optimized for machines with limited RAM:
# - max_length=64 (tickets are short, saves ~4x memory vs 256)
# - batch_size=2 (minimal footprint)
# - gradient_accumulation_steps=8 (effective batch=16)
# - fp16=True if CUDA available
# - Datasets cleared before model loading
import os
import sys
import gc
# Disable TensorFlow to prevent DLL loading errors under Application Control policies
os.environ['USE_TF'] = '0'
os.environ['USE_JAX'] = '0'
# Limit torch threads to reduce memory pressure
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import pandas as pd
import torch
import logging
from transformers import (
DistilBertTokenizer,
DistilBertForSequenceClassification,
Trainer,
TrainingArguments,
TrainerCallback
)
from transformers.trainer_utils import get_last_checkpoint
import psutil
from datasets import Dataset
import numpy as np
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_DIR = os.path.join(BASE_DIR, 'data', 'processed')
MODEL_DIR = os.path.join(BASE_DIR, 'models', 'ticket_classifier')
# Shorter max_length β€” support tickets are typically short
# 64 tokens is enough to capture intent from these tickets
MAX_LENGTH = 64
class MemoryProfilerCallback(TrainerCallback):
"""Logs memory usage + progress summary every N steps."""
def __init__(self, total_steps: int):
import os
self.process = psutil.Process(os.getpid())
self.total_steps = total_steps
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.logging_steps == 0:
mem_mb = self.process.memory_info().rss / (1024 * 1024)
pct = (state.global_step / self.total_steps) * 100 if self.total_steps else 0
logger.info(
f"[{pct:5.1f}%] Step {state.global_step}/{self.total_steps} "
f"| Epoch {state.epoch:.2f} | RAM: {mem_mb:.0f} MB"
)
def compute_metrics(eval_pred):
"""Compute accuracy metric for evaluation."""
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = (predictions == labels).astype(np.float32).mean()
return {"accuracy": float(accuracy)}
def main():
train_path = os.path.join(DATA_DIR, 'train.csv')
val_path = os.path.join(DATA_DIR, 'val.csv')
if not os.path.exists(train_path):
logger.error(f"Training data not found at {train_path}. Run data/preprocess.py first.")
sys.exit(1)
# ── Step 1: Load & tokenize data ──────────────────────
logger.info("Loading processed datasets...")
train_df = pd.read_csv(train_path)
val_df = pd.read_csv(val_path)
logger.info(f"Train: {len(train_df)} samples, Val: {len(val_df)} samples")
logger.info(f"Label distribution:\n{train_df['label'].value_counts().to_string()}")
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
use_fp16 = device == "cuda"
logger.info(f"Device: {device} | FP16: {use_fp16}")
# Convert to HF Datasets
train_dataset = Dataset.from_pandas(train_df[['text', 'label']])
val_dataset = Dataset.from_pandas(val_df[['text', 'label']])
# Free DataFrame memory before tokenization
del train_df, val_df
gc.collect()
logger.info("Initializing Tokenizer...")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_LENGTH)
logger.info("Tokenizing datasets...")
tokenized_train = train_dataset.map(tokenize_function, batched=True, batch_size=64)
tokenized_val = val_dataset.map(tokenize_function, batched=True, batch_size=64)
# Free raw datasets
del train_dataset, val_dataset
gc.collect()
# ── Step 2: Compute class weights for imbalanced data ─
from sklearn.utils.class_weight import compute_class_weight
labels_array = tokenized_train['label']
unique_labels = sorted(set(labels_array))
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.array(unique_labels),
y=np.array(labels_array)
)
# Map to all 8 classes (some might be missing)
weight_dict = {c: w for c, w in zip(unique_labels, class_weights)}
weights_tensor = torch.tensor(
[weight_dict.get(i, 1.0) for i in range(8)], dtype=torch.float32
)
logger.info(f"Class weights: {weights_tensor.tolist()}")
# ── Step 3: Load model ────────────────────────────────
logger.info("Loading DistilBERT model...")
model = DistilBertForSequenceClassification.from_pretrained(
'distilbert-base-uncased',
num_labels=8
)
param_count = sum(p.numel() for p in model.parameters())
logger.info(f"Model loaded. Parameters: {param_count:,}")
# ── Freeze base layers β€” only fine-tune last 2 transformer layers + head ─
# Freezing layers[0-3] cuts trainable params from 67M to ~7M,
# reducing peak RAM from ~3.5 GB to ~800 MB. Quality impact is minimal
# because the ticket vocabulary is similar to DistilBERT pretraining data.
for name, param in model.named_parameters():
param.requires_grad = False # freeze everything first
# Unfreeze: last 2 transformer layers (layer 4 and 5 of 6)
for name, param in model.named_parameters():
if any(key in name for key in [
'transformer.layer.4',
'transformer.layer.5',
'pre_classifier',
'classifier',
]):
param.requires_grad = True
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
logger.info(f"Trainable params: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)")
# Force garbage collection after model load
gc.collect()
# ── Step 4: Custom Trainer with weighted loss ─────────
from torch.nn import CrossEntropyLoss
class WeightedTrainer(Trainer):
"""Trainer with class-weighted cross-entropy for imbalanced datasets."""
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fn = CrossEntropyLoss(weight=weights_tensor.to(logits.device))
loss = loss_fn(logits, labels)
return (loss, outputs) if return_outputs else loss
# ── Step 5: Training ──────────────────────────────────
# batch=1 with gradient_accumulation=16 gives effective batch=16
# gradient_checkpointing trades compute for memory (critical on 5GB RAM)
# Total steps = (train_samples / effective_batch) * epochs
# 2800 / 16 * 5 = 875 steps
total_steps = (len(tokenized_train) // 16) * 5
training_args = TrainingArguments(
output_dir=os.path.join(BASE_DIR, 'results'),
num_train_epochs=5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
warmup_steps=50,
weight_decay=0.01,
learning_rate=3e-5,
logging_dir=os.path.join(BASE_DIR, 'logs'),
logging_steps=25, # Log every 25 steps (~2 min on CPU)
evaluation_strategy="steps",
eval_steps=50, # Evaluate every 50 steps (~4 min)
save_strategy="steps",
save_steps=50, # Must equal eval_steps when load_best_model_at_end=True
save_total_limit=3, # Keep 3 checkpoints (~75 steps of safety)
load_best_model_at_end=True,
metric_for_best_model="accuracy",
fp16=False,
dataloader_num_workers=0,
report_to="none",
use_cpu=True,
optim="adafactor", # Much less memory than AdamW
)
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_val,
compute_metrics=compute_metrics,
callbacks=[MemoryProfilerCallback(total_steps=total_steps)],
)
logger.info("=" * 60)
logger.info("Starting DistilBERT fine-tuning (5 epochs, weighted loss)...")
logger.info(f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
logger.info(f" Max sequence length: {MAX_LENGTH}")
logger.info(f" Training samples: {len(tokenized_train)}")
logger.info("=" * 60)
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None:
logger.info(f"Resuming training from checkpoint: {last_checkpoint}")
else:
logger.info("No checkpoint found. Starting from scratch.")
trainer.train(resume_from_checkpoint=last_checkpoint)
# ── Step 6: Evaluate ──────────────────────────────────
logger.info("Running final evaluation...")
eval_results = trainer.evaluate()
logger.info(f"Eval results: {eval_results}")
# ── Step 7: Save ──────────────────────────────────────
logger.info(f"Saving fine-tuned model to {MODEL_DIR}")
os.makedirs(MODEL_DIR, exist_ok=True)
model.save_pretrained(MODEL_DIR)
tokenizer.save_pretrained(MODEL_DIR)
# Save eval results
import json
results_path = os.path.join(BASE_DIR, 'results', 'training_results.json')
os.makedirs(os.path.dirname(results_path), exist_ok=True)
with open(results_path, 'w') as f:
json.dump(eval_results, f, indent=2, default=str)
logger.info(f"Results saved to {results_path}")
logger.info("=" * 60)
logger.info("Training complete! Model is ready for inference.")
logger.info("=" * 60)
if __name__ == '__main__':
main()