recsys-ecommerce / src /models /two_tower.py
dscsdvdfsvs's picture
fix: upload src folder with model classes
80843b0 verified
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# ── Dataset ──────────────────────────────────────────────────────────────────
class InteractionDataset(Dataset):
"""
Feeds (user_features, item_features, label) triplets to the model.
Positive label = 1 (observed interaction)
Negative label = 0 (randomly sampled unobserved pair)
We use 1:1 negative sampling ratio.
"""
def __init__(self, train_df, user_feats_norm, item_feats_norm, neg_ratio=1):
self.user_feat_cols = [c for c in user_feats_norm.columns if c != "user_idx"]
self.item_feat_cols = [c for c in item_feats_norm.columns if c != "item_idx"]
self.user_feat_map = (
user_feats_norm.set_index("user_idx")[self.user_feat_cols]
.astype(np.float32).to_dict("index")
)
self.item_feat_map = (
item_feats_norm.set_index("item_idx")[self.item_feat_cols]
.astype(np.float32).to_dict("index")
)
self.all_items = list(item_feats_norm["item_idx"].values)
self.neg_ratio = neg_ratio
# Build positive pairs
self.positives = train_df[["user_idx","item_idx"]].values.tolist()
# Build user history for negative sampling
self.user_history = (
train_df.groupby("user_idx")["item_idx"].apply(set).to_dict()
)
# Pre-sample negatives for speed
self.negatives = self._sample_negatives()
self.pairs = (
[(u, i, 1.0) for u, i in self.positives] +
[(u, i, 0.0) for u, i in self.negatives]
)
def _sample_negatives(self):
negs = []
for user_idx, item_idx in self.positives:
seen = self.user_history.get(user_idx, set())
for _ in range(self.neg_ratio):
neg = np.random.choice(self.all_items)
while neg in seen:
neg = np.random.choice(self.all_items)
negs.append((user_idx, neg))
return negs
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
user_idx, item_idx, label = self.pairs[idx]
u_feats = list(self.user_feat_map.get(
user_idx, {c: 0.0 for c in self.user_feat_cols}
).values())
i_feats = list(self.item_feat_map.get(
item_idx, {c: 0.0 for c in self.item_feat_cols}
).values())
return (
torch.tensor(u_feats, dtype=torch.float32),
torch.tensor(i_feats, dtype=torch.float32),
torch.tensor(label, dtype=torch.float32),
)
# ── Model ─────────────────────────────────────────────────────────────────────
class TwoTowerModel(nn.Module):
"""
Two-Tower (Dual Encoder) architecture.
User tower : user features → dense embedding
Item tower : item features → dense embedding
Score : cosine similarity of the two embeddings
Why two towers?
At serving time, item embeddings are pre-computed offline.
Only the user tower runs at request time → fast inference.
This is how YouTube, Pinterest, and most large-scale recommenders work.
Architecture:
Input → Linear → BatchNorm → ReLU → Dropout
→ Linear → BatchNorm → ReLU → Dropout
→ Linear → L2-normalised embedding
"""
def __init__(self, user_dim: int, item_dim: int, embedding_dim: int = 64,
hidden_dim: int = 128, dropout: float = 0.2):
super().__init__()
self.embedding_dim = embedding_dim
self.user_tower = nn.Sequential(
nn.Linear(user_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embedding_dim),
)
self.item_tower = nn.Sequential(
nn.Linear(item_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embedding_dim),
)
def forward(self, user_feats, item_feats):
u_emb = self.user_tower(user_feats)
i_emb = self.item_tower(item_feats)
# L2 normalise → cosine similarity = dot product
u_emb = nn.functional.normalize(u_emb, dim=1)
i_emb = nn.functional.normalize(i_emb, dim=1)
score = (u_emb * i_emb).sum(dim=1)
return score
def get_user_embedding(self, user_feats):
with torch.no_grad():
emb = self.user_tower(user_feats)
return nn.functional.normalize(emb, dim=1)
def get_item_embedding(self, item_feats):
with torch.no_grad():
emb = self.item_tower(item_feats)
return nn.functional.normalize(emb, dim=1)
# ── Trainer ───────────────────────────────────────────────────────────────────
class TwoTowerRecommender:
"""
Wraps TwoTowerModel with fit / recommend interface
matching PopularityRecommender and ALSRecommender.
"""
def __init__(self, embedding_dim=64, hidden_dim=128,
dropout=0.2, lr=1e-3, epochs=10, batch_size=512):
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.lr = lr
self.epochs = epochs
self.batch_size = batch_size
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.is_fitted = False
# Stored after fit for recommend()
self.item_embeddings = None
self.user_feat_cols = None
self.item_feat_cols = None
self.user_feat_map = None
self.item_feats_norm = None
self.user_history = {}
def fit(self, train_df, user_feats_norm, item_feats_norm):
print(f"Training on: {self.device}")
dataset = InteractionDataset(train_df, user_feats_norm, item_feats_norm)
loader = DataLoader(dataset, batch_size=self.batch_size,
shuffle=True, num_workers=0)
self.user_feat_cols = dataset.user_feat_cols
self.item_feat_cols = dataset.item_feat_cols
self.user_feat_map = dataset.user_feat_map
self.item_feats_norm = item_feats_norm
self.user_history = dataset.user_history
user_dim = len(self.user_feat_cols)
item_dim = len(self.item_feat_cols)
self.model = TwoTowerModel(
user_dim = user_dim,
item_dim = item_dim,
embedding_dim = self.embedding_dim,
hidden_dim = self.hidden_dim,
dropout = self.dropout,
).to(self.device)
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
criterion = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
self.model.train()
for epoch in range(self.epochs):
total_loss = 0.0
for u_feats, i_feats, labels in loader:
u_feats = u_feats.to(self.device)
i_feats = i_feats.to(self.device)
labels = labels.to(self.device)
optimizer.zero_grad()
scores = self.model(u_feats, i_feats)
loss = criterion(scores, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
avg_loss = total_loss / len(loader)
print(f" Epoch {epoch+1:02d}/{self.epochs} — loss: {avg_loss:.4f}")
# Pre-compute all item embeddings for fast inference
self._precompute_item_embeddings()
self.is_fitted = True
print("Training complete.")
return self
def _precompute_item_embeddings(self):
"""Compute and cache all item embeddings once after training."""
self.model.eval()
item_feat_matrix = (
self.item_feats_norm
.set_index("item_idx")[self.item_feat_cols]
.astype(np.float32)
)
tensor = torch.tensor(
item_feat_matrix.values, dtype=torch.float32
).to(self.device)
with torch.no_grad():
embs = self.model.get_item_embedding(tensor)
self.item_embeddings = embs.cpu().numpy()
self.item_indices = item_feat_matrix.index.values
print(f"Pre-computed {len(self.item_indices):,} item embeddings")
def recommend(self, user_idx: int, k: int = 10) -> np.ndarray:
if not self.is_fitted:
raise RuntimeError("Call fit() first")
self.model.eval()
u_feats_dict = self.user_feat_map.get(
user_idx, {c: 0.0 for c in self.user_feat_cols}
)
u_tensor = torch.tensor(
list(u_feats_dict.values()), dtype=torch.float32
).unsqueeze(0).to(self.device)
with torch.no_grad():
u_emb = self.model.get_user_embedding(u_tensor).cpu().numpy()
# Dot product with all item embeddings (cosine sim since normalised)
scores = self.item_embeddings @ u_emb.T
scores = scores.flatten()
# Filter seen items
seen = self.user_history.get(user_idx, set())
for i, item_idx in enumerate(self.item_indices):
if item_idx in seen:
scores[i] = -np.inf
top_k_local = np.argsort(scores)[::-1][:k]
return self.item_indices[top_k_local]
def recommend_batch(self, user_indices, k: int = 10) -> dict:
return {u: self.recommend(u, k) for u in user_indices}
def save(self, path):
path = Path(path)
torch.save(self.model.state_dict(), path.with_suffix(".pt"))
tmp_model = self.model
self.model = None
with open(path, "wb") as f:
pickle.dump(self, f)
self.model = tmp_model
print(f"Saved TwoTowerRecommender to {path}")
@staticmethod
def load(path, user_dim, item_dim):
path = Path(path)
with open(path, "rb") as f:
obj = pickle.load(f)
obj.model = TwoTowerModel(
user_dim=user_dim, item_dim=item_dim,
embedding_dim=obj.embedding_dim, hidden_dim=obj.hidden_dim,
dropout=obj.dropout,
)
obj.model.load_state_dict(torch.load(path.with_suffix(".pt"), map_location="cpu"))
obj.model.eval()
return obj