File size: 2,947 Bytes
34c3ed5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
import os
import tempfile
from PIL import Image, ImageOps

from transformers import AutoProcessor, AutoModelForImageTextToText

MODEL_PATH = "zai-org/GLM-OCR"

processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
    pretrained_model_name_or_path=MODEL_PATH,
    torch_dtype=torch.float32,
    device_map="cpu",
    trust_remote_code=True,
)
model.eval()

TASK_PROMPTS = {
    "Text": "Text Recognition:",
    "Formula": "Formula Recognition:",
    "Table": "Table Recognition:",
}


def process_image(image, task):
    if image is None:
        return "Please upload an image first.", "Please upload an image first."

    if image.mode in ("RGBA", "LA", "P"):
        image = image.convert("RGB")
    image = ImageOps.exif_transpose(image)

    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
    image.save(tmp.name, "PNG")
    tmp.close()

    prompt = TASK_PROMPTS.get(task, "Text Recognition:")

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "url": tmp.name},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    inputs = processor.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt",
    ).to("cpu")

    inputs.pop("token_type_ids", None)

    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=4096)

    output_text = processor.decode(
        generated_ids[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True,
    )

    os.unlink(tmp.name)

    result = output_text.strip()
    return result, result


with gr.Blocks(
    theme="NoCrypt/miku",
    fill_height=True,
    css="footer {display: none !important}",
) as demo:

    with gr.Sidebar(width=400):
        gr.Markdown("# GLM-OCR (CPU)")
        image_input = gr.Image(
            type="pil",
            label="Upload Image",
            sources=["upload", "clipboard"],
            height=300,
        )
        task = gr.Radio(
            choices=list(TASK_PROMPTS.keys()),
            value="Text",
            label="Recognition Type",
        )
        btn = gr.Button("Perform OCR", variant="primary")

    gr.Markdown("## Output")
    output_text = gr.Textbox(label="Raw Output", interactive=True, lines=22)

    with gr.Accordion("Rendered Markdown", open=False):
        output_md = gr.Markdown(label="Rendered Markdown")

    btn.click(
        fn=process_image,
        inputs=[image_input, task],
        outputs=[output_text, output_md],
    )

    image_input.change(
        fn=lambda: ("", ""),
        inputs=None,
        outputs=[output_text, output_md],
    )

if __name__ == "__main__":
    demo.queue(max_size=50).launch(
        ssr_mode=False,
        show_error=True,
    )