Spaces:
Runtime error
Runtime error
Error: Error processing document with GOT-OCR: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
Browse files- src/parsers/got_ocr_parser.py +141 -9
src/parsers/got_ocr_parser.py
CHANGED
|
@@ -11,6 +11,17 @@ import importlib
|
|
| 11 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
|
| 12 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
| 13 |
os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
from src.parsers.parser_interface import DocumentParser
|
| 16 |
from src.parsers.parser_registry import ParserRegistry
|
|
@@ -117,6 +128,39 @@ class GotOcrParser(DocumentParser):
|
|
| 117 |
torch.set_default_tensor_type(torch.FloatTensor)
|
| 118 |
torch.set_default_dtype(torch.float16)
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
cls._model = AutoModel.from_pretrained(
|
| 121 |
'stepfun-ai/GOT-OCR2_0',
|
| 122 |
trust_remote_code=True,
|
|
@@ -127,6 +171,35 @@ class GotOcrParser(DocumentParser):
|
|
| 127 |
torch_dtype=torch.float16 # Explicitly specify float16 dtype
|
| 128 |
)
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# Patch the model's chat method to use float16 instead of bfloat16
|
| 131 |
logger.info("Patching model to use float16 instead of bfloat16")
|
| 132 |
original_chat = cls._model.chat
|
|
@@ -144,6 +217,11 @@ class GotOcrParser(DocumentParser):
|
|
| 144 |
"""A patched version of chat method that forces float16 precision"""
|
| 145 |
logger.info(f"Using patched chat method with float16, ocr_type={ocr_type}")
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# Set explicit autocast dtype
|
| 148 |
if hasattr(torch.amp, 'autocast'):
|
| 149 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
|
@@ -162,11 +240,16 @@ class GotOcrParser(DocumentParser):
|
|
| 162 |
except RuntimeError as e:
|
| 163 |
if "bfloat16" in str(e):
|
| 164 |
logger.error(f"BFloat16 error encountered despite patching: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
|
| 166 |
else:
|
| 167 |
raise
|
| 168 |
else:
|
| 169 |
-
#
|
| 170 |
try:
|
| 171 |
# Direct call without 'self' as first arg
|
| 172 |
return original_chat(tokenizer, image_path, ocr_type, **kwargs)
|
|
@@ -183,11 +266,27 @@ class GotOcrParser(DocumentParser):
|
|
| 183 |
import types
|
| 184 |
cls._model.chat = types.MethodType(patched_chat, cls._model)
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
# Set model to evaluation mode
|
| 187 |
if device_map == 'cuda':
|
| 188 |
-
cls._model = cls._model.eval().cuda()
|
| 189 |
else:
|
| 190 |
-
cls._model = cls._model.eval()
|
| 191 |
|
| 192 |
# Reset default dtype to float32 after model loading
|
| 193 |
torch.set_default_dtype(torch.float32)
|
|
@@ -377,22 +476,40 @@ class GotOcrParser(DocumentParser):
|
|
| 377 |
try:
|
| 378 |
# Use ocr_type as a positional argument based on the correct signature
|
| 379 |
logger.info(f"Using OCR method: {ocr_type}")
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
| 385 |
except RuntimeError as e:
|
| 386 |
if "bfloat16" in str(e) or "BFloat16" in str(e):
|
| 387 |
logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
|
| 388 |
# Try with explicit autocast
|
| 389 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 391 |
# Temporarily set default dtype
|
| 392 |
old_dtype = torch.get_default_dtype()
|
| 393 |
torch.set_default_dtype(torch.float16)
|
| 394 |
|
| 395 |
# Call with positional argument for ocr_type
|
|
|
|
| 396 |
result = self._model.chat(
|
| 397 |
self._tokenizer,
|
| 398 |
str(file_path),
|
|
@@ -403,7 +520,22 @@ class GotOcrParser(DocumentParser):
|
|
| 403 |
torch.set_default_dtype(old_dtype)
|
| 404 |
except Exception as inner_e:
|
| 405 |
logger.error(f"Error in fallback method: {str(inner_e)}")
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
else:
|
| 408 |
# Re-raise other errors
|
| 409 |
raise
|
|
|
|
| 11 |
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0+PTX" # For T4 GPU compatibility
|
| 12 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
| 13 |
os.environ["TORCH_AMP_AUTOCAST_DTYPE"] = "float16"
|
| 14 |
+
os.environ["PYTORCH_DISPATCHER_DISABLE_TORCH_FUNCTION_AUTOGRAD_FALLBACK"] = "1" # Disable fallbacks that might use bfloat16
|
| 15 |
+
|
| 16 |
+
# Add patch for bfloat16 at the module level
|
| 17 |
+
if 'torch' in sys.modules:
|
| 18 |
+
torch_module = sys.modules['torch']
|
| 19 |
+
if hasattr(torch_module, 'bfloat16'):
|
| 20 |
+
# Create a reference to the original bfloat16 function
|
| 21 |
+
original_bfloat16 = torch_module.bfloat16
|
| 22 |
+
# Replace it with float16
|
| 23 |
+
torch_module.bfloat16 = torch_module.float16
|
| 24 |
+
logger.info("Patched torch.bfloat16 to use torch.float16 instead")
|
| 25 |
|
| 26 |
from src.parsers.parser_interface import DocumentParser
|
| 27 |
from src.parsers.parser_registry import ParserRegistry
|
|
|
|
| 128 |
torch.set_default_tensor_type(torch.FloatTensor)
|
| 129 |
torch.set_default_dtype(torch.float16)
|
| 130 |
|
| 131 |
+
# Aggressively patch torch.autocast to always use float16
|
| 132 |
+
original_autocast = torch.amp.autocast if hasattr(torch.amp, 'autocast') else None
|
| 133 |
+
|
| 134 |
+
if original_autocast:
|
| 135 |
+
def patched_autocast(*args, **kwargs):
|
| 136 |
+
# Force dtype to float16
|
| 137 |
+
kwargs['dtype'] = torch.float16
|
| 138 |
+
return original_autocast(*args, **kwargs)
|
| 139 |
+
|
| 140 |
+
torch.amp.autocast = patched_autocast
|
| 141 |
+
logger.info("Patched torch.amp.autocast to always use float16")
|
| 142 |
+
|
| 143 |
+
# Patch tensor casting methods for bfloat16
|
| 144 |
+
if hasattr(torch, 'Tensor'):
|
| 145 |
+
if hasattr(torch.Tensor, 'to'):
|
| 146 |
+
original_to = torch.Tensor.to
|
| 147 |
+
def patched_to(self, *args, **kwargs):
|
| 148 |
+
# If the first arg is a dtype and it's bfloat16, replace with float16
|
| 149 |
+
if args and args[0] == torch.bfloat16:
|
| 150 |
+
logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
|
| 151 |
+
args = list(args)
|
| 152 |
+
args[0] = torch.float16
|
| 153 |
+
args = tuple(args)
|
| 154 |
+
# If dtype is specified in kwargs and it's bfloat16, replace with float16
|
| 155 |
+
if kwargs.get('dtype') == torch.bfloat16:
|
| 156 |
+
logger.warning("Intercepted attempt to cast tensor to bfloat16, using float16 instead")
|
| 157 |
+
kwargs['dtype'] = torch.float16
|
| 158 |
+
return original_to(self, *args, **kwargs)
|
| 159 |
+
|
| 160 |
+
torch.Tensor.to = patched_to
|
| 161 |
+
logger.info("Patched torch.Tensor.to method to prevent bfloat16 usage")
|
| 162 |
+
|
| 163 |
+
# Load the model with explicit float16 dtype
|
| 164 |
cls._model = AutoModel.from_pretrained(
|
| 165 |
'stepfun-ai/GOT-OCR2_0',
|
| 166 |
trust_remote_code=True,
|
|
|
|
| 171 |
torch_dtype=torch.float16 # Explicitly specify float16 dtype
|
| 172 |
)
|
| 173 |
|
| 174 |
+
# Ensure all model parameters are float16
|
| 175 |
+
for param in cls._model.parameters():
|
| 176 |
+
param.data = param.data.to(torch.float16)
|
| 177 |
+
|
| 178 |
+
# Examine model internals to find any direct bfloat16 usage
|
| 179 |
+
def find_and_patch_bfloat16_attributes(module, path=""):
|
| 180 |
+
for name, child in module.named_children():
|
| 181 |
+
child_path = f"{path}.{name}" if path else name
|
| 182 |
+
# Check if any attribute contains "bfloat16" in its name
|
| 183 |
+
for attr_name in dir(child):
|
| 184 |
+
if "bfloat16" in attr_name.lower():
|
| 185 |
+
try:
|
| 186 |
+
# Try to get the attribute
|
| 187 |
+
attr_value = getattr(child, attr_name)
|
| 188 |
+
logger.warning(f"Found potential bfloat16 usage at {child_path}.{attr_name}")
|
| 189 |
+
# Try to replace with float16 equivalent if it exists
|
| 190 |
+
float16_attr_name = attr_name.replace("bfloat16", "float16").replace("bf16", "fp16")
|
| 191 |
+
if hasattr(child, float16_attr_name):
|
| 192 |
+
logger.info(f"Replacing {attr_name} with {float16_attr_name}")
|
| 193 |
+
setattr(child, attr_name, getattr(child, float16_attr_name))
|
| 194 |
+
except Exception as e:
|
| 195 |
+
logger.error(f"Error examining attribute {attr_name}: {e}")
|
| 196 |
+
# Recursively check child modules
|
| 197 |
+
find_and_patch_bfloat16_attributes(child, child_path)
|
| 198 |
+
|
| 199 |
+
# Apply the internal examination
|
| 200 |
+
logger.info("Examining model for potential bfloat16 usage...")
|
| 201 |
+
find_and_patch_bfloat16_attributes(cls._model)
|
| 202 |
+
|
| 203 |
# Patch the model's chat method to use float16 instead of bfloat16
|
| 204 |
logger.info("Patching model to use float16 instead of bfloat16")
|
| 205 |
original_chat = cls._model.chat
|
|
|
|
| 217 |
"""A patched version of chat method that forces float16 precision"""
|
| 218 |
logger.info(f"Using patched chat method with float16, ocr_type={ocr_type}")
|
| 219 |
|
| 220 |
+
# Force any bfloat16 tensors to float16
|
| 221 |
+
if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
|
| 222 |
+
torch.bfloat16 = torch.float16
|
| 223 |
+
logger.info("Forcing torch.bfloat16 to be torch.float16 within chat method")
|
| 224 |
+
|
| 225 |
# Set explicit autocast dtype
|
| 226 |
if hasattr(torch.amp, 'autocast'):
|
| 227 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
|
|
|
| 240 |
except RuntimeError as e:
|
| 241 |
if "bfloat16" in str(e):
|
| 242 |
logger.error(f"BFloat16 error encountered despite patching: {e}")
|
| 243 |
+
# More aggressive handling
|
| 244 |
+
if hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
|
| 245 |
+
logger.info("Attempting with torch.cuda.amp.autocast as last resort")
|
| 246 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
| 247 |
+
return original_chat(tokenizer, str(image_path), ocr_type, **kwargs)
|
| 248 |
raise RuntimeError(f"GPU doesn't support bfloat16: {e}")
|
| 249 |
else:
|
| 250 |
raise
|
| 251 |
else:
|
| 252 |
+
# If autocast is not available, try to manually ensure everything is float16
|
| 253 |
try:
|
| 254 |
# Direct call without 'self' as first arg
|
| 255 |
return original_chat(tokenizer, image_path, ocr_type, **kwargs)
|
|
|
|
| 266 |
import types
|
| 267 |
cls._model.chat = types.MethodType(patched_chat, cls._model)
|
| 268 |
|
| 269 |
+
# Check if the model has a cast_to_bfloat16 method and override it
|
| 270 |
+
if hasattr(cls._model, 'cast_to_bfloat16'):
|
| 271 |
+
original_cast = cls._model.cast_to_bfloat16
|
| 272 |
+
def patched_cast(self, *args, **kwargs):
|
| 273 |
+
logger.info("Intercepted attempt to cast model to bfloat16, using float16 instead")
|
| 274 |
+
# If the model has a cast_to_float16 method, use that instead
|
| 275 |
+
if hasattr(self, 'cast_to_float16'):
|
| 276 |
+
return self.cast_to_float16(*args, **kwargs)
|
| 277 |
+
# Otherwise, cast all parameters manually
|
| 278 |
+
for param in self.parameters():
|
| 279 |
+
param.data = param.data.to(torch.float16)
|
| 280 |
+
return self
|
| 281 |
+
|
| 282 |
+
cls._model.cast_to_bfloat16 = types.MethodType(patched_cast, cls._model)
|
| 283 |
+
logger.info("Patched model.cast_to_bfloat16 method")
|
| 284 |
+
|
| 285 |
# Set model to evaluation mode
|
| 286 |
if device_map == 'cuda':
|
| 287 |
+
cls._model = cls._model.eval().cuda().half() # Explicitly cast to half precision (float16)
|
| 288 |
else:
|
| 289 |
+
cls._model = cls._model.eval().half() # Explicitly cast to half precision (float16)
|
| 290 |
|
| 291 |
# Reset default dtype to float32 after model loading
|
| 292 |
torch.set_default_dtype(torch.float32)
|
|
|
|
| 476 |
try:
|
| 477 |
# Use ocr_type as a positional argument based on the correct signature
|
| 478 |
logger.info(f"Using OCR method: {ocr_type}")
|
| 479 |
+
|
| 480 |
+
# Temporarily force any PyTorch operations to use float16
|
| 481 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 482 |
+
result = self._model.chat(
|
| 483 |
+
self._tokenizer,
|
| 484 |
+
str(file_path),
|
| 485 |
+
ocr_type # Pass as positional arg, not keyword
|
| 486 |
+
)
|
| 487 |
except RuntimeError as e:
|
| 488 |
if "bfloat16" in str(e) or "BFloat16" in str(e):
|
| 489 |
logger.warning("Caught bfloat16 error, trying to force float16 with autocast")
|
| 490 |
# Try with explicit autocast
|
| 491 |
try:
|
| 492 |
+
# More aggressive approach with multiple settings
|
| 493 |
+
|
| 494 |
+
# Ensure bfloat16 is aliased to float16 globally
|
| 495 |
+
if hasattr(torch, 'bfloat16') and torch.bfloat16 != torch.float16:
|
| 496 |
+
logger.info("Forcing bfloat16 to be float16 in exception handler")
|
| 497 |
+
torch.bfloat16 = torch.float16
|
| 498 |
+
|
| 499 |
+
# Apply patch to the model's config if it exists
|
| 500 |
+
if hasattr(self._model, 'config'):
|
| 501 |
+
if hasattr(self._model.config, 'torch_dtype'):
|
| 502 |
+
logger.info(f"Setting model config dtype from {self._model.config.torch_dtype} to float16")
|
| 503 |
+
self._model.config.torch_dtype = torch.float16
|
| 504 |
+
|
| 505 |
+
# Try with all possible autocast combinations
|
| 506 |
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
|
| 507 |
# Temporarily set default dtype
|
| 508 |
old_dtype = torch.get_default_dtype()
|
| 509 |
torch.set_default_dtype(torch.float16)
|
| 510 |
|
| 511 |
# Call with positional argument for ocr_type
|
| 512 |
+
logger.info("Using fallback with autocast and default dtype set to float16")
|
| 513 |
result = self._model.chat(
|
| 514 |
self._tokenizer,
|
| 515 |
str(file_path),
|
|
|
|
| 520 |
torch.set_default_dtype(old_dtype)
|
| 521 |
except Exception as inner_e:
|
| 522 |
logger.error(f"Error in fallback method: {str(inner_e)}")
|
| 523 |
+
|
| 524 |
+
# Last resort: try using torchscript if available
|
| 525 |
+
try:
|
| 526 |
+
logger.info("Attempting final approach with model.half() and direct call")
|
| 527 |
+
# Force model to half precision
|
| 528 |
+
self._model = self._model.half()
|
| 529 |
+
|
| 530 |
+
# Try direct call with the original method
|
| 531 |
+
result = self._model.chat(
|
| 532 |
+
self._tokenizer,
|
| 533 |
+
str(file_path),
|
| 534 |
+
ocr_type
|
| 535 |
+
)
|
| 536 |
+
except Exception as final_e:
|
| 537 |
+
logger.error(f"All fallback approaches failed: {str(final_e)}")
|
| 538 |
+
raise RuntimeError(f"Error processing with GOT-OCR using fallback: {str(final_e)}")
|
| 539 |
else:
|
| 540 |
# Re-raise other errors
|
| 541 |
raise
|