brocks1234 commited on
Commit
dd09c74
·
verified ·
1 Parent(s): 8fee4ab

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -26
handler.py CHANGED
@@ -1,29 +1,15 @@
1
- import os
2
- # Force PyTorch to use its built-in stable attention and ignore custom kernels
3
- os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
4
-
5
  import torch
6
  from typing import Any, Dict, List
7
- from transformers import AutoTokenizer, AutoModel, AutoConfig
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
- self.model_id = "zhihan1996/DNABERT-2-117M"
 
12
 
 
13
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
14
-
15
- # 1. Load config and EXPLICITLY set the attn_implementation to 'eager'
16
- # 'eager' means 'plain PyTorch math' - no Triton, no Flash, just stability.
17
- config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
18
- config.use_flash_attn = False
19
-
20
- # 2. Load Model with the 'eager' implementation if supported
21
- self.model = AutoModel.from_pretrained(
22
- self.model_id,
23
- trust_remote_code=True,
24
- config=config,
25
- attn_implementation="eager"
26
- )
27
 
28
  if torch.cuda.is_available():
29
  self.model = self.model.to("cuda")
@@ -34,21 +20,23 @@ class EndpointHandler:
34
  if isinstance(inputs, list):
35
  inputs = inputs[0]
36
 
 
37
  encoded_input = self.tokenizer(
38
  inputs,
39
- return_tensors='pt',
40
- padding=True,
41
  truncation=True,
42
- max_length=512
43
  )
44
 
45
  if torch.cuda.is_available():
46
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
47
 
48
  with torch.no_grad():
49
- # 3. Use the inference mode context for extra stability
50
- with torch.inference_mode():
51
- outputs = self.model(**encoded_input)
52
 
53
- embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
 
 
 
 
54
  return embeddings
 
 
 
 
 
1
  import torch
2
  from typing import Any, Dict, List
3
+ from transformers import AutoTokenizer, AutoModel
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ # We'll use the 'phulia' variant which is highly regarded for stability
8
+ self.model_id = "kuleshov-group/caduceus-phulia-16-soft"
9
 
10
+ # Load tokenizer and model
11
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
12
+ self.model = AutoModel.from_pretrained(self.model_id, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  if torch.cuda.is_available():
15
  self.model = self.model.to("cuda")
 
20
  if isinstance(inputs, list):
21
  inputs = inputs[0]
22
 
23
+ # Caduceus often performs better without excessive padding
24
  encoded_input = self.tokenizer(
25
  inputs,
26
+ return_tensors='pt',
 
27
  truncation=True,
28
+ max_length=2048 # Caduceus handles long sequences better than BERT
29
  )
30
 
31
  if torch.cuda.is_available():
32
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
33
 
34
  with torch.no_grad():
35
+ outputs = self.model(**encoded_input)
 
 
36
 
37
+ # Caduceus (Mamba) outputs hidden states.
38
+ # We take the mean across the sequence length (dim 1)
39
+ # to get a fixed-size vector for your LangGraph logic.
40
+ embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().tolist()
41
+
42
  return embeddings