Spaces:
Running
Running
File size: 1,195 Bytes
0045f6d 55bc0b0 0045f6d 55bc0b0 0045f6d 55bc0b0 0045f6d 55bc0b0 0045f6d 55bc0b0 0045f6d 55bc0b0 0045f6d 55bc0b0 0045f6d 55bc0b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import torch
import torch.nn as nn
from core.config import EMBEDDING_DIM, HIDDEN_DIM
from core.device import DEVICE
class SentenceEncoder(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.embedding_layer = nn.Embedding(
vocab_size,
EMBEDDING_DIM
).to(DEVICE)
self.projection = nn.Linear(
EMBEDDING_DIM,
HIDDEN_DIM
)
self.activation = nn.ReLU()
self.to(DEVICE)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
input_ids = input_ids.long().to(DEVICE)
attention_mask = attention_mask.float().to(DEVICE)
embeddings = self.embedding_layer(input_ids) # (B, T, EMBEDDING_DIM)
# Correct broadcast mask
mask = attention_mask.unsqueeze(-1) # (B, T, 1)
masked_embeddings = embeddings * mask
sum_embeddings = masked_embeddings.sum(dim=1)
token_count = mask.sum(dim=1).clamp(min=1)
pooled = sum_embeddings / token_count
sentence_embedding = self.activation(
self.projection(pooled)
)
return sentence_embedding |