brocks1234 commited on
Commit
fd76be3
·
verified ·
1 Parent(s): e299aae

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -6
handler.py CHANGED
@@ -3,12 +3,12 @@ import torch
3
  from typing import Any, Dict, List
4
  from transformers import AutoConfig, AutoTokenizer, AutoModelForMaskedLM
5
 
6
- # Force the trust flag at the environment level
7
  os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "True"
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- # We explicitly ignore 'path' and pull from the source
12
  self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
13
 
14
  self.config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
@@ -28,7 +28,7 @@ class EndpointHandler:
28
  if isinstance(inputs, list):
29
  inputs = inputs[0]
30
 
31
- # 12.2kb APRIL promoter chunking logic
32
  chunk_size = 1000
33
  stride = 500
34
  chunks = [inputs[i:i + chunk_size] for i in range(0, len(inputs), stride)]
@@ -44,6 +44,4 @@ class EndpointHandler:
44
  chunk_emb = torch.mean(outputs.hidden_states[-1], dim=1).squeeze()
45
  all_embeddings.append(chunk_emb)
46
 
47
- # Average the chunks for one representative vector
48
- final_embedding = torch.stack(all_embeddings).mean(dim=0).cpu().numpy().tolist()
49
- return final_embedding
 
3
  from typing import Any, Dict, List
4
  from transformers import AutoConfig, AutoTokenizer, AutoModelForMaskedLM
5
 
6
+ # Force the trust flag globally
7
  os.environ["HF_HUB_TRUST_REMOTE_CODE"] = "True"
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ # We ignore the local 'path' and pull fresh from the source
12
  self.model_id = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
13
 
14
  self.config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
 
28
  if isinstance(inputs, list):
29
  inputs = inputs[0]
30
 
31
+ # 12.2kb APRIL promoter chunking
32
  chunk_size = 1000
33
  stride = 500
34
  chunks = [inputs[i:i + chunk_size] for i in range(0, len(inputs), stride)]
 
44
  chunk_emb = torch.mean(outputs.hidden_states[-1], dim=1).squeeze()
45
  all_embeddings.append(chunk_emb)
46
 
47
+ return torch.stack(all_embeddings).mean(dim=0).cpu().numpy().tolist()