Spaces:
Running
Running
| # src/train_router.py | |
| # Fine-tune DistilBERT for 8-class ticket routing | |
| # SupportMind v1.0 β Asmitha | |
| # | |
| # Memory-optimized for machines with limited RAM: | |
| # - max_length=64 (tickets are short, saves ~4x memory vs 256) | |
| # - batch_size=2 (minimal footprint) | |
| # - gradient_accumulation_steps=8 (effective batch=16) | |
| # - fp16=True if CUDA available | |
| # - Datasets cleared before model loading | |
| import os | |
| import sys | |
| import gc | |
| # Disable TensorFlow to prevent DLL loading errors under Application Control policies | |
| os.environ['USE_TF'] = '0' | |
| os.environ['USE_JAX'] = '0' | |
| # Limit torch threads to reduce memory pressure | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| os.environ['MKL_NUM_THREADS'] = '1' | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| import pandas as pd | |
| import torch | |
| import logging | |
| from transformers import ( | |
| DistilBertTokenizer, | |
| DistilBertForSequenceClassification, | |
| Trainer, | |
| TrainingArguments, | |
| TrainerCallback | |
| ) | |
| from transformers.trainer_utils import get_last_checkpoint | |
| import psutil | |
| from datasets import Dataset | |
| import numpy as np | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') | |
| logger = logging.getLogger(__name__) | |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| DATA_DIR = os.path.join(BASE_DIR, 'data', 'processed') | |
| MODEL_DIR = os.path.join(BASE_DIR, 'models', 'ticket_classifier') | |
| # Shorter max_length β support tickets are typically short | |
| # 64 tokens is enough to capture intent from these tickets | |
| MAX_LENGTH = 64 | |
| class MemoryProfilerCallback(TrainerCallback): | |
| """Logs memory usage + progress summary every N steps.""" | |
| def __init__(self, total_steps: int): | |
| import os | |
| self.process = psutil.Process(os.getpid()) | |
| self.total_steps = total_steps | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if state.global_step % args.logging_steps == 0: | |
| mem_mb = self.process.memory_info().rss / (1024 * 1024) | |
| pct = (state.global_step / self.total_steps) * 100 if self.total_steps else 0 | |
| logger.info( | |
| f"[{pct:5.1f}%] Step {state.global_step}/{self.total_steps} " | |
| f"| Epoch {state.epoch:.2f} | RAM: {mem_mb:.0f} MB" | |
| ) | |
| def compute_metrics(eval_pred): | |
| """Compute accuracy metric for evaluation.""" | |
| logits, labels = eval_pred | |
| predictions = np.argmax(logits, axis=-1) | |
| accuracy = (predictions == labels).astype(np.float32).mean() | |
| return {"accuracy": float(accuracy)} | |
| def main(): | |
| train_path = os.path.join(DATA_DIR, 'train.csv') | |
| val_path = os.path.join(DATA_DIR, 'val.csv') | |
| if not os.path.exists(train_path): | |
| logger.error(f"Training data not found at {train_path}. Run data/preprocess.py first.") | |
| sys.exit(1) | |
| # ββ Step 1: Load & tokenize data ββββββββββββββββββββββ | |
| logger.info("Loading processed datasets...") | |
| train_df = pd.read_csv(train_path) | |
| val_df = pd.read_csv(val_path) | |
| logger.info(f"Train: {len(train_df)} samples, Val: {len(val_df)} samples") | |
| logger.info(f"Label distribution:\n{train_df['label'].value_counts().to_string()}") | |
| # Check device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| use_fp16 = device == "cuda" | |
| logger.info(f"Device: {device} | FP16: {use_fp16}") | |
| # Convert to HF Datasets | |
| train_dataset = Dataset.from_pandas(train_df[['text', 'label']]) | |
| val_dataset = Dataset.from_pandas(val_df[['text', 'label']]) | |
| # Free DataFrame memory before tokenization | |
| del train_df, val_df | |
| gc.collect() | |
| logger.info("Initializing Tokenizer...") | |
| tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
| def tokenize_function(examples): | |
| return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_LENGTH) | |
| logger.info("Tokenizing datasets...") | |
| tokenized_train = train_dataset.map(tokenize_function, batched=True, batch_size=64) | |
| tokenized_val = val_dataset.map(tokenize_function, batched=True, batch_size=64) | |
| # Free raw datasets | |
| del train_dataset, val_dataset | |
| gc.collect() | |
| # ββ Step 2: Compute class weights for imbalanced data β | |
| from sklearn.utils.class_weight import compute_class_weight | |
| labels_array = tokenized_train['label'] | |
| unique_labels = sorted(set(labels_array)) | |
| class_weights = compute_class_weight( | |
| class_weight='balanced', | |
| classes=np.array(unique_labels), | |
| y=np.array(labels_array) | |
| ) | |
| # Map to all 8 classes (some might be missing) | |
| weight_dict = {c: w for c, w in zip(unique_labels, class_weights)} | |
| weights_tensor = torch.tensor( | |
| [weight_dict.get(i, 1.0) for i in range(8)], dtype=torch.float32 | |
| ) | |
| logger.info(f"Class weights: {weights_tensor.tolist()}") | |
| # ββ Step 3: Load model ββββββββββββββββββββββββββββββββ | |
| logger.info("Loading DistilBERT model...") | |
| model = DistilBertForSequenceClassification.from_pretrained( | |
| 'distilbert-base-uncased', | |
| num_labels=8 | |
| ) | |
| param_count = sum(p.numel() for p in model.parameters()) | |
| logger.info(f"Model loaded. Parameters: {param_count:,}") | |
| # ββ Freeze base layers β only fine-tune last 2 transformer layers + head β | |
| # Freezing layers[0-3] cuts trainable params from 67M to ~7M, | |
| # reducing peak RAM from ~3.5 GB to ~800 MB. Quality impact is minimal | |
| # because the ticket vocabulary is similar to DistilBERT pretraining data. | |
| for name, param in model.named_parameters(): | |
| param.requires_grad = False # freeze everything first | |
| # Unfreeze: last 2 transformer layers (layer 4 and 5 of 6) | |
| for name, param in model.named_parameters(): | |
| if any(key in name for key in [ | |
| 'transformer.layer.4', | |
| 'transformer.layer.5', | |
| 'pre_classifier', | |
| 'classifier', | |
| ]): | |
| param.requires_grad = True | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in model.parameters()) | |
| logger.info(f"Trainable params: {trainable:,} / {total:,} ({trainable/total*100:.1f}%)") | |
| # Force garbage collection after model load | |
| gc.collect() | |
| # ββ Step 4: Custom Trainer with weighted loss βββββββββ | |
| from torch.nn import CrossEntropyLoss | |
| class WeightedTrainer(Trainer): | |
| """Trainer with class-weighted cross-entropy for imbalanced datasets.""" | |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): | |
| labels = inputs.pop("labels") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| loss_fn = CrossEntropyLoss(weight=weights_tensor.to(logits.device)) | |
| loss = loss_fn(logits, labels) | |
| return (loss, outputs) if return_outputs else loss | |
| # ββ Step 5: Training ββββββββββββββββββββββββββββββββββ | |
| # batch=1 with gradient_accumulation=16 gives effective batch=16 | |
| # gradient_checkpointing trades compute for memory (critical on 5GB RAM) | |
| # Total steps = (train_samples / effective_batch) * epochs | |
| # 2800 / 16 * 5 = 875 steps | |
| total_steps = (len(tokenized_train) // 16) * 5 | |
| training_args = TrainingArguments( | |
| output_dir=os.path.join(BASE_DIR, 'results'), | |
| num_train_epochs=5, | |
| per_device_train_batch_size=1, | |
| per_device_eval_batch_size=1, | |
| gradient_accumulation_steps=16, | |
| gradient_checkpointing=True, | |
| warmup_steps=50, | |
| weight_decay=0.01, | |
| learning_rate=3e-5, | |
| logging_dir=os.path.join(BASE_DIR, 'logs'), | |
| logging_steps=25, # Log every 25 steps (~2 min on CPU) | |
| evaluation_strategy="steps", | |
| eval_steps=50, # Evaluate every 50 steps (~4 min) | |
| save_strategy="steps", | |
| save_steps=50, # Must equal eval_steps when load_best_model_at_end=True | |
| save_total_limit=3, # Keep 3 checkpoints (~75 steps of safety) | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| fp16=False, | |
| dataloader_num_workers=0, | |
| report_to="none", | |
| use_cpu=True, | |
| optim="adafactor", # Much less memory than AdamW | |
| ) | |
| trainer = WeightedTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_train, | |
| eval_dataset=tokenized_val, | |
| compute_metrics=compute_metrics, | |
| callbacks=[MemoryProfilerCallback(total_steps=total_steps)], | |
| ) | |
| logger.info("=" * 60) | |
| logger.info("Starting DistilBERT fine-tuning (5 epochs, weighted loss)...") | |
| logger.info(f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}") | |
| logger.info(f" Max sequence length: {MAX_LENGTH}") | |
| logger.info(f" Training samples: {len(tokenized_train)}") | |
| logger.info("=" * 60) | |
| last_checkpoint = get_last_checkpoint(training_args.output_dir) | |
| if last_checkpoint is not None: | |
| logger.info(f"Resuming training from checkpoint: {last_checkpoint}") | |
| else: | |
| logger.info("No checkpoint found. Starting from scratch.") | |
| trainer.train(resume_from_checkpoint=last_checkpoint) | |
| # ββ Step 6: Evaluate ββββββββββββββββββββββββββββββββββ | |
| logger.info("Running final evaluation...") | |
| eval_results = trainer.evaluate() | |
| logger.info(f"Eval results: {eval_results}") | |
| # ββ Step 7: Save ββββββββββββββββββββββββββββββββββββββ | |
| logger.info(f"Saving fine-tuned model to {MODEL_DIR}") | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| model.save_pretrained(MODEL_DIR) | |
| tokenizer.save_pretrained(MODEL_DIR) | |
| # Save eval results | |
| import json | |
| results_path = os.path.join(BASE_DIR, 'results', 'training_results.json') | |
| os.makedirs(os.path.dirname(results_path), exist_ok=True) | |
| with open(results_path, 'w') as f: | |
| json.dump(eval_results, f, indent=2, default=str) | |
| logger.info(f"Results saved to {results_path}") | |
| logger.info("=" * 60) | |
| logger.info("Training complete! Model is ready for inference.") | |
| logger.info("=" * 60) | |
| if __name__ == '__main__': | |
| main() | |