File size: 3,349 Bytes
cb6f32c
1d8fe68
 
 
 
 
 
 
 
cb6f32c
1d8fe68
 
 
 
 
 
 
 
 
85d8a1c
1d8fe68
 
 
85d8a1c
1d8fe68
cb6f32c
1d8fe68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from PIL import Image
import numpy as np
# Transformers imports are deferred to avoid requiring heavy packages when
# NO_MODEL_LOAD is set. The module-level imports happen only if we actually
# need to load the model. This makes tests and CI simpler.
import tempfile
import os
import shutil

# Allow delaying heavy model load if the environment variable NO_MODEL_LOAD is set
if os.environ.get('NO_MODEL_LOAD'):
    tokenizer = None
    model = None
else:
    # Import heavy transformer classes lazily
    from transformers import AutoModel, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
    try:
        model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_safetensors=True, low_cpu_mem_usage=True, pad_token_id=tokenizer.eos_token_id)
        model = model.eval()
    except Exception as e:
        # If model fails to load (e.g. due to no network or heavy resources), keep a placeholder
        print(f"Warning: Failed to load model: {e}")
        model = None


def process_image(image):
    """Saves an uploaded image to a temporary file and runs `model.chat(tokenizer, image_file, ocr_type='ocr')`.

    Returns the model output as a string. If the model is unavailable or an
    exception occurs, returns an informative error string.
    """
    if image is None:
        return "No image provided."

    # Convert numpy arrays to PIL Image if needed
    if isinstance(image, np.ndarray):
        pil_img = Image.fromarray(image)
    else:
        pil_img = image

    # Save the image to a temp file (model.chat expects a path)
    tmpfile = None
    try:
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
        tmpfile = tmp.name
        pil_img.save(tmpfile, format='JPEG')
        tmp.close()

        if model is None or not hasattr(model, 'chat'):
            return "Model not available or does not implement `chat`."

        # Call the model.chat method using an image file path (as requested)
        res = model.chat(tokenizer, tmpfile, ocr_type='ocr')

        # Try to give a human-readable string
        try:
            return str(res)
        except Exception:
            return f"Model returned an object of type {type(res)}: {res}"
    except Exception as e:
        return f"Error processing image: {repr(e)}"
    finally:
        # Clean up temp file
        if tmpfile and os.path.exists(tmpfile):
            try:
                os.remove(tmpfile)
            except Exception:
                pass


def _launch_demo():
    """Create a Gradio Blocks UI and launch it. The interface contains an image
    uploader, a 'Process' button, and a text output box which displays the
    OCR/chat results from the loaded model.
    """
    with gr.Blocks(title="OCR Processing Demo") as demo:
        gr.Markdown("## OCR Processing Demo\nUpload an image and press **Process** to run the OCR model.")
        with gr.Row():
            image_input = gr.Image(type='pil', label='Upload Image')
            output_text = gr.Textbox(label='Detected text / model output', lines=8)
        process_btn = gr.Button('Process')
        process_btn.click(fn=process_image, inputs=image_input, outputs=output_text)
    return demo


if __name__ == "__main__":
    demo = _launch_demo()
    demo.launch()