rc-docker / src /recsys /models /two_tower.py
moecr7
Dockerize rc-ranked: FastAPI service for HF Spaces
188f0cf
"""Two-Tower neural recommender with side features.
Each tower takes (id_embedding, side_features) -> MLP -> output_dim vector.
Final score is the dot product of the user and item vectors, which preserves
the retrieval structure — `score_all_items` becomes a single matmul.
Design notes:
- Feature tables (`user_features`, `item_features`) live on the module
as non-trainable `register_buffer`s, so `.to(device)` moves them and
`state_dict()` persists them alongside learnable params.
- No BatchNorm — during eval we call the item tower on ALL items at once,
where batch statistics would be meaningless. Dropout is enough.
- Dot-product (not concat + MLP) at the top: preserves fast retrieval.
"""
from __future__ import annotations
import torch
from torch import Tensor, nn
from ..config import TwoTowerModelConfig
from .base import BaseRecommender
_ACTIVATIONS: dict[str, type[nn.Module]] = {
"relu": nn.ReLU,
"gelu": nn.GELU,
"tanh": nn.Tanh,
}
class _Tower(nn.Module):
"""Id-embedding + side-feature concatenation -> MLP -> output_dim vector."""
def __init__(
self,
*,
vocab_size: int,
id_embedding_dim: int,
side_feat_dim: int,
hidden_dims: tuple[int, ...],
output_dim: int,
dropout: float,
activation: str,
init_std: float,
) -> None:
super().__init__()
self.id_emb = nn.Embedding(vocab_size, id_embedding_dim)
nn.init.normal_(self.id_emb.weight, mean=0.0, std=init_std)
act_cls = _ACTIVATIONS[activation]
layers: list[nn.Module] = []
in_dim = id_embedding_dim + side_feat_dim
for h in hidden_dims:
layers.append(nn.Linear(in_dim, h))
layers.append(act_cls())
if dropout > 0:
layers.append(nn.Dropout(dropout))
in_dim = h
layers.append(nn.Linear(in_dim, output_dim))
self.mlp = nn.Sequential(*layers)
def forward(self, ids: Tensor, side_features: Tensor) -> Tensor:
return self.mlp(torch.cat([self.id_emb(ids), side_features], dim=-1))
class TwoTower(BaseRecommender):
def __init__(
self,
*,
num_users: int,
num_items: int,
user_feat_dim: int,
item_feat_dim: int,
cfg: TwoTowerModelConfig,
user_features: torch.Tensor,
item_features: torch.Tensor,
) -> None:
super().__init__(num_users=num_users, num_items=num_items)
if user_features.shape != (num_users, user_feat_dim):
raise ValueError(
f"user_features shape {tuple(user_features.shape)} != "
f"({num_users}, {user_feat_dim})"
)
if item_features.shape != (num_items, item_feat_dim):
raise ValueError(
f"item_features shape {tuple(item_features.shape)} != "
f"({num_items}, {item_feat_dim})"
)
self.user_tower = _Tower(
vocab_size=num_users,
id_embedding_dim=cfg.user_id_embedding_dim,
side_feat_dim=user_feat_dim,
hidden_dims=cfg.mlp_hidden_dims,
output_dim=cfg.output_dim,
dropout=cfg.dropout,
activation=cfg.activation,
init_std=cfg.init_std,
)
self.item_tower = _Tower(
vocab_size=num_items,
id_embedding_dim=cfg.item_id_embedding_dim,
side_feat_dim=item_feat_dim,
hidden_dims=cfg.mlp_hidden_dims,
output_dim=cfg.output_dim,
dropout=cfg.dropout,
activation=cfg.activation,
init_std=cfg.init_std,
)
# Non-trainable feature lookup tables. Registered as buffers so
# `.to(device)` moves them and `state_dict()` saves them.
self.register_buffer("user_features", user_features.float().contiguous())
self.register_buffer("item_features", item_features.float().contiguous())
# ---------- tower reprs ----------
def _user_repr(self, user_ids: Tensor) -> Tensor:
return self.user_tower(user_ids, self.user_features[user_ids])
def _item_repr(self, item_ids: Tensor) -> Tensor:
return self.item_tower(item_ids, self.item_features[item_ids])
def _all_item_reprs(self) -> Tensor:
all_ids = torch.arange(self.num_items, device=self.item_features.device)
return self.item_tower(all_ids, self.item_features)
# ---------- BaseRecommender interface ----------
def score(self, users: Tensor, items: Tensor) -> Tensor:
u = self._user_repr(users)
i = self._item_repr(items)
return (u * i).sum(dim=-1)
def score_all_items(self, users: Tensor) -> Tensor:
u = self._user_repr(users) # [B, D]
i = self._all_item_reprs() # [N, D]
return u @ i.T # [B, N]