"""Two-Tower neural recommender with side features. Each tower takes (id_embedding, side_features) -> MLP -> output_dim vector. Final score is the dot product of the user and item vectors, which preserves the retrieval structure — `score_all_items` becomes a single matmul. Design notes: - Feature tables (`user_features`, `item_features`) live on the module as non-trainable `register_buffer`s, so `.to(device)` moves them and `state_dict()` persists them alongside learnable params. - No BatchNorm — during eval we call the item tower on ALL items at once, where batch statistics would be meaningless. Dropout is enough. - Dot-product (not concat + MLP) at the top: preserves fast retrieval. """ from __future__ import annotations import torch from torch import Tensor, nn from ..config import TwoTowerModelConfig from .base import BaseRecommender _ACTIVATIONS: dict[str, type[nn.Module]] = { "relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh, } class _Tower(nn.Module): """Id-embedding + side-feature concatenation -> MLP -> output_dim vector.""" def __init__( self, *, vocab_size: int, id_embedding_dim: int, side_feat_dim: int, hidden_dims: tuple[int, ...], output_dim: int, dropout: float, activation: str, init_std: float, ) -> None: super().__init__() self.id_emb = nn.Embedding(vocab_size, id_embedding_dim) nn.init.normal_(self.id_emb.weight, mean=0.0, std=init_std) act_cls = _ACTIVATIONS[activation] layers: list[nn.Module] = [] in_dim = id_embedding_dim + side_feat_dim for h in hidden_dims: layers.append(nn.Linear(in_dim, h)) layers.append(act_cls()) if dropout > 0: layers.append(nn.Dropout(dropout)) in_dim = h layers.append(nn.Linear(in_dim, output_dim)) self.mlp = nn.Sequential(*layers) def forward(self, ids: Tensor, side_features: Tensor) -> Tensor: return self.mlp(torch.cat([self.id_emb(ids), side_features], dim=-1)) class TwoTower(BaseRecommender): def __init__( self, *, num_users: int, num_items: int, user_feat_dim: int, item_feat_dim: int, cfg: TwoTowerModelConfig, user_features: torch.Tensor, item_features: torch.Tensor, ) -> None: super().__init__(num_users=num_users, num_items=num_items) if user_features.shape != (num_users, user_feat_dim): raise ValueError( f"user_features shape {tuple(user_features.shape)} != " f"({num_users}, {user_feat_dim})" ) if item_features.shape != (num_items, item_feat_dim): raise ValueError( f"item_features shape {tuple(item_features.shape)} != " f"({num_items}, {item_feat_dim})" ) self.user_tower = _Tower( vocab_size=num_users, id_embedding_dim=cfg.user_id_embedding_dim, side_feat_dim=user_feat_dim, hidden_dims=cfg.mlp_hidden_dims, output_dim=cfg.output_dim, dropout=cfg.dropout, activation=cfg.activation, init_std=cfg.init_std, ) self.item_tower = _Tower( vocab_size=num_items, id_embedding_dim=cfg.item_id_embedding_dim, side_feat_dim=item_feat_dim, hidden_dims=cfg.mlp_hidden_dims, output_dim=cfg.output_dim, dropout=cfg.dropout, activation=cfg.activation, init_std=cfg.init_std, ) # Non-trainable feature lookup tables. Registered as buffers so # `.to(device)` moves them and `state_dict()` saves them. self.register_buffer("user_features", user_features.float().contiguous()) self.register_buffer("item_features", item_features.float().contiguous()) # ---------- tower reprs ---------- def _user_repr(self, user_ids: Tensor) -> Tensor: return self.user_tower(user_ids, self.user_features[user_ids]) def _item_repr(self, item_ids: Tensor) -> Tensor: return self.item_tower(item_ids, self.item_features[item_ids]) def _all_item_reprs(self) -> Tensor: all_ids = torch.arange(self.num_items, device=self.item_features.device) return self.item_tower(all_ids, self.item_features) # ---------- BaseRecommender interface ---------- def score(self, users: Tensor, items: Tensor) -> Tensor: u = self._user_repr(users) i = self._item_repr(items) return (u * i).sum(dim=-1) def score_all_items(self, users: Tensor) -> Tensor: u = self._user_repr(users) # [B, D] i = self._all_item_reprs() # [N, D] return u @ i.T # [B, N]