mvi-ai-engine / language /encoder.py
Musombi's picture
Update language/encoder.py
55bc0b0 verified
import torch
import torch.nn as nn
from core.config import EMBEDDING_DIM, HIDDEN_DIM
from core.device import DEVICE
class SentenceEncoder(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.embedding_layer = nn.Embedding(
vocab_size,
EMBEDDING_DIM
).to(DEVICE)
self.projection = nn.Linear(
EMBEDDING_DIM,
HIDDEN_DIM
)
self.activation = nn.ReLU()
self.to(DEVICE)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
input_ids = input_ids.long().to(DEVICE)
attention_mask = attention_mask.float().to(DEVICE)
embeddings = self.embedding_layer(input_ids) # (B, T, EMBEDDING_DIM)
# Correct broadcast mask
mask = attention_mask.unsqueeze(-1) # (B, T, 1)
masked_embeddings = embeddings * mask
sum_embeddings = masked_embeddings.sum(dim=1)
token_count = mask.sum(dim=1).clamp(min=1)
pooled = sum_embeddings / token_count
sentence_embedding = self.activation(
self.projection(pooled)
)
return sentence_embedding