brocks1234 commited on
Commit
92674ca
·
verified ·
1 Parent(s): 36a90d6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +73 -34
handler.py CHANGED
@@ -1,47 +1,86 @@
1
- import os
2
  import torch
3
- from typing import Any, Dict, List
4
- from transformers import AutoConfig, AutoTokenizer, AutoModelForMaskedLM
5
-
6
- # Force the trust flag globally
7
- os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "True"
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- # We ignore the local 'path' and pull fresh from the source
12
  self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
13
-
14
- self.config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
15
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
16
- self.model = AutoModelForMaskedLM.from_pretrained(
17
- self.model_id,
18
- config=self.config,
19
- trust_remote_code=True
20
- )
21
 
22
- if torch.cuda.is_available():
23
- self.model = self.model.to("cuda")
24
  self.model.eval()
25
 
26
- def __call__(self, data: Dict[str, Any]) -> List[float]:
27
- inputs = data.pop("inputs", data)
28
- if isinstance(inputs, list):
29
- inputs = inputs[0]
 
 
 
 
30
 
31
- # 12.2kb APRIL promoter chunking
32
- chunk_size = 1000
33
- stride = 500
34
- chunks = [inputs[i:i + chunk_size] for i in range(0, len(inputs), stride)]
35
 
36
- all_embeddings = []
37
- with torch.no_grad():
38
- for chunk in chunks:
39
- tokens = self.tokenizer(chunk, return_tensors='pt', padding=True, truncation=True, max_length=chunk_size)
40
- if torch.cuda.is_available():
41
- tokens = {k: v.to("cuda") for k, v in tokens.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- outputs = self.model(**tokens, output_hidden_states=True)
44
- chunk_emb = torch.mean(outputs.hidden_states[-1], dim=1).squeeze()
45
- all_embeddings.append(chunk_emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- return torch.stack(all_embeddings).mean(dim=0).cpu().numpy().tolist()
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import numpy as np
 
 
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ # Load model and tokenizer
8
  self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
 
 
9
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
10
+ self.model = AutoModelForMaskedLM.from_pretrained(self.model_id, trust_remote_code=True)
 
 
 
 
11
 
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model.to(self.device)
14
  self.model.eval()
15
 
16
+ def _get_embedding(self, sequence):
17
+ """Helper to get a single mean embedding."""
18
+ inputs = self.tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
19
+ with torch.no_grad():
20
+ outputs = self.model(**inputs, output_hidden_states=True)
21
+ # Use mean of the last hidden state
22
+ embeddings = outputs.hidden_states[-1].mean(dim=1)
23
+ return embeddings
24
 
25
+ def map_sensitivity(self, sequence, window_size=50, step=100):
26
+ """Generates a sensitivity map by perturbing segments of the sequence."""
27
+ # 1. Get Baseline
28
+ baseline_embedding = self._get_embedding(sequence)
29
 
30
+ # 2. Create variants
31
+ variants = []
32
+ indices = []
33
+ seq_list = list(sequence)
34
+
35
+ for i in range(0, len(sequence) - window_size, step):
36
+ # Create a "shuffled" variant of the window
37
+ variant_seq = seq_list.copy()
38
+ sub_seq = variant_seq[i : i + window_size]
39
+ import random
40
+ random.shuffle(sub_seq)
41
+ variant_seq[i : i + window_size] = sub_seq
42
+
43
+ variants.append("".join(variant_seq))
44
+ indices.append(i)
45
+
46
+ # 3. Batch Inference (Processing variants in chunks to fit in VRAM)
47
+ batch_size = 16
48
+ distances = []
49
+
50
+ for k in range(0, len(variants), batch_size):
51
+ batch_texts = variants[k : k + batch_size]
52
+ inputs = self.tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(self.device)
53
+
54
+ with torch.no_grad():
55
+ outputs = self.model(**inputs, output_hidden_states=True)
56
+ batch_embeddings = outputs.hidden_states[-1].mean(dim=1)
57
 
58
+ # Calculate Euclidean Distance to baseline on GPU
59
+ # distance = sqrt(sum((a - b)^2))
60
+ diff = batch_embeddings - baseline_embedding
61
+ dist = torch.norm(diff, dim=1)
62
+ distances.extend(dist.cpu().tolist())
63
+
64
+ # 4. Return coordinates and their corresponding sensitivity scores
65
+ return [{"coord": idx, "score": score} for idx, score in zip(indices, distances)]
66
+
67
+ def __call__(self, data):
68
+ """
69
+ Args:
70
+ data (:obj:`dict`):
71
+ - "inputs": the DNA sequence
72
+ - "method": "embed" (default) or "sensitivity"
73
+ """
74
+ inputs = data.get("inputs", "")
75
+ method = data.get("method", "embed")
76
+
77
+ if not inputs:
78
+ return {"error": "No input sequence provided"}
79
 
80
+ if method == "sensitivity":
81
+ # Returns the map of high-leverage coordinates
82
+ return self.map_sensitivity(inputs)
83
+ else:
84
+ # Standard embedding behavior
85
+ embedding = self._get_embedding(inputs)
86
+ return embedding.cpu().tolist()[0]