score-ae / src /model /entity_embeddings.py
hroth's picture
Upload 90 files
b57c46e verified
raw
history blame
4.23 kB
# EntityEmbeddingOrthogonal class adapted from LaM-SLidE
# https://github.com/ml-jku/LaM-SLidE/blob/main/src/modules/entity_embeddings.py
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.init as init
class EntityEmbeddingOrthogonal(nn.Module):
n_entity_embeddings: int
embedding_dim: int
max_norm: Optional[float] = None
requires_grad: bool = False
def __init__(
self,
n_entiy_embeddings,
embedding_dim,
max_norm: Optional[float] = None,
requires_grad: bool = False,
):
super().__init__()
self.n_entity_embeddings = n_entiy_embeddings
self.embedding_dim = embedding_dim
self.max_norm = max_norm
self.embedding = nn.Embedding(n_entiy_embeddings, embedding_dim, max_norm=max_norm)
init.orthogonal_(self.embedding.weight)
self.embedding.weight.requires_grad = requires_grad
def forward(self, entities):
return self.embedding(entities)
class EntityEmbeddingFactorized(nn.Module):
"""
Square-root factorized entity embeddings for more frequent updates.
Instead of N separate embeddings, we decompose entity ID into:
id = base * sqrt_n + offset
where base and offset each come from a smaller embedding table of size sqrt(N).
The final embedding is the sum (or concatenation) of base and offset embeddings.
This ensures each embedding vector gets updated more frequently since
multiple entity IDs share the same base or offset components.
Args:
n_entity_embeddings: Size of the identifier pool (e.g., 512)
embedding_dim: Dimension of the output embedding
max_norm: Max norm for embedding vectors
requires_grad: Whether embeddings are trainable
combine: How to combine base and offset embeddings ('sum' or 'concat')
"""
def __init__(
self,
n_entiy_embeddings: int,
embedding_dim: int,
max_norm: Optional[float] = None,
requires_grad: bool = True,
combine: str = 'sum',
):
super().__init__()
self.n_entity_embeddings = n_entiy_embeddings
self.embedding_dim = embedding_dim
self.max_norm = max_norm
self.combine = combine
# Compute sqrt factorization size
self.sqrt_n = math.ceil(math.sqrt(n_entiy_embeddings))
# For 'concat', each sub-embedding is half the dimension
# For 'sum', each sub-embedding is full dimension
if combine == 'concat':
assert embedding_dim % 2 == 0, "embedding_dim must be even for concat mode"
sub_dim = embedding_dim // 2
else:
sub_dim = embedding_dim
self.sub_dim = sub_dim
# Base embedding (quotient part): id // sqrt_n
self.base_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm)
init.orthogonal_(self.base_embedding.weight)
self.base_embedding.weight.requires_grad = requires_grad
# Offset embedding (remainder part): id % sqrt_n
self.offset_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm)
init.orthogonal_(self.offset_embedding.weight)
self.offset_embedding.weight.requires_grad = requires_grad
def forward(self, entities):
"""
Args:
entities: (batch, num_entities) tensor of entity IDs
Returns:
embeddings: (batch, num_entities, embedding_dim) - normalized to unit norm
"""
# Decompose entity IDs into base and offset
base_ids = entities // self.sqrt_n # (B, N)
offset_ids = entities % self.sqrt_n # (B, N)
# Look up embeddings
base_emb = self.base_embedding(base_ids) # (B, N, sub_dim)
offset_emb = self.offset_embedding(offset_ids) # (B, N, sub_dim)
# Combine embeddings
if self.combine == 'concat':
combined = torch.cat([base_emb, offset_emb], dim=-1) # (B, N, embedding_dim)
else: # sum
combined = base_emb + offset_emb # (B, N, embedding_dim)
return combined