| from huggingface_inference_toolkit.logging import logger |
| import torch |
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| logger.info("Initializing Nucleotide Transformer...") |
| self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True) |
| self.model = AutoModelForMaskedLM.from_pretrained(self.model_id, trust_remote_code=True) |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model.to(self.device) |
| self.model.eval() |
| logger.info(f"Model loaded on {self.device}") |
|
|
| def _get_embedding(self, sequence): |
| inputs = self.tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(self.device) |
| with torch.no_grad(): |
| outputs = self.model(**inputs, output_hidden_states=True) |
| |
| return outputs.hidden_states[-1].mean(dim=1) |
|
|
| def map_sensitivity(self, sequence, window_size, step): |
| logger.info(f"Starting sensitivity map: window={window_size}, step={step}") |
| baseline = self._get_embedding(sequence) |
| |
| results = [] |
| |
| for i in range(0, len(sequence) - window_size, step or 1): |
| |
| |
| window = sequence[i : i + window_size] |
| mutated_seq = sequence[:i] + window[::-1] + sequence[i + window_size:] |
| |
| mutated_emb = self._get_embedding(mutated_seq) |
| |
| |
| dist = torch.norm(baseline - mutated_emb).item() |
| results.append({"coord": i, "score": dist}) |
| |
| return results |
|
|
| def __call__(self, data): |
| """ |
| The Toolkit calls this method. |
| 'data' is the dictionary from your payload. |
| """ |
| logger.info(f"Payload received: {data}") |
| |
| inputs = data.get("inputs", "") |
| method = data.get("method", "embed") |
| |
| window_size = int(data.get("window_size", 50)) |
| step = int(data.get("step", 20)) |
|
|
| if method == "sensitivity": |
| return self.map_sensitivity(inputs, window_size, step) |
| |
| |
| emb = self._get_embedding(inputs) |
| return emb.cpu().tolist()[0] |