File size: 2,018 Bytes
188f0cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """Common interface for every recommender model.
Trainer and evaluator depend only on this class — concrete models (MF,
TwoTower, future GNNs) plug in without any changes upstream.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import torch
from torch import Tensor, nn
class BaseRecommender(nn.Module, ABC):
"""Abstract base. Subclasses implement `score` and `score_all_items`."""
num_users: int
num_items: int
def __init__(self, num_users: int, num_items: int) -> None:
super().__init__()
if num_users < 1 or num_items < 1:
raise ValueError("num_users and num_items must be >= 1")
self.num_users = int(num_users)
self.num_items = int(num_items)
@abstractmethod
def score(self, users: Tensor, items: Tensor) -> Tensor:
"""Score (user, item) pairs.
Args:
users: int64 tensor of shape [B] or broadcastable to items.
items: int64 tensor of shape [B] or [B, K].
Returns:
float tensor with shape matching `items`.
"""
@abstractmethod
def score_all_items(self, users: Tensor) -> Tensor:
"""Score a batch of users against every item in the catalog.
Args:
users: int64 tensor of shape [B].
Returns:
float tensor of shape [B, num_items].
"""
def forward(
self, users: Tensor, pos_items: Tensor, neg_items: Tensor
) -> tuple[Tensor, Tensor]:
"""Shared forward used by BPR training.
Args:
users: [B] int64.
pos_items: [B] int64.
neg_items: [B, K] int64.
Returns:
(pos_scores [B], neg_scores [B, K]).
"""
pos_scores = self.score(users, pos_items)
# Broadcast users along the K dim before scoring negatives.
users_expanded = users.unsqueeze(-1).expand_as(neg_items)
neg_scores = self.score(users_expanded, neg_items)
return pos_scores, neg_scores
|