Chaman1234's picture
Upload cat_v3/encoder.py with huggingface_hub
bc5e3b9 verified
Raw
History Blame Contribute Delete
2.11 kB
"""Tiny Encoder module for CAT V3."""
from __future__ import annotations
import torch
import torch.nn as nn
class TinyEncoder(nn.Module):
"""A self-contained lightweight transformer encoder for processing text queries."""
def __init__(
self,
vocab_size: int,
hidden_size: int = 128,
nhead: int = 4,
num_layers: int = 2,
max_seq_len: int = 128,
) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, hidden_size))
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=nhead,
dim_feedforward=hidden_size * 2,
batch_first=True,
norm_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.hidden_size = hidden_size
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Encode token sequence to a dense vector.
Args:
input_ids: [batch_size, seq_len] tensor of token IDs
attention_mask: [batch_size, seq_len] mask tensor (1 for active, 0 for pad)
Returns:
[batch_size, hidden_size] dense sentence representation
"""
seq_len = input_ids.size(1)
# Add word embeddings and positional encodings
x = self.embedding(input_ids) + self.pos_embedding[:, :seq_len]
# PyTorch src_key_padding_mask requires True on padding positions
padding_mask = (attention_mask == 0)
# Run Transformer
out = self.transformer(x, src_key_padding_mask=padding_mask)
# Perform mean pooling over non-padded tokens
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_embeddings = torch.sum(out * mask_expanded, dim=1)
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
pooled = sum_embeddings / sum_mask
return pooled