from __future__ import annotations import torch import torch.nn as nn from yaml_bert.config import TreePosVariant, YamlBertConfig class YamlBertEmbedding(nn.Module): """v9 embedding layer: single subword table + tree positional encoding. Produces input vectors by summing: - Subword embedding (looked up by token_id; same table for KEY and VALUE positions — what they ARE is signalled separately via node_type_emb) - Tree positional encoding (composition depends on config.tree_pos_variant) """ def __init__( self, config: YamlBertConfig, subword_vocab_size: int, ) -> None: super().__init__() d: int = config.d_model variant: TreePosVariant = config.tree_pos_variant self.variant: TreePosVariant = variant self.subword_embedding: nn.Embedding = nn.Embedding(subword_vocab_size, d) self.node_type_embedding: nn.Embedding = nn.Embedding(4, d) use_depth: bool = variant in (TreePosVariant.FULL, TreePosVariant.NO_SIBLING) use_sibling: bool = variant in (TreePosVariant.FULL, TreePosVariant.NO_DEPTH) use_seq_pos: bool = variant == TreePosVariant.SEQUENTIAL self.depth_embedding: nn.Embedding | None = ( nn.Embedding(config.max_depth, d) if use_depth else None ) self.sibling_embedding: nn.Embedding | None = ( nn.Embedding(config.max_sibling, d) if use_sibling else None ) self.pos_embedding: nn.Embedding | None = ( nn.Embedding(config.max_seq_len, d) if use_seq_pos else None ) self.layer_norm: nn.LayerNorm = nn.LayerNorm(d) def forward( self, token_ids: torch.Tensor, node_types: torch.Tensor, depths: torch.Tensor, sibling_indices: torch.Tensor, ) -> torch.Tensor: token_emb = self.subword_embedding(token_ids) tree_pos = self.node_type_embedding(node_types) if self.depth_embedding is not None: tree_pos = tree_pos + self.depth_embedding(depths) if self.sibling_embedding is not None: tree_pos = tree_pos + self.sibling_embedding(sibling_indices) if self.pos_embedding is not None: seq_len: int = token_ids.size(1) max_pos: int = self.pos_embedding.num_embeddings positions = ( torch.arange(seq_len, device=token_ids.device) .clamp(max=max_pos - 1) .unsqueeze(0) .expand(token_ids.size(0), seq_len) ) tree_pos = tree_pos + self.pos_embedding(positions) return self.layer_norm(token_emb + tree_pos)