| """Two-Tower neural recommender with side features. |
| |
| Each tower takes (id_embedding, side_features) -> MLP -> output_dim vector. |
| Final score is the dot product of the user and item vectors, which preserves |
| the retrieval structure — `score_all_items` becomes a single matmul. |
| |
| Design notes: |
| - Feature tables (`user_features`, `item_features`) live on the module |
| as non-trainable `register_buffer`s, so `.to(device)` moves them and |
| `state_dict()` persists them alongside learnable params. |
| - No BatchNorm — during eval we call the item tower on ALL items at once, |
| where batch statistics would be meaningless. Dropout is enough. |
| - Dot-product (not concat + MLP) at the top: preserves fast retrieval. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from ..config import TwoTowerModelConfig |
| from .base import BaseRecommender |
|
|
| _ACTIVATIONS: dict[str, type[nn.Module]] = { |
| "relu": nn.ReLU, |
| "gelu": nn.GELU, |
| "tanh": nn.Tanh, |
| } |
|
|
|
|
| class _Tower(nn.Module): |
| """Id-embedding + side-feature concatenation -> MLP -> output_dim vector.""" |
|
|
| def __init__( |
| self, |
| *, |
| vocab_size: int, |
| id_embedding_dim: int, |
| side_feat_dim: int, |
| hidden_dims: tuple[int, ...], |
| output_dim: int, |
| dropout: float, |
| activation: str, |
| init_std: float, |
| ) -> None: |
| super().__init__() |
| self.id_emb = nn.Embedding(vocab_size, id_embedding_dim) |
| nn.init.normal_(self.id_emb.weight, mean=0.0, std=init_std) |
|
|
| act_cls = _ACTIVATIONS[activation] |
| layers: list[nn.Module] = [] |
| in_dim = id_embedding_dim + side_feat_dim |
| for h in hidden_dims: |
| layers.append(nn.Linear(in_dim, h)) |
| layers.append(act_cls()) |
| if dropout > 0: |
| layers.append(nn.Dropout(dropout)) |
| in_dim = h |
| layers.append(nn.Linear(in_dim, output_dim)) |
| self.mlp = nn.Sequential(*layers) |
|
|
| def forward(self, ids: Tensor, side_features: Tensor) -> Tensor: |
| return self.mlp(torch.cat([self.id_emb(ids), side_features], dim=-1)) |
|
|
|
|
| class TwoTower(BaseRecommender): |
| def __init__( |
| self, |
| *, |
| num_users: int, |
| num_items: int, |
| user_feat_dim: int, |
| item_feat_dim: int, |
| cfg: TwoTowerModelConfig, |
| user_features: torch.Tensor, |
| item_features: torch.Tensor, |
| ) -> None: |
| super().__init__(num_users=num_users, num_items=num_items) |
|
|
| if user_features.shape != (num_users, user_feat_dim): |
| raise ValueError( |
| f"user_features shape {tuple(user_features.shape)} != " |
| f"({num_users}, {user_feat_dim})" |
| ) |
| if item_features.shape != (num_items, item_feat_dim): |
| raise ValueError( |
| f"item_features shape {tuple(item_features.shape)} != " |
| f"({num_items}, {item_feat_dim})" |
| ) |
|
|
| self.user_tower = _Tower( |
| vocab_size=num_users, |
| id_embedding_dim=cfg.user_id_embedding_dim, |
| side_feat_dim=user_feat_dim, |
| hidden_dims=cfg.mlp_hidden_dims, |
| output_dim=cfg.output_dim, |
| dropout=cfg.dropout, |
| activation=cfg.activation, |
| init_std=cfg.init_std, |
| ) |
| self.item_tower = _Tower( |
| vocab_size=num_items, |
| id_embedding_dim=cfg.item_id_embedding_dim, |
| side_feat_dim=item_feat_dim, |
| hidden_dims=cfg.mlp_hidden_dims, |
| output_dim=cfg.output_dim, |
| dropout=cfg.dropout, |
| activation=cfg.activation, |
| init_std=cfg.init_std, |
| ) |
|
|
| |
| |
| self.register_buffer("user_features", user_features.float().contiguous()) |
| self.register_buffer("item_features", item_features.float().contiguous()) |
|
|
| |
|
|
| def _user_repr(self, user_ids: Tensor) -> Tensor: |
| return self.user_tower(user_ids, self.user_features[user_ids]) |
|
|
| def _item_repr(self, item_ids: Tensor) -> Tensor: |
| return self.item_tower(item_ids, self.item_features[item_ids]) |
|
|
| def _all_item_reprs(self) -> Tensor: |
| all_ids = torch.arange(self.num_items, device=self.item_features.device) |
| return self.item_tower(all_ids, self.item_features) |
|
|
| |
|
|
| def score(self, users: Tensor, items: Tensor) -> Tensor: |
| u = self._user_repr(users) |
| i = self._item_repr(items) |
| return (u * i).sum(dim=-1) |
|
|
| def score_all_items(self, users: Tensor) -> Tensor: |
| u = self._user_repr(users) |
| i = self._all_item_reprs() |
| return u @ i.T |
|
|