Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText | |
| import torch | |
| import os | |
| from PIL import Image | |
| import tempfile | |
| import datetime | |
| import fitz # PyMuPDF | |
| import io | |
| import gc | |
| import warnings | |
| try: | |
| from transformers.models.llama import modeling_llama as _modeling_llama | |
| if not hasattr(_modeling_llama, "LlamaFlashAttention2") and hasattr(_modeling_llama, "LlamaAttention"): | |
| _modeling_llama.LlamaFlashAttention2 = _modeling_llama.LlamaAttention | |
| except Exception: | |
| pass | |
| try: | |
| from transformers.utils import import_utils as _import_utils | |
| if not hasattr(_import_utils, "is_torch_fx_available"): | |
| def is_torch_fx_available(): | |
| try: | |
| import torch as _torch | |
| return hasattr(_torch, "fx") | |
| except Exception: | |
| return False | |
| _import_utils.is_torch_fx_available = is_torch_fx_available | |
| except Exception: | |
| pass | |
| # Suppress annoying warnings | |
| warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()") | |
| warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported") | |
| warnings.filterwarnings("ignore", message="The following generation flags are not valid and may be ignored") | |
| warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set") | |
| warnings.filterwarnings("ignore", message="You are using a model of type .* to instantiate a model of type .*") | |
| # --- Configuration --- | |
| DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2' | |
| MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it' | |
| # --- Device Setup --- | |
| if torch.backends.mps.is_available(): | |
| print("Using MPS device") | |
| device = "mps" | |
| # Patch for DeepSeek custom code which uses .cuda() | |
| torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps") | |
| torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps") | |
| dtype = torch.float16 | |
| # Patch to avoid BFloat16 vs Float16 mismatch in custom modeling code on MPS | |
| torch.bfloat16 = torch.float16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| class ModelManager: | |
| def __init__(self): | |
| self.current_model_name = None | |
| self.model = None | |
| self.processor = None | |
| self.tokenizer = None | |
| def unload_current_model(self): | |
| if self.model is not None: | |
| print(f"Unloading {self.current_model_name}...") | |
| del self.model | |
| del self.processor | |
| del self.tokenizer | |
| self.model = None | |
| self.processor = None | |
| self.tokenizer = None | |
| self.current_model_name = None | |
| if torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| gc.collect() | |
| def load_model(self, model_name): | |
| if self.current_model_name == model_name: | |
| return self.model, self.processor or self.tokenizer | |
| self.unload_current_model() | |
| print(f"Loading {model_name}...") | |
| if model_name == DEEPSEEK_MODEL: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| self.model = AutoModel.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| torch_dtype=dtype | |
| ) | |
| self.model = self.model.to(device=device, dtype=dtype) | |
| self.model.eval() | |
| self.current_model_name = model_name | |
| return self.model, self.tokenizer | |
| elif model_name == MEDGEMMA_MODEL: | |
| self.processor = AutoProcessor.from_pretrained(model_name) | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| torch_dtype=dtype if device == "mps" else torch.float32, | |
| device_map="auto" if device != "mps" else None | |
| ) | |
| if device == "mps": | |
| self.model = self.model.to("mps") | |
| self.model.eval() | |
| # Ensure pad_token_id is set | |
| if self.processor.tokenizer.pad_token_id is None: | |
| self.processor.tokenizer.pad_token_id = self.processor.tokenizer.eos_token_id | |
| self.current_model_name = model_name | |
| return self.model, self.processor | |
| manager = ModelManager() | |
| def pdf_to_images(pdf_path): | |
| doc = fitz.open(pdf_path) | |
| images = [] | |
| for page in doc: | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) | |
| img_data = pix.tobytes("png") | |
| img = Image.open(io.BytesIO(img_data)) | |
| images.append(img) | |
| doc.close() | |
| return images | |
| def run_ocr(input_image, input_file, model_choice, custom_prompt): | |
| images_to_process = [] | |
| if input_file is not None: | |
| # Compatibility with different Gradio versions (object with .name vs string path) | |
| file_path = input_file.name if hasattr(input_file, 'name') else input_file | |
| if file_path.lower().endswith(".pdf"): | |
| try: | |
| images_to_process = pdf_to_images(file_path) | |
| except Exception as e: | |
| return f"Помилка читання PDF: {str(e)}" | |
| else: | |
| try: | |
| images_to_process = [Image.open(file_path)] | |
| except Exception as e: | |
| return f"Помилка завантаження файлу: {str(e)}" | |
| elif input_image is not None: | |
| images_to_process = [input_image] | |
| else: | |
| return "Будь ласка, завантажте зображення або PDF файл." | |
| model, processor_or_tokenizer = manager.load_model(model_choice) | |
| output_dir = 'outputs' | |
| os.makedirs(output_dir, exist_ok=True) | |
| all_results = [] | |
| for i, img in enumerate(images_to_process): | |
| img = img.convert("RGB") | |
| try: | |
| print(f"Processing page/image {i+1} with {model_choice}...") | |
| if model_choice == DEEPSEEK_MODEL: | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| img.save(tmp.name) | |
| tmp_path = tmp.name | |
| try: | |
| with torch.no_grad(): | |
| res = model.infer( | |
| processor_or_tokenizer, | |
| prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ", | |
| image_file=tmp_path, | |
| output_path=output_dir, | |
| base_size=1024, | |
| image_size=768, | |
| crop_mode=True, | |
| eval_mode=True | |
| ) | |
| all_results.append(f"--- Page/Image {i+1} ---\n{res}") | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| elif model_choice == MEDGEMMA_MODEL: | |
| prompt_text = custom_prompt if custom_prompt else "extract all text from image" | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": prompt_text} | |
| ] | |
| } | |
| ] | |
| inputs = processor_or_tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, max_new_tokens=4096, do_sample=False) | |
| input_len = inputs["input_ids"].shape[-1] | |
| res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True) | |
| all_results.append(f"--- Page/Image {i+1} ---\n{res}") | |
| except Exception as e: | |
| all_results.append(f"--- Page/Image {i+1} ---\nПомилка: {str(e)}") | |
| if torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| return "\n\n".join(all_results) | |
| def save_result_to_file(text): | |
| if not text or text.startswith("Будь ласка") or text.startswith("Помилка"): | |
| return None | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"ocr_result_{timestamp}.txt" | |
| os.makedirs("outputs", exist_ok=True) | |
| filepath = os.path.abspath(os.path.join("outputs", filename)) | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| return filepath | |
| custom_css = """ | |
| .header { text-align: center; margin-bottom: 30px; } | |
| .header h1 { font-size: 2.5rem; } | |
| .footer { text-align: center; margin-top: 50px; font-size: 0.9rem; color: #718096; } | |
| """ | |
| with gr.Blocks(title="OCR Comparison: DeepSeek vs MedGemma", css=custom_css) as demo: | |
| with gr.Column(): | |
| gr.Markdown("# 🔍 OCR & Medical Document Analysis", elem_classes="header") | |
| gr.Markdown("Порівняння DeepSeek-OCR-2 та MedGemma-1.5-4B", elem_classes="header") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Tab("Зображення"): | |
| input_img = gr.Image(type="pil", label="Перетягніть зображення") | |
| with gr.Tab("PDF / Файли"): | |
| input_file = gr.File(label="Завантажте PDF або інший файл") | |
| model_selector = gr.Dropdown( | |
| choices=[DEEPSEEK_MODEL, MEDGEMMA_MODEL], | |
| value=DEEPSEEK_MODEL, | |
| label="Оберіть модель" | |
| ) | |
| with gr.Accordion("Налаштування", open=False): | |
| prompt_input = gr.Textbox( | |
| value="", | |
| label="Користувацький промпт (залиште порожнім для дефолтного)", | |
| placeholder="Наприклад: Extract all text from image" | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Очистити", variant="secondary") | |
| ocr_btn = gr.Button("Запустити аналіз", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="Результат", | |
| lines=20 | |
| ) | |
| with gr.Row(): | |
| save_btn = gr.Button("Зберегти у файл 💾") | |
| download_file = gr.File(label="Завантажити результат") | |
| gr.Markdown("---") | |
| gr.Examples( | |
| examples=[["sample_test.png", None, DEEPSEEK_MODEL, ""]], | |
| inputs=[input_img, input_file, model_selector, prompt_input] | |
| ) | |
| # Event handlers | |
| ocr_btn.click( | |
| fn=run_ocr, | |
| inputs=[input_img, input_file, model_selector, prompt_input], | |
| outputs=output_text | |
| ) | |
| save_btn.click( | |
| fn=save_result_to_file, | |
| inputs=output_text, | |
| outputs=download_file | |
| ) | |
| def clear_all(): | |
| return None, None, "", "" | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=None, | |
| outputs=[input_img, input_file, output_text, prompt_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", share=False) | |