"""Pooling strategies for Ogma.""" from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from .config import OgmaConfig, PoolingType __all__ = [ "create_pooling", "TaskTokenPooling", "LatentAttentionPooling", "MeanPooling", ] def create_pooling(config: OgmaConfig) -> nn.Module: """Factory for pooling layers.""" if config.pooling == PoolingType.TASK_TOKEN: return TaskTokenPooling() elif config.pooling == PoolingType.LATENT_ATTENTION: return LatentAttentionPooling(config.d_model) elif config.pooling == PoolingType.MEAN: return MeanPooling() raise ValueError(f"Unknown pooling type: {config.pooling}") class TaskTokenPooling(nn.Module): """Use the output at position 0 (task token) as the sentence embedding.""" def forward( self, x: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Extract task token output. Args: x: (B, S, D) sequence outputs. attention_mask: unused, for interface compatibility. Returns: (B, D) pooled output. """ return x[:, 0, :] class LatentAttentionPooling(nn.Module): """Learned query vector attends over all token outputs.""" def __init__(self, d_model: int) -> None: super().__init__() self.query = nn.Parameter(torch.randn(d_model)) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Attend over sequence with learned query. Args: x: (B, S, D) sequence outputs. attention_mask: (B, S) mask where 1=valid, 0=pad. Returns: (B, D) pooled output. """ # (B, S) scores = torch.matmul(x, self.query) / (x.shape[-1] ** 0.5) if attention_mask is not None: scores = scores.masked_fill(attention_mask == 0, float("-inf")) weights = F.softmax(scores, dim=-1) # (B, S) return torch.bmm(weights.unsqueeze(1), x).squeeze(1) # (B, D) class MeanPooling(nn.Module): """Average all token outputs (excluding padding).""" def forward( self, x: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Mean pool over valid tokens. Args: x: (B, S, D) sequence outputs. attention_mask: (B, S) mask where 1=valid, 0=pad. Returns: (B, D) pooled output. """ if attention_mask is None: return x.mean(dim=1) mask = attention_mask.unsqueeze(-1).float() # (B, S, 1) return (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)