from __future__ import annotations from types import SimpleNamespace import pandas as pd import torch import torch.nn as nn from constants import NUMERICAL_FEATURES class RedditModel(nn.Module): """ Supports both architectures used in this project: - Legacy single-encoder model (embedding/encoder) - New dual-encoder model (embedding_title+embedding_text, encoder_title+encoder_text) """ def __init__(self, params: SimpleNamespace, target_mean: float | None = None): super().__init__() self.DEVICE = params.DEVICE self.is_dual_encoder = hasattr(params, "VOCAB_SIZE_TEXT") and hasattr(params, "VOCAB_SIZE_TITLE") if target_mean is not None: self.register_buffer("target_mean", torch.tensor(float(target_mean), dtype=torch.float32)) if self.is_dual_encoder: self.embedding_text = nn.Embedding( num_embeddings=params.VOCAB_SIZE_TEXT, embedding_dim=params.D_MODEL, ) self.embedding_title = nn.Embedding( num_embeddings=params.VOCAB_SIZE_TITLE, embedding_dim=params.D_MODEL, ) self.encoder_text = nn.TransformerEncoder( encoder_layer=nn.TransformerEncoderLayer( d_model=params.D_MODEL, nhead=params.N_HEAD, dim_feedforward=params.DIM_FEEDFORWARD, batch_first=True, ), num_layers=params.NB_ENCODER_LAYERS, ) self.encoder_title = nn.TransformerEncoder( encoder_layer=nn.TransformerEncoderLayer( d_model=params.D_MODEL, nhead=params.N_HEAD, dim_feedforward=params.DIM_FEEDFORWARD, batch_first=True, ), num_layers=params.NB_ENCODER_LAYERS, ) self.attention_pooling_text = nn.Linear(params.D_MODEL, 1) else: self.embedding = nn.Embedding( num_embeddings=params.VOCAB_SIZE, embedding_dim=params.D_MODEL, ) self.encoder = nn.TransformerEncoder( encoder_layer=nn.TransformerEncoderLayer( d_model=params.D_MODEL, nhead=params.N_HEAD, dim_feedforward=params.DIM_FEEDFORWARD, batch_first=True, ), num_layers=params.NB_ENCODER_LAYERS, ) layers = [] in_dim = params.D_MODEL * 2 + len(NUMERICAL_FEATURES) for _ in range(params.NB_HIDDEN_LAYERS): layers.append(nn.Linear(in_dim, params.HIDDEN_SIZE)) layers.append(nn.ReLU()) layers.append(nn.Dropout(params.DROPOUT_RATE)) in_dim = params.HIDDEN_SIZE layers.append(nn.Linear(in_dim, 1)) self.regression_head = nn.Sequential(*layers) @staticmethod def _mean_pool_non_pad(encoded: torch.Tensor, token_mask: torch.Tensor) -> torch.Tensor: # token_mask is 1 for tokens and 0 for padding. mask = token_mask.unsqueeze(-1).float() return (encoded * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) def forward(self, x: pd.DataFrame, numerical: torch.Tensor) -> torch.Tensor: title_tokens = torch.stack(x["title_tokens"].tolist()).long().to(self.DEVICE) title_mask = torch.stack(x["title_mask"].tolist()).bool().to(self.DEVICE) text_tokens = torch.stack(x["text_tokens"].tolist()).long().to(self.DEVICE) text_mask = torch.stack(x["text_mask"].tolist()).bool().to(self.DEVICE) if self.is_dual_encoder: # Hugging Face attention masks are 1=token, 0=pad while PyTorch expects True=pad. title_vec = self.encoder_title( self.embedding_title(title_tokens), src_key_padding_mask=~title_mask, ).mean(dim=1) text_encoded = self.encoder_text( self.embedding_text(text_tokens), src_key_padding_mask=~text_mask, ) scores = self.attention_pooling_text(text_encoded) weights = torch.softmax(scores, dim=1) text_vec = (text_encoded * weights).sum(dim=1) else: max_token_id = int(torch.maximum(title_tokens.max(), text_tokens.max()).item()) vocab_size = int(self.embedding.num_embeddings) if max_token_id >= vocab_size: raise ValueError( f"Legacy checkpoint vocabulary mismatch: token id {max_token_id} " f"is outside embedding range [0, {vocab_size - 1}]." ) title_encoded = self.encoder( self.embedding(title_tokens), src_key_padding_mask=~title_mask, ) text_encoded = self.encoder( self.embedding(text_tokens), src_key_padding_mask=~text_mask, ) title_vec = self._mean_pool_non_pad(title_encoded, title_mask) text_vec = self._mean_pool_non_pad(text_encoded, text_mask) combined = torch.cat([title_vec, text_vec, numerical], dim=-1) return self.regression_head(combined).squeeze(-1) def predict(self, x: pd.DataFrame, numerical: torch.Tensor) -> torch.Tensor: self.eval() with torch.no_grad(): return self.forward(x, numerical)