ocr-test / app.py
HarryLovesCode's picture
Add updated requirements to fix model load
85d8a1c
import gradio as gr
from PIL import Image
import numpy as np
# Transformers imports are deferred to avoid requiring heavy packages when
# NO_MODEL_LOAD is set. The module-level imports happen only if we actually
# need to load the model. This makes tests and CI simpler.
import tempfile
import os
import shutil
# Allow delaying heavy model load if the environment variable NO_MODEL_LOAD is set
if os.environ.get('NO_MODEL_LOAD'):
tokenizer = None
model = None
else:
# Import heavy transformer classes lazily
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
try:
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_safetensors=True, low_cpu_mem_usage=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval()
except Exception as e:
# If model fails to load (e.g. due to no network or heavy resources), keep a placeholder
print(f"Warning: Failed to load model: {e}")
model = None
def process_image(image):
"""Saves an uploaded image to a temporary file and runs `model.chat(tokenizer, image_file, ocr_type='ocr')`.
Returns the model output as a string. If the model is unavailable or an
exception occurs, returns an informative error string.
"""
if image is None:
return "No image provided."
# Convert numpy arrays to PIL Image if needed
if isinstance(image, np.ndarray):
pil_img = Image.fromarray(image)
else:
pil_img = image
# Save the image to a temp file (model.chat expects a path)
tmpfile = None
try:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
tmpfile = tmp.name
pil_img.save(tmpfile, format='JPEG')
tmp.close()
if model is None or not hasattr(model, 'chat'):
return "Model not available or does not implement `chat`."
# Call the model.chat method using an image file path (as requested)
res = model.chat(tokenizer, tmpfile, ocr_type='ocr')
# Try to give a human-readable string
try:
return str(res)
except Exception:
return f"Model returned an object of type {type(res)}: {res}"
except Exception as e:
return f"Error processing image: {repr(e)}"
finally:
# Clean up temp file
if tmpfile and os.path.exists(tmpfile):
try:
os.remove(tmpfile)
except Exception:
pass
def _launch_demo():
"""Create a Gradio Blocks UI and launch it. The interface contains an image
uploader, a 'Process' button, and a text output box which displays the
OCR/chat results from the loaded model.
"""
with gr.Blocks(title="OCR Processing Demo") as demo:
gr.Markdown("## OCR Processing Demo\nUpload an image and press **Process** to run the OCR model.")
with gr.Row():
image_input = gr.Image(type='pil', label='Upload Image')
output_text = gr.Textbox(label='Detected text / model output', lines=8)
process_btn = gr.Button('Process')
process_btn.click(fn=process_image, inputs=image_input, outputs=output_text)
return demo
if __name__ == "__main__":
demo = _launch_demo()
demo.launch()