import gradio as gr from PIL import Image import numpy as np # Transformers imports are deferred to avoid requiring heavy packages when # NO_MODEL_LOAD is set. The module-level imports happen only if we actually # need to load the model. This makes tests and CI simpler. import tempfile import os import shutil # Allow delaying heavy model load if the environment variable NO_MODEL_LOAD is set if os.environ.get('NO_MODEL_LOAD'): tokenizer = None model = None else: # Import heavy transformer classes lazily from transformers import AutoModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) try: model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_safetensors=True, low_cpu_mem_usage=True, pad_token_id=tokenizer.eos_token_id) model = model.eval() except Exception as e: # If model fails to load (e.g. due to no network or heavy resources), keep a placeholder print(f"Warning: Failed to load model: {e}") model = None def process_image(image): """Saves an uploaded image to a temporary file and runs `model.chat(tokenizer, image_file, ocr_type='ocr')`. Returns the model output as a string. If the model is unavailable or an exception occurs, returns an informative error string. """ if image is None: return "No image provided." # Convert numpy arrays to PIL Image if needed if isinstance(image, np.ndarray): pil_img = Image.fromarray(image) else: pil_img = image # Save the image to a temp file (model.chat expects a path) tmpfile = None try: tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') tmpfile = tmp.name pil_img.save(tmpfile, format='JPEG') tmp.close() if model is None or not hasattr(model, 'chat'): return "Model not available or does not implement `chat`." # Call the model.chat method using an image file path (as requested) res = model.chat(tokenizer, tmpfile, ocr_type='ocr') # Try to give a human-readable string try: return str(res) except Exception: return f"Model returned an object of type {type(res)}: {res}" except Exception as e: return f"Error processing image: {repr(e)}" finally: # Clean up temp file if tmpfile and os.path.exists(tmpfile): try: os.remove(tmpfile) except Exception: pass def _launch_demo(): """Create a Gradio Blocks UI and launch it. The interface contains an image uploader, a 'Process' button, and a text output box which displays the OCR/chat results from the loaded model. """ with gr.Blocks(title="OCR Processing Demo") as demo: gr.Markdown("## OCR Processing Demo\nUpload an image and press **Process** to run the OCR model.") with gr.Row(): image_input = gr.Image(type='pil', label='Upload Image') output_text = gr.Textbox(label='Detected text / model output', lines=8) process_btn = gr.Button('Process') process_btn.click(fn=process_image, inputs=image_input, outputs=output_text) return demo if __name__ == "__main__": demo = _launch_demo() demo.launch()