Spaces:
Running
Running
| """ | |
| BERT Baseline Training for LexiMind Comparison. | |
| Fine-tunes bert-base-uncased on topic classification and emotion detection | |
| to provide baselines for comparison with LexiMind (FLAN-T5-based). | |
| Supports three training modes to disentangle architecture vs. MTL effects: | |
| 1. single-topic — BERT fine-tuned on topic classification only | |
| 2. single-emotion — BERT fine-tuned on emotion detection only | |
| 3. multitask — BERT fine-tuned on both tasks jointly | |
| Uses the same datasets, splits, label encoders, and evaluation metrics as the | |
| main LexiMind pipeline for fair comparison. | |
| Usage: | |
| python scripts/train_bert_baseline.py --mode single-topic | |
| python scripts/train_bert_baseline.py --mode single-emotion | |
| python scripts/train_bert_baseline.py --mode multitask | |
| python scripts/train_bert_baseline.py --mode all # Run all three sequentially | |
| Author: Oliver Perrin | |
| Date: March 2026 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from sklearn.metrics import accuracy_score, classification_report, f1_score | |
| from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer | |
| from torch.cuda.amp import GradScaler, autocast | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from transformers import AutoModel, AutoTokenizer | |
| # Project imports | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from src.data.dataset import ( | |
| EmotionExample, | |
| TopicExample, | |
| load_emotion_jsonl, | |
| load_topic_jsonl, | |
| ) | |
| from src.training.metrics import ( | |
| bootstrap_confidence_interval, | |
| multilabel_f1, | |
| multilabel_macro_f1, | |
| multilabel_micro_f1, | |
| multilabel_per_class_metrics, | |
| tune_per_class_thresholds, | |
| ) | |
| # Configuration | |
| class BertBaselineConfig: | |
| """Hyperparameters aligned with LexiMind's full.yaml where applicable.""" | |
| # Model | |
| model_name: str = "bert-base-uncased" | |
| max_length: int = 256 # Same as LexiMind classification max_len | |
| # Optimizer (matching LexiMind's full.yaml) | |
| lr: float = 3e-5 | |
| weight_decay: float = 0.01 | |
| betas: tuple[float, float] = (0.9, 0.98) | |
| eps: float = 1e-6 | |
| # Training | |
| batch_size: int = 10 # Same as LexiMind | |
| gradient_accumulation_steps: int = 4 # Same effective batch = 40 | |
| max_epochs: int = 8 | |
| warmup_steps: int = 300 | |
| gradient_clip_norm: float = 1.0 | |
| early_stopping_patience: int = 3 | |
| seed: int = 17 # Same as LexiMind | |
| # Task weights (for multi-task mode) | |
| topic_weight: float = 0.3 # Same as LexiMind | |
| emotion_weight: float = 1.0 | |
| # Temperature sampling (for multi-task mode) | |
| task_sampling_alpha: float = 0.5 | |
| # Frozen layers: freeze bottom 4 layers (matching LexiMind's encoder strategy) | |
| freeze_layers: int = 4 | |
| # Precision | |
| use_amp: bool = True # BFloat16 mixed precision | |
| # Paths | |
| data_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed") | |
| output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "outputs" / "bert_baseline") | |
| checkpoint_dir: Path = field( | |
| default_factory=lambda: PROJECT_ROOT / "checkpoints" / "bert_baseline" | |
| ) | |
| # Emotion threshold | |
| emotion_threshold: float = 0.3 | |
| # Datasets | |
| class BertEmotionDataset(Dataset): | |
| """Tokenized emotion dataset for BERT.""" | |
| def __init__( | |
| self, | |
| examples: List[EmotionExample], | |
| tokenizer: AutoTokenizer, | |
| binarizer: MultiLabelBinarizer, | |
| max_length: int = 256, | |
| ): | |
| self.examples = examples | |
| self.tokenizer = tokenizer | |
| self.binarizer = binarizer | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.examples) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| ex = self.examples[idx] | |
| encoding = self.tokenizer( | |
| ex.text, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| labels = self.binarizer.transform([ex.emotions])[0] | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(0), | |
| "attention_mask": encoding["attention_mask"].squeeze(0), | |
| "labels": torch.tensor(labels, dtype=torch.float32), | |
| } | |
| class BertTopicDataset(Dataset): | |
| """Tokenized topic dataset for BERT.""" | |
| def __init__( | |
| self, | |
| examples: List[TopicExample], | |
| tokenizer: AutoTokenizer, | |
| encoder: LabelEncoder, | |
| max_length: int = 256, | |
| ): | |
| self.examples = examples | |
| self.tokenizer = tokenizer | |
| self.encoder = encoder | |
| self.max_length = max_length | |
| def __len__(self) -> int: | |
| return len(self.examples) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| ex = self.examples[idx] | |
| encoding = self.tokenizer( | |
| ex.text, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| label = self.encoder.transform([ex.topic])[0] | |
| return { | |
| "input_ids": encoding["input_ids"].squeeze(0), | |
| "attention_mask": encoding["attention_mask"].squeeze(0), | |
| "labels": torch.tensor(label, dtype=torch.long), | |
| } | |
| # Model | |
| class BertClassificationHead(nn.Module): | |
| """Classification head on top of BERT [CLS] token. | |
| For emotion: uses attention pooling + 2-layer MLP (matching LexiMind's emotion head) | |
| For topic: uses [CLS] + single linear (matching LexiMind's mean pool + linear) | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_labels: int, | |
| pooling: str = "cls", # "cls" or "attention" | |
| hidden_dim: Optional[int] = None, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.pooling = pooling | |
| self.dropout = nn.Dropout(dropout) | |
| if pooling == "attention": | |
| self.attn_query = nn.Linear(hidden_size, 1, bias=False) | |
| if hidden_dim is not None: | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, num_labels), | |
| ) | |
| else: | |
| self.classifier = nn.Linear(hidden_size, num_labels) | |
| def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| if self.pooling == "attention": | |
| # Learned attention pooling (same mechanism as LexiMind) | |
| scores = self.attn_query(hidden_states) # (B, L, 1) | |
| mask = attention_mask.unsqueeze(-1).bool() | |
| scores = scores.masked_fill(~mask, float("-inf")) | |
| weights = F.softmax(scores, dim=1) | |
| pooled = (weights * hidden_states).sum(dim=1) | |
| elif self.pooling == "mean": | |
| # Mean pooling over valid tokens | |
| mask_expanded = attention_mask.unsqueeze(-1).float() | |
| sum_embeddings = (hidden_states * mask_expanded).sum(dim=1) | |
| sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9) | |
| pooled = sum_embeddings / sum_mask | |
| else: | |
| # [CLS] token | |
| pooled = hidden_states[:, 0, :] | |
| pooled = self.dropout(pooled) | |
| return self.classifier(pooled) | |
| class BertBaseline(nn.Module): | |
| """BERT baseline model with task-specific heads. | |
| Supports single-task and multi-task configurations. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "bert-base-uncased", | |
| num_emotions: int = 28, | |
| num_topics: int = 7, | |
| tasks: Sequence[str] = ("emotion", "topic"), | |
| freeze_layers: int = 4, | |
| ): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained(model_name) | |
| hidden_size = self.bert.config.hidden_size # 768 for bert-base | |
| self.tasks = list(tasks) | |
| self.heads = nn.ModuleDict() | |
| if "emotion" in tasks: | |
| # Attention pooling + 2-layer MLP (matching LexiMind's emotion head) | |
| self.heads["emotion"] = BertClassificationHead( | |
| hidden_size=hidden_size, | |
| num_labels=num_emotions, | |
| pooling="attention", | |
| hidden_dim=hidden_size // 2, # 384, same ratio as LexiMind | |
| dropout=0.1, | |
| ) | |
| if "topic" in tasks: | |
| # Mean pooling + single linear (matching LexiMind's topic head) | |
| self.heads["topic"] = BertClassificationHead( | |
| hidden_size=hidden_size, | |
| num_labels=num_topics, | |
| pooling="mean", | |
| hidden_dim=None, | |
| dropout=0.1, | |
| ) | |
| # Freeze bottom N encoder layers (matching LexiMind's strategy) | |
| self._freeze_layers(freeze_layers) | |
| def _freeze_layers(self, n: int) -> None: | |
| """Freeze embedding + bottom n encoder layers.""" | |
| # Freeze embeddings | |
| for param in self.bert.embeddings.parameters(): | |
| param.requires_grad = False | |
| # Freeze bottom n layers | |
| for i in range(min(n, len(self.bert.encoder.layer))): | |
| for param in self.bert.encoder.layer[i].parameters(): | |
| param.requires_grad = False | |
| frozen = sum(1 for p in self.bert.parameters() if not p.requires_grad) | |
| total = sum(1 for p in self.bert.parameters()) | |
| print(f" Frozen {frozen}/{total} BERT parameters (bottom {n} layers + embeddings)") | |
| def forward( | |
| self, | |
| task: str, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| hidden_states = outputs.last_hidden_state # (B, L, 768) | |
| return self.heads[task](hidden_states, attention_mask) | |
| def param_count(self) -> Dict[str, int]: | |
| """Count parameters by component.""" | |
| counts = {} | |
| counts["bert_encoder"] = sum(p.numel() for p in self.bert.parameters()) | |
| counts["bert_trainable"] = sum(p.numel() for p in self.bert.parameters() if p.requires_grad) | |
| for name, head in self.heads.items(): | |
| counts[f"head_{name}"] = sum(p.numel() for p in head.parameters()) | |
| counts["total"] = sum(p.numel() for p in self.parameters()) | |
| counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| return counts | |
| # Training | |
| class BertTrainer: | |
| """Trainer supporting single-task and multi-task BERT training.""" | |
| def __init__( | |
| self, | |
| model: BertBaseline, | |
| config: BertBaselineConfig, | |
| train_loaders: Dict[str, DataLoader], | |
| val_loaders: Dict[str, DataLoader], | |
| device: torch.device, | |
| mode: str, | |
| ): | |
| self.model = model | |
| self.config = config | |
| self.train_loaders = train_loaders | |
| self.val_loaders = val_loaders | |
| self.device = device | |
| self.mode = mode | |
| # Optimizer | |
| self.optimizer = AdamW( | |
| [p for p in model.parameters() if p.requires_grad], | |
| lr=config.lr, | |
| weight_decay=config.weight_decay, | |
| betas=config.betas, | |
| eps=config.eps, | |
| ) | |
| # Calculate total training steps | |
| if len(train_loaders) > 1: | |
| # Multi-task: use temperature-sampled steps | |
| sizes = {k: len(v) for k, v in train_loaders.items()} | |
| total_batches = sum(sizes.values()) | |
| else: | |
| total_batches = sum(len(v) for v in train_loaders.values()) | |
| self.steps_per_epoch = total_batches // config.gradient_accumulation_steps | |
| self.total_steps = self.steps_per_epoch * config.max_epochs | |
| # LR scheduler: linear warmup + cosine decay (matching LexiMind) | |
| warmup_scheduler = LinearLR( | |
| self.optimizer, | |
| start_factor=1e-8 / config.lr, | |
| end_factor=1.0, | |
| total_iters=config.warmup_steps, | |
| ) | |
| cosine_scheduler = CosineAnnealingLR( | |
| self.optimizer, | |
| T_max=max(self.total_steps - config.warmup_steps, 1), | |
| eta_min=config.lr * 0.1, # Decay to 10% of peak (matching LexiMind) | |
| ) | |
| self.scheduler = SequentialLR( | |
| self.optimizer, | |
| schedulers=[warmup_scheduler, cosine_scheduler], | |
| milestones=[config.warmup_steps], | |
| ) | |
| # Mixed precision | |
| self.scaler = GradScaler(enabled=config.use_amp) | |
| # Loss functions | |
| self.emotion_loss_fn = nn.BCEWithLogitsLoss() | |
| self.topic_loss_fn = nn.CrossEntropyLoss() | |
| # Tracking | |
| self.global_step = 0 | |
| self.best_metric = -float("inf") | |
| self.patience_counter = 0 | |
| self.training_history: List[Dict[str, Any]] = [] | |
| def _compute_loss(self, task: str, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | |
| if task == "emotion": | |
| return self.emotion_loss_fn(logits, labels) | |
| else: | |
| return self.topic_loss_fn(logits, labels) | |
| def _get_task_weight(self, task: str) -> float: | |
| if self.mode != "multitask": | |
| return 1.0 | |
| if task == "topic": | |
| return self.config.topic_weight | |
| return self.config.emotion_weight | |
| def _make_multitask_iterator(self): | |
| """Temperature-based task sampling (matching LexiMind).""" | |
| sizes = {k: len(v.dataset) for k, v in self.train_loaders.items()} | |
| alpha = self.config.task_sampling_alpha | |
| # Compute sampling probabilities | |
| raw = {k: s ** (1.0 / alpha) for k, s in sizes.items()} | |
| total = sum(raw.values()) | |
| probs = {k: v / total for k, v in raw.items()} | |
| # Create iterators | |
| iters = {k: iter(v) for k, v in self.train_loaders.items()} | |
| tasks = list(probs.keys()) | |
| weights = [probs[t] for t in tasks] | |
| while True: | |
| task = random.choices(tasks, weights=weights, k=1)[0] | |
| try: | |
| batch = next(iters[task]) | |
| except StopIteration: | |
| iters[task] = iter(self.train_loaders[task]) | |
| batch = next(iters[task]) | |
| yield task, batch | |
| def train_epoch(self, epoch: int) -> Dict[str, float]: | |
| """Train one epoch.""" | |
| self.model.train() | |
| self.optimizer.zero_grad() | |
| epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders} | |
| if len(self.train_loaders) > 1: | |
| # Multi-task: temperature sampling | |
| iterator = self._make_multitask_iterator() | |
| total_batches = sum(len(v) for v in self.train_loaders.values()) | |
| else: | |
| # Single-task: iterate normally | |
| task_name = list(self.train_loaders.keys())[0] | |
| iterator = ((task_name, batch) for batch in self.train_loaders[task_name]) | |
| total_batches = len(self.train_loaders[task_name]) | |
| pbar = tqdm(total=total_batches, desc=f"Epoch {epoch + 1}/{self.config.max_epochs}") | |
| for step_in_epoch in range(total_batches): | |
| task, batch = next(iterator) | |
| input_ids = batch["input_ids"].to(self.device) | |
| attention_mask = batch["attention_mask"].to(self.device) | |
| labels = batch["labels"].to(self.device) | |
| # Forward pass with AMP | |
| with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp): | |
| logits = self.model(task, input_ids, attention_mask) | |
| loss = self._compute_loss(task, logits, labels) | |
| loss = loss * self._get_task_weight(task) | |
| loss = loss / self.config.gradient_accumulation_steps | |
| # Backward | |
| self.scaler.scale(loss).backward() | |
| epoch_losses[task].append(loss.item() * self.config.gradient_accumulation_steps) | |
| # Optimizer step (every N accumulation steps) | |
| if (step_in_epoch + 1) % self.config.gradient_accumulation_steps == 0: | |
| self.scaler.unscale_(self.optimizer) | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), self.config.gradient_clip_norm | |
| ) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| self.optimizer.zero_grad() | |
| self.scheduler.step() | |
| self.global_step += 1 | |
| pbar.set_postfix( | |
| { | |
| f"{task}_loss": f"{epoch_losses[task][-1]:.4f}", | |
| "lr": f"{self.scheduler.get_last_lr()[0]:.2e}", | |
| } | |
| ) | |
| pbar.update(1) | |
| pbar.close() | |
| # Aggregate | |
| results = {} | |
| for task, losses in epoch_losses.items(): | |
| if losses: | |
| results[f"train_{task}_loss"] = sum(losses) / len(losses) | |
| return results | |
| def validate(self) -> Dict[str, Any]: | |
| """Run validation across all tasks.""" | |
| self.model.eval() | |
| results: Dict[str, Any] = {} | |
| for task, loader in self.val_loaders.items(): | |
| all_logits = [] | |
| all_labels = [] | |
| total_loss = 0.0 | |
| n_batches = 0 | |
| for batch in loader: | |
| input_ids = batch["input_ids"].to(self.device) | |
| attention_mask = batch["attention_mask"].to(self.device) | |
| labels = batch["labels"].to(self.device) | |
| with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp): | |
| logits = self.model(task, input_ids, attention_mask) | |
| loss = self._compute_loss(task, logits, labels) | |
| total_loss += loss.item() | |
| n_batches += 1 | |
| all_logits.append(logits.float().cpu()) | |
| all_labels.append(labels.float().cpu()) | |
| all_logits_t = torch.cat(all_logits, dim=0) | |
| all_labels_t = torch.cat(all_labels, dim=0) | |
| results[f"val_{task}_loss"] = total_loss / max(n_batches, 1) | |
| if task == "emotion": | |
| preds = (torch.sigmoid(all_logits_t) > self.config.emotion_threshold).int() | |
| targets = all_labels_t.int() | |
| results["val_emotion_sample_f1"] = multilabel_f1(preds, targets) | |
| results["val_emotion_macro_f1"] = multilabel_macro_f1(preds, targets) | |
| results["val_emotion_micro_f1"] = multilabel_micro_f1(preds, targets) | |
| # Store raw logits for threshold tuning later | |
| results["_emotion_logits"] = all_logits_t | |
| results["_emotion_labels"] = all_labels_t | |
| elif task == "topic": | |
| preds = all_logits_t.argmax(dim=1).numpy() | |
| targets = all_labels_t.long().numpy() | |
| results["val_topic_accuracy"] = float(accuracy_score(targets, preds)) | |
| results["val_topic_macro_f1"] = float( | |
| f1_score(targets, preds, average="macro", zero_division=0) | |
| ) | |
| # Combined metric for early stopping / checkpointing | |
| metric_parts = [] | |
| if "val_emotion_sample_f1" in results: | |
| metric_parts.append(results["val_emotion_sample_f1"]) | |
| if "val_topic_accuracy" in results: | |
| metric_parts.append(results["val_topic_accuracy"]) | |
| results["val_combined_metric"] = sum(metric_parts) / max(len(metric_parts), 1) | |
| return results | |
| def save_checkpoint(self, path: Path, epoch: int, metrics: Dict[str, Any]) -> None: | |
| """Save model checkpoint.""" | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| # Filter out tensors from metrics | |
| clean_metrics = {k: v for k, v in metrics.items() if not k.startswith("_")} | |
| torch.save( | |
| { | |
| "epoch": epoch, | |
| "model_state_dict": self.model.state_dict(), | |
| "optimizer_state_dict": self.optimizer.state_dict(), | |
| "scheduler_state_dict": self.scheduler.state_dict(), | |
| "metrics": clean_metrics, | |
| "config": { | |
| "mode": self.mode, | |
| "tasks": self.model.tasks, | |
| "model_name": self.config.model_name, | |
| }, | |
| }, | |
| path, | |
| ) | |
| def train(self) -> Dict[str, Any]: | |
| """Full training loop.""" | |
| print(f"\n{'=' * 60}") | |
| print(f"Training BERT Baseline — Mode: {self.mode}") | |
| print(f"{'=' * 60}") | |
| param_counts = self.model.param_count() | |
| print(f" Total parameters: {param_counts['total']:,}") | |
| print(f" Trainable parameters: {param_counts['trainable']:,}") | |
| for name, count in param_counts.items(): | |
| if name.startswith("head_"): | |
| print(f" {name}: {count:,}") | |
| print(f" Steps/epoch: {self.steps_per_epoch}") | |
| print(f" Total steps: {self.total_steps}") | |
| print() | |
| all_results: Dict[str, Any] = {"mode": self.mode, "epochs": []} | |
| start_time = time.time() | |
| for epoch in range(self.config.max_epochs): | |
| epoch_start = time.time() | |
| # Train | |
| train_metrics = self.train_epoch(epoch) | |
| # Validate | |
| val_metrics = self.validate() | |
| epoch_time = time.time() - epoch_start | |
| # Log | |
| epoch_result = { | |
| "epoch": epoch + 1, | |
| "time_seconds": epoch_time, | |
| **train_metrics, | |
| **{k: v for k, v in val_metrics.items() if not k.startswith("_")}, | |
| } | |
| all_results["epochs"].append(epoch_result) | |
| self.training_history.append(epoch_result) | |
| # Print summary | |
| print(f"\n Epoch {epoch + 1} ({epoch_time:.0f}s):") | |
| for k, v in sorted(epoch_result.items()): | |
| if k not in ("epoch", "time_seconds") and isinstance(v, float): | |
| print(f" {k}: {v:.4f}") | |
| # Checkpointing | |
| combined = val_metrics["val_combined_metric"] | |
| if combined > self.best_metric: | |
| self.best_metric = combined | |
| self.patience_counter = 0 | |
| self.save_checkpoint( | |
| self.config.checkpoint_dir / self.mode / "best.pt", | |
| epoch, | |
| val_metrics, | |
| ) | |
| print(f" New best model (combined metric: {combined:.4f})") | |
| else: | |
| self.patience_counter += 1 | |
| print( | |
| f" No improvement ({self.patience_counter}/{self.config.early_stopping_patience})" | |
| ) | |
| # Always save epoch checkpoint | |
| self.save_checkpoint( | |
| self.config.checkpoint_dir / self.mode / f"epoch_{epoch + 1}.pt", | |
| epoch, | |
| val_metrics, | |
| ) | |
| # Early stopping | |
| if self.patience_counter >= self.config.early_stopping_patience: | |
| print(f"\n Early stopping triggered at epoch {epoch + 1}") | |
| all_results["early_stopped"] = True | |
| all_results["best_epoch"] = epoch + 1 - self.config.early_stopping_patience | |
| break | |
| total_time = time.time() - start_time | |
| all_results["total_time_seconds"] = total_time | |
| all_results["total_time_human"] = f"{total_time / 3600:.1f}h" | |
| if "early_stopped" not in all_results: | |
| all_results["early_stopped"] = False | |
| all_results["best_epoch"] = ( | |
| epoch + 1 - self.patience_counter if self.patience_counter > 0 else epoch + 1 | |
| ) | |
| all_results["param_counts"] = param_counts | |
| print(f"\n Training complete in {total_time / 3600:.1f}h") | |
| print(f" Best combined metric: {self.best_metric:.4f}") | |
| return all_results | |
| # Evaluation | |
| def evaluate_bert_model( | |
| model: BertBaseline, | |
| val_loaders: Dict[str, DataLoader], | |
| device: torch.device, | |
| config: BertBaselineConfig, | |
| emotion_classes: Optional[List[str]] = None, | |
| topic_classes: Optional[List[str]] = None, | |
| ) -> Dict[str, Any]: | |
| """Full evaluation with the same metrics as LexiMind's evaluate.py.""" | |
| model.eval() | |
| results: Dict[str, Any] = {} | |
| with torch.no_grad(): | |
| for task, loader in val_loaders.items(): | |
| all_logits = [] | |
| all_labels = [] | |
| for batch in tqdm(loader, desc=f"Evaluating {task}"): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| with autocast(dtype=torch.bfloat16, enabled=config.use_amp): | |
| logits = model(task, input_ids, attention_mask) | |
| all_logits.append(logits.float().cpu()) | |
| all_labels.append(labels.float().cpu()) | |
| all_logits_t = torch.cat(all_logits, dim=0) | |
| all_labels_t = torch.cat(all_labels, dim=0) | |
| if task == "emotion": | |
| # Default threshold | |
| preds_default = (torch.sigmoid(all_logits_t) > config.emotion_threshold).int() | |
| targets = all_labels_t.int() | |
| results["emotion"] = { | |
| "default_threshold": config.emotion_threshold, | |
| "sample_avg_f1": multilabel_f1(preds_default, targets), | |
| "macro_f1": multilabel_macro_f1(preds_default, targets), | |
| "micro_f1": multilabel_micro_f1(preds_default, targets), | |
| } | |
| # Per-class metrics | |
| if emotion_classes: | |
| per_class = multilabel_per_class_metrics( | |
| preds_default, targets, emotion_classes | |
| ) | |
| results["emotion"]["per_class"] = per_class | |
| # Threshold tuning | |
| best_thresholds, tuned_macro = tune_per_class_thresholds(all_logits_t, all_labels_t) | |
| tuned_preds = torch.zeros_like(all_logits_t) | |
| probs = torch.sigmoid(all_logits_t) | |
| for c in range(all_logits_t.shape[1]): | |
| tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float() | |
| tuned_preds = tuned_preds.int() | |
| results["emotion"]["tuned_macro_f1"] = tuned_macro | |
| results["emotion"]["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, targets) | |
| results["emotion"]["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, targets) | |
| # Bootstrap CI on sample-avg F1 | |
| per_sample_f1 = [] | |
| for i in range(preds_default.shape[0]): | |
| p = preds_default[i].float() | |
| g = targets[i].float() | |
| tp = (p * g).sum() | |
| prec = tp / p.sum().clamp(min=1) | |
| rec = tp / g.sum().clamp(min=1) | |
| f = (2 * prec * rec) / (prec + rec).clamp(min=1e-8) | |
| per_sample_f1.append(f.item()) | |
| mean_f1, ci_low, ci_high = bootstrap_confidence_interval(per_sample_f1) | |
| results["emotion"]["sample_avg_f1_ci"] = [ci_low, ci_high] | |
| elif task == "topic": | |
| preds = all_logits_t.argmax(dim=1).numpy() | |
| targets = all_labels_t.long().numpy() | |
| acc = float(accuracy_score(targets, preds)) | |
| macro_f1 = float(f1_score(targets, preds, average="macro", zero_division=0)) | |
| results["topic"] = { | |
| "accuracy": acc, | |
| "macro_f1": macro_f1, | |
| } | |
| # Per-class metrics | |
| if topic_classes: | |
| report = classification_report( | |
| targets, | |
| preds, | |
| target_names=topic_classes, | |
| output_dict=True, | |
| zero_division=0, | |
| ) | |
| results["topic"]["per_class"] = { | |
| name: { | |
| "precision": report[name]["precision"], | |
| "recall": report[name]["recall"], | |
| "f1": report[name]["f1-score"], | |
| "support": report[name]["support"], | |
| } | |
| for name in topic_classes | |
| if name in report | |
| } | |
| # Bootstrap CI on accuracy | |
| per_sample_correct = (preds == targets).astype(float).tolist() | |
| mean_acc, ci_low, ci_high = bootstrap_confidence_interval(per_sample_correct) | |
| results["topic"]["accuracy_ci"] = [ci_low, ci_high] | |
| return results | |
| # Main Pipeline | |
| def set_seed(seed: int) -> None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def load_data(config: BertBaselineConfig): | |
| """Load all datasets and create label encoders.""" | |
| data_dir = config.data_dir | |
| # Load emotion data | |
| emo_train = load_emotion_jsonl(str(data_dir / "emotion" / "train.jsonl")) | |
| emo_val_path = data_dir / "emotion" / "validation.jsonl" | |
| if not emo_val_path.exists(): | |
| emo_val_path = data_dir / "emotion" / "val.jsonl" | |
| emo_val = load_emotion_jsonl(str(emo_val_path)) | |
| # Load topic data | |
| top_train = load_topic_jsonl(str(data_dir / "topic" / "train.jsonl")) | |
| top_val_path = data_dir / "topic" / "validation.jsonl" | |
| if not top_val_path.exists(): | |
| top_val_path = data_dir / "topic" / "val.jsonl" | |
| top_val = load_topic_jsonl(str(top_val_path)) | |
| # Fit label encoders on training data (same as LexiMind) | |
| binarizer = MultiLabelBinarizer() | |
| binarizer.fit([ex.emotions for ex in emo_train]) | |
| label_encoder = LabelEncoder() | |
| label_encoder.fit([ex.topic for ex in top_train]) | |
| print( | |
| f" Emotion: {len(emo_train)} train, {len(emo_val)} val, {len(binarizer.classes_)} classes" | |
| ) | |
| print( | |
| f" Topic: {len(top_train)} train, {len(top_val)} val, {len(label_encoder.classes_)} classes" | |
| ) | |
| print(f" Emotion classes: {list(binarizer.classes_)[:5]}...") | |
| print(f" Topic classes: {list(label_encoder.classes_)}") | |
| return { | |
| "emotion_train": emo_train, | |
| "emotion_val": emo_val, | |
| "topic_train": top_train, | |
| "topic_val": top_val, | |
| "binarizer": binarizer, | |
| "label_encoder": label_encoder, | |
| } | |
| def run_experiment(mode: str, config: BertBaselineConfig) -> Dict[str, Any]: | |
| """Run a single experiment (single-topic, single-emotion, or multitask).""" | |
| print(f"\n{'═' * 60}") | |
| print(f" BERT BASELINE EXPERIMENT: {mode.upper()}") | |
| print(f"{'═' * 60}") | |
| set_seed(config.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f" Device: {device}") | |
| if torch.cuda.is_available(): | |
| print(f" GPU: {torch.cuda.get_device_name()}") | |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
| # CUDA optimizations | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.benchmark = True | |
| if hasattr(torch.backends, "cuda"): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # Load tokenizer | |
| print(f"\n Loading tokenizer: {config.model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
| # Load data | |
| print(" Loading datasets...") | |
| data = load_data(config) | |
| # Determine tasks for this mode | |
| if mode == "single-topic": | |
| tasks = ["topic"] | |
| elif mode == "single-emotion": | |
| tasks = ["emotion"] | |
| else: | |
| tasks = ["emotion", "topic"] | |
| # Create datasets | |
| train_loaders: Dict[str, DataLoader] = {} | |
| val_loaders: Dict[str, DataLoader] = {} | |
| if "emotion" in tasks: | |
| emo_train_ds = BertEmotionDataset( | |
| data["emotion_train"], tokenizer, data["binarizer"], config.max_length | |
| ) | |
| emo_val_ds = BertEmotionDataset( | |
| data["emotion_val"], tokenizer, data["binarizer"], config.max_length | |
| ) | |
| train_loaders["emotion"] = DataLoader( | |
| emo_train_ds, | |
| batch_size=config.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| val_loaders["emotion"] = DataLoader( | |
| emo_val_ds, | |
| batch_size=config.batch_size * 2, | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True, | |
| ) | |
| if "topic" in tasks: | |
| top_train_ds = BertTopicDataset( | |
| data["topic_train"], tokenizer, data["label_encoder"], config.max_length | |
| ) | |
| top_val_ds = BertTopicDataset( | |
| data["topic_val"], tokenizer, data["label_encoder"], config.max_length | |
| ) | |
| train_loaders["topic"] = DataLoader( | |
| top_train_ds, | |
| batch_size=config.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| val_loaders["topic"] = DataLoader( | |
| top_val_ds, | |
| batch_size=config.batch_size * 2, | |
| shuffle=False, | |
| num_workers=4, | |
| pin_memory=True, | |
| ) | |
| # Create model | |
| print(f"\n Creating model with tasks: {tasks}") | |
| model = BertBaseline( | |
| model_name=config.model_name, | |
| num_emotions=len(data["binarizer"].classes_), | |
| num_topics=len(data["label_encoder"].classes_), | |
| tasks=tasks, | |
| freeze_layers=config.freeze_layers, | |
| ).to(device) | |
| # Train | |
| trainer = BertTrainer(model, config, train_loaders, val_loaders, device, mode) | |
| training_results = trainer.train() | |
| # Load best checkpoint for final evaluation | |
| best_path = config.checkpoint_dir / mode / "best.pt" | |
| if best_path.exists(): | |
| print("\n Loading best checkpoint for final evaluation...") | |
| checkpoint = torch.load(best_path, map_location=device, weights_only=False) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| # Full evaluation | |
| print("\n Running final evaluation...") | |
| eval_results = evaluate_bert_model( | |
| model, | |
| val_loaders, | |
| device, | |
| config, | |
| emotion_classes=list(data["binarizer"].classes_) if "emotion" in tasks else None, | |
| topic_classes=list(data["label_encoder"].classes_) if "topic" in tasks else None, | |
| ) | |
| # Combine results | |
| final_results = { | |
| "mode": mode, | |
| "model": config.model_name, | |
| "tasks": tasks, | |
| "training": training_results, | |
| "evaluation": eval_results, | |
| } | |
| # Save results | |
| output_path = config.output_dir / f"{mode}_results.json" | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Remove non-serializable fields | |
| def make_serializable(obj): | |
| if isinstance(obj, dict): | |
| return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")} | |
| if isinstance(obj, list): | |
| return [make_serializable(item) for item in obj] | |
| if isinstance(obj, (np.integer, np.int64)): | |
| return int(obj) | |
| if isinstance(obj, (np.floating, np.float64)): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| return obj | |
| with open(output_path, "w") as f: | |
| json.dump(make_serializable(final_results), f, indent=2) | |
| print(f"\n Results saved to {output_path}") | |
| return final_results | |
| def print_comparison_summary(all_results: Dict[str, Dict[str, Any]]) -> None: | |
| """Print a side-by-side comparison of all experiments.""" | |
| print(f"\n{'═' * 70}") | |
| print(" BERT BASELINE COMPARISON SUMMARY") | |
| print(f"{'═' * 70}") | |
| # Header | |
| modes = list(all_results.keys()) | |
| header = f"{'Metric':<30}" + "".join(f"{m:>16}" for m in modes) + f"{'LexiMind':>16}" | |
| print(f"\n {header}") | |
| print(f" {'─' * len(header)}") | |
| # LexiMind reference values | |
| lexmind = { | |
| "topic_accuracy": 0.8571, | |
| "topic_macro_f1": 0.8539, | |
| "emotion_sample_f1": 0.3523, | |
| "emotion_macro_f1": 0.1432, | |
| "emotion_micro_f1": 0.4430, | |
| "emotion_tuned_macro_f1": 0.2936, | |
| } | |
| # Topic metrics | |
| print(f"\n {'Topic Classification':}") | |
| for metric_name, display_name in [ | |
| ("accuracy", "Accuracy"), | |
| ("macro_f1", "Macro F1"), | |
| ]: | |
| row = f" {display_name:<30}" | |
| for mode in modes: | |
| eval_data = all_results[mode].get("evaluation", {}) | |
| topic = eval_data.get("topic", {}) | |
| val = topic.get(metric_name, None) | |
| row += f"{val:>16.4f}" if val is not None else f"{'—':>16}" | |
| lm_key = f"topic_{metric_name}" | |
| row += f"{lexmind.get(lm_key, 0):>16.4f}" | |
| print(row) | |
| # Emotion metrics | |
| print(f"\n {'Emotion Detection':}") | |
| for metric_name, display_name in [ | |
| ("sample_avg_f1", "Sample-avg F1 (τ=0.3)"), | |
| ("macro_f1", "Macro F1 (τ=0.3)"), | |
| ("micro_f1", "Micro F1 (τ=0.3)"), | |
| ("tuned_macro_f1", "Tuned Macro F1"), | |
| ("tuned_sample_avg_f1", "Tuned Sample-avg F1"), | |
| ]: | |
| row = f" {display_name:<30}" | |
| for mode in modes: | |
| eval_data = all_results[mode].get("evaluation", {}) | |
| emo = eval_data.get("emotion", {}) | |
| val = emo.get(metric_name, None) | |
| row += f"{val:>16.4f}" if val is not None else f"{'—':>16}" | |
| lm_key = f"emotion_{metric_name}" | |
| row += f"{lexmind.get(lm_key, 0):>16.4f}" | |
| print(row) | |
| # Training time | |
| print(f"\n {'Training Time':}") | |
| row = f" {'Hours':<30}" | |
| for mode in modes: | |
| t = all_results[mode].get("training", {}).get("total_time_seconds", 0) / 3600 | |
| row += f"{t:>15.1f}h" | |
| row += f"{'~9.0h':>16}" | |
| print(row) | |
| print(f"\n{'═' * 70}\n") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="BERT Baseline Training for LexiMind") | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| required=True, | |
| choices=["single-topic", "single-emotion", "multitask", "all"], | |
| help="Training mode", | |
| ) | |
| parser.add_argument("--epochs", type=int, default=None, help="Override max epochs") | |
| parser.add_argument("--lr", type=float, default=None, help="Override learning rate") | |
| parser.add_argument("--batch-size", type=int, default=None, help="Override batch size") | |
| parser.add_argument( | |
| "--model", type=str, default="bert-base-uncased", help="HuggingFace model name" | |
| ) | |
| args = parser.parse_args() | |
| config = BertBaselineConfig() | |
| config.model_name = args.model | |
| if args.epochs is not None: | |
| config.max_epochs = args.epochs | |
| if args.lr is not None: | |
| config.lr = args.lr | |
| if args.batch_size is not None: | |
| config.batch_size = args.batch_size | |
| if args.mode == "all": | |
| modes = ["single-topic", "single-emotion", "multitask"] | |
| else: | |
| modes = [args.mode] | |
| all_results: Dict[str, Dict[str, Any]] = {} | |
| for mode in modes: | |
| results = run_experiment(mode, config) | |
| all_results[mode] = results | |
| # Clear GPU memory between experiments | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Save combined results | |
| if len(all_results) > 1: | |
| combined_path = config.output_dir / "combined_results.json" | |
| def make_serializable(obj): | |
| if isinstance(obj, dict): | |
| return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")} | |
| if isinstance(obj, list): | |
| return [make_serializable(item) for item in obj] | |
| if isinstance(obj, (np.integer, np.int64)): | |
| return int(obj) | |
| if isinstance(obj, (np.floating, np.float64)): | |
| return float(obj) | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| return obj | |
| with open(combined_path, "w") as f: | |
| json.dump(make_serializable(all_results), f, indent=2) | |
| print(f" Combined results saved to {combined_path}") | |
| print_comparison_summary(all_results) | |
| if __name__ == "__main__": | |
| main() | |