| |
|
|
| import os |
| import torch |
| from torch import nn, Tensor |
| from transformers import AutoModel, AutoConfig |
| from huggingface_hub import snapshot_download |
| from typing import Dict |
|
|
|
|
| class BGEM3InferenceModel(nn.Module): |
| def __init__( |
| self, |
| model_name: str = "BAAI/bge-m3", |
| colbert_dim: int = -1, |
| ) -> None: |
| super().__init__() |
|
|
| model_name = snapshot_download( |
| repo_id=model_name, |
| allow_patterns=[ |
| "model.safetensors", |
| "colbert_linear.pt", |
| "sparse_linear.pt", |
| "config.json", |
| ], |
| ) |
|
|
| self.config = AutoConfig.from_pretrained(model_name) |
| self.model = AutoModel.from_pretrained(model_name) |
| self.colbert_linear = torch.nn.Linear( |
| in_features=self.model.config.hidden_size, |
| out_features=( |
| self.model.config.hidden_size if colbert_dim == -1 else colbert_dim |
| ), |
| ) |
| self.sparse_linear = torch.nn.Linear( |
| in_features=self.model.config.hidden_size, out_features=1 |
| ) |
| colbert_state_dict = torch.load( |
| os.path.join(model_name, "colbert_linear.pt"), map_location="cpu" |
| ) |
| sparse_state_dict = torch.load( |
| os.path.join(model_name, "sparse_linear.pt"), map_location="cpu" |
| ) |
| self.colbert_linear.load_state_dict(colbert_state_dict) |
| self.sparse_linear.load_state_dict(sparse_state_dict) |
|
|
| def dense_embedding(self, last_hidden_state: Tensor) -> Tensor: |
| return last_hidden_state[:, 0] |
|
|
| def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor: |
| with torch.no_grad(): |
| return torch.relu(self.sparse_linear(last_hidden_state)) |
|
|
| def colbert_embedding( |
| self, last_hidden_state: Tensor, attention_mask: Tensor |
| ) -> Tensor: |
| with torch.no_grad(): |
| colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:]) |
| colbert_vecs = colbert_vecs * attention_mask[:, 1:][:, :, None].float() |
| return colbert_vecs |
|
|
| def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]: |
| with torch.no_grad(): |
| last_hidden_state = self.model( |
| input_ids=input_ids, attention_mask=attention_mask, return_dict=True |
| ).last_hidden_state |
|
|
| output = {} |
| dense_vecs = self.dense_embedding(last_hidden_state) |
| output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1) |
|
|
| sparse_vecs = self.sparse_embedding(last_hidden_state) |
| output["sparse_vecs"] = sparse_vecs |
|
|
| colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask) |
| output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1) |
|
|
| return output |
|
|