import torch import torch.nn as nn from core.config import EMBEDDING_DIM, HIDDEN_DIM from core.device import DEVICE class SentenceEncoder(nn.Module): def __init__(self, vocab_size: int): super().__init__() self.embedding_layer = nn.Embedding( vocab_size, EMBEDDING_DIM ).to(DEVICE) self.projection = nn.Linear( EMBEDDING_DIM, HIDDEN_DIM ) self.activation = nn.ReLU() self.to(DEVICE) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): input_ids = input_ids.long().to(DEVICE) attention_mask = attention_mask.float().to(DEVICE) embeddings = self.embedding_layer(input_ids) # (B, T, EMBEDDING_DIM) # Correct broadcast mask mask = attention_mask.unsqueeze(-1) # (B, T, 1) masked_embeddings = embeddings * mask sum_embeddings = masked_embeddings.sum(dim=1) token_count = mask.sum(dim=1).clamp(min=1) pooled = sum_embeddings / token_count sentence_embedding = self.activation( self.projection(pooled) ) return sentence_embedding