DocUA commited on
Commit
4f43939
·
1 Parent(s): 002d07a

feat: Implement CUDA BF16 error handling with automatic fallback to CPU for model inference and generation.

Browse files
Files changed (1) hide show
  1. app_hf.py +73 -29
app_hf.py CHANGED
@@ -186,11 +186,16 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
186
  else:
187
  return "Будь ласка, завантажте зображення або PDF файл."
188
 
 
 
 
 
189
  try:
190
  model, processor_or_tokenizer = manager.get_model(model_choice)
191
  # Move to GPU only inside the decorated function
192
  print(f"Moving {model_choice} to GPU...")
193
  model.to(device="cuda", dtype=torch.float16)
 
194
  except Exception as e:
195
  return f"Помилка завантаження чи переміщення моделі: {str(e)}\nЯкщо це MedGemma, переконайтеся, що ви надали HF_TOKEN."
196
 
@@ -200,11 +205,10 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
200
  all_results = []
201
 
202
  try:
203
- _autocast_ctx = (
204
- torch.autocast(device_type="cuda", dtype=torch.float16)
205
- if torch.cuda.is_available()
206
- else contextlib.nullcontext()
207
- )
208
 
209
  for i, img in enumerate(images_to_process):
210
  img = img.convert("RGB")
@@ -216,18 +220,38 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
216
  tmp_path = tmp.name
217
 
218
  try:
219
- with torch.no_grad(), _autocast_ctx:
220
- res = model.infer(
221
- processor_or_tokenizer,
222
- prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
223
- image_file=tmp_path,
224
- output_path=output_dir,
225
- base_size=1024,
226
- image_size=768,
227
- crop_mode=True,
228
- eval_mode=True
229
- )
230
- all_results.append(f"--- Page/Image {i+1} ---\n{res}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  finally:
232
  if os.path.exists(tmp_path):
233
  os.remove(tmp_path)
@@ -250,22 +274,42 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
250
  tokenize=True,
251
  return_dict=True,
252
  return_tensors="pt"
253
- ).to("cuda") # Ensure inputs are on cuda
254
 
255
  if "attention_mask" not in inputs:
256
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long)
257
 
258
- with torch.no_grad(), _autocast_ctx:
259
- output = model.generate(
260
- **inputs,
261
- max_new_tokens=4096,
262
- do_sample=False,
263
- pad_token_id=processor_or_tokenizer.tokenizer.pad_token_id,
264
- )
265
-
266
- input_len = inputs["input_ids"].shape[-1]
267
- res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
268
- all_results.append(f"--- Page/Image {i+1} ---\n{res}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  except Exception as e:
271
  all_results.append(f"--- Page/Image {i+1} ---\nПомилка: {str(e)}")
 
186
  else:
187
  return "Будь ласка, завантажте зображення або PDF файл."
188
 
189
+ def _is_cuda_bf16_error(err):
190
+ msg = str(err)
191
+ return "CUBLAS_STATUS_INVALID_VALUE" in msg and "CUDA_R_16BF" in msg
192
+
193
  try:
194
  model, processor_or_tokenizer = manager.get_model(model_choice)
195
  # Move to GPU only inside the decorated function
196
  print(f"Moving {model_choice} to GPU...")
197
  model.to(device="cuda", dtype=torch.float16)
198
+ run_device = "cuda"
199
  except Exception as e:
200
  return f"Помилка завантаження чи переміщення моделі: {str(e)}\nЯкщо це MedGemma, переконайтеся, що ви надали HF_TOKEN."
201
 
 
205
  all_results = []
206
 
207
  try:
208
+ def _autocast_for(device_str):
209
+ if device_str == "cuda" and torch.cuda.is_available():
210
+ return torch.autocast(device_type="cuda", dtype=torch.float16)
211
+ return contextlib.nullcontext()
 
212
 
213
  for i, img in enumerate(images_to_process):
214
  img = img.convert("RGB")
 
220
  tmp_path = tmp.name
221
 
222
  try:
223
+ try:
224
+ with torch.no_grad(), _autocast_for(run_device):
225
+ res = model.infer(
226
+ processor_or_tokenizer,
227
+ prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
228
+ image_file=tmp_path,
229
+ output_path=output_dir,
230
+ base_size=1024,
231
+ image_size=768,
232
+ crop_mode=True,
233
+ eval_mode=True
234
+ )
235
+ all_results.append(f"--- Page/Image {i+1} ---\n{res}")
236
+ except Exception as e:
237
+ if run_device == "cuda" and _is_cuda_bf16_error(e):
238
+ print("CUDA BF16 error detected, retrying on CPU...")
239
+ model.to(device="cpu", dtype=torch.float32)
240
+ run_device = "cpu"
241
+ with torch.no_grad(), _autocast_for(run_device):
242
+ res = model.infer(
243
+ processor_or_tokenizer,
244
+ prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
245
+ image_file=tmp_path,
246
+ output_path=output_dir,
247
+ base_size=1024,
248
+ image_size=768,
249
+ crop_mode=True,
250
+ eval_mode=True
251
+ )
252
+ all_results.append(f"--- Page/Image {i+1} ---\n{res}")
253
+ else:
254
+ raise
255
  finally:
256
  if os.path.exists(tmp_path):
257
  os.remove(tmp_path)
 
274
  tokenize=True,
275
  return_dict=True,
276
  return_tensors="pt"
277
+ ).to(run_device)
278
 
279
  if "attention_mask" not in inputs:
280
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long)
281
 
282
+ try:
283
+ with torch.no_grad(), _autocast_for(run_device):
284
+ output = model.generate(
285
+ **inputs,
286
+ max_new_tokens=4096,
287
+ do_sample=False,
288
+ pad_token_id=processor_or_tokenizer.tokenizer.pad_token_id,
289
+ )
290
+
291
+ input_len = inputs["input_ids"].shape[-1]
292
+ res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
293
+ all_results.append(f"--- Page/Image {i+1} ---\n{res}")
294
+ except Exception as e:
295
+ if run_device == "cuda" and _is_cuda_bf16_error(e):
296
+ print("CUDA BF16 error detected, retrying on CPU...")
297
+ model.to(device="cpu", dtype=torch.float32)
298
+ run_device = "cpu"
299
+ inputs = inputs.to(run_device)
300
+ with torch.no_grad(), _autocast_for(run_device):
301
+ output = model.generate(
302
+ **inputs,
303
+ max_new_tokens=4096,
304
+ do_sample=False,
305
+ pad_token_id=processor_or_tokenizer.tokenizer.pad_token_id,
306
+ )
307
+
308
+ input_len = inputs["input_ids"].shape[-1]
309
+ res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
310
+ all_results.append(f"--- Page/Image {i+1} ---\n{res}")
311
+ else:
312
+ raise
313
 
314
  except Exception as e:
315
  all_results.append(f"--- Page/Image {i+1} ---\nПомилка: {str(e)}")