ogma-micro / pooling.py
Antreas's picture
Fix: relative imports for Hub loading, corrected model cards with accurate numbers and usage examples
ac59af7 verified
"""Pooling strategies for Ogma."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import OgmaConfig, PoolingType
__all__ = [
"create_pooling",
"TaskTokenPooling",
"LatentAttentionPooling",
"MeanPooling",
]
def create_pooling(config: OgmaConfig) -> nn.Module:
"""Factory for pooling layers."""
if config.pooling == PoolingType.TASK_TOKEN:
return TaskTokenPooling()
elif config.pooling == PoolingType.LATENT_ATTENTION:
return LatentAttentionPooling(config.d_model)
elif config.pooling == PoolingType.MEAN:
return MeanPooling()
raise ValueError(f"Unknown pooling type: {config.pooling}")
class TaskTokenPooling(nn.Module):
"""Use the output at position 0 (task token) as the sentence embedding."""
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Extract task token output.
Args:
x: (B, S, D) sequence outputs.
attention_mask: unused, for interface compatibility.
Returns:
(B, D) pooled output.
"""
return x[:, 0, :]
class LatentAttentionPooling(nn.Module):
"""Learned query vector attends over all token outputs."""
def __init__(self, d_model: int) -> None:
super().__init__()
self.query = nn.Parameter(torch.randn(d_model))
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Attend over sequence with learned query.
Args:
x: (B, S, D) sequence outputs.
attention_mask: (B, S) mask where 1=valid, 0=pad.
Returns:
(B, D) pooled output.
"""
# (B, S)
scores = torch.matmul(x, self.query) / (x.shape[-1] ** 0.5)
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1) # (B, S)
return torch.bmm(weights.unsqueeze(1), x).squeeze(1) # (B, D)
class MeanPooling(nn.Module):
"""Average all token outputs (excluding padding)."""
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Mean pool over valid tokens.
Args:
x: (B, S, D) sequence outputs.
attention_mask: (B, S) mask where 1=valid, 0=pad.
Returns:
(B, D) pooled output.
"""
if attention_mask is None:
return x.mean(dim=1)
mask = attention_mask.unsqueeze(-1).float() # (B, S, 1)
return (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)