brocks1234 commited on
Commit
1214747
·
verified ·
1 Parent(s): 364362c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -20
handler.py CHANGED
@@ -1,40 +1,45 @@
1
- import sys
2
- from unittest.mock import MagicMock
3
-
4
- # 1. MOCK TRITON: This trick prevents the model from even TRYING to load the broken code
5
- sys.modules["triton"] = MagicMock()
6
- sys.modules["triton.language"] = MagicMock()
7
-
8
- from typing import Any, Dict, List
9
- from transformers import AutoTokenizer, AutoModel, AutoConfig
10
  import torch
 
 
11
 
12
  class EndpointHandler:
13
  def __init__(self, path=""):
14
  self.model_id = "zhihan1996/DNABERT-2-117M"
15
 
16
- # 2. FORCE CONFIG: Explicitly disable flash attention in multiple places
17
- config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
18
- config.use_flash_attn = False
19
- if hasattr(config, "auto_map"):
20
- # This ensures it doesn't try to use the custom 'Flash' modeling class
21
- config.auto_map["AutoModel"] = "modeling_bert.BertModel"
22
 
 
23
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
24
 
25
- # 3. LOAD MODEL: Pass the specific config and force trust_remote_code
26
- self.model = AutoModel.from_pretrained(
 
27
  self.model_id,
28
  config=config,
29
- trust_remote_code=True
 
30
  )
31
 
32
  if torch.cuda.is_available():
33
  self.model = self.model.to("cuda")
 
34
 
35
  def __call__(self, data: Dict[str, Any]) -> List[float]:
 
36
  inputs = data.pop("inputs", data)
37
- encoded_input = self.tokenizer(inputs, return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
38
 
39
  if torch.cuda.is_available():
40
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
@@ -42,6 +47,8 @@ class EndpointHandler:
42
  with torch.no_grad():
43
  outputs = self.model(**encoded_input)
44
 
45
- # Mean pooling to get sequence embedding
 
46
  embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
 
47
  return embeddings
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from typing import Any, Dict, List
3
+ from transformers import AutoTokenizer, BertModel, BertConfig
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  self.model_id = "zhihan1996/DNABERT-2-117M"
8
 
9
+ # 1. Use a standard BERT config instead of the custom DNABERT one
10
+ # This prevents the 'flash_attn_triton.py' from ever being triggered
11
+ config = BertConfig.from_pretrained(self.model_id)
 
 
 
12
 
13
+ # 2. Load the tokenizer normally
14
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
15
 
16
+ # 3. Load as a standard BertModel
17
+ # We use 'trust_remote_code=False' here to force standard layers
18
+ self.model = BertModel.from_pretrained(
19
  self.model_id,
20
  config=config,
21
+ trust_remote_code=False,
22
+ ignore_mismatched_sizes=True
23
  )
24
 
25
  if torch.cuda.is_available():
26
  self.model = self.model.to("cuda")
27
+ self.model.eval()
28
 
29
  def __call__(self, data: Dict[str, Any]) -> List[float]:
30
+ # Handle input strings or dictionaries
31
  inputs = data.pop("inputs", data)
32
+ if isinstance(inputs, list):
33
+ inputs = inputs[0]
34
+
35
+ # Standard tokenization
36
+ encoded_input = self.tokenizer(
37
+ inputs,
38
+ return_tensors='pt',
39
+ padding=True,
40
+ truncation=True,
41
+ max_length=512
42
+ )
43
 
44
  if torch.cuda.is_available():
45
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
 
47
  with torch.no_grad():
48
  outputs = self.model(**encoded_input)
49
 
50
+ # Get the hidden states and perform mean pooling
51
+ # index 0 is the last_hidden_state
52
  embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
53
+
54
  return embeddings