| from typing import Any, Dict, List |
| from transformers import AutoTokenizer, AutoModel |
| import torch |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| self.model_id = "zhihan1996/DNABERT-2-117M" |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True) |
| self.model = AutoModel.from_pretrained(self.model_id, trust_remote_code=True) |
| if torch.cuda.is_available(): |
| self.model = self.model.to("cuda") |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[float]: |
| inputs = data.pop("inputs", data) |
| |
| |
| encoded_input = self.tokenizer(inputs, return_tensors='pt') |
| if torch.cuda.is_available(): |
| encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model(**encoded_input) |
| |
| |
| embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist() |
| |
| return embeddings |