Reddit / models /reddit_model.py
cyrilfrl's picture
hope it works this time
44748ce verified
from __future__ import annotations
from types import SimpleNamespace
import pandas as pd
import torch
import torch.nn as nn
from constants import NUMERICAL_FEATURES
class RedditModel(nn.Module):
"""
Supports both architectures used in this project:
- Legacy single-encoder model (embedding/encoder)
- New dual-encoder model (embedding_title+embedding_text, encoder_title+encoder_text)
"""
def __init__(self, params: SimpleNamespace, target_mean: float | None = None):
super().__init__()
self.DEVICE = params.DEVICE
self.is_dual_encoder = hasattr(params, "VOCAB_SIZE_TEXT") and hasattr(params, "VOCAB_SIZE_TITLE")
if target_mean is not None:
self.register_buffer("target_mean", torch.tensor(float(target_mean), dtype=torch.float32))
if self.is_dual_encoder:
self.embedding_text = nn.Embedding(
num_embeddings=params.VOCAB_SIZE_TEXT,
embedding_dim=params.D_MODEL,
)
self.embedding_title = nn.Embedding(
num_embeddings=params.VOCAB_SIZE_TITLE,
embedding_dim=params.D_MODEL,
)
self.encoder_text = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(
d_model=params.D_MODEL,
nhead=params.N_HEAD,
dim_feedforward=params.DIM_FEEDFORWARD,
batch_first=True,
),
num_layers=params.NB_ENCODER_LAYERS,
)
self.encoder_title = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(
d_model=params.D_MODEL,
nhead=params.N_HEAD,
dim_feedforward=params.DIM_FEEDFORWARD,
batch_first=True,
),
num_layers=params.NB_ENCODER_LAYERS,
)
self.attention_pooling_text = nn.Linear(params.D_MODEL, 1)
else:
self.embedding = nn.Embedding(
num_embeddings=params.VOCAB_SIZE,
embedding_dim=params.D_MODEL,
)
self.encoder = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(
d_model=params.D_MODEL,
nhead=params.N_HEAD,
dim_feedforward=params.DIM_FEEDFORWARD,
batch_first=True,
),
num_layers=params.NB_ENCODER_LAYERS,
)
layers = []
in_dim = params.D_MODEL * 2 + len(NUMERICAL_FEATURES)
for _ in range(params.NB_HIDDEN_LAYERS):
layers.append(nn.Linear(in_dim, params.HIDDEN_SIZE))
layers.append(nn.ReLU())
layers.append(nn.Dropout(params.DROPOUT_RATE))
in_dim = params.HIDDEN_SIZE
layers.append(nn.Linear(in_dim, 1))
self.regression_head = nn.Sequential(*layers)
@staticmethod
def _mean_pool_non_pad(encoded: torch.Tensor, token_mask: torch.Tensor) -> torch.Tensor:
# token_mask is 1 for tokens and 0 for padding.
mask = token_mask.unsqueeze(-1).float()
return (encoded * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
def forward(self, x: pd.DataFrame, numerical: torch.Tensor) -> torch.Tensor:
title_tokens = torch.stack(x["title_tokens"].tolist()).long().to(self.DEVICE)
title_mask = torch.stack(x["title_mask"].tolist()).bool().to(self.DEVICE)
text_tokens = torch.stack(x["text_tokens"].tolist()).long().to(self.DEVICE)
text_mask = torch.stack(x["text_mask"].tolist()).bool().to(self.DEVICE)
if self.is_dual_encoder:
# Hugging Face attention masks are 1=token, 0=pad while PyTorch expects True=pad.
title_vec = self.encoder_title(
self.embedding_title(title_tokens),
src_key_padding_mask=~title_mask,
).mean(dim=1)
text_encoded = self.encoder_text(
self.embedding_text(text_tokens),
src_key_padding_mask=~text_mask,
)
scores = self.attention_pooling_text(text_encoded)
weights = torch.softmax(scores, dim=1)
text_vec = (text_encoded * weights).sum(dim=1)
else:
max_token_id = int(torch.maximum(title_tokens.max(), text_tokens.max()).item())
vocab_size = int(self.embedding.num_embeddings)
if max_token_id >= vocab_size:
raise ValueError(
f"Legacy checkpoint vocabulary mismatch: token id {max_token_id} "
f"is outside embedding range [0, {vocab_size - 1}]."
)
title_encoded = self.encoder(
self.embedding(title_tokens),
src_key_padding_mask=~title_mask,
)
text_encoded = self.encoder(
self.embedding(text_tokens),
src_key_padding_mask=~text_mask,
)
title_vec = self._mean_pool_non_pad(title_encoded, title_mask)
text_vec = self._mean_pool_non_pad(text_encoded, text_mask)
combined = torch.cat([title_vec, text_vec, numerical], dim=-1)
return self.regression_head(combined).squeeze(-1)
def predict(self, x: pd.DataFrame, numerical: torch.Tensor) -> torch.Tensor:
self.eval()
with torch.no_grad():
return self.forward(x, numerical)