| """Embedding layers for SwipeTransformer.""" |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class PathEmbedding(nn.Module): |
| """Embeds path coordinates (x, y, t) to d_model dimension.""" |
|
|
| def __init__(self, d_model: int = 256): |
| super().__init__() |
| self.projection = nn.Linear(3, d_model) |
|
|
| def forward(self, path_coords: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| path_coords: [batch, seq_len, 3] - (x, y, t) coordinates |
| |
| Returns: |
| [batch, seq_len, d_model] embeddings |
| """ |
| return self.projection(path_coords) |
|
|
|
|
| class CharacterEmbedding(nn.Module): |
| """Embeds character tokens.""" |
|
|
| def __init__(self, vocab_size: int, d_model: int = 256, padding_idx: int = 0): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) |
|
|
| def forward(self, char_tokens: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| char_tokens: [batch, seq_len] character token IDs |
| |
| Returns: |
| [batch, seq_len, d_model] embeddings |
| """ |
| return self.embedding(char_tokens) |
|
|
|
|
| class PositionalEmbedding(nn.Module): |
| """Learned positional embeddings.""" |
|
|
| def __init__(self, max_position: int, d_model: int = 256): |
| super().__init__() |
| self.embedding = nn.Embedding(max_position, d_model) |
|
|
| def forward(self, positions: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| positions: [batch, seq_len] position indices |
| |
| Returns: |
| [batch, seq_len, d_model] positional embeddings |
| """ |
| return self.embedding(positions) |
|
|
|
|
| class TypeEmbedding(nn.Module): |
| """Token type embeddings to distinguish PATH (0) vs TEXT (1) tokens.""" |
|
|
| def __init__(self, d_model: int = 256): |
| super().__init__() |
| |
| self.embedding = nn.Embedding(2, d_model) |
|
|
| def forward(self, token_types: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| token_types: [batch, seq_len] type indices (0 or 1) |
| |
| Returns: |
| [batch, seq_len, d_model] type embeddings |
| """ |
| return self.embedding(token_types) |
|
|
|
|
| class MixedEmbedding(nn.Module): |
| """ |
| Combines path and character embeddings with positional and type information. |
| Constructs sequence: [CLS] + path_tokens + [SEP] + char_tokens |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| max_path_len: int, |
| max_char_len: int, |
| d_model: int = 256, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.d_model = d_model |
|
|
| |
| self.path_embedding = PathEmbedding(d_model) |
| self.char_embedding = CharacterEmbedding(vocab_size, d_model, padding_idx=0) |
|
|
| |
| max_seq_len = 1 + max_path_len + 1 + max_char_len |
| self.positional_embedding = PositionalEmbedding(max_seq_len, d_model) |
|
|
| |
| self.type_embedding = TypeEmbedding(d_model) |
|
|
| |
| self.layer_norm = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward( |
| self, |
| path_coords: torch.Tensor, |
| char_tokens: torch.Tensor, |
| cls_token: torch.Tensor, |
| sep_token: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Create mixed sequence with embeddings. |
| |
| Args: |
| path_coords: [batch, path_len, 3] path coordinates |
| char_tokens: [batch, char_len] character token IDs |
| cls_token: [batch, 1] CLS token IDs |
| sep_token: [batch, 1] SEP token IDs |
| |
| Returns: |
| [batch, total_seq_len, d_model] embeddings where |
| total_seq_len = 1 + path_len + 1 + char_len |
| """ |
| batch_size = path_coords.shape[0] |
| path_len = path_coords.shape[1] |
| char_len = char_tokens.shape[1] |
| device = path_coords.device |
|
|
| |
| cls_emb = self.char_embedding(cls_token) |
|
|
| |
| path_emb = self.path_embedding(path_coords) |
|
|
| |
| sep_emb = self.char_embedding(sep_token) |
|
|
| |
| char_emb = self.char_embedding(char_tokens) |
|
|
| |
| sequence = torch.cat( |
| [cls_emb, path_emb, sep_emb, char_emb], dim=1 |
| ) |
| seq_len = sequence.shape[1] |
|
|
| |
| positions = ( |
| torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) |
| ) |
| pos_emb = self.positional_embedding(positions) |
|
|
| |
| |
| type_ids = torch.cat( |
| [ |
| torch.zeros( |
| batch_size, 1 + path_len + 1, dtype=torch.long, device=device |
| ), |
| torch.ones(batch_size, char_len, dtype=torch.long, device=device), |
| ], |
| dim=1, |
| ) |
| type_emb = self.type_embedding(type_ids) |
|
|
| |
| embeddings = sequence + pos_emb + type_emb |
| embeddings = self.layer_norm(embeddings) |
| embeddings = self.dropout(embeddings) |
|
|
| return embeddings |
|
|