Local_OCR_Demo / app_hf.py
DocUA's picture
feat: Implement CUDA BF16 error handling with automatic fallback to CPU for model inference and generation.
4f43939
import warnings
# Try to import spaces, if not available (local run), create a dummy decorator
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(func):
return func
import gradio as gr
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
from transformers import logging as hf_logging
import torch
import os
from PIL import Image
import tempfile
import datetime
import fitz # PyMuPDF
import io
import gc
import threading
import contextlib
try:
from transformers.models.llama import modeling_llama as _modeling_llama
if not hasattr(_modeling_llama, "LlamaFlashAttention2") and hasattr(_modeling_llama, "LlamaAttention"):
_modeling_llama.LlamaFlashAttention2 = _modeling_llama.LlamaAttention
except Exception:
pass
try:
from transformers.utils import import_utils as _import_utils
if not hasattr(_import_utils, "is_torch_fx_available"):
def is_torch_fx_available():
try:
import torch as _torch
return hasattr(_torch, "fx")
except Exception:
return False
_import_utils.is_torch_fx_available = is_torch_fx_available
except Exception:
pass
# Suppress annoying warnings
warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()")
warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported")
warnings.filterwarnings("ignore", message="The following generation flags are not valid and may be ignored")
warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set")
warnings.filterwarnings("ignore", message="You are using a model of type .* to instantiate a model of type .*")
hf_logging.set_verbosity_error()
# --- Configuration ---
DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2'
MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it'
_default_hf_home = "/data/.huggingface" if os.path.isdir("/data") else os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
os.environ.setdefault("HF_HOME", _default_hf_home)
_hf_cache_dir = os.environ.get("HF_HUB_CACHE") or os.path.join(os.environ["HF_HOME"], "hub")
os.environ.setdefault("HF_HUB_CACHE", _hf_cache_dir)
os.environ.setdefault("TRANSFORMERS_CACHE", _hf_cache_dir)
def _warmup_hf_cache():
try:
from huggingface_hub import snapshot_download
except Exception as e:
print(f"Warmup cache failed: {e}")
return
for _repo_id in (DEEPSEEK_MODEL, MEDGEMMA_MODEL):
try:
snapshot_download(repo_id=_repo_id, cache_dir=_hf_cache_dir)
except Exception as e:
print(f"Warmup cache failed for {_repo_id}: {e}")
threading.Thread(target=_warmup_hf_cache, daemon=True).start()
# --- Device Setup ---
# For HF Spaces with ZeroGPU, we'll use cuda if available
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
def _configure_cuda_precision():
if not torch.cuda.is_available():
return
# Avoid BF16 on GPUs that don't support it (sm80+).
try:
major, minor = torch.cuda.get_device_capability()
if (major, minor) < (8, 0):
torch.backends.cuda.matmul.allow_bf16 = False
except Exception:
torch.backends.cuda.matmul.allow_bf16 = False
_configure_cuda_precision()
class ModelManager:
def __init__(self):
self.models = {}
self.processors = {}
def get_model(self, model_name):
if model_name not in self.models:
print(f"Loading {model_name} to CPU...")
if model_name == DEEPSEEK_MODEL:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=_hf_cache_dir)
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
use_safetensors=True,
attn_implementation="eager",
cache_dir=_hf_cache_dir,
torch_dtype=dtype
)
if hasattr(model, "config") and getattr(model.config, "pad_token_id", None) is None and getattr(tokenizer, "pad_token_id", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
if hasattr(model, "generation_config"):
if getattr(model.generation_config, "pad_token_id", None) is None and getattr(tokenizer, "pad_token_id", None) is not None:
model.generation_config.pad_token_id = tokenizer.pad_token_id
if getattr(model.generation_config, "eos_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.eval()
self.models[model_name] = model
self.processors[model_name] = tokenizer
elif model_name == MEDGEMMA_MODEL:
processor = AutoProcessor.from_pretrained(model_name, cache_dir=_hf_cache_dir)
model = AutoModelForImageTextToText.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir=_hf_cache_dir,
torch_dtype=dtype
)
model.eval()
# Ensure pad_token_id is set
if processor.tokenizer.pad_token_id is None:
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
if hasattr(model, "generation_config"):
if getattr(model.generation_config, "pad_token_id", None) is None:
model.generation_config.pad_token_id = processor.tokenizer.pad_token_id
if getattr(model.generation_config, "eos_token_id", None) is None:
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id
self.models[model_name] = model
self.processors[model_name] = processor
return self.models[model_name], self.processors[model_name]
manager = ModelManager()
def pdf_to_images(pdf_path):
doc = fitz.open(pdf_path)
images = []
for page in doc:
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
img_data = pix.tobytes("png")
img = Image.open(io.BytesIO(img_data))
images.append(img)
doc.close()
return images
@spaces.GPU(duration=120)
def run_ocr(input_image, input_file, model_choice, custom_prompt):
images_to_process = []
if input_file is not None:
# Compatibility with different Gradio versions (object with .name vs string path)
file_path = input_file.name if hasattr(input_file, 'name') else input_file
if file_path.lower().endswith(".pdf"):
try:
images_to_process = pdf_to_images(file_path)
except Exception as e:
return f"Помилка читання PDF: {str(e)}"
else:
try:
images_to_process = [Image.open(file_path)]
except Exception as e:
return f"Помилка завантаження файлу: {str(e)}"
elif input_image is not None:
images_to_process = [input_image]
else:
return "Будь ласка, завантажте зображення або PDF файл."
def _is_cuda_bf16_error(err):
msg = str(err)
return "CUBLAS_STATUS_INVALID_VALUE" in msg and "CUDA_R_16BF" in msg
try:
model, processor_or_tokenizer = manager.get_model(model_choice)
# Move to GPU only inside the decorated function
print(f"Moving {model_choice} to GPU...")
model.to(device="cuda", dtype=torch.float16)
run_device = "cuda"
except Exception as e:
return f"Помилка завантаження чи переміщення моделі: {str(e)}\nЯкщо це MedGemma, переконайтеся, що ви надали HF_TOKEN."
output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)
all_results = []
try:
def _autocast_for(device_str):
if device_str == "cuda" and torch.cuda.is_available():
return torch.autocast(device_type="cuda", dtype=torch.float16)
return contextlib.nullcontext()
for i, img in enumerate(images_to_process):
img = img.convert("RGB")
try:
print(f"Processing page/image {i+1} with {model_choice}...")
if model_choice == DEEPSEEK_MODEL:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
img.save(tmp.name)
tmp_path = tmp.name
try:
try:
with torch.no_grad(), _autocast_for(run_device):
res = model.infer(
processor_or_tokenizer,
prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
image_file=tmp_path,
output_path=output_dir,
base_size=1024,
image_size=768,
crop_mode=True,
eval_mode=True
)
all_results.append(f"--- Page/Image {i+1} ---\n{res}")
except Exception as e:
if run_device == "cuda" and _is_cuda_bf16_error(e):
print("CUDA BF16 error detected, retrying on CPU...")
model.to(device="cpu", dtype=torch.float32)
run_device = "cpu"
with torch.no_grad(), _autocast_for(run_device):
res = model.infer(
processor_or_tokenizer,
prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
image_file=tmp_path,
output_path=output_dir,
base_size=1024,
image_size=768,
crop_mode=True,
eval_mode=True
)
all_results.append(f"--- Page/Image {i+1} ---\n{res}")
else:
raise
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
elif model_choice == MEDGEMMA_MODEL:
prompt_text = custom_prompt if custom_prompt else "extract all text from image"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt_text}
]
}
]
inputs = processor_or_tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(run_device)
if "attention_mask" not in inputs:
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long)
try:
with torch.no_grad(), _autocast_for(run_device):
output = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=False,
pad_token_id=processor_or_tokenizer.tokenizer.pad_token_id,
)
input_len = inputs["input_ids"].shape[-1]
res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
all_results.append(f"--- Page/Image {i+1} ---\n{res}")
except Exception as e:
if run_device == "cuda" and _is_cuda_bf16_error(e):
print("CUDA BF16 error detected, retrying on CPU...")
model.to(device="cpu", dtype=torch.float32)
run_device = "cpu"
inputs = inputs.to(run_device)
with torch.no_grad(), _autocast_for(run_device):
output = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=False,
pad_token_id=processor_or_tokenizer.tokenizer.pad_token_id,
)
input_len = inputs["input_ids"].shape[-1]
res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
all_results.append(f"--- Page/Image {i+1} ---\n{res}")
else:
raise
except Exception as e:
all_results.append(f"--- Page/Image {i+1} ---\nПомилка: {str(e)}")
finally:
# Move back to CPU and clean up to free ZeroGPU resources
print(f"Moving {model_choice} back to CPU...")
model.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return "\n\n".join(all_results)
def save_result_to_file(text):
if not text or text.startswith("Будь ласка") or text.startswith("Помилка"):
return None
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"ocr_result_{timestamp}.txt"
os.makedirs("outputs", exist_ok=True)
filepath = os.path.abspath(os.path.join("outputs", filename))
with open(filepath, "w", encoding="utf-8") as f:
f.write(text)
return filepath
custom_css = """
.header { text-align: center; margin-bottom: 30px; }
.header h1 { font-size: 2.5rem; }
.footer { text-align: center; margin-top: 50px; font-size: 0.9rem; color: #718096; }
"""
with gr.Blocks(title="OCR Comparison: DeepSeek vs MedGemma", css=custom_css) as demo:
with gr.Column():
gr.Markdown("# 🔍 OCR & Medical Document Analysis", elem_classes="header")
gr.Markdown("Порівняння DeepSeek-OCR-2 та MedGemma-1.5-4B (HuggingFace ZeroGPU Edition)", elem_classes="header")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Зображення"):
input_img = gr.Image(type="pil", label="Перетягніть зображення")
with gr.Tab("PDF / Файли"):
input_file = gr.File(label="Завантажте PDF або інший файл")
model_selector = gr.Dropdown(
choices=[DEEPSEEK_MODEL, MEDGEMMA_MODEL],
value=DEEPSEEK_MODEL,
label="Оберіть модель"
)
with gr.Accordion("Налаштування", open=False):
prompt_input = gr.Textbox(
value="",
label="Користувацький промпт (залиште порожнім для дефолтного)",
placeholder="Наприклад: Extract all text from image"
)
with gr.Row():
clear_btn = gr.Button("Очистити", variant="secondary")
ocr_btn = gr.Button("Запустити аналіз", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Результат",
lines=20
)
with gr.Row():
save_btn = gr.Button("Зберегти у файл 💾")
download_file = gr.File(label="Завантажити результат")
gr.Markdown("---")
gr.Markdown("### Як використовувати:\n1. Завантажте зображення або PDF.\n2. Виберіть модель.\n3. Натисніть 'Запустити аналіз'.\n*Примітка: MedGemma потребує HF_TOKEN з доступом до моделі.*")
# Event handlers
ocr_btn.click(
fn=run_ocr,
inputs=[input_img, input_file, model_selector, prompt_input],
outputs=output_text
)
save_btn.click(
fn=save_result_to_file,
inputs=output_text,
outputs=download_file
)
def clear_all():
return None, None, "", ""
clear_btn.click(
fn=clear_all,
inputs=None,
outputs=[input_img, input_file, output_text, prompt_input]
)
if __name__ == "__main__":
demo.queue().launch()