import gradio as gr from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText import torch import os from PIL import Image import tempfile import datetime import fitz # PyMuPDF import io import gc import warnings try: from transformers.models.llama import modeling_llama as _modeling_llama if not hasattr(_modeling_llama, "LlamaFlashAttention2") and hasattr(_modeling_llama, "LlamaAttention"): _modeling_llama.LlamaFlashAttention2 = _modeling_llama.LlamaAttention except Exception: pass try: from transformers.utils import import_utils as _import_utils if not hasattr(_import_utils, "is_torch_fx_available"): def is_torch_fx_available(): try: import torch as _torch return hasattr(_torch, "fx") except Exception: return False _import_utils.is_torch_fx_available = is_torch_fx_available except Exception: pass # Suppress annoying warnings warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()") warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported") warnings.filterwarnings("ignore", message="The following generation flags are not valid and may be ignored") warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set") warnings.filterwarnings("ignore", message="You are using a model of type .* to instantiate a model of type .*") # --- Configuration --- DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2' MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it' # --- Device Setup --- if torch.backends.mps.is_available(): print("Using MPS device") device = "mps" # Patch for DeepSeek custom code which uses .cuda() torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps") torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps") dtype = torch.float16 # Patch to avoid BFloat16 vs Float16 mismatch in custom modeling code on MPS torch.bfloat16 = torch.float16 else: device = "cpu" dtype = torch.float32 class ModelManager: def __init__(self): self.current_model_name = None self.model = None self.processor = None self.tokenizer = None def unload_current_model(self): if self.model is not None: print(f"Unloading {self.current_model_name}...") del self.model del self.processor del self.tokenizer self.model = None self.processor = None self.tokenizer = None self.current_model_name = None if torch.backends.mps.is_available(): torch.mps.empty_cache() gc.collect() def load_model(self, model_name): if self.current_model_name == model_name: return self.model, self.processor or self.tokenizer self.unload_current_model() print(f"Loading {model_name}...") if model_name == DEEPSEEK_MODEL: self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.model = AutoModel.from_pretrained( model_name, trust_remote_code=True, use_safetensors=True, torch_dtype=dtype ) self.model = self.model.to(device=device, dtype=dtype) self.model.eval() self.current_model_name = model_name return self.model, self.tokenizer elif model_name == MEDGEMMA_MODEL: self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForImageTextToText.from_pretrained( model_name, trust_remote_code=True, torch_dtype=dtype if device == "mps" else torch.float32, device_map="auto" if device != "mps" else None ) if device == "mps": self.model = self.model.to("mps") self.model.eval() # Ensure pad_token_id is set if self.processor.tokenizer.pad_token_id is None: self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id self.current_model_name = model_name return self.model, self.processor manager = ModelManager() def pdf_to_images(pdf_path): doc = fitz.open(pdf_path) images = [] for page in doc: pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) img_data = pix.tobytes("png") img = Image.open(io.BytesIO(img_data)) images.append(img) doc.close() return images def run_ocr(input_image, input_file, model_choice, custom_prompt): images_to_process = [] if input_file is not None: # Compatibility with different Gradio versions (object with .name vs string path) file_path = input_file.name if hasattr(input_file, 'name') else input_file if file_path.lower().endswith(".pdf"): try: images_to_process = pdf_to_images(file_path) except Exception as e: return f"Помилка читання PDF: {str(e)}" else: try: images_to_process = [Image.open(file_path)] except Exception as e: return f"Помилка завантаження файлу: {str(e)}" elif input_image is not None: images_to_process = [input_image] else: return "Будь ласка, завантажте зображення або PDF файл." model, processor_or_tokenizer = manager.load_model(model_choice) output_dir = 'outputs' os.makedirs(output_dir, exist_ok=True) all_results = [] for i, img in enumerate(images_to_process): img = img.convert("RGB") try: print(f"Processing page/image {i+1} with {model_choice}...") if model_choice == DEEPSEEK_MODEL: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: img.save(tmp.name) tmp_path = tmp.name try: with torch.no_grad(): res = model.infer( processor_or_tokenizer, prompt=custom_prompt if custom_prompt else "\nFree OCR. ", image_file=tmp_path, output_path=output_dir, base_size=1024, image_size=768, crop_mode=True, eval_mode=True ) all_results.append(f"--- Page/Image {i+1} ---\n{res}") finally: if os.path.exists(tmp_path): os.remove(tmp_path) elif model_choice == MEDGEMMA_MODEL: prompt_text = custom_prompt if custom_prompt else "extract all text from image" messages = [ { "role": "user", "content": [ {"type": "image", "image": img}, {"type": "text", "text": prompt_text} ] } ] inputs = processor_or_tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device) with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=4096, do_sample=False) input_len = inputs["input_ids"].shape[-1] res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True) all_results.append(f"--- Page/Image {i+1} ---\n{res}") except Exception as e: all_results.append(f"--- Page/Image {i+1} ---\nПомилка: {str(e)}") if torch.backends.mps.is_available(): torch.mps.empty_cache() return "\n\n".join(all_results) def save_result_to_file(text): if not text or text.startswith("Будь ласка") or text.startswith("Помилка"): return None timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"ocr_result_{timestamp}.txt" os.makedirs("outputs", exist_ok=True) filepath = os.path.abspath(os.path.join("outputs", filename)) with open(filepath, "w", encoding="utf-8") as f: f.write(text) return filepath custom_css = """ .header { text-align: center; margin-bottom: 30px; } .header h1 { font-size: 2.5rem; } .footer { text-align: center; margin-top: 50px; font-size: 0.9rem; color: #718096; } """ with gr.Blocks(title="OCR Comparison: DeepSeek vs MedGemma", css=custom_css) as demo: with gr.Column(): gr.Markdown("# 🔍 OCR & Medical Document Analysis", elem_classes="header") gr.Markdown("Порівняння DeepSeek-OCR-2 та MedGemma-1.5-4B", elem_classes="header") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Зображення"): input_img = gr.Image(type="pil", label="Перетягніть зображення") with gr.Tab("PDF / Файли"): input_file = gr.File(label="Завантажте PDF або інший файл") model_selector = gr.Dropdown( choices=[DEEPSEEK_MODEL, MEDGEMMA_MODEL], value=DEEPSEEK_MODEL, label="Оберіть модель" ) with gr.Accordion("Налаштування", open=False): prompt_input = gr.Textbox( value="", label="Користувацький промпт (залиште порожнім для дефолтного)", placeholder="Наприклад: Extract all text from image" ) with gr.Row(): clear_btn = gr.Button("Очистити", variant="secondary") ocr_btn = gr.Button("Запустити аналіз", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox( label="Результат", lines=20 ) with gr.Row(): save_btn = gr.Button("Зберегти у файл 💾") download_file = gr.File(label="Завантажити результат") gr.Markdown("---") gr.Examples( examples=[["sample_test.png", None, DEEPSEEK_MODEL, ""]], inputs=[input_img, input_file, model_selector, prompt_input] ) # Event handlers ocr_btn.click( fn=run_ocr, inputs=[input_img, input_file, model_selector, prompt_input], outputs=output_text ) save_btn.click( fn=save_result_to_file, inputs=output_text, outputs=download_file ) def clear_all(): return None, None, "", "" clear_btn.click( fn=clear_all, inputs=None, outputs=[input_img, input_file, output_text, prompt_input] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", share=False)