Spaces:
Running
on
Zero
Running
on
Zero
| import warnings | |
| # Try to import spaces, if not available (local run), create a dummy decorator | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| import gradio as gr | |
| from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText | |
| from transformers import logging as hf_logging | |
| import torch | |
| import os | |
| from PIL import Image | |
| import tempfile | |
| import datetime | |
| import fitz # PyMuPDF | |
| import io | |
| import gc | |
| import threading | |
| import contextlib | |
| try: | |
| from transformers.models.llama import modeling_llama as _modeling_llama | |
| if not hasattr(_modeling_llama, "LlamaFlashAttention2") and hasattr(_modeling_llama, "LlamaAttention"): | |
| _modeling_llama.LlamaFlashAttention2 = _modeling_llama.LlamaAttention | |
| except Exception: | |
| pass | |
| try: | |
| from transformers.utils import import_utils as _import_utils | |
| if not hasattr(_import_utils, "is_torch_fx_available"): | |
| def is_torch_fx_available(): | |
| try: | |
| import torch as _torch | |
| return hasattr(_torch, "fx") | |
| except Exception: | |
| return False | |
| _import_utils.is_torch_fx_available = is_torch_fx_available | |
| except Exception: | |
| pass | |
| # Suppress annoying warnings | |
| warnings.filterwarnings("ignore", message="The parameters have been moved from the Blocks constructor to the launch()") | |
| warnings.filterwarnings("ignore", message="CUDA is not available or torch_xla is imported") | |
| warnings.filterwarnings("ignore", message="The following generation flags are not valid and may be ignored") | |
| warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set") | |
| warnings.filterwarnings("ignore", message="You are using a model of type .* to instantiate a model of type .*") | |
| hf_logging.set_verbosity_error() | |
| # --- Configuration --- | |
| DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2' | |
| MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it' | |
| _default_hf_home = "/data/.huggingface" if os.path.isdir("/data") else os.path.join(os.path.expanduser("~"), ".cache", "huggingface") | |
| os.environ.setdefault("HF_HOME", _default_hf_home) | |
| _hf_cache_dir = os.environ.get("HF_HUB_CACHE") or os.path.join(os.environ["HF_HOME"], "hub") | |
| os.environ.setdefault("HF_HUB_CACHE", _hf_cache_dir) | |
| os.environ.setdefault("TRANSFORMERS_CACHE", _hf_cache_dir) | |
| def _warmup_hf_cache(): | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except Exception as e: | |
| print(f"Warmup cache failed: {e}") | |
| return | |
| for _repo_id in (DEEPSEEK_MODEL, MEDGEMMA_MODEL): | |
| try: | |
| snapshot_download(repo_id=_repo_id, cache_dir=_hf_cache_dir) | |
| except Exception as e: | |
| print(f"Warmup cache failed for {_repo_id}: {e}") | |
| threading.Thread(target=_warmup_hf_cache, daemon=True).start() | |
| # --- Device Setup --- | |
| # For HF Spaces with ZeroGPU, we'll use cuda if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| def _configure_cuda_precision(): | |
| if not torch.cuda.is_available(): | |
| return | |
| # Avoid BF16 on GPUs that don't support it (sm80+). | |
| try: | |
| major, minor = torch.cuda.get_device_capability() | |
| if (major, minor) < (8, 0): | |
| torch.backends.cuda.matmul.allow_bf16 = False | |
| except Exception: | |
| torch.backends.cuda.matmul.allow_bf16 = False | |
| _configure_cuda_precision() | |
| class ModelManager: | |
| def __init__(self): | |
| self.models = {} | |
| self.processors = {} | |
| def get_model(self, model_name): | |
| if model_name not in self.models: | |
| print(f"Loading {model_name} to CPU...") | |
| if model_name == DEEPSEEK_MODEL: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=_hf_cache_dir) | |
| if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| attn_implementation="eager", | |
| cache_dir=_hf_cache_dir, | |
| torch_dtype=dtype | |
| ) | |
| if hasattr(model, "config") and getattr(model.config, "pad_token_id", None) is None and getattr(tokenizer, "pad_token_id", None) is not None: | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| if hasattr(model, "generation_config"): | |
| if getattr(model.generation_config, "pad_token_id", None) is None and getattr(tokenizer, "pad_token_id", None) is not None: | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| if getattr(model.generation_config, "eos_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None: | |
| model.generation_config.eos_token_id = tokenizer.eos_token_id | |
| model.eval() | |
| self.models[model_name] = model | |
| self.processors[model_name] = tokenizer | |
| elif model_name == MEDGEMMA_MODEL: | |
| processor = AutoProcessor.from_pretrained(model_name, cache_dir=_hf_cache_dir) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| cache_dir=_hf_cache_dir, | |
| torch_dtype=dtype | |
| ) | |
| model.eval() | |
| # Ensure pad_token_id is set | |
| if processor.tokenizer.pad_token_id is None: | |
| processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id | |
| if hasattr(model, "generation_config"): | |
| if getattr(model.generation_config, "pad_token_id", None) is None: | |
| model.generation_config.pad_token_id = processor.tokenizer.pad_token_id | |
| if getattr(model.generation_config, "eos_token_id", None) is None: | |
| model.generation_config.eos_token_id = processor.tokenizer.eos_token_id | |
| self.models[model_name] = model | |
| self.processors[model_name] = processor | |
| return self.models[model_name], self.processors[model_name] | |
| manager = ModelManager() | |
| def pdf_to_images(pdf_path): | |
| doc = fitz.open(pdf_path) | |
| images = [] | |
| for page in doc: | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) | |
| img_data = pix.tobytes("png") | |
| img = Image.open(io.BytesIO(img_data)) | |
| images.append(img) | |
| doc.close() | |
| return images | |
| def run_ocr(input_image, input_file, model_choice, custom_prompt): | |
| images_to_process = [] | |
| if input_file is not None: | |
| # Compatibility with different Gradio versions (object with .name vs string path) | |
| file_path = input_file.name if hasattr(input_file, 'name') else input_file | |
| if file_path.lower().endswith(".pdf"): | |
| try: | |
| images_to_process = pdf_to_images(file_path) | |
| except Exception as e: | |
| return f"Помилка читання PDF: {str(e)}" | |
| else: | |
| try: | |
| images_to_process = [Image.open(file_path)] | |
| except Exception as e: | |
| return f"Помилка завантаження файлу: {str(e)}" | |
| elif input_image is not None: | |
| images_to_process = [input_image] | |
| else: | |
| return "Будь ласка, завантажте зображення або PDF файл." | |
| def _is_cuda_bf16_error(err): | |
| msg = str(err) | |
| return "CUBLAS_STATUS_INVALID_VALUE" in msg and "CUDA_R_16BF" in msg | |
| try: | |
| model, processor_or_tokenizer = manager.get_model(model_choice) | |
| # Move to GPU only inside the decorated function | |
| print(f"Moving {model_choice} to GPU...") | |
| model.to(device="cuda", dtype=torch.float16) | |
| run_device = "cuda" | |
| except Exception as e: | |
| return f"Помилка завантаження чи переміщення моделі: {str(e)}\nЯкщо це MedGemma, переконайтеся, що ви надали HF_TOKEN." | |
| output_dir = 'outputs' | |
| os.makedirs(output_dir, exist_ok=True) | |
| all_results = [] | |
| try: | |
| def _autocast_for(device_str): | |
| if device_str == "cuda" and torch.cuda.is_available(): | |
| return torch.autocast(device_type="cuda", dtype=torch.float16) | |
| return contextlib.nullcontext() | |
| for i, img in enumerate(images_to_process): | |
| img = img.convert("RGB") | |
| try: | |
| print(f"Processing page/image {i+1} with {model_choice}...") | |
| if model_choice == DEEPSEEK_MODEL: | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| img.save(tmp.name) | |
| tmp_path = tmp.name | |
| try: | |
| try: | |
| with torch.no_grad(), _autocast_for(run_device): | |
| res = model.infer( | |
| processor_or_tokenizer, | |
| prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ", | |
| image_file=tmp_path, | |
| output_path=output_dir, | |
| base_size=1024, | |
| image_size=768, | |
| crop_mode=True, | |
| eval_mode=True | |
| ) | |
| all_results.append(f"--- Page/Image {i+1} ---\n{res}") | |
| except Exception as e: | |
| if run_device == "cuda" and _is_cuda_bf16_error(e): | |
| print("CUDA BF16 error detected, retrying on CPU...") | |
| model.to(device="cpu", dtype=torch.float32) | |
| run_device = "cpu" | |
| with torch.no_grad(), _autocast_for(run_device): | |
| res = model.infer( | |
| processor_or_tokenizer, | |
| prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ", | |
| image_file=tmp_path, | |
| output_path=output_dir, | |
| base_size=1024, | |
| image_size=768, | |
| crop_mode=True, | |
| eval_mode=True | |
| ) | |
| all_results.append(f"--- Page/Image {i+1} ---\n{res}") | |
| else: | |
| raise | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| elif model_choice == MEDGEMMA_MODEL: | |
| prompt_text = custom_prompt if custom_prompt else "extract all text from image" | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": prompt_text} | |
| ] | |
| } | |
| ] | |
| inputs = processor_or_tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(run_device) | |
| if "attention_mask" not in inputs: | |
| inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long) | |
| try: | |
| with torch.no_grad(), _autocast_for(run_device): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=4096, | |
| do_sample=False, | |
| pad_token_id=processor_or_tokenizer.tokenizer.pad_token_id, | |
| ) | |
| input_len = inputs["input_ids"].shape[-1] | |
| res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True) | |
| all_results.append(f"--- Page/Image {i+1} ---\n{res}") | |
| except Exception as e: | |
| if run_device == "cuda" and _is_cuda_bf16_error(e): | |
| print("CUDA BF16 error detected, retrying on CPU...") | |
| model.to(device="cpu", dtype=torch.float32) | |
| run_device = "cpu" | |
| inputs = inputs.to(run_device) | |
| with torch.no_grad(), _autocast_for(run_device): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=4096, | |
| do_sample=False, | |
| pad_token_id=processor_or_tokenizer.tokenizer.pad_token_id, | |
| ) | |
| input_len = inputs["input_ids"].shape[-1] | |
| res = processor_or_tokenizer.decode(output[0][input_len:], skip_special_tokens=True) | |
| all_results.append(f"--- Page/Image {i+1} ---\n{res}") | |
| else: | |
| raise | |
| except Exception as e: | |
| all_results.append(f"--- Page/Image {i+1} ---\nПомилка: {str(e)}") | |
| finally: | |
| # Move back to CPU and clean up to free ZeroGPU resources | |
| print(f"Moving {model_choice} back to CPU...") | |
| model.to("cpu") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return "\n\n".join(all_results) | |
| def save_result_to_file(text): | |
| if not text or text.startswith("Будь ласка") or text.startswith("Помилка"): | |
| return None | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"ocr_result_{timestamp}.txt" | |
| os.makedirs("outputs", exist_ok=True) | |
| filepath = os.path.abspath(os.path.join("outputs", filename)) | |
| with open(filepath, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| return filepath | |
| custom_css = """ | |
| .header { text-align: center; margin-bottom: 30px; } | |
| .header h1 { font-size: 2.5rem; } | |
| .footer { text-align: center; margin-top: 50px; font-size: 0.9rem; color: #718096; } | |
| """ | |
| with gr.Blocks(title="OCR Comparison: DeepSeek vs MedGemma", css=custom_css) as demo: | |
| with gr.Column(): | |
| gr.Markdown("# 🔍 OCR & Medical Document Analysis", elem_classes="header") | |
| gr.Markdown("Порівняння DeepSeek-OCR-2 та MedGemma-1.5-4B (HuggingFace ZeroGPU Edition)", elem_classes="header") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Tab("Зображення"): | |
| input_img = gr.Image(type="pil", label="Перетягніть зображення") | |
| with gr.Tab("PDF / Файли"): | |
| input_file = gr.File(label="Завантажте PDF або інший файл") | |
| model_selector = gr.Dropdown( | |
| choices=[DEEPSEEK_MODEL, MEDGEMMA_MODEL], | |
| value=DEEPSEEK_MODEL, | |
| label="Оберіть модель" | |
| ) | |
| with gr.Accordion("Налаштування", open=False): | |
| prompt_input = gr.Textbox( | |
| value="", | |
| label="Користувацький промпт (залиште порожнім для дефолтного)", | |
| placeholder="Наприклад: Extract all text from image" | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Очистити", variant="secondary") | |
| ocr_btn = gr.Button("Запустити аналіз", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="Результат", | |
| lines=20 | |
| ) | |
| with gr.Row(): | |
| save_btn = gr.Button("Зберегти у файл 💾") | |
| download_file = gr.File(label="Завантажити результат") | |
| gr.Markdown("---") | |
| gr.Markdown("### Як використовувати:\n1. Завантажте зображення або PDF.\n2. Виберіть модель.\n3. Натисніть 'Запустити аналіз'.\n*Примітка: MedGemma потребує HF_TOKEN з доступом до моделі.*") | |
| # Event handlers | |
| ocr_btn.click( | |
| fn=run_ocr, | |
| inputs=[input_img, input_file, model_selector, prompt_input], | |
| outputs=output_text | |
| ) | |
| save_btn.click( | |
| fn=save_result_to_file, | |
| inputs=output_text, | |
| outputs=download_file | |
| ) | |
| def clear_all(): | |
| return None, None, "", "" | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=None, | |
| outputs=[input_img, input_file, output_text, prompt_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |