tommulder commited on
Commit
420a04f
·
1 Parent(s): 399c4d1

perf(attn): default to SDPA; gracefully fallback when flash_attn missing; use dtype arg

Browse files
Files changed (1) hide show
  1. src/kybtech_dots_ocr/model_loader.py +26 -11
src/kybtech_dots_ocr/model_loader.py CHANGED
@@ -23,7 +23,7 @@ REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr")
23
  LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
24
  DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto")
25
  MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
26
- USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "1") == "1"
27
  MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56
28
  MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360
29
  CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
@@ -91,6 +91,22 @@ class DotsOCRModelLoader:
91
  except Exception as e:
92
  logger.error(f"Failed to download model: {e}")
93
  raise RuntimeError(f"Model download failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def load_model(self) -> None:
96
  """Load the Dots.OCR model and processor."""
@@ -110,21 +126,20 @@ class DotsOCRModelLoader:
110
 
111
  # Load model with appropriate configuration
112
  model_kwargs = {
113
- "torch_dtype": self.dtype,
114
  "trust_remote_code": True,
115
  }
116
 
117
  # Add device-specific configurations
118
  if self.device == "cuda":
119
- # Use flash attention if available and requested
120
- if USE_FLASH_ATTENTION:
121
- try:
122
- model_kwargs["attn_implementation"] = "flash_attention_2"
123
- logger.info("Using flash attention 2")
124
- except Exception as e:
125
- logger.warning(f"Flash attention not available: {e}")
126
- logger.info("Falling back to standard attention")
127
-
128
  # Use device_map for automatic GPU memory management
129
  model_kwargs["device_map"] = "auto"
130
  else:
 
23
  LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
24
  DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto")
25
  MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
26
+ USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1"
27
  MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56
28
  MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360
29
  CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
 
91
  except Exception as e:
92
  logger.error(f"Failed to download model: {e}")
93
  raise RuntimeError(f"Model download failed: {e}")
94
+
95
+ def _can_use_flash_attn(self) -> bool:
96
+ """Check whether FlashAttention2 can be enabled safely.
97
+
98
+ Returns True only if the package is importable and dtype is fp16/bf16.
99
+ """
100
+ if not USE_FLASH_ATTENTION:
101
+ return False
102
+ try:
103
+ # Import check avoids runtime error from Transformers if not installed
104
+ import flash_attn # type: ignore # noqa: F401
105
+ except Exception:
106
+ logger.warning("flash_attn package not installed; disabling FlashAttention2")
107
+ return False
108
+ # FlashAttention2 supports fp16/bf16 only (see HF docs)
109
+ return self.dtype in (torch.float16, torch.bfloat16)
110
 
111
  def load_model(self) -> None:
112
  """Load the Dots.OCR model and processor."""
 
126
 
127
  # Load model with appropriate configuration
128
  model_kwargs = {
129
+ "dtype": self.dtype, # torch_dtype is deprecated
130
  "trust_remote_code": True,
131
  }
132
 
133
  # Add device-specific configurations
134
  if self.device == "cuda":
135
+ # Prefer FlashAttention2 when truly available; otherwise SDPA
136
+ if self._can_use_flash_attn():
137
+ model_kwargs["attn_implementation"] = "flash_attention_2"
138
+ logger.info("Using flash attention 2")
139
+ else:
140
+ model_kwargs["attn_implementation"] = "sdpa"
141
+ logger.info("Using SDPA attention (flash-attn unavailable or disabled)")
142
+
 
143
  # Use device_map for automatic GPU memory management
144
  model_kwargs["device_map"] = "auto"
145
  else: