| from __future__ import annotations
|
|
|
| from types import SimpleNamespace
|
|
|
| import pandas as pd
|
| import torch
|
| import torch.nn as nn
|
|
|
| from constants import NUMERICAL_FEATURES
|
|
|
|
|
| class RedditModel(nn.Module):
|
| """
|
| Supports both architectures used in this project:
|
| - Legacy single-encoder model (embedding/encoder)
|
| - New dual-encoder model (embedding_title+embedding_text, encoder_title+encoder_text)
|
| """
|
|
|
| def __init__(self, params: SimpleNamespace, target_mean: float | None = None):
|
| super().__init__()
|
| self.DEVICE = params.DEVICE
|
| self.is_dual_encoder = hasattr(params, "VOCAB_SIZE_TEXT") and hasattr(params, "VOCAB_SIZE_TITLE")
|
|
|
| if target_mean is not None:
|
| self.register_buffer("target_mean", torch.tensor(float(target_mean), dtype=torch.float32))
|
|
|
| if self.is_dual_encoder:
|
| self.embedding_text = nn.Embedding(
|
| num_embeddings=params.VOCAB_SIZE_TEXT,
|
| embedding_dim=params.D_MODEL,
|
| )
|
| self.embedding_title = nn.Embedding(
|
| num_embeddings=params.VOCAB_SIZE_TITLE,
|
| embedding_dim=params.D_MODEL,
|
| )
|
|
|
| self.encoder_text = nn.TransformerEncoder(
|
| encoder_layer=nn.TransformerEncoderLayer(
|
| d_model=params.D_MODEL,
|
| nhead=params.N_HEAD,
|
| dim_feedforward=params.DIM_FEEDFORWARD,
|
| batch_first=True,
|
| ),
|
| num_layers=params.NB_ENCODER_LAYERS,
|
| )
|
| self.encoder_title = nn.TransformerEncoder(
|
| encoder_layer=nn.TransformerEncoderLayer(
|
| d_model=params.D_MODEL,
|
| nhead=params.N_HEAD,
|
| dim_feedforward=params.DIM_FEEDFORWARD,
|
| batch_first=True,
|
| ),
|
| num_layers=params.NB_ENCODER_LAYERS,
|
| )
|
|
|
| self.attention_pooling_text = nn.Linear(params.D_MODEL, 1)
|
| else:
|
| self.embedding = nn.Embedding(
|
| num_embeddings=params.VOCAB_SIZE,
|
| embedding_dim=params.D_MODEL,
|
| )
|
| self.encoder = nn.TransformerEncoder(
|
| encoder_layer=nn.TransformerEncoderLayer(
|
| d_model=params.D_MODEL,
|
| nhead=params.N_HEAD,
|
| dim_feedforward=params.DIM_FEEDFORWARD,
|
| batch_first=True,
|
| ),
|
| num_layers=params.NB_ENCODER_LAYERS,
|
| )
|
|
|
| layers = []
|
| in_dim = params.D_MODEL * 2 + len(NUMERICAL_FEATURES)
|
| for _ in range(params.NB_HIDDEN_LAYERS):
|
| layers.append(nn.Linear(in_dim, params.HIDDEN_SIZE))
|
| layers.append(nn.ReLU())
|
| layers.append(nn.Dropout(params.DROPOUT_RATE))
|
| in_dim = params.HIDDEN_SIZE
|
|
|
| layers.append(nn.Linear(in_dim, 1))
|
| self.regression_head = nn.Sequential(*layers)
|
|
|
| @staticmethod
|
| def _mean_pool_non_pad(encoded: torch.Tensor, token_mask: torch.Tensor) -> torch.Tensor:
|
|
|
| mask = token_mask.unsqueeze(-1).float()
|
| return (encoded * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
|
|
|
| def forward(self, x: pd.DataFrame, numerical: torch.Tensor) -> torch.Tensor:
|
| title_tokens = torch.stack(x["title_tokens"].tolist()).long().to(self.DEVICE)
|
| title_mask = torch.stack(x["title_mask"].tolist()).bool().to(self.DEVICE)
|
|
|
| text_tokens = torch.stack(x["text_tokens"].tolist()).long().to(self.DEVICE)
|
| text_mask = torch.stack(x["text_mask"].tolist()).bool().to(self.DEVICE)
|
|
|
| if self.is_dual_encoder:
|
|
|
| title_vec = self.encoder_title(
|
| self.embedding_title(title_tokens),
|
| src_key_padding_mask=~title_mask,
|
| ).mean(dim=1)
|
|
|
| text_encoded = self.encoder_text(
|
| self.embedding_text(text_tokens),
|
| src_key_padding_mask=~text_mask,
|
| )
|
| scores = self.attention_pooling_text(text_encoded)
|
| weights = torch.softmax(scores, dim=1)
|
| text_vec = (text_encoded * weights).sum(dim=1)
|
| else:
|
| max_token_id = int(torch.maximum(title_tokens.max(), text_tokens.max()).item())
|
| vocab_size = int(self.embedding.num_embeddings)
|
| if max_token_id >= vocab_size:
|
| raise ValueError(
|
| f"Legacy checkpoint vocabulary mismatch: token id {max_token_id} "
|
| f"is outside embedding range [0, {vocab_size - 1}]."
|
| )
|
|
|
| title_encoded = self.encoder(
|
| self.embedding(title_tokens),
|
| src_key_padding_mask=~title_mask,
|
| )
|
| text_encoded = self.encoder(
|
| self.embedding(text_tokens),
|
| src_key_padding_mask=~text_mask,
|
| )
|
| title_vec = self._mean_pool_non_pad(title_encoded, title_mask)
|
| text_vec = self._mean_pool_non_pad(text_encoded, text_mask)
|
|
|
| combined = torch.cat([title_vec, text_vec, numerical], dim=-1)
|
| return self.regression_head(combined).squeeze(-1)
|
|
|
| def predict(self, x: pd.DataFrame, numerical: torch.Tensor) -> torch.Tensor:
|
| self.eval()
|
| with torch.no_grad():
|
| return self.forward(x, numerical)
|
|
|