from huggingface_inference_toolkit.logging import logger import torch from transformers import AutoTokenizer, AutoModelForMaskedLM class EndpointHandler: def __init__(self, path=""): logger.info("Initializing Nucleotide Transformer...") self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True) self.model = AutoModelForMaskedLM.from_pretrained(self.model_id, trust_remote_code=True) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.model.eval() logger.info(f"Model loaded on {self.device}") def _get_embedding(self, sequence): inputs = self.tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(self.device) with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) # mean of last hidden state -> [1, 512] return outputs.hidden_states[-1].mean(dim=1) def map_sensitivity(self, sequence, window_size, step): logger.info(f"Starting sensitivity map: window={window_size}, step={step}") baseline = self._get_embedding(sequence) results = [] # Optimization: We can do this in a single list comprehension to keep it clean for i in range(0, len(sequence) - window_size, step or 1): # Create a 'mutated' version by reversing the window segment # (Reversing is more deterministic and faster than random.shuffle for testing) window = sequence[i : i + window_size] mutated_seq = sequence[:i] + window[::-1] + sequence[i + window_size:] mutated_emb = self._get_embedding(mutated_seq) # Distance calculation dist = torch.norm(baseline - mutated_emb).item() results.append({"coord": i, "score": dist}) return results def __call__(self, data): """ The Toolkit calls this method. 'data' is the dictionary from your payload. """ logger.info(f"Payload received: {data}") inputs = data.get("inputs", "") method = data.get("method", "embed") # Ensure these are integers window_size = int(data.get("window_size", 50)) step = int(data.get("step", 20)) # Default to 20 to match your test if method == "sensitivity": return self.map_sensitivity(inputs, window_size, step) # Default embedding behavior emb = self._get_embedding(inputs) return emb.cpu().tolist()[0]