Local_OCR_Demo / app.py
DocUA's picture
fix: Add a fallback definition for `is_torch_fx_available` in `transformers.utils.import_utils`.
3537ca8
import gradio as gr
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
import torch
import os
from PIL import Image
import tempfile
import datetime
import fitz # PyMuPDF
import io
import gc
import warnings
try:
from transformers.models.llama import modeling_llama as _modeling_llama
if not hasattr(_modeling_llama, "LlamaFlashAttention2") and hasattr(_modeling_llama, "LlamaAttention"):
_modeling_llama.LlamaFlashAttention2 = _modeling_llama.LlamaAttention
except Exception:
pass
try:
from transformers.utils import import_utils as _import_utils
if not hasattr(_import_utils, "is_torch_fx_available"):
def is_torch_fx_available():
try:
import torch as _torch
return hasattr(_torch, "fx")
except Exception:
return False
_import_utils.is_torch_fx_available = is_torch_fx_available
except Exception:
pass
# Suppress annoying warnings
warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()")
warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported")
warnings.filterwarnings("ignore", message="The following generation flags are not valid and may be ignored")
warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set")
warnings.filterwarnings("ignore", message="You are using a model of type .* to instantiate a model of type .*")
# --- Configuration ---
DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2'
MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it'
# --- Device Setup ---
if torch.backends.mps.is_available():
print("Using MPS device")
device = "mps"
# Patch for DeepSeek custom code which uses .cuda()
torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps")
torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps")
dtype = torch.float16
# Patch to avoid BFloat16 vs Float16 mismatch in custom modeling code on MPS
torch.bfloat16 = torch.float16
else:
device = "cpu"
dtype = torch.float32
class ModelManager:
def __init__(self):
self.current_model_name = None
self.model = None
self.processor = None
self.tokenizer = None
def unload_current_model(self):
if self.model is not None:
print(f"Unloading {self.current_model_name}...")
del self.model
del self.processor
del self.tokenizer
self.model = None
self.processor = None
self.tokenizer = None
self.current_model_name = None
if torch.backends.mps.is_available():
torch.mps.empty_cache()
gc.collect()
def load_model(self, model_name):
if self.current_model_name == model_name:
return self.model, self.processor or self.tokenizer
self.unload_current_model()
print(f"Loading {model_name}...")
if model_name == DEEPSEEK_MODEL:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
use_safetensors=True,
torch_dtype=dtype
)
self.model = self.model.to(device=device, dtype=dtype)
self.model.eval()
self.current_model_name = model_name
return self.model, self.tokenizer
elif model_name == MEDGEMMA_MODEL:
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=dtype if device == "mps" else torch.float32,
device_map="auto" if device != "mps" else None
)
if device == "mps":
self.model = self.model.to("mps")
self.model.eval()
# Ensure pad_token_id is set
if self.processor.tokenizer.pad_token_id is None:
self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id
self.current_model_name = model_name
return self.model, self.processor
manager = ModelManager()
def pdf_to_images(pdf_path):
doc = fitz.open(pdf_path)
images = []
for page in doc:
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
img_data = pix.tobytes("png")
img = Image.open(io.BytesIO(img_data))
images.append(img)
doc.close()
return images
def run_ocr(input_image, input_file, model_choice, custom_prompt):
images_to_process = []
if input_file is not None:
# Compatibility with different Gradio versions (object with .name vs string path)
file_path = input_file.name if hasattr(input_file, 'name') else input_file
if file_path.lower().endswith(".pdf"):
try:
images_to_process = pdf_to_images(file_path)
except Exception as e:
return f"Помилка читання PDF: {str(e)}"
else:
try:
images_to_process = [Image.open(file_path)]
except Exception as e:
return f"Помилка завантаження файлу: {str(e)}"
elif input_image is not None:
images_to_process = [input_image]
else:
return "Будь ласка, завантажте зображення або PDF файл."
model, processor_or_tokenizer = manager.load_model(model_choice)
output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)
all_results = []
for i, img in enumerate(images_to_process):
img = img.convert("RGB")
try:
print(f"Processing page/image {i+1} with {model_choice}...")
if model_choice == DEEPSEEK_MODEL:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
img.save(tmp.name)
tmp_path = tmp.name
try:
with torch.no_grad():
res = model.infer(
processor_or_tokenizer,
prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
image_file=tmp_path,
output_path=output_dir,
base_size=1024,
image_size=768,
crop_mode=True,
eval_mode=True
)
all_results.append(f"--- Page/Image {i+1} ---\n{res}")
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
elif model_choice == MEDGEMMA_MODEL:
prompt_text = custom_prompt if custom_prompt else "extract all text from image"
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt_text}
]
}
]
inputs = processor_or_tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=4096, do_sample=False)
input_len = inputs["input_ids"].shape[-1]
res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
all_results.append(f"--- Page/Image {i+1} ---\n{res}")
except Exception as e:
all_results.append(f"--- Page/Image {i+1} ---\nПомилка: {str(e)}")
if torch.backends.mps.is_available():
torch.mps.empty_cache()
return "\n\n".join(all_results)
def save_result_to_file(text):
if not text or text.startswith("Будь ласка") or text.startswith("Помилка"):
return None
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"ocr_result_{timestamp}.txt"
os.makedirs("outputs", exist_ok=True)
filepath = os.path.abspath(os.path.join("outputs", filename))
with open(filepath, "w", encoding="utf-8") as f:
f.write(text)
return filepath
custom_css = """
.header { text-align: center; margin-bottom: 30px; }
.header h1 { font-size: 2.5rem; }
.footer { text-align: center; margin-top: 50px; font-size: 0.9rem; color: #718096; }
"""
with gr.Blocks(title="OCR Comparison: DeepSeek vs MedGemma", css=custom_css) as demo:
with gr.Column():
gr.Markdown("# 🔍 OCR & Medical Document Analysis", elem_classes="header")
gr.Markdown("Порівняння DeepSeek-OCR-2 та MedGemma-1.5-4B", elem_classes="header")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Зображення"):
input_img = gr.Image(type="pil", label="Перетягніть зображення")
with gr.Tab("PDF / Файли"):
input_file = gr.File(label="Завантажте PDF або інший файл")
model_selector = gr.Dropdown(
choices=[DEEPSEEK_MODEL, MEDGEMMA_MODEL],
value=DEEPSEEK_MODEL,
label="Оберіть модель"
)
with gr.Accordion("Налаштування", open=False):
prompt_input = gr.Textbox(
value="",
label="Користувацький промпт (залиште порожнім для дефолтного)",
placeholder="Наприклад: Extract all text from image"
)
with gr.Row():
clear_btn = gr.Button("Очистити", variant="secondary")
ocr_btn = gr.Button("Запустити аналіз", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Результат",
lines=20
)
with gr.Row():
save_btn = gr.Button("Зберегти у файл 💾")
download_file = gr.File(label="Завантажити результат")
gr.Markdown("---")
gr.Examples(
examples=[["sample_test.png", None, DEEPSEEK_MODEL, ""]],
inputs=[input_img, input_file, model_selector, prompt_input]
)
# Event handlers
ocr_btn.click(
fn=run_ocr,
inputs=[input_img, input_file, model_selector, prompt_input],
outputs=output_text
)
save_btn.click(
fn=save_result_to_file,
inputs=output_text,
outputs=download_file
)
def clear_all():
return None, None, "", ""
clear_btn.click(
fn=clear_all,
inputs=None,
outputs=[input_img, input_file, output_text, prompt_input]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False)