File size: 2,637 Bytes
083b138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Structured-feature embedding layer for transaction sequences.

Replaces the single text-token embedding table of public LFM2.5 with
per-feature value tables plus a feature-type table, summed. Each feature
has its own embedding table sized to that feature's vocabulary. A shared
feature-type table (15 rows) tells the model which feature a token represents.

Input:  (B, T, F) int tensor of token IDs, where T=64 transactions, F=15 features
Output: (B, T*F, D) float tensor of embeddings, where D=hidden_dim
"""

import torch
import torch.nn as nn

from src.data.schema import SchemaConfig


class StructuredEmbedding(nn.Module):
    """Per-feature value embeddings + feature-type embeddings, summed.

    The value_tables are exposed as a ModuleList so the per-feature LM heads
    can tie weights to them. The type_table is NOT tied to anything.
    """

    def __init__(self, schema: SchemaConfig, hidden_dim: int) -> None:
        super().__init__()
        self.num_features = schema.num_features
        self.num_transactions = schema.num_transactions
        self.hidden_dim = hidden_dim

        self.value_tables = nn.ModuleList([
            nn.Embedding(feature.vocab_size, hidden_dim)
            for feature in schema.features
        ])

        self.type_table = nn.Embedding(schema.num_features, hidden_dim)

        self._vocab_sizes = [f.vocab_size for f in schema.features]

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        """Embed structured token IDs into a flat sequence.

        Args:
            token_ids: (B, T, F) int tensor. T=num_transactions, F=num_features.
                       Each value must be in [0, vocab_size) for its feature.

        Returns:
            (B, T*F, D) float tensor of summed value + type embeddings.
        """
        B, T, F = token_ids.shape
        assert F == self.num_features, (
            f"Expected {self.num_features} features, got {F}"
        )

        # Feature type indices: [0, 1, 2, ..., F-1], broadcast across batch and time
        type_indices = torch.arange(F, device=token_ids.device)  # (F,)
        type_emb = self.type_table(type_indices)  # (F, D)

        feature_embeddings = []
        for f_idx in range(F):
            feat_tokens = token_ids[:, :, f_idx]  # (B, T)
            val_emb = self.value_tables[f_idx](feat_tokens)  # (B, T, D)
            feature_embeddings.append(val_emb + type_emb[f_idx])  # (B, T, D) + (D,) broadcast

        # (B, T, F, D) -> (B, T*F, D)
        stacked = torch.stack(feature_embeddings, dim=2)  # (B, T, F, D)
        return stacked.reshape(B, T * F, self.hidden_dim)  # (B, 960, D)