brocks1234 commited on
Commit
40ac285
·
verified ·
1 Parent(s): 9b12332

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -25
handler.py CHANGED
@@ -1,31 +1,51 @@
1
  import sys
2
- from typing import Any, Dict, List
 
 
 
 
 
 
3
  import torch
 
4
  from transformers import AutoTokenizer, AutoModel, AutoConfig
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
  self.model_id = "zhihan1996/DNABERT-2-117M"
9
 
10
- # 1. Load tokenizer and config
11
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
12
  config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
13
 
14
- # 2. Force the config to disable flash attention
15
  config.use_flash_attn = False
 
 
 
 
 
 
 
 
 
 
16
 
17
- # 3. Load the model
18
- self.model = AutoModel.from_pretrained(self.model_id, trust_remote_code=True, config=config)
19
-
 
 
 
20
  if torch.cuda.is_available():
21
  self.model = self.model.to("cuda")
22
  self.model.eval()
23
 
24
  def __call__(self, data: Dict[str, Any]) -> List[float]:
25
- # Extract inputs
26
  inputs = data.pop("inputs", data)
27
-
28
- # Tokenize
 
29
  encoded_input = self.tokenizer(
30
  inputs,
31
  return_tensors='pt',
@@ -36,20 +56,10 @@ class EndpointHandler:
36
 
37
  if torch.cuda.is_available():
38
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
39
-
40
- # THE "TOTAL ECLIPSE":
41
- # We temporarily move triton out of sys.modules so the model
42
- # thinks it's not installed and falls back to standard PyTorch math.
43
- real_triton = sys.modules.pop("triton", None)
44
-
45
- try:
46
- with torch.no_grad():
47
- outputs = self.model(**encoded_input)
48
- # Mean pooling for embedding
49
- embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
50
- finally:
51
- # Restore triton so we don't break the rest of the HF environment
52
- if real_triton:
53
- sys.modules["triton"] = real_triton
54
-
55
  return embeddings
 
1
  import sys
2
+ from unittest.mock import MagicMock
3
+
4
+ # 1. GLOBAL BLACKOUT: Must be at the very top, before any other imports
5
+ # This makes Triton invisible to every script the model downloads.
6
+ sys.modules["triton"] = MagicMock()
7
+ sys.modules["triton.language"] = MagicMock()
8
+
9
  import torch
10
+ from typing import Any, Dict, List
11
  from transformers import AutoTokenizer, AutoModel, AutoConfig
12
 
13
  class EndpointHandler:
14
  def __init__(self, path=""):
15
  self.model_id = "zhihan1996/DNABERT-2-117M"
16
 
17
+ # 2. Config level: Explicitly set flash_attn to False in the config object
18
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
19
  config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
20
 
21
+ # Some custom implementations check for 'use_flash_attn' or 'flash_attn'
22
  config.use_flash_attn = False
23
+ if hasattr(config, "auto_map"):
24
+ # Force it to use the standard modeling rather than the Triton-based one
25
+ config.auto_map["AutoModel"] = "modeling_bert.BertModel"
26
+
27
+ # 3. Load Model
28
+ self.model = AutoModel.from_pretrained(
29
+ self.model_id,
30
+ trust_remote_code=True,
31
+ config=config
32
+ )
33
 
34
+ # 4. Layer Level: Double-check the individual attention layers
35
+ # This is our last-resort safety net
36
+ for module in self.model.modules():
37
+ if hasattr(module, "use_flash_attn"):
38
+ module.use_flash_attn = False
39
+
40
  if torch.cuda.is_available():
41
  self.model = self.model.to("cuda")
42
  self.model.eval()
43
 
44
  def __call__(self, data: Dict[str, Any]) -> List[float]:
 
45
  inputs = data.pop("inputs", data)
46
+ if isinstance(inputs, list):
47
+ inputs = inputs[0]
48
+
49
  encoded_input = self.tokenizer(
50
  inputs,
51
  return_tensors='pt',
 
56
 
57
  if torch.cuda.is_available():
58
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
59
+
60
+ with torch.no_grad():
61
+ outputs = self.model(**encoded_input)
62
+
63
+ # Mean pooling
64
+ embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
 
 
 
 
 
 
 
 
 
 
65
  return embeddings