| """Structured-feature embedding layer for transaction sequences. |
| |
| Replaces the single text-token embedding table of public LFM2.5 with |
| per-feature value tables plus a feature-type table, summed. Each feature |
| has its own embedding table sized to that feature's vocabulary. A shared |
| feature-type table (15 rows) tells the model which feature a token represents. |
| |
| Input: (B, T, F) int tensor of token IDs, where T=64 transactions, F=15 features |
| Output: (B, T*F, D) float tensor of embeddings, where D=hidden_dim |
| """ |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from src.data.schema import SchemaConfig |
|
|
|
|
| class StructuredEmbedding(nn.Module): |
| """Per-feature value embeddings + feature-type embeddings, summed. |
| |
| The value_tables are exposed as a ModuleList so the per-feature LM heads |
| can tie weights to them. The type_table is NOT tied to anything. |
| """ |
|
|
| def __init__(self, schema: SchemaConfig, hidden_dim: int) -> None: |
| super().__init__() |
| self.num_features = schema.num_features |
| self.num_transactions = schema.num_transactions |
| self.hidden_dim = hidden_dim |
|
|
| self.value_tables = nn.ModuleList([ |
| nn.Embedding(feature.vocab_size, hidden_dim) |
| for feature in schema.features |
| ]) |
|
|
| self.type_table = nn.Embedding(schema.num_features, hidden_dim) |
|
|
| self._vocab_sizes = [f.vocab_size for f in schema.features] |
|
|
| def forward(self, token_ids: torch.Tensor) -> torch.Tensor: |
| """Embed structured token IDs into a flat sequence. |
| |
| Args: |
| token_ids: (B, T, F) int tensor. T=num_transactions, F=num_features. |
| Each value must be in [0, vocab_size) for its feature. |
| |
| Returns: |
| (B, T*F, D) float tensor of summed value + type embeddings. |
| """ |
| B, T, F = token_ids.shape |
| assert F == self.num_features, ( |
| f"Expected {self.num_features} features, got {F}" |
| ) |
|
|
| |
| type_indices = torch.arange(F, device=token_ids.device) |
| type_emb = self.type_table(type_indices) |
|
|
| feature_embeddings = [] |
| for f_idx in range(F): |
| feat_tokens = token_ids[:, :, f_idx] |
| val_emb = self.value_tables[f_idx](feat_tokens) |
| feature_embeddings.append(val_emb + type_emb[f_idx]) |
|
|
| |
| stacked = torch.stack(feature_embeddings, dim=2) |
| return stacked.reshape(B, T * F, self.hidden_dim) |
|
|