yaml-bert / yaml_bert /embedding.py
vimalk78's picture
feat(v9): sub-tokenization — [UNK] collisions fixed + namespace probe passes + apiVersion probe added
3457b3c verified
Raw
History Blame Contribute Delete
2.68 kB
from __future__ import annotations
import torch
import torch.nn as nn
from yaml_bert.config import TreePosVariant, YamlBertConfig
class YamlBertEmbedding(nn.Module):
"""v9 embedding layer: single subword table + tree positional encoding.
Produces input vectors by summing:
- Subword embedding (looked up by token_id; same table for KEY and VALUE
positions — what they ARE is signalled separately via node_type_emb)
- Tree positional encoding (composition depends on config.tree_pos_variant)
"""
def __init__(
self,
config: YamlBertConfig,
subword_vocab_size: int,
) -> None:
super().__init__()
d: int = config.d_model
variant: TreePosVariant = config.tree_pos_variant
self.variant: TreePosVariant = variant
self.subword_embedding: nn.Embedding = nn.Embedding(subword_vocab_size, d)
self.node_type_embedding: nn.Embedding = nn.Embedding(4, d)
use_depth: bool = variant in (TreePosVariant.FULL, TreePosVariant.NO_SIBLING)
use_sibling: bool = variant in (TreePosVariant.FULL, TreePosVariant.NO_DEPTH)
use_seq_pos: bool = variant == TreePosVariant.SEQUENTIAL
self.depth_embedding: nn.Embedding | None = (
nn.Embedding(config.max_depth, d) if use_depth else None
)
self.sibling_embedding: nn.Embedding | None = (
nn.Embedding(config.max_sibling, d) if use_sibling else None
)
self.pos_embedding: nn.Embedding | None = (
nn.Embedding(config.max_seq_len, d) if use_seq_pos else None
)
self.layer_norm: nn.LayerNorm = nn.LayerNorm(d)
def forward(
self,
token_ids: torch.Tensor,
node_types: torch.Tensor,
depths: torch.Tensor,
sibling_indices: torch.Tensor,
) -> torch.Tensor:
token_emb = self.subword_embedding(token_ids)
tree_pos = self.node_type_embedding(node_types)
if self.depth_embedding is not None:
tree_pos = tree_pos + self.depth_embedding(depths)
if self.sibling_embedding is not None:
tree_pos = tree_pos + self.sibling_embedding(sibling_indices)
if self.pos_embedding is not None:
seq_len: int = token_ids.size(1)
max_pos: int = self.pos_embedding.num_embeddings
positions = (
torch.arange(seq_len, device=token_ids.device)
.clamp(max=max_pos - 1)
.unsqueeze(0)
.expand(token_ids.size(0), seq_len)
)
tree_pos = tree_pos + self.pos_embedding(positions)
return self.layer_norm(token_emb + tree_pos)