| import torch |
| import torch.nn as nn |
| from transformers import AutoModel, AutoConfig |
|
|
| class CodeEmbedder(nn.Module): |
| """ |
| A wrapper around a Transformer model (default: CodeBERT) to produce |
| dense vector embeddings for code snippets using Mean Pooling. |
| """ |
| def __init__(self, model_name_or_path="microsoft/codebert-base", trust_remote_code=False): |
| super(CodeEmbedder, self).__init__() |
| self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code) |
| self.encoder = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=trust_remote_code) |
| |
| def mean_pooling(self, token_embeddings, attention_mask): |
| """ |
| Average the token embeddings, ignoring padding tokens. |
| """ |
| |
| |
| |
| |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| |
| |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) |
| |
| |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
| |
| return sum_embeddings / sum_mask |
|
|
| def forward(self, input_ids, attention_mask): |
| |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| |
| |
| |
| last_hidden_state = outputs.last_hidden_state |
| |
| |
| embeddings = self.mean_pooling(last_hidden_state, attention_mask) |
| |
| |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
| |
| return embeddings |
|
|