brocks1234 commited on
Commit
8fee4ab
·
verified ·
1 Parent(s): 40ac285

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -24
handler.py CHANGED
@@ -1,10 +1,6 @@
1
- import sys
2
- from unittest.mock import MagicMock
3
-
4
- # 1. GLOBAL BLACKOUT: Must be at the very top, before any other imports
5
- # This makes Triton invisible to every script the model downloads.
6
- sys.modules["triton"] = MagicMock()
7
- sys.modules["triton.language"] = MagicMock()
8
 
9
  import torch
10
  from typing import Any, Dict, List
@@ -14,29 +10,21 @@ class EndpointHandler:
14
  def __init__(self, path=""):
15
  self.model_id = "zhihan1996/DNABERT-2-117M"
16
 
17
- # 2. Config level: Explicitly set flash_attn to False in the config object
18
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
19
- config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
20
 
21
- # Some custom implementations check for 'use_flash_attn' or 'flash_attn'
 
 
22
  config.use_flash_attn = False
23
- if hasattr(config, "auto_map"):
24
- # Force it to use the standard modeling rather than the Triton-based one
25
- config.auto_map["AutoModel"] = "modeling_bert.BertModel"
26
-
27
- # 3. Load Model
28
  self.model = AutoModel.from_pretrained(
29
  self.model_id,
30
  trust_remote_code=True,
31
- config=config
 
32
  )
33
 
34
- # 4. Layer Level: Double-check the individual attention layers
35
- # This is our last-resort safety net
36
- for module in self.model.modules():
37
- if hasattr(module, "use_flash_attn"):
38
- module.use_flash_attn = False
39
-
40
  if torch.cuda.is_available():
41
  self.model = self.model.to("cuda")
42
  self.model.eval()
@@ -58,8 +46,9 @@ class EndpointHandler:
58
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
59
 
60
  with torch.no_grad():
61
- outputs = self.model(**encoded_input)
 
 
62
 
63
- # Mean pooling
64
  embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
65
  return embeddings
 
1
+ import os
2
+ # Force PyTorch to use its built-in stable attention and ignore custom kernels
3
+ os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
 
 
 
 
4
 
5
  import torch
6
  from typing import Any, Dict, List
 
10
  def __init__(self, path=""):
11
  self.model_id = "zhihan1996/DNABERT-2-117M"
12
 
 
13
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
 
14
 
15
+ # 1. Load config and EXPLICITLY set the attn_implementation to 'eager'
16
+ # 'eager' means 'plain PyTorch math' - no Triton, no Flash, just stability.
17
+ config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
18
  config.use_flash_attn = False
19
+
20
+ # 2. Load Model with the 'eager' implementation if supported
 
 
 
21
  self.model = AutoModel.from_pretrained(
22
  self.model_id,
23
  trust_remote_code=True,
24
+ config=config,
25
+ attn_implementation="eager"
26
  )
27
 
 
 
 
 
 
 
28
  if torch.cuda.is_available():
29
  self.model = self.model.to("cuda")
30
  self.model.eval()
 
46
  encoded_input = {k: v.to("cuda") for k, v in encoded_input.items()}
47
 
48
  with torch.no_grad():
49
+ # 3. Use the inference mode context for extra stability
50
+ with torch.inference_mode():
51
+ outputs = self.model(**encoded_input)
52
 
 
53
  embeddings = outputs[0][0].mean(dim=0).cpu().numpy().tolist()
54
  return embeddings