File size: 2,727 Bytes
05261c9 82a5b91 92674ca 18bdcac 95c104f 05261c9 18bdcac 82a5b91 92674ca 40a50f2 92674ca 1214747 05261c9 95c104f 92674ca 05261c9 40ac285 05261c9 92674ca 05261c9 92674ca 05261c9 92674ca 05261c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | from huggingface_inference_toolkit.logging import logger
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
class EndpointHandler:
def __init__(self, path=""):
logger.info("Initializing Nucleotide Transformer...")
self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
self.model = AutoModelForMaskedLM.from_pretrained(self.model_id, trust_remote_code=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
logger.info(f"Model loaded on {self.device}")
def _get_embedding(self, sequence):
inputs = self.tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# mean of last hidden state -> [1, 512]
return outputs.hidden_states[-1].mean(dim=1)
def map_sensitivity(self, sequence, window_size, step):
logger.info(f"Starting sensitivity map: window={window_size}, step={step}")
baseline = self._get_embedding(sequence)
results = []
# Optimization: We can do this in a single list comprehension to keep it clean
for i in range(0, len(sequence) - window_size, step or 1):
# Create a 'mutated' version by reversing the window segment
# (Reversing is more deterministic and faster than random.shuffle for testing)
window = sequence[i : i + window_size]
mutated_seq = sequence[:i] + window[::-1] + sequence[i + window_size:]
mutated_emb = self._get_embedding(mutated_seq)
# Distance calculation
dist = torch.norm(baseline - mutated_emb).item()
results.append({"coord": i, "score": dist})
return results
def __call__(self, data):
"""
The Toolkit calls this method.
'data' is the dictionary from your payload.
"""
logger.info(f"Payload received: {data}")
inputs = data.get("inputs", "")
method = data.get("method", "embed")
# Ensure these are integers
window_size = int(data.get("window_size", 50))
step = int(data.get("step", 20)) # Default to 20 to match your test
if method == "sensitivity":
return self.map_sensitivity(inputs, window_size, step)
# Default embedding behavior
emb = self._get_embedding(inputs)
return emb.cpu().tolist()[0] |