""" training/train.py ================= Trains the stress detection model on the unified dataset with: - document-level stratified splitting (label + domain) - regularization (dropout, weight decay, label smoothing) - validation F1 tracking, early stopping, and threshold calibration - optional happy/neutral evaluation set to monitor false positives The CNN tokenizer used here is intentionally identical to ``_simple_tokenize`` in ``api/main.py`` (hash-based, vocab_size=10000) so that saved checkpoints work correctly at inference time without any vocabulary file. Usage (local or Google Colab) ------------------------------ # 1. Prepare data (only needed once) python data_preprocessing.py # 2. Train (CNN) python training/train.py # Optional flags python training/train.py --epochs 15 --batch-size 64 --lr 1e-3 \ --data data/processed/unified_stress.csv \ --output checkpoints/model.pt After training the checkpoint is automatically picked up by the API server (``uvicorn api.main:app``). """ from __future__ import annotations import argparse import hashlib import math import os import random import sys import time import numpy as np import pandas as pd import torch import torch.nn as nn from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score from torch.utils.data import DataLoader, Dataset # Allow running from the repo root as well as from inside training/ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.architecture import ( # noqa: E402 DeBERTaStressClassifier, MiniLMStressClassifier, OptimizedMultichannelCNN, ) from utils.sentiment import get_sentiment_score # noqa: E402 from utils.text_preprocessing import clean_text # noqa: E402 # --------------------------------------------------------------------------- # Constants — must stay in sync with api/main.py # --------------------------------------------------------------------------- VOCAB_SIZE = 10_000 # _DEFAULT_VOCAB_SIZE in api/main.py CHUNK_SIZE = 200 # _CHUNK_SIZE in api/main.py STRIDE = 50 # Stride 50 with CHUNK_SIZE=200 → 150-token overlap (75% of chunk) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) THRESHOLD_MIN = 0.05 THRESHOLD_MAX = 0.95 THRESHOLD_STEPS = 19 MIN_RECALL_THRESHOLD = 0.6 TRANSFORMER_LR = 2e-5 # --------------------------------------------------------------------------- # Tokenizer (identical to _simple_tokenize in api/main.py) # --------------------------------------------------------------------------- def _tokenize(text: str) -> list[int]: """Hash each whitespace-delimited token into [1, VOCAB_SIZE-1]. Index 0 is reserved for padding. Uses ``hashlib.md5`` for a fully deterministic mapping that is stable across all platforms, Python processes, and interpreter restarts (unlike ``hash()`` which is randomised by ``PYTHONHASHSEED``). This guarantees that a model trained on Colab produces identical token IDs when served on Windows. """ tokens = text.lower().split() return [ int(hashlib.md5(t.encode("utf-8"), usedforsecurity=False).hexdigest(), 16) % (VOCAB_SIZE - 1) + 1 for t in tokens ] # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- def _load_csv(data_path: str) -> pd.DataFrame: df = pd.read_csv(data_path) if "text" not in df.columns or "label" not in df.columns: raise ValueError(f"{data_path} must have 'text' and 'label' columns.") df = df.dropna(subset=["text"]).reset_index(drop=True) df["text"] = df["text"].astype(str).str.strip() df = df[df["text"] != ""].reset_index(drop=True) # Apply the same preprocessing pipeline used at inference time so that # the model trains on clean tokens identical to what it will see in # production (URLs, HTML, emojis, repeated chars all normalised). df["text"] = df["text"].apply(clean_text) df = df[df["text"] != ""].reset_index(drop=True) df["label"] = df["label"].astype(int) if "domain" not in df.columns: df["domain"] = "unknown" return df def _describe_dataset(df: pd.DataFrame) -> None: print("\nDataset summary") print("-" * 50) print("Label distribution:") print(df["label"].value_counts().sort_index().to_string()) print("\nDomain distribution:") print(df["domain"].value_counts().to_string()) print("\nLabel by domain:") table = pd.crosstab(df["domain"], df["label"]) print(table.to_string()) missing = [] for domain in table.index: row = table.loc[domain] for label in (0, 1): if row.get(label, 0) == 0: missing.append((domain, label)) if missing: print( "\nWarning: Some domains contain only one label. " "Consider adding happy/neutral negatives to reduce false positives." ) def _stratified_split( df: pd.DataFrame, val_ratio: float, seed: int ) -> tuple[pd.DataFrame, pd.DataFrame]: rng = random.Random(seed) group_keys = (df["label"].astype(str) + "|" + df["domain"].astype(str)).tolist() groups: dict[str, list[int]] = {} for idx, key in enumerate(group_keys): groups.setdefault(key, []).append(idx) train_idx: list[int] = [] val_idx: list[int] = [] for key in sorted(groups.keys()): indices = groups[key] rng.shuffle(indices) n_val = int(round(len(indices) * val_ratio)) if len(indices) > 1: # Keep at least one train sample and one val sample. n_val = max(1, min(n_val, len(indices) - 1)) else: n_val = 0 if n_val: val_idx.extend(indices[:n_val]) train_idx.extend(indices[n_val:]) else: train_idx.extend(indices) if not val_idx: fallback = max(1, int(round(len(train_idx) * val_ratio))) val_idx = train_idx[:fallback] train_idx = train_idx[fallback:] train_df = df.iloc[train_idx].reset_index(drop=True) val_df = df.iloc[val_idx].reset_index(drop=True) return train_df, val_df class _StressChunkDataset(Dataset): """CNN dataset with sliding-window chunking. Optionally accepts per-sample ``rewards`` (non-negative floats) that are used by ``weighted_loss`` during RL-style fine-tuning via ``training/retrain.py``. When ``rewards`` is ``None`` every sample gets an implicit weight of ``1.0`` and the standard ``CrossEntropyLoss`` is applied unchanged. """ def __init__( self, texts: list[str], labels: list[int], features: np.ndarray | None = None, chunk_size: int = CHUNK_SIZE, stride: int = STRIDE, rewards: list[float] | None = None, ) -> None: self._chunks: list[torch.Tensor] = [] self._labels: list[int] = [] self._features: list[torch.Tensor] | None = None self._rewards: list[float] | None = None if rewards is not None: if len(rewards) != len(texts): raise ValueError("rewards must align with texts.") self._rewards = [] feature_rows = None if features is not None: if len(features) != len(texts): raise ValueError("Features must align with texts.") self._features = [] feature_rows = [torch.tensor(row, dtype=torch.float) for row in features] for idx, (text, label) in enumerate(zip(texts, labels)): token_ids = _tokenize(text) label = int(label) feature_tensor = feature_rows[idx] if feature_rows is not None else None reward_val = rewards[idx] if rewards is not None else None if len(token_ids) == 0: self._chunks.append(torch.zeros(chunk_size, dtype=torch.long)) self._labels.append(label) if self._features is not None: self._features.append(feature_tensor) if self._rewards is not None: self._rewards.append(reward_val) continue for start in range(0, len(token_ids), stride): end = start + chunk_size chunk = token_ids[start:end] if len(chunk) < chunk_size: chunk = chunk + [0] * (chunk_size - len(chunk)) self._chunks.append(torch.tensor(chunk, dtype=torch.long)) self._labels.append(label) if self._features is not None: self._features.append(feature_tensor) if self._rewards is not None: self._rewards.append(reward_val) if end >= len(token_ids): break def __len__(self) -> int: return len(self._chunks) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: item = { "input_ids": self._chunks[idx], "label": torch.tensor(self._labels[idx], dtype=torch.long), } if self._features is not None: item["features"] = self._features[idx] if self._rewards is not None: item["reward"] = torch.tensor(self._rewards[idx], dtype=torch.float) return item class _TransformerDataset(Dataset): """Transformer dataset with tokenizer-based encoding.""" def __init__( self, texts: list[str], labels: list[int], tokenizer, max_length: int, ) -> None: encodings = tokenizer( texts, truncation=True, padding="max_length", max_length=max_length, ) self._encodings = encodings self._labels = labels self._texts = texts def __len__(self) -> int: return len(self._labels) def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: item = {k: torch.tensor(v[idx]) for k, v in self._encodings.items()} item["label"] = torch.tensor(self._labels[idx], dtype=torch.long) sentiment = get_sentiment_score(self._texts[idx]) item["sentiment"] = torch.tensor(sentiment, dtype=torch.float) return item # --------------------------------------------------------------------------- # Training helpers # --------------------------------------------------------------------------- def _accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float: preds = logits.argmax(dim=-1) return (preds == labels).float().mean().item() def weighted_loss( logits: torch.Tensor, labels: torch.Tensor, rewards: torch.Tensor, ) -> torch.Tensor: """Per-sample reward-weighted cross-entropy loss. Good predictions (high reward) are reinforced; bad predictions (which were corrected in the dataset) are penalised by the same magnitude, implementing the RL-style update described in the problem statement. Parameters ---------- logits : Tensor, shape ``(B, C)`` labels : Tensor, shape ``(B,)`` rewards : Tensor, shape ``(B,)`` Non-negative scalar weight for each sample (e.g. ``1.5`` for feedback-derived samples, ``1.0`` for baseline samples). Returns ------- Tensor Scalar mean weighted loss. """ per_sample = nn.functional.cross_entropy(logits, labels, reduction="none") return (per_sample * rewards).mean() class FocalLoss(nn.Module): """Focal loss (Lin et al. 2017) for binary stress classification. Reduces the loss contribution of easy, well-classified examples and focuses training on hard, ambiguous inputs — preventing a handful of common keyword patterns from dominating the gradient updates. Parameters ---------- gamma : float Focusing parameter. ``gamma=0`` recovers standard cross-entropy. ``gamma=2`` is the value recommended by Lin et al. weight : Tensor or None Per-class weight tensor forwarded to the underlying cross-entropy (handles label imbalance in the same way as ``nn.CrossEntropyLoss``). label_smoothing : float Label smoothing coefficient forwarded to cross-entropy. """ def __init__( self, gamma: float = 2.0, weight: torch.Tensor | None = None, label_smoothing: float = 0.0, ) -> None: super().__init__() self.gamma = gamma self._ce = nn.CrossEntropyLoss( weight=weight, reduction="none", label_smoothing=label_smoothing, ) def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: ce = self._ce(logits, targets) # (B,) per-sample CE loss pt = torch.exp(-ce) # p(correct class) return ((1.0 - pt) ** self.gamma * ce).mean() def _run_epoch( model: nn.Module, loader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer | None, device: torch.device, is_train: bool, collect_probs: bool = False, warmup_scheduler=None, fp_penalty_weight: float = 0.0, ) -> tuple[float, float, np.ndarray | None, np.ndarray | None]: model.train(is_train) total_loss = 0.0 total_acc = 0.0 n_batches = 0 all_probs: list[torch.Tensor] = [] all_labels: list[torch.Tensor] = [] ctx = torch.enable_grad() if is_train else torch.no_grad() with ctx: for batch in loader: input_ids = batch["input_ids"].to(device) labels = batch["label"].to(device) attention_mask = batch.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(device) sentiment = batch.get("sentiment") if sentiment is not None: sentiment = sentiment.to(device) output = model( input_ids=input_ids, attention_mask=attention_mask, sentiment=sentiment, ) else: output = model(input_ids=input_ids, attention_mask=attention_mask) else: features = batch.get("features") if features is not None: features = features.to(device) output = model(input_ids, aux_features=features) else: output = model(input_ids) logits = output["logits"] # Use reward-weighted loss when the batch provides per-sample # rewards (i.e. during RL-style fine-tuning from feedback data). reward_weights = batch.get("reward") if reward_weights is not None: reward_weights = reward_weights.to(device) loss = weighted_loss(logits, labels, reward_weights) else: loss = criterion(logits, labels) # Soft FPR penalty: penalise the mean stress probability assigned # to negative samples in this batch. Unlike threshold calibration # (which only acts at evaluation time), this pushes the gradient # directly toward lower false-positive rates during training. if is_train and fp_penalty_weight > 0.0: neg_mask = (labels == 0).float() n_neg = neg_mask.sum() if n_neg > 0: stress_probs = torch.softmax(logits, dim=-1)[:, 1] fpr_penalty = (stress_probs * neg_mask).sum() / n_neg loss = loss + fp_penalty_weight * fpr_penalty if is_train and optimizer is not None: optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() if warmup_scheduler is not None: warmup_scheduler.step() total_loss += loss.item() total_acc += _accuracy(logits, labels) n_batches += 1 if collect_probs: probs = torch.softmax(logits, dim=-1)[:, 1] all_probs.append(probs.detach().cpu()) all_labels.append(labels.detach().cpu()) if collect_probs and all_probs: probs_np = torch.cat(all_probs).numpy() labels_np = torch.cat(all_labels).numpy() else: probs_np = None labels_np = None return ( total_loss / max(n_batches, 1), total_acc / max(n_batches, 1), probs_np, labels_np, ) def _compute_metrics( labels: np.ndarray, probs: np.ndarray, threshold: float ) -> dict[str, float | list[list[int]]]: preds = (probs >= threshold).astype(int) precision = precision_score(labels, preds, zero_division=0) recall = recall_score(labels, preds, zero_division=0) f1 = f1_score(labels, preds, zero_division=0) cm = confusion_matrix(labels, preds, labels=[0, 1]) return { "precision": float(precision), "recall": float(recall), "f1": float(f1), "confusion_matrix": cm.tolist(), } def _find_best_threshold( labels: np.ndarray, probs: np.ndarray, max_fpr: float = 0.20, min_threshold: float = 0.50, ) -> tuple[float, dict[str, float | list[list[int]]]]: """Find the best decision threshold subject to FPR and minimum-threshold constraints. Why this matters ---------------- Optimising purely for F1 on a stress-heavy dataset causes the search to collapse to very low thresholds (~0.15) that label almost everything as stressed. Such a model has high recall but an FPR of 75–95%, making it clinically useless. Parameters ---------- labels : np.ndarray Ground-truth binary labels (0 = no stress, 1 = stress). probs : np.ndarray Model stress probabilities (softmax output for class 1). max_fpr : float Maximum false positive rate allowed. Thresholds that exceed this on the validation set are rejected regardless of their F1. min_threshold : float Hard lower bound on the selected threshold. The deployed model will never be more aggressive than this value. Returns ------- tuple ``(best_threshold, best_metrics)`` """ best_threshold = min_threshold best_metrics = _compute_metrics(labels, probs, best_threshold) best_f1 = best_metrics["f1"] for threshold in np.linspace(min_threshold, THRESHOLD_MAX, THRESHOLD_STEPS): t = float(threshold) metrics = _compute_metrics(labels, probs, t) cm = metrics["confusion_matrix"] tn, fp = cm[0][0], cm[0][1] fpr = fp / max(tn + fp, 1) if ( metrics["f1"] > best_f1 and metrics["recall"] >= MIN_RECALL_THRESHOLD and fpr <= max_fpr ): best_f1 = metrics["f1"] best_threshold = t best_metrics = metrics return best_threshold, best_metrics def _build_model_and_tokenizer( model_type: str, dropout: float, max_length: int, aux_dim: int = 0, ) -> tuple[nn.Module, object | None, str | None]: if model_type == "cnn": model = OptimizedMultichannelCNN( vocab_size=VOCAB_SIZE, embed_dim=128, num_filters=64, kernel_sizes=(2, 3, 5), num_classes=2, dropout=dropout, aux_dim=aux_dim, ) return model, None, None if model_type == "deberta": model = DeBERTaStressClassifier(dropout=dropout) elif model_type == "minilm": model = MiniLMStressClassifier(dropout=dropout) else: raise ValueError(f"Unknown model type: {model_type}") from transformers import AutoTokenizer model_name = model.MODEL_NAME tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.model_max_length = max_length return model, tokenizer, model_name def _load_eval_set(path: str) -> pd.DataFrame: df = pd.read_csv(path) if "text" not in df.columns or "label" not in df.columns: raise ValueError(f"{path} must have 'text' and 'label' columns.") df = df.dropna(subset=["text"]).reset_index(drop=True) df["text"] = df["text"].astype(str).str.strip() df = df[df["text"] != ""].reset_index(drop=True) df["label"] = df["label"].astype(int) return df def _prepare_feature_frame( df: pd.DataFrame, feature_cols: list[str], means: pd.Series | None = None, stds: pd.Series | None = None, ) -> tuple[np.ndarray, pd.Series, pd.Series]: features = df[feature_cols].apply(pd.to_numeric, errors="coerce") if means is None: means = features.mean() if stds is None: stds = features.std().replace(0, 1) features = features.fillna(means) normalized = (features - means) / stds return normalized.to_numpy(dtype=np.float32), means, stds # --------------------------------------------------------------------------- # Main training loop # --------------------------------------------------------------------------- def train( data_path: str, output_path: str, epochs: int, batch_size: int, lr: float, val_ratio: float, seed: int, device_str: str, model_type: str, dropout: float, weight_decay: float, label_smoothing: float, class_weighted: bool, patience: int, eval_set_path: str | None, max_length: int, max_fpr: float = 0.20, min_threshold: float = 0.50, fp_penalty_weight: float = 0.2, ) -> None: torch.manual_seed(seed) device = torch.device(device_str if torch.cuda.is_available() or device_str == "cpu" else "cpu") print(f"Using device: {device}") # ── Data ── print(f"\nLoading dataset from: {data_path}") df = _load_csv(data_path) _describe_dataset(df) train_df, val_df = _stratified_split(df, val_ratio, seed) candidate_cols = [ col for col in df.columns if col not in {"text", "label", "domain"} ] if candidate_cols: raw_candidates = df[candidate_cols] numeric_candidates = raw_candidates.apply( pd.to_numeric, errors="coerce" ) coerced = numeric_candidates.isna() & raw_candidates.notna() if coerced.any().any(): bad_cols = [ col for col in candidate_cols if coerced[col].any() ] print( "Warning: Non-numeric values coerced to NaN in columns: " + ", ".join(bad_cols) ) coverage = numeric_candidates.notna().mean() feature_cols = [ col for col in numeric_candidates.columns if coverage[col] >= 0.5 ] for col in feature_cols: df[col] = numeric_candidates[col] else: feature_cols = [] if feature_cols: print(f"Numeric feature columns detected: {len(feature_cols):,}") elif model_type == "cnn": print("No numeric feature columns detected; training text-only CNN.") print( f"\nDocuments: {len(df):,} | Train: {len(train_df):,} | Val: {len(val_df):,}" ) aux_dim = len(feature_cols) if model_type == "cnn" and feature_cols else 0 model, tokenizer, model_name = _build_model_and_tokenizer( model_type=model_type, dropout=dropout, max_length=max_length, aux_dim=aux_dim, ) model = model.to(device) train_features = None val_features = None eval_features = None feature_means = None feature_stds = None if feature_cols and model_type == "cnn": train_features, feature_means, feature_stds = _prepare_feature_frame( train_df, feature_cols ) val_features, _, _ = _prepare_feature_frame( val_df, feature_cols, feature_means, feature_stds ) elif feature_cols and model_type != "cnn": print( "Note: numeric features are available but only the CNN uses them." ) if model_type == "cnn": train_ds = _StressChunkDataset( train_df["text"].tolist(), train_df["label"].tolist(), features=train_features, ) val_ds = _StressChunkDataset( val_df["text"].tolist(), val_df["label"].tolist(), features=val_features, ) else: train_ds = _TransformerDataset( train_df["text"].tolist(), train_df["label"].tolist(), tokenizer, max_length=max_length, ) val_ds = _TransformerDataset( val_df["text"].tolist(), val_df["label"].tolist(), tokenizer, max_length=max_length, ) print(f"Chunks: {len(train_ds):,} train | {len(val_ds):,} val") train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=0 ) val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, num_workers=0 ) eval_loader = None if eval_set_path and os.path.isfile(eval_set_path): eval_df = _load_eval_set(eval_set_path) if model_type == "cnn": if feature_cols and feature_means is not None and feature_stds is not None: for col in feature_cols: if col not in eval_df.columns: eval_df[col] = np.nan eval_features, _, _ = _prepare_feature_frame( eval_df, feature_cols, feature_means, feature_stds ) eval_ds = _StressChunkDataset( eval_df["text"].tolist(), eval_df["label"].tolist(), features=eval_features, ) else: eval_ds = _TransformerDataset( eval_df["text"].tolist(), eval_df["label"].tolist(), tokenizer, max_length=max_length, ) eval_loader = DataLoader( eval_ds, batch_size=batch_size, shuffle=False, num_workers=0 ) print(f"Happy/neutral eval set: {len(eval_ds):,} samples") n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model parameters: {n_params:,}\n") class_weights = None if class_weighted: counts = train_df["label"].value_counts().reindex([0, 1], fill_value=0) total = counts.sum() weights = [ total / (2 * counts[0]) if counts[0] > 0 else 1.0, total / (2 * counts[1]) if counts[1] > 0 else 1.0, ] class_weights = torch.tensor(weights, dtype=torch.float, device=device) criterion = ( FocalLoss(gamma=2.0, weight=class_weights, label_smoothing=label_smoothing) if model_type == "cnn" else nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing) ) if model_type == "cnn": optimizer = torch.optim.Adam( model.parameters(), lr=lr, weight_decay=weight_decay ) # Cosine annealing with linear warmup: better convergence than # ReduceLROnPlateau for text CNNs — avoids premature learning-rate # collapses caused by noisy F1 on small validation sets. total_steps = len(train_loader) * epochs warmup_steps = min(int(0.1 * total_steps), 200) def _lr_lambda(current_step: int) -> float: """Linear warmup then cosine decay.""" if current_step < warmup_steps: return float(current_step) / max(warmup_steps, 1) progress = float(current_step - warmup_steps) / max( total_steps - warmup_steps, 1 ) return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress))) warmup_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, _lr_lambda ) scheduler = None # per-step updates only else: from transformers import get_linear_schedule_with_warmup optimizer = torch.optim.AdamW( model.parameters(), lr=TRANSFORMER_LR, weight_decay=weight_decay ) total_steps = len(train_loader) * epochs warmup_steps = int(0.1 * total_steps) warmup_scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, ) scheduler = None best_val_f1 = 0.0 best_threshold = min_threshold epochs_since_improve = 0 os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) _EPOCH_HEADER = ( f"{'Epoch':>6} {'Train Loss':>11} {'Train Acc':>10} " f"{'Val Loss':>9} {'Val F1':>8} {'Val Prec':>9} {'Val Rec':>8} " f"{'Val FPR':>8} {'Thresh':>7} {'Time':>6}" ) print(_EPOCH_HEADER) print("-" * len(_EPOCH_HEADER)) for epoch in range(1, epochs + 1): t0 = time.time() tr_loss, tr_acc, _, _ = _run_epoch( model, train_loader, criterion, optimizer, device, is_train=True, warmup_scheduler=warmup_scheduler, fp_penalty_weight=fp_penalty_weight, ) vl_loss, _, val_probs, val_labels = _run_epoch( model, val_loader, criterion, None, device, is_train=False, collect_probs=True, ) elapsed = time.time() - t0 if val_probs is None or val_labels is None: raise RuntimeError("Validation probabilities were not collected.") threshold, val_metrics = _find_best_threshold( val_labels, val_probs, max_fpr=max_fpr, min_threshold=min_threshold, ) val_f1 = val_metrics["f1"] val_precision = val_metrics["precision"] val_recall = val_metrics["recall"] val_cm = val_metrics["confusion_matrix"] val_tn, val_fp = val_cm[0][0], val_cm[0][1] val_fpr = val_fp / max(val_tn + val_fp, 1) if scheduler is not None: scheduler.step(val_f1) marker = " ←" if val_f1 > best_val_f1 else "" if val_f1 > best_val_f1: best_val_f1 = val_f1 best_threshold = threshold epochs_since_improve = 0 checkpoint = { "model_state_dict": model.state_dict(), "decision_threshold": float(best_threshold), "model_type": model_type, "dropout": float(dropout), } if model_type != "cnn" and model_name: checkpoint["model_name"] = model_name checkpoint["tokenizer_max_length"] = int(max_length) if model_type == "cnn": checkpoint["chunk_size"] = int(CHUNK_SIZE) checkpoint["stride"] = int(STRIDE) if feature_cols: checkpoint["feature_dim"] = len(feature_cols) checkpoint["feature_columns"] = feature_cols checkpoint["feature_means"] = feature_means.tolist() checkpoint["feature_stds"] = feature_stds.tolist() torch.save(checkpoint, output_path) else: epochs_since_improve += 1 print( f"{epoch:>6} {tr_loss:>11.4f} {tr_acc:>9.2%} " f"{vl_loss:>9.4f} {val_f1:>7.2%} {val_precision:>8.2%} " f"{val_recall:>7.2%} {val_fpr:>7.2%} {threshold:>7.2f} {elapsed:>5.1f}s{marker}" ) cm = val_metrics["confusion_matrix"] print( " Confusion matrix: " f"TN={cm[0][0]} FP={cm[0][1]} FN={cm[1][0]} TP={cm[1][1]}" ) if eval_loader is not None: _, _, eval_probs, eval_labels = _run_epoch( model, eval_loader, criterion, None, device, is_train=False, collect_probs=True, ) if eval_probs is not None and eval_labels is not None: eval_metrics = _compute_metrics( eval_labels, eval_probs, threshold ) eval_cm = eval_metrics["confusion_matrix"] tn, fp = eval_cm[0][0], eval_cm[0][1] fp_rate = fp / max(tn + fp, 1) print( " Happy/neutral eval — " f"FP rate: {fp_rate:.2%}, " f"F1: {eval_metrics['f1']:.2%}" ) if patience > 0 and epochs_since_improve >= patience: print( f"\nEarly stopping: no F1 improvement in {patience} epochs." ) break print(f"\nBest validation F1: {best_val_f1:.2%}") print(f"Best decision threshold: {best_threshold:.2f}") print(f"Checkpoint saved to: {output_path}") print( "\nTo use with the API, start the server normally:\n" " uvicorn api.main:app --host 0.0.0.0 --port 8000\n" "The checkpoint is loaded automatically on the first /analyze request." ) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Train stress detection models with calibrated thresholds.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument( "--data", default=os.path.join(ROOT_DIR, "data", "processed", "unified_stress.csv"), help="Path to the unified CSV produced by data_preprocessing.py", ) p.add_argument( "--output", default=os.path.join(ROOT_DIR, "checkpoints", "model.pt"), help="Destination path for the best checkpoint", ) p.add_argument( "--model", default="cnn", choices=["cnn", "deberta", "minilm"], help="Model backbone to train", ) p.add_argument("--epochs", type=int, default=10, help="Number of training epochs") p.add_argument("--batch-size", type=int, default=64, help="Mini-batch size") p.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate") p.add_argument( "--weight-decay", type=float, default=1e-4, help="Weight decay for Adam/AdamW", ) p.add_argument( "--dropout", type=float, default=0.3, help="Dropout probability (CNN) / classifier dropout (transformers)", ) p.add_argument( "--label-smoothing", type=float, default=0.0, help="Label smoothing for cross-entropy loss", ) p.add_argument( "--class-weighted", action=argparse.BooleanOptionalAction, default=True, help=( "Use inverse-frequency class weights (enabled by default). " "Balances the gradient contribution of the minority class (no-stress) " "which is heavily outnumbered in typical datasets. " "Pass --no-class-weighted to disable." ), ) p.add_argument( "--patience", type=int, default=3, help="Early stopping patience (epochs without F1 improvement)", ) p.add_argument("--val-ratio", type=float, default=0.1, help="Fraction of data used for validation") p.add_argument("--seed", type=int, default=42, help="Random seed") p.add_argument( "--eval-set", default=os.path.join(ROOT_DIR, "data", "eval", "happy_neutral_eval.csv"), help="Optional happy/neutral eval CSV (text,label)", ) p.add_argument( "--max-length", type=int, default=256, help="Max sequence length for transformer models", ) p.add_argument( "--device", default="cuda", choices=["cuda", "cpu"], help="Device to train on (falls back to cpu if CUDA is unavailable)", ) p.add_argument( "--max-fpr", type=float, default=0.20, help=( "Maximum false positive rate allowed during threshold calibration. " "Thresholds that produce FPR > this value are rejected, preventing " "the model from collapsing to very low decision thresholds." ), ) p.add_argument( "--min-threshold", type=float, default=0.50, help=( "Hard lower bound for the decision threshold. " "Ensures the model never classifies more than 50%% of the " "probability space as stressed by default." ), ) p.add_argument( "--fp-penalty", type=float, default=0.2, help=( "Weight of the soft false-positive penalty term added to the " "training loss each batch. At each step the mean stress " "probability assigned to negative (label=0) samples in the " "batch is multiplied by this weight and added to the main loss, " "directly penalising false positives during gradient descent. " "Set to 0.0 to disable. Increase (e.g. 0.4) when the " "happy/neutral FP rate remains high after training." ), ) return p.parse_args() if __name__ == "__main__": args = _parse_args() if not os.path.isfile(args.data): print(f"ERROR: Dataset not found: {args.data}") print("Run python data_preprocessing.py first to create the unified CSV.") sys.exit(1) train( data_path=args.data, output_path=args.output, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, val_ratio=args.val_ratio, seed=args.seed, device_str=args.device, model_type=args.model, dropout=args.dropout, weight_decay=args.weight_decay, label_smoothing=args.label_smoothing, class_weighted=args.class_weighted, patience=args.patience, eval_set_path=args.eval_set, max_length=args.max_length, max_fpr=args.max_fpr, min_threshold=args.min_threshold, fp_penalty_weight=args.fp_penalty, )