twoimo commited on
Commit
6e0dd90
·
verified ·
1 Parent(s): e9e3585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -65
app.py CHANGED
@@ -1,77 +1,86 @@
1
- import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForImageTextToText
3
  import torch
4
  from PIL import Image
 
5
 
6
- # Model setup with error handling
7
- MODEL_PATH = "zai-org/GLM-OCR"
8
- model = None
9
- processor = None
10
 
11
- try:
12
- print(f"Loading processor from {MODEL_PATH}...")
13
- processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
14
- print(f"Loading model from {MODEL_PATH}...")
15
- model = AutoModelForImageTextToText.from_pretrained(
16
- MODEL_PATH,
17
- torch_dtype="auto",
18
- device_map="auto",
19
- trust_remote_code=True,
20
- )
21
- print("Model loaded successfully!")
22
- except Exception as e:
23
- print(f"Error loading model: {e}")
24
- model = None
25
- processor = None
26
 
27
- def process_image(image):
28
- """Process image and extract text using GLM-OCR"""
29
- if model is None or processor is None:
30
- return "Error: Model not loaded. Please refresh the page and try again."
31
-
32
  try:
33
- if isinstance(image, str):
34
- image = Image.open(image).convert("RGB")
35
- elif not isinstance(image, Image.Image):
36
- image = Image.fromarray(image).convert("RGB")
37
-
38
- messages = [{
39
- "role": "user",
40
- "content": [
41
- {"type": "image", "image": image},
42
- {"type": "text", "text": "Text Recognition:"}
43
- ],
44
- }]
45
-
46
- inputs = processor.apply_chat_template(
47
- messages, tokenize=True, add_generation_prompt=True,
48
- return_dict=True, return_tensors="pt"
49
- ).to(model.device)
50
-
51
- inputs.pop("token_type_ids", None)
52
-
53
- with torch.no_grad():
54
- generated_ids = model.generate(**inputs, max_new_tokens=2048)
55
-
56
- output_text = processor.decode(
57
- generated_ids[0][inputs["input_ids"].shape[1]:],
58
- skip_special_tokens=True,
59
  )
60
-
61
- return output_text
62
-
63
  except Exception as e:
64
- return f"Error processing image: {str(e)}"
 
 
 
 
 
65
 
66
- # Simple Gradio Interface
67
- demo = gr.Interface(
68
- fn=process_image,
69
- inputs=gr.Image(type="pil", label="Upload Image"),
70
- outputs=gr.Textbox(label="Extracted Text"),
71
- title="GLM-OCR: Multimodal OCR Model",
72
- description="Upload an image to extract text using the GLM-OCR model.",
73
- allow_flagging="never"
74
  )
75
 
76
- if __name__ == "__main__":
77
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  from transformers import AutoProcessor, AutoModelForImageTextToText
3
  import torch
4
  from PIL import Image
5
+ import io
6
 
7
+ st.set_page_config(page_title="GLM-OCR", layout="centered")
 
 
 
8
 
9
+ st.title("🎯 GLM-OCR: Multimodal OCR Model")
10
+ st.markdown("Upload an image to extract text using the GLM-OCR model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Load model with caching
13
+ @st.cache_resource
14
+ def load_model():
 
 
15
  try:
16
+ MODEL_PATH = "zai-org/GLM-OCR"
17
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
18
+ model = AutoModelForImageTextToText.from_pretrained(
19
+ MODEL_PATH,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto",
22
+ trust_remote_code=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
+ return processor, model
 
 
25
  except Exception as e:
26
+ st.error(f"Error loading model: {str(e)}")
27
+ return None, None
28
+
29
+ # Load model
30
+ with st.spinner("Loading GLM-OCR model... This may take a moment."):
31
+ processor, model = load_model()
32
 
33
+ if processor is None or model is None:
34
+ st.error("Failed to load the model. Please try refreshing the page.")
35
+ st.stop()
36
+
37
+ # File uploader
38
+ uploaded_file = st.file_uploader(
39
+ "Choose an image",
40
+ type=["jpg", "jpeg", "png", "bmp", "gif"],
41
  )
42
 
43
+ if uploaded_file is not None:
44
+ # Display the image
45
+ image = Image.open(uploaded_file).convert("RGB")
46
+ st.image(image, caption="Uploaded Image", use_column_width=True)
47
+
48
+ # Process the image
49
+ if st.button("Extract Text", type="primary"):
50
+ with st.spinner("Processing image... Please wait."):
51
+ try:
52
+ # Prepare input
53
+ messages = [{
54
+ "role": "user",
55
+ "content": [
56
+ {"type": "image", "image": image},
57
+ {"type": "text", "text": "Text Recognition:"}
58
+ ],
59
+ }]
60
+
61
+ # Process
62
+ inputs = processor.apply_chat_template(
63
+ messages, tokenize=True, add_generation_prompt=True,
64
+ return_dict=True, return_tensors="pt"
65
+ ).to(model.device)
66
+
67
+ inputs.pop("token_type_ids", None)
68
+
69
+ # Generate
70
+ with torch.no_grad():
71
+ generated_ids = model.generate(**inputs, max_new_tokens=2048)
72
+
73
+ # Decode
74
+ output_text = processor.decode(
75
+ generated_ids[0][inputs["input_ids"].shape[1]:],
76
+ skip_special_tokens=True,
77
+ )
78
+
79
+ st.success("Text extraction completed!")
80
+ st.text_area("Extracted Text", value=output_text, height=300)
81
+
82
+ except Exception as e:
83
+ st.error(f"Error processing image: {str(e)}")
84
+
85
+ st.markdown("---")
86
+ st.markdown("Powered by GLM-OCR from [ZAI](https://huggingface.co/zai-org)")