# EntityEmbeddingOrthogonal class adapted from LaM-SLidE # https://github.com/ml-jku/LaM-SLidE/blob/main/src/modules/entity_embeddings.py from typing import Optional import math import torch import torch.nn as nn import torch.nn.init as init class EntityEmbeddingOrthogonal(nn.Module): n_entity_embeddings: int embedding_dim: int max_norm: Optional[float] = None requires_grad: bool = False def __init__( self, n_entiy_embeddings, embedding_dim, max_norm: Optional[float] = None, requires_grad: bool = False, ): super().__init__() self.n_entity_embeddings = n_entiy_embeddings self.embedding_dim = embedding_dim self.max_norm = max_norm self.embedding = nn.Embedding(n_entiy_embeddings, embedding_dim, max_norm=max_norm) init.orthogonal_(self.embedding.weight) self.embedding.weight.requires_grad = requires_grad def forward(self, entities): return self.embedding(entities) class EntityEmbeddingFactorized(nn.Module): """ Square-root factorized entity embeddings for more frequent updates. Instead of N separate embeddings, we decompose entity ID into: id = base * sqrt_n + offset where base and offset each come from a smaller embedding table of size sqrt(N). The final embedding is the sum (or concatenation) of base and offset embeddings. This ensures each embedding vector gets updated more frequently since multiple entity IDs share the same base or offset components. Args: n_entity_embeddings: Size of the identifier pool (e.g., 512) embedding_dim: Dimension of the output embedding max_norm: Max norm for embedding vectors requires_grad: Whether embeddings are trainable combine: How to combine base and offset embeddings ('sum' or 'concat') """ def __init__( self, n_entiy_embeddings: int, embedding_dim: int, max_norm: Optional[float] = None, requires_grad: bool = True, combine: str = 'sum', ): super().__init__() self.n_entity_embeddings = n_entiy_embeddings self.embedding_dim = embedding_dim self.max_norm = max_norm self.combine = combine # Compute sqrt factorization size self.sqrt_n = math.ceil(math.sqrt(n_entiy_embeddings)) # For 'concat', each sub-embedding is half the dimension # For 'sum', each sub-embedding is full dimension if combine == 'concat': assert embedding_dim % 2 == 0, "embedding_dim must be even for concat mode" sub_dim = embedding_dim // 2 else: sub_dim = embedding_dim self.sub_dim = sub_dim # Base embedding (quotient part): id // sqrt_n self.base_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm) init.orthogonal_(self.base_embedding.weight) self.base_embedding.weight.requires_grad = requires_grad # Offset embedding (remainder part): id % sqrt_n self.offset_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm) init.orthogonal_(self.offset_embedding.weight) self.offset_embedding.weight.requires_grad = requires_grad def forward(self, entities): """ Args: entities: (batch, num_entities) tensor of entity IDs Returns: embeddings: (batch, num_entities, embedding_dim) - normalized to unit norm """ # Decompose entity IDs into base and offset base_ids = entities // self.sqrt_n # (B, N) offset_ids = entities % self.sqrt_n # (B, N) # Look up embeddings base_emb = self.base_embedding(base_ids) # (B, N, sub_dim) offset_emb = self.offset_embedding(offset_ids) # (B, N, sub_dim) # Combine embeddings if self.combine == 'concat': combined = torch.cat([base_emb, offset_emb], dim=-1) # (B, N, embedding_dim) else: # sum combined = base_emb + offset_emb # (B, N, embedding_dim) return combined