cdotsanghvi's picture
add multi-head demo as 4th-6th tabs; restore Why Liquid + Integration
083b138
Raw
History Blame Contribute Delete
2.64 kB
"""Structured-feature embedding layer for transaction sequences.
Replaces the single text-token embedding table of public LFM2.5 with
per-feature value tables plus a feature-type table, summed. Each feature
has its own embedding table sized to that feature's vocabulary. A shared
feature-type table (15 rows) tells the model which feature a token represents.
Input: (B, T, F) int tensor of token IDs, where T=64 transactions, F=15 features
Output: (B, T*F, D) float tensor of embeddings, where D=hidden_dim
"""
import torch
import torch.nn as nn
from src.data.schema import SchemaConfig
class StructuredEmbedding(nn.Module):
"""Per-feature value embeddings + feature-type embeddings, summed.
The value_tables are exposed as a ModuleList so the per-feature LM heads
can tie weights to them. The type_table is NOT tied to anything.
"""
def __init__(self, schema: SchemaConfig, hidden_dim: int) -> None:
super().__init__()
self.num_features = schema.num_features
self.num_transactions = schema.num_transactions
self.hidden_dim = hidden_dim
self.value_tables = nn.ModuleList([
nn.Embedding(feature.vocab_size, hidden_dim)
for feature in schema.features
])
self.type_table = nn.Embedding(schema.num_features, hidden_dim)
self._vocab_sizes = [f.vocab_size for f in schema.features]
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""Embed structured token IDs into a flat sequence.
Args:
token_ids: (B, T, F) int tensor. T=num_transactions, F=num_features.
Each value must be in [0, vocab_size) for its feature.
Returns:
(B, T*F, D) float tensor of summed value + type embeddings.
"""
B, T, F = token_ids.shape
assert F == self.num_features, (
f"Expected {self.num_features} features, got {F}"
)
# Feature type indices: [0, 1, 2, ..., F-1], broadcast across batch and time
type_indices = torch.arange(F, device=token_ids.device) # (F,)
type_emb = self.type_table(type_indices) # (F, D)
feature_embeddings = []
for f_idx in range(F):
feat_tokens = token_ids[:, :, f_idx] # (B, T)
val_emb = self.value_tables[f_idx](feat_tokens) # (B, T, D)
feature_embeddings.append(val_emb + type_emb[f_idx]) # (B, T, D) + (D,) broadcast
# (B, T, F, D) -> (B, T*F, D)
stacked = torch.stack(feature_embeddings, dim=2) # (B, T, F, D)
return stacked.reshape(B, T * F, self.hidden_dim) # (B, 960, D)