Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel, AutoConfig | |
| class CodeEmbedder(nn.Module): | |
| """ | |
| A wrapper around a Transformer model (default: CodeBERT) to produce | |
| dense vector embeddings for code snippets using Mean Pooling. | |
| """ | |
| def __init__(self, model_name_or_path="microsoft/codebert-base", trust_remote_code=False): | |
| super(CodeEmbedder, self).__init__() | |
| self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) | |
| self.encoder = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=trust_remote_code) | |
| def mean_pooling(self, token_embeddings, attention_mask): | |
| """ | |
| Average the token embeddings, ignoring padding tokens. | |
| """ | |
| # attention_mask: (batch_size, seq_len) | |
| # token_embeddings: (batch_size, seq_len, hidden_dim) | |
| # Expand mask to match embedding dimensions | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| # Sum embeddings (ignoring padding) | |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
| # Count non-padding tokens (prevent division by zero with clamp) | |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| return sum_embeddings / sum_mask | |
| def forward(self, input_ids, attention_mask): | |
| # Pass through the transformer | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| # Extract last hidden state | |
| # Shape: (batch_size, seq_len, hidden_dim) | |
| last_hidden_state = outputs.last_hidden_state | |
| # Perform Mean Pooling (Better than CLS token for sentence similarity) | |
| embeddings = self.mean_pooling(last_hidden_state, attention_mask) | |
| # Normalize embeddings (Optional but recommended for cosine similarity) | |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
| return embeddings | |