File size: 2,727 Bytes
05261c9
82a5b91
92674ca
18bdcac
95c104f
 
05261c9
18bdcac
82a5b91
92674ca
40a50f2
92674ca
 
1214747
05261c9
95c104f
92674ca
 
 
 
05261c9
 
40ac285
05261c9
 
 
92674ca
05261c9
 
 
 
 
 
 
92674ca
05261c9
92674ca
05261c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from huggingface_inference_toolkit.logging import logger
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

class EndpointHandler:
    def __init__(self, path=""):
        logger.info("Initializing Nucleotide Transformer...")
        self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
        self.model = AutoModelForMaskedLM.from_pretrained(self.model_id, trust_remote_code=True)
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
        logger.info(f"Model loaded on {self.device}")

    def _get_embedding(self, sequence):
        inputs = self.tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            # mean of last hidden state -> [1, 512]
            return outputs.hidden_states[-1].mean(dim=1)

    def map_sensitivity(self, sequence, window_size, step):
        logger.info(f"Starting sensitivity map: window={window_size}, step={step}")
        baseline = self._get_embedding(sequence)
        
        results = []
        # Optimization: We can do this in a single list comprehension to keep it clean
        for i in range(0, len(sequence) - window_size, step or 1):
            # Create a 'mutated' version by reversing the window segment
            # (Reversing is more deterministic and faster than random.shuffle for testing)
            window = sequence[i : i + window_size]
            mutated_seq = sequence[:i] + window[::-1] + sequence[i + window_size:]
            
            mutated_emb = self._get_embedding(mutated_seq)
            
            # Distance calculation
            dist = torch.norm(baseline - mutated_emb).item()
            results.append({"coord": i, "score": dist})
            
        return results

    def __call__(self, data):
        """
        The Toolkit calls this method. 
        'data' is the dictionary from your payload.
        """
        logger.info(f"Payload received: {data}")
        
        inputs = data.get("inputs", "")
        method = data.get("method", "embed")
        # Ensure these are integers
        window_size = int(data.get("window_size", 50))
        step = int(data.get("step", 20)) # Default to 20 to match your test

        if method == "sensitivity":
            return self.map_sensitivity(inputs, window_size, step)
        
        # Default embedding behavior
        emb = self._get_embedding(inputs)
        return emb.cpu().tolist()[0]