"""Structured-feature embedding layer for transaction sequences. Replaces the single text-token embedding table of public LFM2.5 with per-feature value tables plus a feature-type table, summed. Each feature has its own embedding table sized to that feature's vocabulary. A shared feature-type table (15 rows) tells the model which feature a token represents. Input: (B, T, F) int tensor of token IDs, where T=64 transactions, F=15 features Output: (B, T*F, D) float tensor of embeddings, where D=hidden_dim """ import torch import torch.nn as nn from src.data.schema import SchemaConfig class StructuredEmbedding(nn.Module): """Per-feature value embeddings + feature-type embeddings, summed. The value_tables are exposed as a ModuleList so the per-feature LM heads can tie weights to them. The type_table is NOT tied to anything. """ def __init__(self, schema: SchemaConfig, hidden_dim: int) -> None: super().__init__() self.num_features = schema.num_features self.num_transactions = schema.num_transactions self.hidden_dim = hidden_dim self.value_tables = nn.ModuleList([ nn.Embedding(feature.vocab_size, hidden_dim) for feature in schema.features ]) self.type_table = nn.Embedding(schema.num_features, hidden_dim) self._vocab_sizes = [f.vocab_size for f in schema.features] def forward(self, token_ids: torch.Tensor) -> torch.Tensor: """Embed structured token IDs into a flat sequence. Args: token_ids: (B, T, F) int tensor. T=num_transactions, F=num_features. Each value must be in [0, vocab_size) for its feature. Returns: (B, T*F, D) float tensor of summed value + type embeddings. """ B, T, F = token_ids.shape assert F == self.num_features, ( f"Expected {self.num_features} features, got {F}" ) # Feature type indices: [0, 1, 2, ..., F-1], broadcast across batch and time type_indices = torch.arange(F, device=token_ids.device) # (F,) type_emb = self.type_table(type_indices) # (F, D) feature_embeddings = [] for f_idx in range(F): feat_tokens = token_ids[:, :, f_idx] # (B, T) val_emb = self.value_tables[f_idx](feat_tokens) # (B, T, D) feature_embeddings.append(val_emb + type_emb[f_idx]) # (B, T, D) + (D,) broadcast # (B, T, F, D) -> (B, T*F, D) stacked = torch.stack(feature_embeddings, dim=2) # (B, T, F, D) return stacked.reshape(B, T * F, self.hidden_dim) # (B, 960, D)