Spaces:
Paused
Paused
perf(attn): default to SDPA; gracefully fallback when flash_attn missing; use dtype arg
Browse files
src/kybtech_dots_ocr/model_loader.py
CHANGED
|
@@ -23,7 +23,7 @@ REPO_ID = os.getenv("DOTS_OCR_REPO_ID", "rednote-hilab/dots.ocr")
|
|
| 23 |
LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
|
| 24 |
DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto")
|
| 25 |
MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
|
| 26 |
-
USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "
|
| 27 |
MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56
|
| 28 |
MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360
|
| 29 |
CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
|
|
@@ -91,6 +91,22 @@ class DotsOCRModelLoader:
|
|
| 91 |
except Exception as e:
|
| 92 |
logger.error(f"Failed to download model: {e}")
|
| 93 |
raise RuntimeError(f"Model download failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def load_model(self) -> None:
|
| 96 |
"""Load the Dots.OCR model and processor."""
|
|
@@ -110,21 +126,20 @@ class DotsOCRModelLoader:
|
|
| 110 |
|
| 111 |
# Load model with appropriate configuration
|
| 112 |
model_kwargs = {
|
| 113 |
-
"
|
| 114 |
"trust_remote_code": True,
|
| 115 |
}
|
| 116 |
|
| 117 |
# Add device-specific configurations
|
| 118 |
if self.device == "cuda":
|
| 119 |
-
#
|
| 120 |
-
if
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
# Use device_map for automatic GPU memory management
|
| 129 |
model_kwargs["device_map"] = "auto"
|
| 130 |
else:
|
|
|
|
| 23 |
LOCAL_DIR = os.getenv("DOTS_OCR_LOCAL_DIR", "/data/models/dots-ocr")
|
| 24 |
DEVICE_CONFIG = os.getenv("DOTS_OCR_DEVICE", "auto")
|
| 25 |
MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "2048"))
|
| 26 |
+
USE_FLASH_ATTENTION = os.getenv("DOTS_OCR_FLASH_ATTENTION", "0") == "1"
|
| 27 |
MIN_PIXELS = int(os.getenv("DOTS_OCR_MIN_PIXELS", "3136")) # 56x56
|
| 28 |
MAX_PIXELS = int(os.getenv("DOTS_OCR_MAX_PIXELS", "11289600")) # 3360x3360
|
| 29 |
CUSTOM_PROMPT = os.getenv("DOTS_OCR_PROMPT")
|
|
|
|
| 91 |
except Exception as e:
|
| 92 |
logger.error(f"Failed to download model: {e}")
|
| 93 |
raise RuntimeError(f"Model download failed: {e}")
|
| 94 |
+
|
| 95 |
+
def _can_use_flash_attn(self) -> bool:
|
| 96 |
+
"""Check whether FlashAttention2 can be enabled safely.
|
| 97 |
+
|
| 98 |
+
Returns True only if the package is importable and dtype is fp16/bf16.
|
| 99 |
+
"""
|
| 100 |
+
if not USE_FLASH_ATTENTION:
|
| 101 |
+
return False
|
| 102 |
+
try:
|
| 103 |
+
# Import check avoids runtime error from Transformers if not installed
|
| 104 |
+
import flash_attn # type: ignore # noqa: F401
|
| 105 |
+
except Exception:
|
| 106 |
+
logger.warning("flash_attn package not installed; disabling FlashAttention2")
|
| 107 |
+
return False
|
| 108 |
+
# FlashAttention2 supports fp16/bf16 only (see HF docs)
|
| 109 |
+
return self.dtype in (torch.float16, torch.bfloat16)
|
| 110 |
|
| 111 |
def load_model(self) -> None:
|
| 112 |
"""Load the Dots.OCR model and processor."""
|
|
|
|
| 126 |
|
| 127 |
# Load model with appropriate configuration
|
| 128 |
model_kwargs = {
|
| 129 |
+
"dtype": self.dtype, # torch_dtype is deprecated
|
| 130 |
"trust_remote_code": True,
|
| 131 |
}
|
| 132 |
|
| 133 |
# Add device-specific configurations
|
| 134 |
if self.device == "cuda":
|
| 135 |
+
# Prefer FlashAttention2 when truly available; otherwise SDPA
|
| 136 |
+
if self._can_use_flash_attn():
|
| 137 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 138 |
+
logger.info("Using flash attention 2")
|
| 139 |
+
else:
|
| 140 |
+
model_kwargs["attn_implementation"] = "sdpa"
|
| 141 |
+
logger.info("Using SDPA attention (flash-attn unavailable or disabled)")
|
| 142 |
+
|
|
|
|
| 143 |
# Use device_map for automatic GPU memory management
|
| 144 |
model_kwargs["device_map"] = "auto"
|
| 145 |
else:
|