Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from core.config import EMBEDDING_DIM, HIDDEN_DIM | |
| from core.device import DEVICE | |
| class SentenceEncoder(nn.Module): | |
| def __init__(self, vocab_size: int): | |
| super().__init__() | |
| self.embedding_layer = nn.Embedding( | |
| vocab_size, | |
| EMBEDDING_DIM | |
| ).to(DEVICE) | |
| self.projection = nn.Linear( | |
| EMBEDDING_DIM, | |
| HIDDEN_DIM | |
| ) | |
| self.activation = nn.ReLU() | |
| self.to(DEVICE) | |
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): | |
| input_ids = input_ids.long().to(DEVICE) | |
| attention_mask = attention_mask.float().to(DEVICE) | |
| embeddings = self.embedding_layer(input_ids) # (B, T, EMBEDDING_DIM) | |
| # Correct broadcast mask | |
| mask = attention_mask.unsqueeze(-1) # (B, T, 1) | |
| masked_embeddings = embeddings * mask | |
| sum_embeddings = masked_embeddings.sum(dim=1) | |
| token_count = mask.sum(dim=1).clamp(min=1) | |
| pooled = sum_embeddings / token_count | |
| sentence_embedding = self.activation( | |
| self.projection(pooled) | |
| ) | |
| return sentence_embedding |