| |
|
|
| import os |
| import torch |
| from torch import nn, Tensor |
| from transformers import AutoModel, AutoConfig |
| from huggingface_hub import snapshot_download |
| from typing import Dict |
|
|
|
|
| class BGEM3InferenceModel(nn.Module): |
| def __init__( |
| self, |
| model_name: str = "BAAI/bge-m3", |
| colbert_dim: int = -1, |
| ) -> None: |
| super().__init__() |
|
|
| model_name = snapshot_download( |
| repo_id=model_name, |
| allow_patterns=[ |
| "pytorch_model.bin", |
| "config.json", |
| ], |
| ) |
|
|
| self.config = AutoConfig.from_pretrained(model_name) |
| self.model = AutoModel.from_pretrained(model_name) |
|
|
| def dense_embedding(self, last_hidden_state: Tensor) -> Tensor: |
| return last_hidden_state[:, 0] |
|
|
| def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]: |
| with torch.no_grad(): |
| last_hidden_state = self.model( |
| input_ids=input_ids, attention_mask=attention_mask, return_dict=True |
| ).last_hidden_state |
|
|
| output = {} |
| dense_vecs = self.dense_embedding(last_hidden_state) |
| output["dense_vecs"] = dense_vecs |
|
|
| return output |
|
|