Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import os | |
| import tempfile | |
| from PIL import Image, ImageOps | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| MODEL_PATH = "zai-org/GLM-OCR" | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| pretrained_model_name_or_path=MODEL_PATH, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| TASK_PROMPTS = { | |
| "Text": "Text Recognition:", | |
| "Formula": "Formula Recognition:", | |
| "Table": "Table Recognition:", | |
| } | |
| def process_image(image, task): | |
| if image is None: | |
| return "Please upload an image first.", "Please upload an image first." | |
| if image.mode in ("RGBA", "LA", "P"): | |
| image = image.convert("RGB") | |
| image = ImageOps.exif_transpose(image) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| image.save(tmp.name, "PNG") | |
| tmp.close() | |
| prompt = TASK_PROMPTS.get(task, "Text Recognition:") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "url": tmp.name}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to("cpu") | |
| inputs.pop("token_type_ids", None) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=4096) | |
| output_text = processor.decode( | |
| generated_ids[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True, | |
| ) | |
| os.unlink(tmp.name) | |
| result = output_text.strip() | |
| return result, result | |
| with gr.Blocks( | |
| theme="NoCrypt/miku", | |
| fill_height=True, | |
| css="footer {display: none !important}", | |
| ) as demo: | |
| with gr.Sidebar(width=400): | |
| gr.Markdown("# GLM-OCR (CPU)") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| sources=["upload", "clipboard"], | |
| height=300, | |
| ) | |
| task = gr.Radio( | |
| choices=list(TASK_PROMPTS.keys()), | |
| value="Text", | |
| label="Recognition Type", | |
| ) | |
| btn = gr.Button("Perform OCR", variant="primary") | |
| gr.Markdown("## Output") | |
| output_text = gr.Textbox(label="Raw Output", interactive=True, lines=22) | |
| with gr.Accordion("Rendered Markdown", open=False): | |
| output_md = gr.Markdown(label="Rendered Markdown") | |
| btn.click( | |
| fn=process_image, | |
| inputs=[image_input, task], | |
| outputs=[output_text, output_md], | |
| ) | |
| image_input.change( | |
| fn=lambda: ("", ""), | |
| inputs=None, | |
| outputs=[output_text, output_md], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=50).launch( | |
| ssr_mode=False, | |
| show_error=True, | |
| ) | |