File size: 3,107 Bytes
d820920
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Isengard - User Tower

Neural network that encodes a user's wine preferences from their reviewed wines.
Uses attention-weighted aggregation of wine embeddings based on user ratings.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

from .config import (
    EMBEDDING_DIM,
    USER_VECTOR_DIM,
    HIDDEN_DIM,
)


class UserTower(nn.Module):
    """
    Isengard: Encodes user preferences from their reviewed wines.

    Architecture:
        1. Rating-weighted attention over wine embeddings
        2. MLP: 768 → 256 → 128
        3. L2 normalization to unit sphere

    Input:
        wine_embeddings: (batch, num_wines, 768) - embeddings of reviewed wines
        ratings: (batch, num_wines) - user ratings for each wine
        mask: (batch, num_wines) - optional mask for padding

    Output:
        user_vector: (batch, 128) - normalized user embedding
    """

    def __init__(
        self,
        embedding_dim: int = EMBEDDING_DIM,
        hidden_dim: int = HIDDEN_DIM,
        output_dim: int = USER_VECTOR_DIM,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.output_dim = output_dim

        # MLP layers
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

        # Dropout for regularization
        self.dropout = nn.Dropout(0.1)

    def forward(
        self,
        wine_embeddings: torch.Tensor,
        ratings: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass through the user tower.

        Args:
            wine_embeddings: (batch, num_wines, embedding_dim)
            ratings: (batch, num_wines) - raw ratings (1-5 scale)
            mask: (batch, num_wines) - 1 for valid wines, 0 for padding

        Returns:
            user_vector: (batch, output_dim) - L2 normalized
        """
        # Convert ratings to attention weights
        # Higher ratings = more attention
        # Shift ratings to be positive and scale
        attention_weights = (ratings - 2.5) / 2.5  # Normalize: 1→-0.6, 5→1.0
        attention_weights = F.softmax(attention_weights, dim=-1)

        # Apply mask if provided
        if mask is not None:
            attention_weights = attention_weights * mask
            # Re-normalize after masking
            attention_weights = attention_weights / (
                attention_weights.sum(dim=-1, keepdim=True) + 1e-8
            )

        # Weighted aggregation: (batch, num_wines) @ (batch, num_wines, embed_dim)
        # Result: (batch, embed_dim)
        aggregated = torch.bmm(
            attention_weights.unsqueeze(1),  # (batch, 1, num_wines)
            wine_embeddings,  # (batch, num_wines, embed_dim)
        ).squeeze(1)  # (batch, embed_dim)

        # MLP projection
        x = F.relu(self.fc1(aggregated))
        x = self.dropout(x)
        user_vector = self.fc2(x)

        # L2 normalize to unit sphere
        user_vector = F.normalize(user_vector, p=2, dim=-1)

        return user_vector