StressDetect / training /train.py
Ace-119's picture
Add training script with focal loss, FPR-constrained threshold calibration
453f48c
"""
training/train.py
=================
Trains the stress detection model on the unified dataset with:
- document-level stratified splitting (label + domain)
- regularization (dropout, weight decay, label smoothing)
- validation F1 tracking, early stopping, and threshold calibration
- optional happy/neutral evaluation set to monitor false positives
The CNN tokenizer used here is intentionally identical to ``_simple_tokenize``
in ``api/main.py`` (hash-based, vocab_size=10000) so that saved checkpoints
work correctly at inference time without any vocabulary file.
Usage (local or Google Colab)
------------------------------
# 1. Prepare data (only needed once)
python data_preprocessing.py
# 2. Train (CNN)
python training/train.py
# Optional flags
python training/train.py --epochs 15 --batch-size 64 --lr 1e-3 \
--data data/processed/unified_stress.csv \
--output checkpoints/model.pt
After training the checkpoint is automatically picked up by the API server
(``uvicorn api.main:app``).
"""
from __future__ import annotations
import argparse
import hashlib
import math
import os
import random
import sys
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
from torch.utils.data import DataLoader, Dataset
# Allow running from the repo root as well as from inside training/
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.architecture import ( # noqa: E402
DeBERTaStressClassifier,
MiniLMStressClassifier,
OptimizedMultichannelCNN,
)
from utils.sentiment import get_sentiment_score # noqa: E402
from utils.text_preprocessing import clean_text # noqa: E402
# ---------------------------------------------------------------------------
# Constants — must stay in sync with api/main.py
# ---------------------------------------------------------------------------
VOCAB_SIZE = 10_000 # _DEFAULT_VOCAB_SIZE in api/main.py
CHUNK_SIZE = 200 # _CHUNK_SIZE in api/main.py
STRIDE = 50 # Stride 50 with CHUNK_SIZE=200 → 150-token overlap (75% of chunk)
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
THRESHOLD_MIN = 0.05
THRESHOLD_MAX = 0.95
THRESHOLD_STEPS = 19
MIN_RECALL_THRESHOLD = 0.6
TRANSFORMER_LR = 2e-5
# ---------------------------------------------------------------------------
# Tokenizer (identical to _simple_tokenize in api/main.py)
# ---------------------------------------------------------------------------
def _tokenize(text: str) -> list[int]:
"""Hash each whitespace-delimited token into [1, VOCAB_SIZE-1].
Index 0 is reserved for padding. Uses ``hashlib.md5`` for a
fully deterministic mapping that is stable across all platforms,
Python processes, and interpreter restarts (unlike ``hash()`` which
is randomised by ``PYTHONHASHSEED``). This guarantees that a model
trained on Colab produces identical token IDs when served on Windows.
"""
tokens = text.lower().split()
return [
int(hashlib.md5(t.encode("utf-8"), usedforsecurity=False).hexdigest(), 16)
% (VOCAB_SIZE - 1) + 1
for t in tokens
]
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
def _load_csv(data_path: str) -> pd.DataFrame:
df = pd.read_csv(data_path)
if "text" not in df.columns or "label" not in df.columns:
raise ValueError(f"{data_path} must have 'text' and 'label' columns.")
df = df.dropna(subset=["text"]).reset_index(drop=True)
df["text"] = df["text"].astype(str).str.strip()
df = df[df["text"] != ""].reset_index(drop=True)
# Apply the same preprocessing pipeline used at inference time so that
# the model trains on clean tokens identical to what it will see in
# production (URLs, HTML, emojis, repeated chars all normalised).
df["text"] = df["text"].apply(clean_text)
df = df[df["text"] != ""].reset_index(drop=True)
df["label"] = df["label"].astype(int)
if "domain" not in df.columns:
df["domain"] = "unknown"
return df
def _describe_dataset(df: pd.DataFrame) -> None:
print("\nDataset summary")
print("-" * 50)
print("Label distribution:")
print(df["label"].value_counts().sort_index().to_string())
print("\nDomain distribution:")
print(df["domain"].value_counts().to_string())
print("\nLabel by domain:")
table = pd.crosstab(df["domain"], df["label"])
print(table.to_string())
missing = []
for domain in table.index:
row = table.loc[domain]
for label in (0, 1):
if row.get(label, 0) == 0:
missing.append((domain, label))
if missing:
print(
"\nWarning: Some domains contain only one label. "
"Consider adding happy/neutral negatives to reduce false positives."
)
def _stratified_split(
df: pd.DataFrame, val_ratio: float, seed: int
) -> tuple[pd.DataFrame, pd.DataFrame]:
rng = random.Random(seed)
group_keys = (df["label"].astype(str) + "|" + df["domain"].astype(str)).tolist()
groups: dict[str, list[int]] = {}
for idx, key in enumerate(group_keys):
groups.setdefault(key, []).append(idx)
train_idx: list[int] = []
val_idx: list[int] = []
for key in sorted(groups.keys()):
indices = groups[key]
rng.shuffle(indices)
n_val = int(round(len(indices) * val_ratio))
if len(indices) > 1:
# Keep at least one train sample and one val sample.
n_val = max(1, min(n_val, len(indices) - 1))
else:
n_val = 0
if n_val:
val_idx.extend(indices[:n_val])
train_idx.extend(indices[n_val:])
else:
train_idx.extend(indices)
if not val_idx:
fallback = max(1, int(round(len(train_idx) * val_ratio)))
val_idx = train_idx[:fallback]
train_idx = train_idx[fallback:]
train_df = df.iloc[train_idx].reset_index(drop=True)
val_df = df.iloc[val_idx].reset_index(drop=True)
return train_df, val_df
class _StressChunkDataset(Dataset):
"""CNN dataset with sliding-window chunking.
Optionally accepts per-sample ``rewards`` (non-negative floats) that are
used by ``weighted_loss`` during RL-style fine-tuning via
``training/retrain.py``. When ``rewards`` is ``None`` every sample
gets an implicit weight of ``1.0`` and the standard ``CrossEntropyLoss``
is applied unchanged.
"""
def __init__(
self,
texts: list[str],
labels: list[int],
features: np.ndarray | None = None,
chunk_size: int = CHUNK_SIZE,
stride: int = STRIDE,
rewards: list[float] | None = None,
) -> None:
self._chunks: list[torch.Tensor] = []
self._labels: list[int] = []
self._features: list[torch.Tensor] | None = None
self._rewards: list[float] | None = None
if rewards is not None:
if len(rewards) != len(texts):
raise ValueError("rewards must align with texts.")
self._rewards = []
feature_rows = None
if features is not None:
if len(features) != len(texts):
raise ValueError("Features must align with texts.")
self._features = []
feature_rows = [torch.tensor(row, dtype=torch.float) for row in features]
for idx, (text, label) in enumerate(zip(texts, labels)):
token_ids = _tokenize(text)
label = int(label)
feature_tensor = feature_rows[idx] if feature_rows is not None else None
reward_val = rewards[idx] if rewards is not None else None
if len(token_ids) == 0:
self._chunks.append(torch.zeros(chunk_size, dtype=torch.long))
self._labels.append(label)
if self._features is not None:
self._features.append(feature_tensor)
if self._rewards is not None:
self._rewards.append(reward_val)
continue
for start in range(0, len(token_ids), stride):
end = start + chunk_size
chunk = token_ids[start:end]
if len(chunk) < chunk_size:
chunk = chunk + [0] * (chunk_size - len(chunk))
self._chunks.append(torch.tensor(chunk, dtype=torch.long))
self._labels.append(label)
if self._features is not None:
self._features.append(feature_tensor)
if self._rewards is not None:
self._rewards.append(reward_val)
if end >= len(token_ids):
break
def __len__(self) -> int:
return len(self._chunks)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
item = {
"input_ids": self._chunks[idx],
"label": torch.tensor(self._labels[idx], dtype=torch.long),
}
if self._features is not None:
item["features"] = self._features[idx]
if self._rewards is not None:
item["reward"] = torch.tensor(self._rewards[idx], dtype=torch.float)
return item
class _TransformerDataset(Dataset):
"""Transformer dataset with tokenizer-based encoding."""
def __init__(
self,
texts: list[str],
labels: list[int],
tokenizer,
max_length: int,
) -> None:
encodings = tokenizer(
texts,
truncation=True,
padding="max_length",
max_length=max_length,
)
self._encodings = encodings
self._labels = labels
self._texts = texts
def __len__(self) -> int:
return len(self._labels)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
item = {k: torch.tensor(v[idx]) for k, v in self._encodings.items()}
item["label"] = torch.tensor(self._labels[idx], dtype=torch.long)
sentiment = get_sentiment_score(self._texts[idx])
item["sentiment"] = torch.tensor(sentiment, dtype=torch.float)
return item
# ---------------------------------------------------------------------------
# Training helpers
# ---------------------------------------------------------------------------
def _accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
preds = logits.argmax(dim=-1)
return (preds == labels).float().mean().item()
def weighted_loss(
logits: torch.Tensor,
labels: torch.Tensor,
rewards: torch.Tensor,
) -> torch.Tensor:
"""Per-sample reward-weighted cross-entropy loss.
Good predictions (high reward) are reinforced; bad predictions
(which were corrected in the dataset) are penalised by the same
magnitude, implementing the RL-style update described in the
problem statement.
Parameters
----------
logits : Tensor, shape ``(B, C)``
labels : Tensor, shape ``(B,)``
rewards : Tensor, shape ``(B,)``
Non-negative scalar weight for each sample (e.g. ``1.5`` for
feedback-derived samples, ``1.0`` for baseline samples).
Returns
-------
Tensor
Scalar mean weighted loss.
"""
per_sample = nn.functional.cross_entropy(logits, labels, reduction="none")
return (per_sample * rewards).mean()
class FocalLoss(nn.Module):
"""Focal loss (Lin et al. 2017) for binary stress classification.
Reduces the loss contribution of easy, well-classified examples and
focuses training on hard, ambiguous inputs — preventing a handful of
common keyword patterns from dominating the gradient updates.
Parameters
----------
gamma : float
Focusing parameter. ``gamma=0`` recovers standard cross-entropy.
``gamma=2`` is the value recommended by Lin et al.
weight : Tensor or None
Per-class weight tensor forwarded to the underlying cross-entropy
(handles label imbalance in the same way as ``nn.CrossEntropyLoss``).
label_smoothing : float
Label smoothing coefficient forwarded to cross-entropy.
"""
def __init__(
self,
gamma: float = 2.0,
weight: torch.Tensor | None = None,
label_smoothing: float = 0.0,
) -> None:
super().__init__()
self.gamma = gamma
self._ce = nn.CrossEntropyLoss(
weight=weight,
reduction="none",
label_smoothing=label_smoothing,
)
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce = self._ce(logits, targets) # (B,) per-sample CE loss
pt = torch.exp(-ce) # p(correct class)
return ((1.0 - pt) ** self.gamma * ce).mean()
def _run_epoch(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer | None,
device: torch.device,
is_train: bool,
collect_probs: bool = False,
warmup_scheduler=None,
fp_penalty_weight: float = 0.0,
) -> tuple[float, float, np.ndarray | None, np.ndarray | None]:
model.train(is_train)
total_loss = 0.0
total_acc = 0.0
n_batches = 0
all_probs: list[torch.Tensor] = []
all_labels: list[torch.Tensor] = []
ctx = torch.enable_grad() if is_train else torch.no_grad()
with ctx:
for batch in loader:
input_ids = batch["input_ids"].to(device)
labels = batch["label"].to(device)
attention_mask = batch.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device)
sentiment = batch.get("sentiment")
if sentiment is not None:
sentiment = sentiment.to(device)
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
sentiment=sentiment,
)
else:
output = model(input_ids=input_ids, attention_mask=attention_mask)
else:
features = batch.get("features")
if features is not None:
features = features.to(device)
output = model(input_ids, aux_features=features)
else:
output = model(input_ids)
logits = output["logits"]
# Use reward-weighted loss when the batch provides per-sample
# rewards (i.e. during RL-style fine-tuning from feedback data).
reward_weights = batch.get("reward")
if reward_weights is not None:
reward_weights = reward_weights.to(device)
loss = weighted_loss(logits, labels, reward_weights)
else:
loss = criterion(logits, labels)
# Soft FPR penalty: penalise the mean stress probability assigned
# to negative samples in this batch. Unlike threshold calibration
# (which only acts at evaluation time), this pushes the gradient
# directly toward lower false-positive rates during training.
if is_train and fp_penalty_weight > 0.0:
neg_mask = (labels == 0).float()
n_neg = neg_mask.sum()
if n_neg > 0:
stress_probs = torch.softmax(logits, dim=-1)[:, 1]
fpr_penalty = (stress_probs * neg_mask).sum() / n_neg
loss = loss + fp_penalty_weight * fpr_penalty
if is_train and optimizer is not None:
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
if warmup_scheduler is not None:
warmup_scheduler.step()
total_loss += loss.item()
total_acc += _accuracy(logits, labels)
n_batches += 1
if collect_probs:
probs = torch.softmax(logits, dim=-1)[:, 1]
all_probs.append(probs.detach().cpu())
all_labels.append(labels.detach().cpu())
if collect_probs and all_probs:
probs_np = torch.cat(all_probs).numpy()
labels_np = torch.cat(all_labels).numpy()
else:
probs_np = None
labels_np = None
return (
total_loss / max(n_batches, 1),
total_acc / max(n_batches, 1),
probs_np,
labels_np,
)
def _compute_metrics(
labels: np.ndarray, probs: np.ndarray, threshold: float
) -> dict[str, float | list[list[int]]]:
preds = (probs >= threshold).astype(int)
precision = precision_score(labels, preds, zero_division=0)
recall = recall_score(labels, preds, zero_division=0)
f1 = f1_score(labels, preds, zero_division=0)
cm = confusion_matrix(labels, preds, labels=[0, 1])
return {
"precision": float(precision),
"recall": float(recall),
"f1": float(f1),
"confusion_matrix": cm.tolist(),
}
def _find_best_threshold(
labels: np.ndarray,
probs: np.ndarray,
max_fpr: float = 0.20,
min_threshold: float = 0.50,
) -> tuple[float, dict[str, float | list[list[int]]]]:
"""Find the best decision threshold subject to FPR and minimum-threshold constraints.
Why this matters
----------------
Optimising purely for F1 on a stress-heavy dataset causes the search to
collapse to very low thresholds (~0.15) that label almost everything as
stressed. Such a model has high recall but an FPR of 75–95%, making it
clinically useless.
Parameters
----------
labels : np.ndarray
Ground-truth binary labels (0 = no stress, 1 = stress).
probs : np.ndarray
Model stress probabilities (softmax output for class 1).
max_fpr : float
Maximum false positive rate allowed. Thresholds that exceed this on
the validation set are rejected regardless of their F1.
min_threshold : float
Hard lower bound on the selected threshold. The deployed model will
never be more aggressive than this value.
Returns
-------
tuple
``(best_threshold, best_metrics)``
"""
best_threshold = min_threshold
best_metrics = _compute_metrics(labels, probs, best_threshold)
best_f1 = best_metrics["f1"]
for threshold in np.linspace(min_threshold, THRESHOLD_MAX, THRESHOLD_STEPS):
t = float(threshold)
metrics = _compute_metrics(labels, probs, t)
cm = metrics["confusion_matrix"]
tn, fp = cm[0][0], cm[0][1]
fpr = fp / max(tn + fp, 1)
if (
metrics["f1"] > best_f1
and metrics["recall"] >= MIN_RECALL_THRESHOLD
and fpr <= max_fpr
):
best_f1 = metrics["f1"]
best_threshold = t
best_metrics = metrics
return best_threshold, best_metrics
def _build_model_and_tokenizer(
model_type: str,
dropout: float,
max_length: int,
aux_dim: int = 0,
) -> tuple[nn.Module, object | None, str | None]:
if model_type == "cnn":
model = OptimizedMultichannelCNN(
vocab_size=VOCAB_SIZE,
embed_dim=128,
num_filters=64,
kernel_sizes=(2, 3, 5),
num_classes=2,
dropout=dropout,
aux_dim=aux_dim,
)
return model, None, None
if model_type == "deberta":
model = DeBERTaStressClassifier(dropout=dropout)
elif model_type == "minilm":
model = MiniLMStressClassifier(dropout=dropout)
else:
raise ValueError(f"Unknown model type: {model_type}")
from transformers import AutoTokenizer
model_name = model.MODEL_NAME
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.model_max_length = max_length
return model, tokenizer, model_name
def _load_eval_set(path: str) -> pd.DataFrame:
df = pd.read_csv(path)
if "text" not in df.columns or "label" not in df.columns:
raise ValueError(f"{path} must have 'text' and 'label' columns.")
df = df.dropna(subset=["text"]).reset_index(drop=True)
df["text"] = df["text"].astype(str).str.strip()
df = df[df["text"] != ""].reset_index(drop=True)
df["label"] = df["label"].astype(int)
return df
def _prepare_feature_frame(
df: pd.DataFrame,
feature_cols: list[str],
means: pd.Series | None = None,
stds: pd.Series | None = None,
) -> tuple[np.ndarray, pd.Series, pd.Series]:
features = df[feature_cols].apply(pd.to_numeric, errors="coerce")
if means is None:
means = features.mean()
if stds is None:
stds = features.std().replace(0, 1)
features = features.fillna(means)
normalized = (features - means) / stds
return normalized.to_numpy(dtype=np.float32), means, stds
# ---------------------------------------------------------------------------
# Main training loop
# ---------------------------------------------------------------------------
def train(
data_path: str,
output_path: str,
epochs: int,
batch_size: int,
lr: float,
val_ratio: float,
seed: int,
device_str: str,
model_type: str,
dropout: float,
weight_decay: float,
label_smoothing: float,
class_weighted: bool,
patience: int,
eval_set_path: str | None,
max_length: int,
max_fpr: float = 0.20,
min_threshold: float = 0.50,
fp_penalty_weight: float = 0.2,
) -> None:
torch.manual_seed(seed)
device = torch.device(device_str if torch.cuda.is_available() or device_str == "cpu" else "cpu")
print(f"Using device: {device}")
# ── Data ──
print(f"\nLoading dataset from: {data_path}")
df = _load_csv(data_path)
_describe_dataset(df)
train_df, val_df = _stratified_split(df, val_ratio, seed)
candidate_cols = [
col for col in df.columns if col not in {"text", "label", "domain"}
]
if candidate_cols:
raw_candidates = df[candidate_cols]
numeric_candidates = raw_candidates.apply(
pd.to_numeric, errors="coerce"
)
coerced = numeric_candidates.isna() & raw_candidates.notna()
if coerced.any().any():
bad_cols = [
col for col in candidate_cols if coerced[col].any()
]
print(
"Warning: Non-numeric values coerced to NaN in columns: "
+ ", ".join(bad_cols)
)
coverage = numeric_candidates.notna().mean()
feature_cols = [
col
for col in numeric_candidates.columns
if coverage[col] >= 0.5
]
for col in feature_cols:
df[col] = numeric_candidates[col]
else:
feature_cols = []
if feature_cols:
print(f"Numeric feature columns detected: {len(feature_cols):,}")
elif model_type == "cnn":
print("No numeric feature columns detected; training text-only CNN.")
print(
f"\nDocuments: {len(df):,} | Train: {len(train_df):,} | Val: {len(val_df):,}"
)
aux_dim = len(feature_cols) if model_type == "cnn" and feature_cols else 0
model, tokenizer, model_name = _build_model_and_tokenizer(
model_type=model_type,
dropout=dropout,
max_length=max_length,
aux_dim=aux_dim,
)
model = model.to(device)
train_features = None
val_features = None
eval_features = None
feature_means = None
feature_stds = None
if feature_cols and model_type == "cnn":
train_features, feature_means, feature_stds = _prepare_feature_frame(
train_df, feature_cols
)
val_features, _, _ = _prepare_feature_frame(
val_df, feature_cols, feature_means, feature_stds
)
elif feature_cols and model_type != "cnn":
print(
"Note: numeric features are available but only the CNN uses them."
)
if model_type == "cnn":
train_ds = _StressChunkDataset(
train_df["text"].tolist(),
train_df["label"].tolist(),
features=train_features,
)
val_ds = _StressChunkDataset(
val_df["text"].tolist(),
val_df["label"].tolist(),
features=val_features,
)
else:
train_ds = _TransformerDataset(
train_df["text"].tolist(),
train_df["label"].tolist(),
tokenizer,
max_length=max_length,
)
val_ds = _TransformerDataset(
val_df["text"].tolist(),
val_df["label"].tolist(),
tokenizer,
max_length=max_length,
)
print(f"Chunks: {len(train_ds):,} train | {len(val_ds):,} val")
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=0
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False, num_workers=0
)
eval_loader = None
if eval_set_path and os.path.isfile(eval_set_path):
eval_df = _load_eval_set(eval_set_path)
if model_type == "cnn":
if feature_cols and feature_means is not None and feature_stds is not None:
for col in feature_cols:
if col not in eval_df.columns:
eval_df[col] = np.nan
eval_features, _, _ = _prepare_feature_frame(
eval_df, feature_cols, feature_means, feature_stds
)
eval_ds = _StressChunkDataset(
eval_df["text"].tolist(),
eval_df["label"].tolist(),
features=eval_features,
)
else:
eval_ds = _TransformerDataset(
eval_df["text"].tolist(),
eval_df["label"].tolist(),
tokenizer,
max_length=max_length,
)
eval_loader = DataLoader(
eval_ds, batch_size=batch_size, shuffle=False, num_workers=0
)
print(f"Happy/neutral eval set: {len(eval_ds):,} samples")
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {n_params:,}\n")
class_weights = None
if class_weighted:
counts = train_df["label"].value_counts().reindex([0, 1], fill_value=0)
total = counts.sum()
weights = [
total / (2 * counts[0]) if counts[0] > 0 else 1.0,
total / (2 * counts[1]) if counts[1] > 0 else 1.0,
]
class_weights = torch.tensor(weights, dtype=torch.float, device=device)
criterion = (
FocalLoss(gamma=2.0, weight=class_weights, label_smoothing=label_smoothing)
if model_type == "cnn"
else nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing)
)
if model_type == "cnn":
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, weight_decay=weight_decay
)
# Cosine annealing with linear warmup: better convergence than
# ReduceLROnPlateau for text CNNs — avoids premature learning-rate
# collapses caused by noisy F1 on small validation sets.
total_steps = len(train_loader) * epochs
warmup_steps = min(int(0.1 * total_steps), 200)
def _lr_lambda(current_step: int) -> float:
"""Linear warmup then cosine decay."""
if current_step < warmup_steps:
return float(current_step) / max(warmup_steps, 1)
progress = float(current_step - warmup_steps) / max(
total_steps - warmup_steps, 1
)
return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, _lr_lambda
)
scheduler = None # per-step updates only
else:
from transformers import get_linear_schedule_with_warmup
optimizer = torch.optim.AdamW(
model.parameters(), lr=TRANSFORMER_LR, weight_decay=weight_decay
)
total_steps = len(train_loader) * epochs
warmup_steps = int(0.1 * total_steps)
warmup_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
scheduler = None
best_val_f1 = 0.0
best_threshold = min_threshold
epochs_since_improve = 0
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
_EPOCH_HEADER = (
f"{'Epoch':>6} {'Train Loss':>11} {'Train Acc':>10} "
f"{'Val Loss':>9} {'Val F1':>8} {'Val Prec':>9} {'Val Rec':>8} "
f"{'Val FPR':>8} {'Thresh':>7} {'Time':>6}"
)
print(_EPOCH_HEADER)
print("-" * len(_EPOCH_HEADER))
for epoch in range(1, epochs + 1):
t0 = time.time()
tr_loss, tr_acc, _, _ = _run_epoch(
model, train_loader, criterion, optimizer, device, is_train=True,
warmup_scheduler=warmup_scheduler,
fp_penalty_weight=fp_penalty_weight,
)
vl_loss, _, val_probs, val_labels = _run_epoch(
model,
val_loader,
criterion,
None,
device,
is_train=False,
collect_probs=True,
)
elapsed = time.time() - t0
if val_probs is None or val_labels is None:
raise RuntimeError("Validation probabilities were not collected.")
threshold, val_metrics = _find_best_threshold(
val_labels, val_probs,
max_fpr=max_fpr,
min_threshold=min_threshold,
)
val_f1 = val_metrics["f1"]
val_precision = val_metrics["precision"]
val_recall = val_metrics["recall"]
val_cm = val_metrics["confusion_matrix"]
val_tn, val_fp = val_cm[0][0], val_cm[0][1]
val_fpr = val_fp / max(val_tn + val_fp, 1)
if scheduler is not None:
scheduler.step(val_f1)
marker = " ←" if val_f1 > best_val_f1 else ""
if val_f1 > best_val_f1:
best_val_f1 = val_f1
best_threshold = threshold
epochs_since_improve = 0
checkpoint = {
"model_state_dict": model.state_dict(),
"decision_threshold": float(best_threshold),
"model_type": model_type,
"dropout": float(dropout),
}
if model_type != "cnn" and model_name:
checkpoint["model_name"] = model_name
checkpoint["tokenizer_max_length"] = int(max_length)
if model_type == "cnn":
checkpoint["chunk_size"] = int(CHUNK_SIZE)
checkpoint["stride"] = int(STRIDE)
if feature_cols:
checkpoint["feature_dim"] = len(feature_cols)
checkpoint["feature_columns"] = feature_cols
checkpoint["feature_means"] = feature_means.tolist()
checkpoint["feature_stds"] = feature_stds.tolist()
torch.save(checkpoint, output_path)
else:
epochs_since_improve += 1
print(
f"{epoch:>6} {tr_loss:>11.4f} {tr_acc:>9.2%} "
f"{vl_loss:>9.4f} {val_f1:>7.2%} {val_precision:>8.2%} "
f"{val_recall:>7.2%} {val_fpr:>7.2%} {threshold:>7.2f} {elapsed:>5.1f}s{marker}"
)
cm = val_metrics["confusion_matrix"]
print(
" Confusion matrix: "
f"TN={cm[0][0]} FP={cm[0][1]} FN={cm[1][0]} TP={cm[1][1]}"
)
if eval_loader is not None:
_, _, eval_probs, eval_labels = _run_epoch(
model,
eval_loader,
criterion,
None,
device,
is_train=False,
collect_probs=True,
)
if eval_probs is not None and eval_labels is not None:
eval_metrics = _compute_metrics(
eval_labels, eval_probs, threshold
)
eval_cm = eval_metrics["confusion_matrix"]
tn, fp = eval_cm[0][0], eval_cm[0][1]
fp_rate = fp / max(tn + fp, 1)
print(
" Happy/neutral eval — "
f"FP rate: {fp_rate:.2%}, "
f"F1: {eval_metrics['f1']:.2%}"
)
if patience > 0 and epochs_since_improve >= patience:
print(
f"\nEarly stopping: no F1 improvement in {patience} epochs."
)
break
print(f"\nBest validation F1: {best_val_f1:.2%}")
print(f"Best decision threshold: {best_threshold:.2f}")
print(f"Checkpoint saved to: {output_path}")
print(
"\nTo use with the API, start the server normally:\n"
" uvicorn api.main:app --host 0.0.0.0 --port 8000\n"
"The checkpoint is loaded automatically on the first /analyze request."
)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Train stress detection models with calibrated thresholds.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
p.add_argument(
"--data",
default=os.path.join(ROOT_DIR, "data", "processed", "unified_stress.csv"),
help="Path to the unified CSV produced by data_preprocessing.py",
)
p.add_argument(
"--output",
default=os.path.join(ROOT_DIR, "checkpoints", "model.pt"),
help="Destination path for the best checkpoint",
)
p.add_argument(
"--model",
default="cnn",
choices=["cnn", "deberta", "minilm"],
help="Model backbone to train",
)
p.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
p.add_argument("--batch-size", type=int, default=64, help="Mini-batch size")
p.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
p.add_argument(
"--weight-decay",
type=float,
default=1e-4,
help="Weight decay for Adam/AdamW",
)
p.add_argument(
"--dropout",
type=float,
default=0.3,
help="Dropout probability (CNN) / classifier dropout (transformers)",
)
p.add_argument(
"--label-smoothing",
type=float,
default=0.0,
help="Label smoothing for cross-entropy loss",
)
p.add_argument(
"--class-weighted",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Use inverse-frequency class weights (enabled by default). "
"Balances the gradient contribution of the minority class (no-stress) "
"which is heavily outnumbered in typical datasets. "
"Pass --no-class-weighted to disable."
),
)
p.add_argument(
"--patience",
type=int,
default=3,
help="Early stopping patience (epochs without F1 improvement)",
)
p.add_argument("--val-ratio", type=float, default=0.1, help="Fraction of data used for validation")
p.add_argument("--seed", type=int, default=42, help="Random seed")
p.add_argument(
"--eval-set",
default=os.path.join(ROOT_DIR, "data", "eval", "happy_neutral_eval.csv"),
help="Optional happy/neutral eval CSV (text,label)",
)
p.add_argument(
"--max-length",
type=int,
default=256,
help="Max sequence length for transformer models",
)
p.add_argument(
"--device",
default="cuda",
choices=["cuda", "cpu"],
help="Device to train on (falls back to cpu if CUDA is unavailable)",
)
p.add_argument(
"--max-fpr",
type=float,
default=0.20,
help=(
"Maximum false positive rate allowed during threshold calibration. "
"Thresholds that produce FPR > this value are rejected, preventing "
"the model from collapsing to very low decision thresholds."
),
)
p.add_argument(
"--min-threshold",
type=float,
default=0.50,
help=(
"Hard lower bound for the decision threshold. "
"Ensures the model never classifies more than 50%% of the "
"probability space as stressed by default."
),
)
p.add_argument(
"--fp-penalty",
type=float,
default=0.2,
help=(
"Weight of the soft false-positive penalty term added to the "
"training loss each batch. At each step the mean stress "
"probability assigned to negative (label=0) samples in the "
"batch is multiplied by this weight and added to the main loss, "
"directly penalising false positives during gradient descent. "
"Set to 0.0 to disable. Increase (e.g. 0.4) when the "
"happy/neutral FP rate remains high after training."
),
)
return p.parse_args()
if __name__ == "__main__":
args = _parse_args()
if not os.path.isfile(args.data):
print(f"ERROR: Dataset not found: {args.data}")
print("Run python data_preprocessing.py first to create the unified CSV.")
sys.exit(1)
train(
data_path=args.data,
output_path=args.output,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
val_ratio=args.val_ratio,
seed=args.seed,
device_str=args.device,
model_type=args.model,
dropout=args.dropout,
weight_decay=args.weight_decay,
label_smoothing=args.label_smoothing,
class_weighted=args.class_weighted,
patience=args.patience,
eval_set_path=args.eval_set,
max_length=args.max_length,
max_fpr=args.max_fpr,
min_threshold=args.min_threshold,
fp_penalty_weight=args.fp_penalty,
)