LexiMind / scripts /train_bert_baseline.py
OliverPerrin
Ruff formatting fix
56554a7
"""
BERT Baseline Training for LexiMind Comparison.
Fine-tunes bert-base-uncased on topic classification and emotion detection
to provide baselines for comparison with LexiMind (FLAN-T5-based).
Supports three training modes to disentangle architecture vs. MTL effects:
1. single-topic — BERT fine-tuned on topic classification only
2. single-emotion — BERT fine-tuned on emotion detection only
3. multitask — BERT fine-tuned on both tasks jointly
Uses the same datasets, splits, label encoders, and evaluation metrics as the
main LexiMind pipeline for fair comparison.
Usage:
python scripts/train_bert_baseline.py --mode single-topic
python scripts/train_bert_baseline.py --mode single-emotion
python scripts/train_bert_baseline.py --mode multitask
python scripts/train_bert_baseline.py --mode all # Run all three sequentially
Author: Oliver Perrin
Date: March 2026
"""
from __future__ import annotations
import argparse
import json
import random
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, classification_report, f1_score
from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
# Project imports
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.data.dataset import (
EmotionExample,
TopicExample,
load_emotion_jsonl,
load_topic_jsonl,
)
from src.training.metrics import (
bootstrap_confidence_interval,
multilabel_f1,
multilabel_macro_f1,
multilabel_micro_f1,
multilabel_per_class_metrics,
tune_per_class_thresholds,
)
# Configuration
@dataclass
class BertBaselineConfig:
"""Hyperparameters aligned with LexiMind's full.yaml where applicable."""
# Model
model_name: str = "bert-base-uncased"
max_length: int = 256 # Same as LexiMind classification max_len
# Optimizer (matching LexiMind's full.yaml)
lr: float = 3e-5
weight_decay: float = 0.01
betas: tuple[float, float] = (0.9, 0.98)
eps: float = 1e-6
# Training
batch_size: int = 10 # Same as LexiMind
gradient_accumulation_steps: int = 4 # Same effective batch = 40
max_epochs: int = 8
warmup_steps: int = 300
gradient_clip_norm: float = 1.0
early_stopping_patience: int = 3
seed: int = 17 # Same as LexiMind
# Task weights (for multi-task mode)
topic_weight: float = 0.3 # Same as LexiMind
emotion_weight: float = 1.0
# Temperature sampling (for multi-task mode)
task_sampling_alpha: float = 0.5
# Frozen layers: freeze bottom 4 layers (matching LexiMind's encoder strategy)
freeze_layers: int = 4
# Precision
use_amp: bool = True # BFloat16 mixed precision
# Paths
data_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed")
output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "outputs" / "bert_baseline")
checkpoint_dir: Path = field(
default_factory=lambda: PROJECT_ROOT / "checkpoints" / "bert_baseline"
)
# Emotion threshold
emotion_threshold: float = 0.3
# Datasets
class BertEmotionDataset(Dataset):
"""Tokenized emotion dataset for BERT."""
def __init__(
self,
examples: List[EmotionExample],
tokenizer: AutoTokenizer,
binarizer: MultiLabelBinarizer,
max_length: int = 256,
):
self.examples = examples
self.tokenizer = tokenizer
self.binarizer = binarizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
ex = self.examples[idx]
encoding = self.tokenizer(
ex.text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
labels = self.binarizer.transform([ex.emotions])[0]
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(labels, dtype=torch.float32),
}
class BertTopicDataset(Dataset):
"""Tokenized topic dataset for BERT."""
def __init__(
self,
examples: List[TopicExample],
tokenizer: AutoTokenizer,
encoder: LabelEncoder,
max_length: int = 256,
):
self.examples = examples
self.tokenizer = tokenizer
self.encoder = encoder
self.max_length = max_length
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
ex = self.examples[idx]
encoding = self.tokenizer(
ex.text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
label = self.encoder.transform([ex.topic])[0]
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(label, dtype=torch.long),
}
# Model
class BertClassificationHead(nn.Module):
"""Classification head on top of BERT [CLS] token.
For emotion: uses attention pooling + 2-layer MLP (matching LexiMind's emotion head)
For topic: uses [CLS] + single linear (matching LexiMind's mean pool + linear)
"""
def __init__(
self,
hidden_size: int,
num_labels: int,
pooling: str = "cls", # "cls" or "attention"
hidden_dim: Optional[int] = None,
dropout: float = 0.1,
):
super().__init__()
self.pooling = pooling
self.dropout = nn.Dropout(dropout)
if pooling == "attention":
self.attn_query = nn.Linear(hidden_size, 1, bias=False)
if hidden_dim is not None:
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_labels),
)
else:
self.classifier = nn.Linear(hidden_size, num_labels)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
if self.pooling == "attention":
# Learned attention pooling (same mechanism as LexiMind)
scores = self.attn_query(hidden_states) # (B, L, 1)
mask = attention_mask.unsqueeze(-1).bool()
scores = scores.masked_fill(~mask, float("-inf"))
weights = F.softmax(scores, dim=1)
pooled = (weights * hidden_states).sum(dim=1)
elif self.pooling == "mean":
# Mean pooling over valid tokens
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_embeddings = (hidden_states * mask_expanded).sum(dim=1)
sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
pooled = sum_embeddings / sum_mask
else:
# [CLS] token
pooled = hidden_states[:, 0, :]
pooled = self.dropout(pooled)
return self.classifier(pooled)
class BertBaseline(nn.Module):
"""BERT baseline model with task-specific heads.
Supports single-task and multi-task configurations.
"""
def __init__(
self,
model_name: str = "bert-base-uncased",
num_emotions: int = 28,
num_topics: int = 7,
tasks: Sequence[str] = ("emotion", "topic"),
freeze_layers: int = 4,
):
super().__init__()
self.bert = AutoModel.from_pretrained(model_name)
hidden_size = self.bert.config.hidden_size # 768 for bert-base
self.tasks = list(tasks)
self.heads = nn.ModuleDict()
if "emotion" in tasks:
# Attention pooling + 2-layer MLP (matching LexiMind's emotion head)
self.heads["emotion"] = BertClassificationHead(
hidden_size=hidden_size,
num_labels=num_emotions,
pooling="attention",
hidden_dim=hidden_size // 2, # 384, same ratio as LexiMind
dropout=0.1,
)
if "topic" in tasks:
# Mean pooling + single linear (matching LexiMind's topic head)
self.heads["topic"] = BertClassificationHead(
hidden_size=hidden_size,
num_labels=num_topics,
pooling="mean",
hidden_dim=None,
dropout=0.1,
)
# Freeze bottom N encoder layers (matching LexiMind's strategy)
self._freeze_layers(freeze_layers)
def _freeze_layers(self, n: int) -> None:
"""Freeze embedding + bottom n encoder layers."""
# Freeze embeddings
for param in self.bert.embeddings.parameters():
param.requires_grad = False
# Freeze bottom n layers
for i in range(min(n, len(self.bert.encoder.layer))):
for param in self.bert.encoder.layer[i].parameters():
param.requires_grad = False
frozen = sum(1 for p in self.bert.parameters() if not p.requires_grad)
total = sum(1 for p in self.bert.parameters())
print(f" Frozen {frozen}/{total} BERT parameters (bottom {n} layers + embeddings)")
def forward(
self,
task: str,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state # (B, L, 768)
return self.heads[task](hidden_states, attention_mask)
def param_count(self) -> Dict[str, int]:
"""Count parameters by component."""
counts = {}
counts["bert_encoder"] = sum(p.numel() for p in self.bert.parameters())
counts["bert_trainable"] = sum(p.numel() for p in self.bert.parameters() if p.requires_grad)
for name, head in self.heads.items():
counts[f"head_{name}"] = sum(p.numel() for p in head.parameters())
counts["total"] = sum(p.numel() for p in self.parameters())
counts["trainable"] = sum(p.numel() for p in self.parameters() if p.requires_grad)
return counts
# Training
class BertTrainer:
"""Trainer supporting single-task and multi-task BERT training."""
def __init__(
self,
model: BertBaseline,
config: BertBaselineConfig,
train_loaders: Dict[str, DataLoader],
val_loaders: Dict[str, DataLoader],
device: torch.device,
mode: str,
):
self.model = model
self.config = config
self.train_loaders = train_loaders
self.val_loaders = val_loaders
self.device = device
self.mode = mode
# Optimizer
self.optimizer = AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=config.lr,
weight_decay=config.weight_decay,
betas=config.betas,
eps=config.eps,
)
# Calculate total training steps
if len(train_loaders) > 1:
# Multi-task: use temperature-sampled steps
sizes = {k: len(v) for k, v in train_loaders.items()}
total_batches = sum(sizes.values())
else:
total_batches = sum(len(v) for v in train_loaders.values())
self.steps_per_epoch = total_batches // config.gradient_accumulation_steps
self.total_steps = self.steps_per_epoch * config.max_epochs
# LR scheduler: linear warmup + cosine decay (matching LexiMind)
warmup_scheduler = LinearLR(
self.optimizer,
start_factor=1e-8 / config.lr,
end_factor=1.0,
total_iters=config.warmup_steps,
)
cosine_scheduler = CosineAnnealingLR(
self.optimizer,
T_max=max(self.total_steps - config.warmup_steps, 1),
eta_min=config.lr * 0.1, # Decay to 10% of peak (matching LexiMind)
)
self.scheduler = SequentialLR(
self.optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[config.warmup_steps],
)
# Mixed precision
self.scaler = GradScaler(enabled=config.use_amp)
# Loss functions
self.emotion_loss_fn = nn.BCEWithLogitsLoss()
self.topic_loss_fn = nn.CrossEntropyLoss()
# Tracking
self.global_step = 0
self.best_metric = -float("inf")
self.patience_counter = 0
self.training_history: List[Dict[str, Any]] = []
def _compute_loss(self, task: str, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
if task == "emotion":
return self.emotion_loss_fn(logits, labels)
else:
return self.topic_loss_fn(logits, labels)
def _get_task_weight(self, task: str) -> float:
if self.mode != "multitask":
return 1.0
if task == "topic":
return self.config.topic_weight
return self.config.emotion_weight
def _make_multitask_iterator(self):
"""Temperature-based task sampling (matching LexiMind)."""
sizes = {k: len(v.dataset) for k, v in self.train_loaders.items()}
alpha = self.config.task_sampling_alpha
# Compute sampling probabilities
raw = {k: s ** (1.0 / alpha) for k, s in sizes.items()}
total = sum(raw.values())
probs = {k: v / total for k, v in raw.items()}
# Create iterators
iters = {k: iter(v) for k, v in self.train_loaders.items()}
tasks = list(probs.keys())
weights = [probs[t] for t in tasks]
while True:
task = random.choices(tasks, weights=weights, k=1)[0]
try:
batch = next(iters[task])
except StopIteration:
iters[task] = iter(self.train_loaders[task])
batch = next(iters[task])
yield task, batch
def train_epoch(self, epoch: int) -> Dict[str, float]:
"""Train one epoch."""
self.model.train()
self.optimizer.zero_grad()
epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders}
if len(self.train_loaders) > 1:
# Multi-task: temperature sampling
iterator = self._make_multitask_iterator()
total_batches = sum(len(v) for v in self.train_loaders.values())
else:
# Single-task: iterate normally
task_name = list(self.train_loaders.keys())[0]
iterator = ((task_name, batch) for batch in self.train_loaders[task_name])
total_batches = len(self.train_loaders[task_name])
pbar = tqdm(total=total_batches, desc=f"Epoch {epoch + 1}/{self.config.max_epochs}")
for step_in_epoch in range(total_batches):
task, batch = next(iterator)
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
# Forward pass with AMP
with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp):
logits = self.model(task, input_ids, attention_mask)
loss = self._compute_loss(task, logits, labels)
loss = loss * self._get_task_weight(task)
loss = loss / self.config.gradient_accumulation_steps
# Backward
self.scaler.scale(loss).backward()
epoch_losses[task].append(loss.item() * self.config.gradient_accumulation_steps)
# Optimizer step (every N accumulation steps)
if (step_in_epoch + 1) % self.config.gradient_accumulation_steps == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.config.gradient_clip_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
self.scheduler.step()
self.global_step += 1
pbar.set_postfix(
{
f"{task}_loss": f"{epoch_losses[task][-1]:.4f}",
"lr": f"{self.scheduler.get_last_lr()[0]:.2e}",
}
)
pbar.update(1)
pbar.close()
# Aggregate
results = {}
for task, losses in epoch_losses.items():
if losses:
results[f"train_{task}_loss"] = sum(losses) / len(losses)
return results
@torch.no_grad()
def validate(self) -> Dict[str, Any]:
"""Run validation across all tasks."""
self.model.eval()
results: Dict[str, Any] = {}
for task, loader in self.val_loaders.items():
all_logits = []
all_labels = []
total_loss = 0.0
n_batches = 0
for batch in loader:
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
with autocast(dtype=torch.bfloat16, enabled=self.config.use_amp):
logits = self.model(task, input_ids, attention_mask)
loss = self._compute_loss(task, logits, labels)
total_loss += loss.item()
n_batches += 1
all_logits.append(logits.float().cpu())
all_labels.append(labels.float().cpu())
all_logits_t = torch.cat(all_logits, dim=0)
all_labels_t = torch.cat(all_labels, dim=0)
results[f"val_{task}_loss"] = total_loss / max(n_batches, 1)
if task == "emotion":
preds = (torch.sigmoid(all_logits_t) > self.config.emotion_threshold).int()
targets = all_labels_t.int()
results["val_emotion_sample_f1"] = multilabel_f1(preds, targets)
results["val_emotion_macro_f1"] = multilabel_macro_f1(preds, targets)
results["val_emotion_micro_f1"] = multilabel_micro_f1(preds, targets)
# Store raw logits for threshold tuning later
results["_emotion_logits"] = all_logits_t
results["_emotion_labels"] = all_labels_t
elif task == "topic":
preds = all_logits_t.argmax(dim=1).numpy()
targets = all_labels_t.long().numpy()
results["val_topic_accuracy"] = float(accuracy_score(targets, preds))
results["val_topic_macro_f1"] = float(
f1_score(targets, preds, average="macro", zero_division=0)
)
# Combined metric for early stopping / checkpointing
metric_parts = []
if "val_emotion_sample_f1" in results:
metric_parts.append(results["val_emotion_sample_f1"])
if "val_topic_accuracy" in results:
metric_parts.append(results["val_topic_accuracy"])
results["val_combined_metric"] = sum(metric_parts) / max(len(metric_parts), 1)
return results
def save_checkpoint(self, path: Path, epoch: int, metrics: Dict[str, Any]) -> None:
"""Save model checkpoint."""
path.parent.mkdir(parents=True, exist_ok=True)
# Filter out tensors from metrics
clean_metrics = {k: v for k, v in metrics.items() if not k.startswith("_")}
torch.save(
{
"epoch": epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"metrics": clean_metrics,
"config": {
"mode": self.mode,
"tasks": self.model.tasks,
"model_name": self.config.model_name,
},
},
path,
)
def train(self) -> Dict[str, Any]:
"""Full training loop."""
print(f"\n{'=' * 60}")
print(f"Training BERT Baseline — Mode: {self.mode}")
print(f"{'=' * 60}")
param_counts = self.model.param_count()
print(f" Total parameters: {param_counts['total']:,}")
print(f" Trainable parameters: {param_counts['trainable']:,}")
for name, count in param_counts.items():
if name.startswith("head_"):
print(f" {name}: {count:,}")
print(f" Steps/epoch: {self.steps_per_epoch}")
print(f" Total steps: {self.total_steps}")
print()
all_results: Dict[str, Any] = {"mode": self.mode, "epochs": []}
start_time = time.time()
for epoch in range(self.config.max_epochs):
epoch_start = time.time()
# Train
train_metrics = self.train_epoch(epoch)
# Validate
val_metrics = self.validate()
epoch_time = time.time() - epoch_start
# Log
epoch_result = {
"epoch": epoch + 1,
"time_seconds": epoch_time,
**train_metrics,
**{k: v for k, v in val_metrics.items() if not k.startswith("_")},
}
all_results["epochs"].append(epoch_result)
self.training_history.append(epoch_result)
# Print summary
print(f"\n Epoch {epoch + 1} ({epoch_time:.0f}s):")
for k, v in sorted(epoch_result.items()):
if k not in ("epoch", "time_seconds") and isinstance(v, float):
print(f" {k}: {v:.4f}")
# Checkpointing
combined = val_metrics["val_combined_metric"]
if combined > self.best_metric:
self.best_metric = combined
self.patience_counter = 0
self.save_checkpoint(
self.config.checkpoint_dir / self.mode / "best.pt",
epoch,
val_metrics,
)
print(f" New best model (combined metric: {combined:.4f})")
else:
self.patience_counter += 1
print(
f" No improvement ({self.patience_counter}/{self.config.early_stopping_patience})"
)
# Always save epoch checkpoint
self.save_checkpoint(
self.config.checkpoint_dir / self.mode / f"epoch_{epoch + 1}.pt",
epoch,
val_metrics,
)
# Early stopping
if self.patience_counter >= self.config.early_stopping_patience:
print(f"\n Early stopping triggered at epoch {epoch + 1}")
all_results["early_stopped"] = True
all_results["best_epoch"] = epoch + 1 - self.config.early_stopping_patience
break
total_time = time.time() - start_time
all_results["total_time_seconds"] = total_time
all_results["total_time_human"] = f"{total_time / 3600:.1f}h"
if "early_stopped" not in all_results:
all_results["early_stopped"] = False
all_results["best_epoch"] = (
epoch + 1 - self.patience_counter if self.patience_counter > 0 else epoch + 1
)
all_results["param_counts"] = param_counts
print(f"\n Training complete in {total_time / 3600:.1f}h")
print(f" Best combined metric: {self.best_metric:.4f}")
return all_results
# Evaluation
def evaluate_bert_model(
model: BertBaseline,
val_loaders: Dict[str, DataLoader],
device: torch.device,
config: BertBaselineConfig,
emotion_classes: Optional[List[str]] = None,
topic_classes: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Full evaluation with the same metrics as LexiMind's evaluate.py."""
model.eval()
results: Dict[str, Any] = {}
with torch.no_grad():
for task, loader in val_loaders.items():
all_logits = []
all_labels = []
for batch in tqdm(loader, desc=f"Evaluating {task}"):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
with autocast(dtype=torch.bfloat16, enabled=config.use_amp):
logits = model(task, input_ids, attention_mask)
all_logits.append(logits.float().cpu())
all_labels.append(labels.float().cpu())
all_logits_t = torch.cat(all_logits, dim=0)
all_labels_t = torch.cat(all_labels, dim=0)
if task == "emotion":
# Default threshold
preds_default = (torch.sigmoid(all_logits_t) > config.emotion_threshold).int()
targets = all_labels_t.int()
results["emotion"] = {
"default_threshold": config.emotion_threshold,
"sample_avg_f1": multilabel_f1(preds_default, targets),
"macro_f1": multilabel_macro_f1(preds_default, targets),
"micro_f1": multilabel_micro_f1(preds_default, targets),
}
# Per-class metrics
if emotion_classes:
per_class = multilabel_per_class_metrics(
preds_default, targets, emotion_classes
)
results["emotion"]["per_class"] = per_class
# Threshold tuning
best_thresholds, tuned_macro = tune_per_class_thresholds(all_logits_t, all_labels_t)
tuned_preds = torch.zeros_like(all_logits_t)
probs = torch.sigmoid(all_logits_t)
for c in range(all_logits_t.shape[1]):
tuned_preds[:, c] = (probs[:, c] >= best_thresholds[c]).float()
tuned_preds = tuned_preds.int()
results["emotion"]["tuned_macro_f1"] = tuned_macro
results["emotion"]["tuned_sample_avg_f1"] = multilabel_f1(tuned_preds, targets)
results["emotion"]["tuned_micro_f1"] = multilabel_micro_f1(tuned_preds, targets)
# Bootstrap CI on sample-avg F1
per_sample_f1 = []
for i in range(preds_default.shape[0]):
p = preds_default[i].float()
g = targets[i].float()
tp = (p * g).sum()
prec = tp / p.sum().clamp(min=1)
rec = tp / g.sum().clamp(min=1)
f = (2 * prec * rec) / (prec + rec).clamp(min=1e-8)
per_sample_f1.append(f.item())
mean_f1, ci_low, ci_high = bootstrap_confidence_interval(per_sample_f1)
results["emotion"]["sample_avg_f1_ci"] = [ci_low, ci_high]
elif task == "topic":
preds = all_logits_t.argmax(dim=1).numpy()
targets = all_labels_t.long().numpy()
acc = float(accuracy_score(targets, preds))
macro_f1 = float(f1_score(targets, preds, average="macro", zero_division=0))
results["topic"] = {
"accuracy": acc,
"macro_f1": macro_f1,
}
# Per-class metrics
if topic_classes:
report = classification_report(
targets,
preds,
target_names=topic_classes,
output_dict=True,
zero_division=0,
)
results["topic"]["per_class"] = {
name: {
"precision": report[name]["precision"],
"recall": report[name]["recall"],
"f1": report[name]["f1-score"],
"support": report[name]["support"],
}
for name in topic_classes
if name in report
}
# Bootstrap CI on accuracy
per_sample_correct = (preds == targets).astype(float).tolist()
mean_acc, ci_low, ci_high = bootstrap_confidence_interval(per_sample_correct)
results["topic"]["accuracy_ci"] = [ci_low, ci_high]
return results
# Main Pipeline
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_data(config: BertBaselineConfig):
"""Load all datasets and create label encoders."""
data_dir = config.data_dir
# Load emotion data
emo_train = load_emotion_jsonl(str(data_dir / "emotion" / "train.jsonl"))
emo_val_path = data_dir / "emotion" / "validation.jsonl"
if not emo_val_path.exists():
emo_val_path = data_dir / "emotion" / "val.jsonl"
emo_val = load_emotion_jsonl(str(emo_val_path))
# Load topic data
top_train = load_topic_jsonl(str(data_dir / "topic" / "train.jsonl"))
top_val_path = data_dir / "topic" / "validation.jsonl"
if not top_val_path.exists():
top_val_path = data_dir / "topic" / "val.jsonl"
top_val = load_topic_jsonl(str(top_val_path))
# Fit label encoders on training data (same as LexiMind)
binarizer = MultiLabelBinarizer()
binarizer.fit([ex.emotions for ex in emo_train])
label_encoder = LabelEncoder()
label_encoder.fit([ex.topic for ex in top_train])
print(
f" Emotion: {len(emo_train)} train, {len(emo_val)} val, {len(binarizer.classes_)} classes"
)
print(
f" Topic: {len(top_train)} train, {len(top_val)} val, {len(label_encoder.classes_)} classes"
)
print(f" Emotion classes: {list(binarizer.classes_)[:5]}...")
print(f" Topic classes: {list(label_encoder.classes_)}")
return {
"emotion_train": emo_train,
"emotion_val": emo_val,
"topic_train": top_train,
"topic_val": top_val,
"binarizer": binarizer,
"label_encoder": label_encoder,
}
def run_experiment(mode: str, config: BertBaselineConfig) -> Dict[str, Any]:
"""Run a single experiment (single-topic, single-emotion, or multitask)."""
print(f"\n{'═' * 60}")
print(f" BERT BASELINE EXPERIMENT: {mode.upper()}")
print(f"{'═' * 60}")
set_seed(config.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
if torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# CUDA optimizations
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
if hasattr(torch.backends, "cuda"):
torch.backends.cuda.matmul.allow_tf32 = True
# Load tokenizer
print(f"\n Loading tokenizer: {config.model_name}")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# Load data
print(" Loading datasets...")
data = load_data(config)
# Determine tasks for this mode
if mode == "single-topic":
tasks = ["topic"]
elif mode == "single-emotion":
tasks = ["emotion"]
else:
tasks = ["emotion", "topic"]
# Create datasets
train_loaders: Dict[str, DataLoader] = {}
val_loaders: Dict[str, DataLoader] = {}
if "emotion" in tasks:
emo_train_ds = BertEmotionDataset(
data["emotion_train"], tokenizer, data["binarizer"], config.max_length
)
emo_val_ds = BertEmotionDataset(
data["emotion_val"], tokenizer, data["binarizer"], config.max_length
)
train_loaders["emotion"] = DataLoader(
emo_train_ds,
batch_size=config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)
val_loaders["emotion"] = DataLoader(
emo_val_ds,
batch_size=config.batch_size * 2,
shuffle=False,
num_workers=4,
pin_memory=True,
)
if "topic" in tasks:
top_train_ds = BertTopicDataset(
data["topic_train"], tokenizer, data["label_encoder"], config.max_length
)
top_val_ds = BertTopicDataset(
data["topic_val"], tokenizer, data["label_encoder"], config.max_length
)
train_loaders["topic"] = DataLoader(
top_train_ds,
batch_size=config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)
val_loaders["topic"] = DataLoader(
top_val_ds,
batch_size=config.batch_size * 2,
shuffle=False,
num_workers=4,
pin_memory=True,
)
# Create model
print(f"\n Creating model with tasks: {tasks}")
model = BertBaseline(
model_name=config.model_name,
num_emotions=len(data["binarizer"].classes_),
num_topics=len(data["label_encoder"].classes_),
tasks=tasks,
freeze_layers=config.freeze_layers,
).to(device)
# Train
trainer = BertTrainer(model, config, train_loaders, val_loaders, device, mode)
training_results = trainer.train()
# Load best checkpoint for final evaluation
best_path = config.checkpoint_dir / mode / "best.pt"
if best_path.exists():
print("\n Loading best checkpoint for final evaluation...")
checkpoint = torch.load(best_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
# Full evaluation
print("\n Running final evaluation...")
eval_results = evaluate_bert_model(
model,
val_loaders,
device,
config,
emotion_classes=list(data["binarizer"].classes_) if "emotion" in tasks else None,
topic_classes=list(data["label_encoder"].classes_) if "topic" in tasks else None,
)
# Combine results
final_results = {
"mode": mode,
"model": config.model_name,
"tasks": tasks,
"training": training_results,
"evaluation": eval_results,
}
# Save results
output_path = config.output_dir / f"{mode}_results.json"
output_path.parent.mkdir(parents=True, exist_ok=True)
# Remove non-serializable fields
def make_serializable(obj):
if isinstance(obj, dict):
return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")}
if isinstance(obj, list):
return [make_serializable(item) for item in obj]
if isinstance(obj, (np.integer, np.int64)):
return int(obj)
if isinstance(obj, (np.floating, np.float64)):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return obj
with open(output_path, "w") as f:
json.dump(make_serializable(final_results), f, indent=2)
print(f"\n Results saved to {output_path}")
return final_results
def print_comparison_summary(all_results: Dict[str, Dict[str, Any]]) -> None:
"""Print a side-by-side comparison of all experiments."""
print(f"\n{'═' * 70}")
print(" BERT BASELINE COMPARISON SUMMARY")
print(f"{'═' * 70}")
# Header
modes = list(all_results.keys())
header = f"{'Metric':<30}" + "".join(f"{m:>16}" for m in modes) + f"{'LexiMind':>16}"
print(f"\n {header}")
print(f" {'─' * len(header)}")
# LexiMind reference values
lexmind = {
"topic_accuracy": 0.8571,
"topic_macro_f1": 0.8539,
"emotion_sample_f1": 0.3523,
"emotion_macro_f1": 0.1432,
"emotion_micro_f1": 0.4430,
"emotion_tuned_macro_f1": 0.2936,
}
# Topic metrics
print(f"\n {'Topic Classification':}")
for metric_name, display_name in [
("accuracy", "Accuracy"),
("macro_f1", "Macro F1"),
]:
row = f" {display_name:<30}"
for mode in modes:
eval_data = all_results[mode].get("evaluation", {})
topic = eval_data.get("topic", {})
val = topic.get(metric_name, None)
row += f"{val:>16.4f}" if val is not None else f"{'—':>16}"
lm_key = f"topic_{metric_name}"
row += f"{lexmind.get(lm_key, 0):>16.4f}"
print(row)
# Emotion metrics
print(f"\n {'Emotion Detection':}")
for metric_name, display_name in [
("sample_avg_f1", "Sample-avg F1 (τ=0.3)"),
("macro_f1", "Macro F1 (τ=0.3)"),
("micro_f1", "Micro F1 (τ=0.3)"),
("tuned_macro_f1", "Tuned Macro F1"),
("tuned_sample_avg_f1", "Tuned Sample-avg F1"),
]:
row = f" {display_name:<30}"
for mode in modes:
eval_data = all_results[mode].get("evaluation", {})
emo = eval_data.get("emotion", {})
val = emo.get(metric_name, None)
row += f"{val:>16.4f}" if val is not None else f"{'—':>16}"
lm_key = f"emotion_{metric_name}"
row += f"{lexmind.get(lm_key, 0):>16.4f}"
print(row)
# Training time
print(f"\n {'Training Time':}")
row = f" {'Hours':<30}"
for mode in modes:
t = all_results[mode].get("training", {}).get("total_time_seconds", 0) / 3600
row += f"{t:>15.1f}h"
row += f"{'~9.0h':>16}"
print(row)
print(f"\n{'═' * 70}\n")
def main():
parser = argparse.ArgumentParser(description="BERT Baseline Training for LexiMind")
parser.add_argument(
"--mode",
type=str,
required=True,
choices=["single-topic", "single-emotion", "multitask", "all"],
help="Training mode",
)
parser.add_argument("--epochs", type=int, default=None, help="Override max epochs")
parser.add_argument("--lr", type=float, default=None, help="Override learning rate")
parser.add_argument("--batch-size", type=int, default=None, help="Override batch size")
parser.add_argument(
"--model", type=str, default="bert-base-uncased", help="HuggingFace model name"
)
args = parser.parse_args()
config = BertBaselineConfig()
config.model_name = args.model
if args.epochs is not None:
config.max_epochs = args.epochs
if args.lr is not None:
config.lr = args.lr
if args.batch_size is not None:
config.batch_size = args.batch_size
if args.mode == "all":
modes = ["single-topic", "single-emotion", "multitask"]
else:
modes = [args.mode]
all_results: Dict[str, Dict[str, Any]] = {}
for mode in modes:
results = run_experiment(mode, config)
all_results[mode] = results
# Clear GPU memory between experiments
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Save combined results
if len(all_results) > 1:
combined_path = config.output_dir / "combined_results.json"
def make_serializable(obj):
if isinstance(obj, dict):
return {k: make_serializable(v) for k, v in obj.items() if not k.startswith("_")}
if isinstance(obj, list):
return [make_serializable(item) for item in obj]
if isinstance(obj, (np.integer, np.int64)):
return int(obj)
if isinstance(obj, (np.floating, np.float64)):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return obj
with open(combined_path, "w") as f:
json.dump(make_serializable(all_results), f, indent=2)
print(f" Combined results saved to {combined_path}")
print_comparison_summary(all_results)
if __name__ == "__main__":
main()