Spaces:
Sleeping
Sleeping
| # EntityEmbeddingOrthogonal class adapted from LaM-SLidE | |
| # https://github.com/ml-jku/LaM-SLidE/blob/main/src/modules/entity_embeddings.py | |
| from typing import Optional | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.init as init | |
| class EntityEmbeddingOrthogonal(nn.Module): | |
| n_entity_embeddings: int | |
| embedding_dim: int | |
| max_norm: Optional[float] = None | |
| requires_grad: bool = False | |
| def __init__( | |
| self, | |
| n_entiy_embeddings, | |
| embedding_dim, | |
| max_norm: Optional[float] = None, | |
| requires_grad: bool = False, | |
| ): | |
| super().__init__() | |
| self.n_entity_embeddings = n_entiy_embeddings | |
| self.embedding_dim = embedding_dim | |
| self.max_norm = max_norm | |
| self.embedding = nn.Embedding(n_entiy_embeddings, embedding_dim, max_norm=max_norm) | |
| init.orthogonal_(self.embedding.weight) | |
| self.embedding.weight.requires_grad = requires_grad | |
| def forward(self, entities): | |
| return self.embedding(entities) | |
| class EntityEmbeddingFactorized(nn.Module): | |
| """ | |
| Square-root factorized entity embeddings for more frequent updates. | |
| Instead of N separate embeddings, we decompose entity ID into: | |
| id = base * sqrt_n + offset | |
| where base and offset each come from a smaller embedding table of size sqrt(N). | |
| The final embedding is the sum (or concatenation) of base and offset embeddings. | |
| This ensures each embedding vector gets updated more frequently since | |
| multiple entity IDs share the same base or offset components. | |
| Args: | |
| n_entity_embeddings: Size of the identifier pool (e.g., 512) | |
| embedding_dim: Dimension of the output embedding | |
| max_norm: Max norm for embedding vectors | |
| requires_grad: Whether embeddings are trainable | |
| combine: How to combine base and offset embeddings ('sum' or 'concat') | |
| """ | |
| def __init__( | |
| self, | |
| n_entiy_embeddings: int, | |
| embedding_dim: int, | |
| max_norm: Optional[float] = None, | |
| requires_grad: bool = True, | |
| combine: str = 'sum', | |
| ): | |
| super().__init__() | |
| self.n_entity_embeddings = n_entiy_embeddings | |
| self.embedding_dim = embedding_dim | |
| self.max_norm = max_norm | |
| self.combine = combine | |
| # Compute sqrt factorization size | |
| self.sqrt_n = math.ceil(math.sqrt(n_entiy_embeddings)) | |
| # For 'concat', each sub-embedding is half the dimension | |
| # For 'sum', each sub-embedding is full dimension | |
| if combine == 'concat': | |
| assert embedding_dim % 2 == 0, "embedding_dim must be even for concat mode" | |
| sub_dim = embedding_dim // 2 | |
| else: | |
| sub_dim = embedding_dim | |
| self.sub_dim = sub_dim | |
| # Base embedding (quotient part): id // sqrt_n | |
| self.base_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm) | |
| init.orthogonal_(self.base_embedding.weight) | |
| self.base_embedding.weight.requires_grad = requires_grad | |
| # Offset embedding (remainder part): id % sqrt_n | |
| self.offset_embedding = nn.Embedding(self.sqrt_n, sub_dim, max_norm=max_norm) | |
| init.orthogonal_(self.offset_embedding.weight) | |
| self.offset_embedding.weight.requires_grad = requires_grad | |
| def forward(self, entities): | |
| """ | |
| Args: | |
| entities: (batch, num_entities) tensor of entity IDs | |
| Returns: | |
| embeddings: (batch, num_entities, embedding_dim) - normalized to unit norm | |
| """ | |
| # Decompose entity IDs into base and offset | |
| base_ids = entities // self.sqrt_n # (B, N) | |
| offset_ids = entities % self.sqrt_n # (B, N) | |
| # Look up embeddings | |
| base_emb = self.base_embedding(base_ids) # (B, N, sub_dim) | |
| offset_emb = self.offset_embedding(offset_ids) # (B, N, sub_dim) | |
| # Combine embeddings | |
| if self.combine == 'concat': | |
| combined = torch.cat([base_emb, offset_emb], dim=-1) # (B, N, embedding_dim) | |
| else: # sum | |
| combined = base_emb + offset_emb # (B, N, embedding_dim) | |
| return combined |