File size: 4,231 Bytes
b57c46e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# EntityEmbeddingOrthogonal class adapted from LaM-SLidE
# https://github.com/ml-jku/LaM-SLidE/blob/main/src/modules/entity_embeddings.py

from typing import Optional
import math

import torch
import torch.nn as nn
import torch.nn.init as init


class EntityEmbeddingOrthogonal(nn.Module):
    n_entity_embeddings: int
    embedding_dim: int
    max_norm: Optional[float] = None
    requires_grad: bool = False

    def __init__(
        self,
        n_entiy_embeddings,
        embedding_dim,
        max_norm: Optional[float] = None,
        requires_grad: bool = False,
    ):
        super().__init__()
        self.n_entity_embeddings = n_entiy_embeddings
        self.embedding_dim = embedding_dim
        self.max_norm = max_norm

        self.embedding = nn.Embedding(n_entiy_embeddings, embedding_dim, max_norm=max_norm)
        init.orthogonal_(self.embedding.weight)
        
        self.embedding.weight.requires_grad = requires_grad

    def forward(self, entities):
        return self.embedding(entities)


class EntityEmbeddingFactorized(nn.Module):
    """
    Square-root factorized entity embeddings for more frequent updates.
    
    Instead of N separate embeddings, we decompose entity ID into:
        id = base * sqrt_n + offset
    
    where base and offset each come from a smaller embedding table of size sqrt(N).
    The final embedding is the sum (or concatenation) of base and offset embeddings.
    
    This ensures each embedding vector gets updated more frequently since
    multiple entity IDs share the same base or offset components.
    
    Args:
        n_entity_embeddings: Size of the identifier pool (e.g., 512)
        embedding_dim: Dimension of the output embedding
        max_norm: Max norm for embedding vectors
        requires_grad: Whether embeddings are trainable
        combine: How to combine base and offset embeddings ('sum' or 'concat')
    """
    
    def __init__(
        self,
        n_entiy_embeddings: int,
        embedding_dim: int,
        max_norm: Optional[float] = None,
        requires_grad: bool = True,
        combine: str = 'sum',
    ):
        super().__init__()
        self.n_entity_embeddings = n_entiy_embeddings
        self.embedding_dim = embedding_dim
        self.max_norm = max_norm
        self.combine = combine
        
        # Compute sqrt factorization size
        self.sqrt_n = math.ceil(math.sqrt(n_entiy_embeddings))
        
        # For 'concat', each sub-embedding is half the dimension
        # For 'sum', each sub-embedding is full dimension
        if combine == 'concat':
            assert embedding_dim % 2 == 0, "embedding_dim must be even for concat mode"
            sub_dim = embedding_dim // 2
        else:
            sub_dim = embedding_dim
        
        self.sub_dim = sub_dim
        
        # Base embedding (quotient part): id // sqrt_n
        self.base_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm)
        init.orthogonal_(self.base_embedding.weight)
        self.base_embedding.weight.requires_grad = requires_grad
        
        # Offset embedding (remainder part): id % sqrt_n
        self.offset_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm)
        init.orthogonal_(self.offset_embedding.weight)
        self.offset_embedding.weight.requires_grad = requires_grad
    
    def forward(self, entities):
        """
        Args:
            entities: (batch, num_entities) tensor of entity IDs
            
        Returns:
            embeddings: (batch, num_entities, embedding_dim) - normalized to unit norm
        """
        # Decompose entity IDs into base and offset
        base_ids = entities // self.sqrt_n  # (B, N)
        offset_ids = entities % self.sqrt_n  # (B, N)
        
        # Look up embeddings
        base_emb = self.base_embedding(base_ids)  # (B, N, sub_dim)
        offset_emb = self.offset_embedding(offset_ids)  # (B, N, sub_dim)
        
        # Combine embeddings
        if self.combine == 'concat':
            combined = torch.cat([base_emb, offset_emb], dim=-1)  # (B, N, embedding_dim)
        else:  # sum
            combined = base_emb + offset_emb  # (B, N, embedding_dim)
        
        return combined