brocks1234 commited on
Commit
364362c
·
verified ·
1 Parent(s): 773f0e8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -13
handler.py CHANGED
@@ -1,35 +1,47 @@
 
 
 
 
 
 
 
1
  from typing import Any, Dict, List
2
- from transformers import AutoTokenizer, AutoModel
3
  import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # We point directly to the original weights
8
  self.model_id = "zhihan1996/DNABERT-2-117M"
9
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
10
-
11
- # --- FIX: Disable Flash Attention to avoid the Triton error ---
12
- from transformers import AutoConfig
13
  config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
14
- config.use_flash_attn = False # This bypasses the broken line 114 code
15
- # --------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- self.model = AutoModel.from_pretrained(self.model_id, trust_remote_code=True)
18
  if torch.cuda.is_available():
19
  self.model = self.model.to("cuda")
20
 
21
  def __call__(self, data: Dict[str, Any]) -> List[float]:
22
  inputs = data.pop("inputs", data)
23
-
24
- # DNA Tokenization
25
  encoded_input = self.tokenizer(inputs, return_tensors='pt')
 
26
  if torch.cuda.is_available():
27
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
28
 
29
  with torch.no_grad():
30
  outputs = self.model(**encoded_input)
31
 
32
- # Returns a 768-dimensional vector representing the DNA sequence
33
  embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
34
-
35
  return embeddings
 
1
+ import sys
2
+ from unittest.mock import MagicMock
3
+
4
+ # 1. MOCK TRITON: This trick prevents the model from even TRYING to load the broken code
5
+ sys.modules["triton"] = MagicMock()
6
+ sys.modules["triton.language"] = MagicMock()
7
+
8
  from typing import Any, Dict, List
9
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
10
  import torch
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
 
14
  self.model_id = "zhihan1996/DNABERT-2-117M"
15
+
16
+ # 2. FORCE CONFIG: Explicitly disable flash attention in multiple places
 
 
17
  config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
18
+ config.use_flash_attn = False
19
+ if hasattr(config, "auto_map"):
20
+ # This ensures it doesn't try to use the custom 'Flash' modeling class
21
+ config.auto_map["AutoModel"] = "modeling_bert.BertModel"
22
+
23
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
24
+
25
+ # 3. LOAD MODEL: Pass the specific config and force trust_remote_code
26
+ self.model = AutoModel.from_pretrained(
27
+ self.model_id,
28
+ config=config,
29
+ trust_remote_code=True
30
+ )
31
 
 
32
  if torch.cuda.is_available():
33
  self.model = self.model.to("cuda")
34
 
35
  def __call__(self, data: Dict[str, Any]) -> List[float]:
36
  inputs = data.pop("inputs", data)
 
 
37
  encoded_input = self.tokenizer(inputs, return_tensors='pt')
38
+
39
  if torch.cuda.is_available():
40
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
41
 
42
  with torch.no_grad():
43
  outputs = self.model(**encoded_input)
44
 
45
+ # Mean pooling to get sequence embedding
46
  embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
 
47
  return embeddings