""" models/architecture.py ====================== Phase 2 → Phase 5: Model Architecture Tier 1 — **OptimizedMultichannelCNN** (PyTorch from scratch) Parallel 1D convolution channels (kernel sizes 2, 3, 5) with **min_len trimming** to prevent shape mismatch on concatenation, followed by **multi-head self-attention** (default 4 heads) and a classification head. Tier 2 — Transformer wrappers Lightweight wrappers around HuggingFace models: - ``DeBERTaStressClassifier`` (DeBERTa-v3-Small) - ``MiniLMStressClassifier`` (MiniLM-L6-v2) Calibration ----------- - ``TemperatureScaling`` — post-hoc probability calibration (Guo et al. 2017). Wraps any classifier and divides logits by a learned scalar ``T`` before softmax, reducing overconfidence. Attention --------- - ``MultiHeadSelfAttention`` (new default) — scaled dot-product attention split across ``num_heads`` independent subspaces then projected back. Produces richer features and more interpretable per-token importance weights compared to single-head dot-product attention. - ``DotProductSelfAttention`` kept for backward compatibility (single head). Design Guardrails ----------------- - Conv1D outputs are trimmed to ``min_len`` before concatenation — this is the CRITICAL guard against tensor shape mismatch. - Self-attention returns attention weights alongside the pooled vector for downstream explainability / heatmap rendering. """ from __future__ import annotations import hashlib from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # --------------------------------------------------------------------------- # Stop-word dampening # --------------------------------------------------------------------------- # Common English stop words that carry little semantic signal for stress # detection but tend to dominate attention weights (e.g. 'I', 'the', 'a'). # Reducing their embedding magnitude before the conv layers prevents the # attention mechanism from over-emphasising them, which was observed in # heatmap analysis. # --------------------------------------------------------------------------- _STOP_WORDS: frozenset[str] = frozenset({ "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", "at", "by", "for", "with", "about", "against", "between", "through", "during", "before", "after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just", "don", "should", "now", "d", "ll", "m", "o", "re", "ve", "y", }) def _compute_stop_word_ids(vocab_size: int) -> set[int]: """Return hash-based token IDs for :data:`_STOP_WORDS`. Uses the same ``md5`` hashing scheme as the project tokenizers in ``api/main.py`` and ``training/train.py`` so that the IDs match at both training and inference time. """ ids: set[int] = set() for word in _STOP_WORDS: token_id = ( int( hashlib.md5( word.encode("utf-8"), usedforsecurity=False ).hexdigest(), 16, ) % (vocab_size - 1) + 1 ) ids.add(token_id) return ids # --------------------------------------------------------------------------- # Tier 1: OptimizedMultichannelCNN # --------------------------------------------------------------------------- class DotProductSelfAttention(nn.Module): """Simple scaled dot-product self-attention over a sequence. Input shape : ``(batch, seq_len, hidden)`` Output shape: ``(batch, hidden)`` (attended pool) + ``(batch, seq_len)`` Kept for backward compatibility. New code should prefer :class:`MultiHeadSelfAttention`. """ def __init__(self, hidden_dim: int) -> None: super().__init__() self.query = nn.Linear(hidden_dim, hidden_dim) self.key = nn.Linear(hidden_dim, hidden_dim) self.value = nn.Linear(hidden_dim, hidden_dim) self.scale = hidden_dim ** 0.5 def forward( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Parameters ---------- x : Tensor, shape ``(B, L, H)`` Returns ------- pooled : Tensor, shape ``(B, H)`` weights : Tensor, shape ``(B, L)`` — attention weights (for heatmaps) """ q = self.query(x) # (B, L, H) k = self.key(x) # (B, L, H) v = self.value(x) # (B, L, H) scores = torch.bmm(q, k.transpose(1, 2)) / self.scale # (B, L, L) attn = F.softmax(scores, dim=-1) # (B, L, L) context = torch.bmm(attn, v) # (B, L, H) # Pool: mean of attention-weighted values pooled = context.mean(dim=1) # (B, H) # Per-token importance: mean attention received from all queries weights = attn.mean(dim=1) # (B, L) return pooled, weights class MultiHeadSelfAttention(nn.Module): """Multi-head scaled dot-product self-attention (Vaswani et al. 2017). Splits the hidden dimension into ``num_heads`` independent subspaces, computes scaled dot-product attention within each head, then concatenates and projects the results. This produces richer feature representations than single-head attention and yields more interpretable per-token importance weights for heatmap rendering. Input shape : ``(batch, seq_len, hidden)`` Output shape: ``(batch, hidden)`` (attended pool) + ``(batch, seq_len)`` Parameters ---------- hidden_dim : int Total hidden dimension. Must be divisible by ``num_heads``. num_heads : int Number of parallel attention heads. Default: 4. dropout : float Dropout applied to attention weights during training. """ def __init__( self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.1 ) -> None: super().__init__() if hidden_dim % num_heads != 0: raise ValueError( f"hidden_dim ({hidden_dim}) must be divisible by " f"num_heads ({num_heads})." ) self.num_heads = num_heads self.d_k = hidden_dim // num_heads self.scale = self.d_k ** 0.5 self.query = nn.Linear(hidden_dim, hidden_dim) self.key = nn.Linear(hidden_dim, hidden_dim) self.value = nn.Linear(hidden_dim, hidden_dim) self.out_proj = nn.Linear(hidden_dim, hidden_dim) self.attn_dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Parameters ---------- x : Tensor, shape ``(B, L, H)`` Returns ------- pooled : Tensor, shape ``(B, H)`` weights : Tensor, shape ``(B, L)`` — per-token importance (for heatmaps) """ B, L, H = x.shape # Project and reshape to (B, num_heads, L, d_k) q = self.query(x).view(B, L, self.num_heads, self.d_k).transpose(1, 2) k = self.key(x).view(B, L, self.num_heads, self.d_k).transpose(1, 2) v = self.value(x).view(B, L, self.num_heads, self.d_k).transpose(1, 2) # Scaled dot-product attention: (B, num_heads, L, L) scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale attn = F.softmax(scores, dim=-1) # (B, num_heads, L, L) attn = self.attn_dropout(attn) # Context: (B, num_heads, L, d_k) context = torch.matmul(attn, v) # Merge heads: (B, L, H) context = context.transpose(1, 2).contiguous().view(B, L, H) # Output projection out = self.out_proj(context) # (B, L, H) # Mean-pool across the sequence pooled = out.mean(dim=1) # (B, H) # Per-token importance: average attention weight received across # all heads and all query positions weights = attn.mean(dim=1).mean(dim=1) # (B, L) return pooled, weights # --------------------------------------------------------------------------- # Temperature Scaling — post-hoc probability calibration # --------------------------------------------------------------------------- class TemperatureScaling(nn.Module): """Post-hoc calibration via temperature scaling (Guo et al. 2017). Divides logits by a single learnable scalar temperature ``T > 0`` before softmax. * ``T > 1`` → probabilities are smoothed toward 0.5 (reduces overconfidence). * ``T < 1`` → probabilities become more extreme. * ``T = 1`` → no effect (identity). The temperature is calibrated on a held-out validation set by minimising NLL loss. During inference with an uncalibrated model, keep ``T = 1.0``. Parameters ---------- temperature : float Initial temperature. Defaults to 1.0 (no calibration). Example ------- >>> ts = TemperatureScaling(temperature=1.5) >>> scaled_logits = ts(logits) # use before softmax >>> ts.calibrate(val_logits, val_labels) # fit T on a held-out set """ def __init__(self, temperature: float = 1.0) -> None: super().__init__() self.temperature = nn.Parameter( torch.ones(1) * max(temperature, 1e-6) ) def forward(self, logits: torch.Tensor) -> torch.Tensor: """Return temperature-scaled logits.""" return logits / self.temperature.clamp(min=1e-6) def calibrate( self, logits: torch.Tensor, labels: torch.Tensor ) -> None: """Fit the temperature on a held-out (logits, labels) set. Uses L-BFGS to minimise NLL. ``logits`` and ``labels`` should be collected on the validation set *before* calling this method. Parameters ---------- logits : Tensor, shape ``(N, C)`` Raw (uncalibrated) model logits. labels : Tensor, shape ``(N,)`` Ground-truth class indices. """ from torch.optim import LBFGS nll = nn.CrossEntropyLoss() optimizer = LBFGS([self.temperature], lr=0.01, max_iter=50) def _eval() -> torch.Tensor: optimizer.zero_grad() loss = nll(self.forward(logits), labels) loss.backward() return loss optimizer.step(_eval) class OptimizedMultichannelCNN(nn.Module): """Multi-channel 1D CNN with multi-head self-attention for stress detection. Architecture ------------ 1. Embedding layer (with stop-word dampening) 2. Three parallel Conv1D branches (kernel sizes 2, 3, 5) 3. **min_len trimming** — outputs are trimmed to the shortest length before concatenation to prevent tensor shape mismatches. 4. Multi-head self-attention (default 4 heads) 5. Classification head (FC → Dropout → FC) Parameters ---------- vocab_size : int Size of the token vocabulary. embed_dim : int Embedding dimension. num_filters : int Number of filters per Conv1D branch. kernel_sizes : tuple[int, ...] Kernel sizes for the parallel Conv1D branches. num_classes : int Number of output classes (default 2: stress / no-stress). dropout : float Dropout probability. aux_dim : int Optional numeric feature dimension appended to pooled CNN features. num_attention_heads : int Number of attention heads. Must divide ``num_filters * len(kernel_sizes)`` evenly. Set to 1 to use single-head dot-product attention (legacy). """ def __init__( self, vocab_size: int, embed_dim: int = 128, num_filters: int = 64, kernel_sizes: tuple[int, ...] = (2, 3, 5), num_classes: int = 2, dropout: float = 0.3, aux_dim: int = 0, stop_word_dampening: float = 0.3, num_attention_heads: int = 4, ) -> None: super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) # ── Stop-word dampening ── # Build a per-token-ID lookup: 1.0 for stop words, 0.0 otherwise. # During forward() each embedding is scaled by # 1.0 - is_stop * (1.0 - stop_word_dampening) # so content words keep their full magnitude while stop-word # embeddings are reduced to ``stop_word_dampening`` of their # original magnitude. self.stop_word_dampening = stop_word_dampening stop_ids = _compute_stop_word_ids(vocab_size) stop_mask = torch.zeros(vocab_size, dtype=torch.float) for sid in stop_ids: stop_mask[sid] = 1.0 # persistent=False → not part of state_dict, avoids checkpoint compat issues self.register_buffer("_stop_word_lookup", stop_mask, persistent=False) # Parallel Conv1D branches self.convs = nn.ModuleList( [ nn.Conv1d(embed_dim, num_filters, kernel_size=ks, padding=0) for ks in kernel_sizes ] ) total_filters = num_filters * len(kernel_sizes) # ── Attention ── # Use multi-head attention when possible; fall back to single-head # dot-product attention if total_filters is not divisible by num_heads. if num_attention_heads > 1 and total_filters % num_attention_heads == 0: self.attention: nn.Module = MultiHeadSelfAttention( total_filters, num_heads=num_attention_heads, dropout=dropout ) else: self.attention = DotProductSelfAttention(total_filters) self.aux_dim = aux_dim aux_hidden = min(aux_dim, total_filters // 2) if aux_dim > 0 else 0 self.aux_projection = ( nn.Sequential( nn.Linear(aux_dim, aux_hidden), nn.ReLU(), nn.Dropout(dropout), ) if aux_dim > 0 else None ) combined_dim = total_filters + aux_hidden # Classification head self.classifier = nn.Sequential( nn.Linear(combined_dim, combined_dim // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(combined_dim // 2, num_classes), ) self.dropout = nn.Dropout(dropout) def forward( self, input_ids: torch.Tensor, aux_features: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """ Parameters ---------- input_ids : Tensor, shape ``(B, L)`` Token indices. Returns ------- dict with keys: ``logits`` : Tensor, shape ``(B, num_classes)`` ``attention_weights`` : Tensor, shape ``(B, seq_len')`` """ # Embedding: (B, L) → (B, L, E) x = self.embedding(input_ids) x = self.dropout(x) # ── Stop-word dampening ── # Reduce embedding magnitudes for stop-word tokens so that the # subsequent conv + attention layers do not over-emphasise them. is_stop = self._stop_word_lookup[input_ids] # (B, L), 0 or 1 dampening = 1.0 - is_stop * (1.0 - self.stop_word_dampening) # (B, L) x = x * dampening.unsqueeze(-1) # (B, L, E) # Conv1D expects (B, C, L) — transpose x_t = x.transpose(1, 2) # (B, E, L) # Apply parallel convolutions + ReLU conv_outputs = [] for conv in self.convs: c = F.relu(conv(x_t)) # (B, F, L') conv_outputs.append(c) # ─── CRITICAL: Trim to min_len to prevent shape mismatch ─── min_len = min(c.size(2) for c in conv_outputs) conv_outputs = [c[:, :, :min_len] for c in conv_outputs] # Concatenate along the filter dimension: (B, F*3, min_len) merged = torch.cat(conv_outputs, dim=1) # Transpose back for attention: (B, min_len, F*3) merged = merged.transpose(1, 2) # Self-attention (multi-head or single-head) pooled, attn_weights = self.attention(merged) # (B, F*3), (B, min_len) if self.aux_projection is not None: if aux_features is None: aux_features = torch.zeros( pooled.size(0), self.aux_dim, device=pooled.device, ) aux_emb = self.aux_projection(aux_features) pooled = torch.cat([pooled, aux_emb], dim=1) # Classification logits = self.classifier(pooled) # (B, num_classes) return {"logits": logits, "attention_weights": attn_weights} # --------------------------------------------------------------------------- # Tier 2: Transformer wrappers # --------------------------------------------------------------------------- class DeBERTaStressClassifier(nn.Module): """Stress classifier wrapping ``microsoft/deberta-v3-small``. Uses the HuggingFace ``transformers`` library for the backbone and adds a simple classification head. """ MODEL_NAME = "microsoft/deberta-v3-small" def __init__(self, num_classes: int = 2, dropout: float = 0.1) -> None: super().__init__() from transformers import AutoModel self.backbone = AutoModel.from_pretrained(self.MODEL_NAME) hidden = self.backbone.config.hidden_size # +1 for optional sentiment feature self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden + 1, num_classes), ) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, sentiment: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: outputs = self.backbone( input_ids=input_ids, attention_mask=attention_mask ) # CLS token pooling pooled = outputs.last_hidden_state[:, 0, :] if sentiment is not None: sentiment = sentiment.unsqueeze(1) if sentiment.dim() == 1 else sentiment pooled = torch.cat([pooled, sentiment], dim=1) else: # Append neutral sentiment (0.5) when not provided neutral = torch.full( (pooled.size(0), 1), 0.5, device=pooled.device, dtype=pooled.dtype, ) pooled = torch.cat([pooled, neutral], dim=1) logits = self.classifier(pooled) return {"logits": logits} class MiniLMStressClassifier(nn.Module): """Stress classifier wrapping ``sentence-transformers/all-MiniLM-L6-v2``. Uses mean pooling over the last hidden state as the sentence representation. """ MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" def __init__(self, num_classes: int = 2, dropout: float = 0.1) -> None: super().__init__() from transformers import AutoModel self.backbone = AutoModel.from_pretrained(self.MODEL_NAME) hidden = self.backbone.config.hidden_size # +1 for optional sentiment feature self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(hidden + 1, num_classes), ) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, sentiment: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: outputs = self.backbone( input_ids=input_ids, attention_mask=attention_mask ) # Mean pooling hidden_states = outputs.last_hidden_state if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp( min=1e-9 ) else: pooled = hidden_states.mean(dim=1) if sentiment is not None: sentiment = sentiment.unsqueeze(1) if sentiment.dim() == 1 else sentiment pooled = torch.cat([pooled, sentiment], dim=1) else: neutral = torch.full( (pooled.size(0), 1), 0.5, device=pooled.device, dtype=pooled.dtype, ) pooled = torch.cat([pooled, neutral], dim=1) logits = self.classifier(pooled) return {"logits": logits}