saif0001 commited on
Commit
a2e157e
ยท
verified ยท
1 Parent(s): 8ffa9a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -198
app.py CHANGED
@@ -1,201 +1,126 @@
1
- import gradio as gr
2
- import spaces
3
- from transformers import AutoModel, AutoTokenizer
4
  import os
5
- import base64
6
- import io
7
- import uuid
8
- import time
9
- import shutil
10
  from pathlib import Path
 
11
 
12
- # Load tokenizer and model for CPU
13
- tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
14
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True)
15
- model = model.eval() # No need for .cuda() since it's on CPU
16
-
17
- UPLOAD_FOLDER = "./uploads"
18
- RESULTS_FOLDER = "./results"
19
-
20
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
21
- if not os.path.exists(folder):
22
- os.makedirs(folder)
23
-
24
- def image_to_base64(image):
25
- buffered = io.BytesIO()
26
- image.save(buffered, format="PNG")
27
- return base64.b64encode(buffered.getvalue()).decode()
28
-
29
- def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
30
- unique_id = str(uuid.uuid4())
31
- image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
32
- result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
33
-
34
- shutil.copy(image, image_path)
35
-
36
- try:
37
- if got_mode == "plain texts OCR":
38
- res = model.chat(tokenizer, image_path, ocr_type='ocr')
39
- return res, None
40
- elif got_mode == "format texts OCR":
41
- res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
42
- elif got_mode == "plain multi-crop OCR":
43
- res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
44
- return res, None
45
- elif got_mode == "format multi-crop OCR":
46
- res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
47
- elif got_mode == "plain fine-grained OCR":
48
- res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
49
- return res, None
50
- elif got_mode == "format fine-grained OCR":
51
- res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
52
-
53
- res_markdown = res
54
-
55
- if "format" in got_mode and os.path.exists(result_path):
56
- with open(result_path, 'r') as f:
57
- html_content = f.read()
58
- encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
59
- iframe_src = f"data:text/html;base64,{encoded_html}"
60
- iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
61
- download_link = f'<a href="data:text/html;base64,{encoded_html}" download="result_{unique_id}.html">Download Full Result</a>'
62
- return res_markdown, f"{download_link}<br>{iframe}"
63
- else:
64
- return res_markdown, None
65
- except Exception as e:
66
- return f"Error: {str(e)}", None
67
- finally:
68
- if os.path.exists(image_path):
69
- os.remove(image_path)
70
-
71
- def task_update(task):
72
- if "fine-grained" in task:
73
- return [
74
- gr.update(visible=True),
75
- gr.update(visible=False),
76
- gr.update(visible=False),
77
- ]
78
- else:
79
- return [
80
- gr.update(visible=False),
81
- gr.update(visible=False),
82
- gr.update(visible=False),
83
- ]
84
-
85
- def fine_grained_update(task):
86
- if task == "box":
87
- return [
88
- gr.update(visible=False, value=""),
89
- gr.update(visible=True),
90
- ]
91
- elif task == 'color':
92
- return [
93
- gr.update(visible=True),
94
- gr.update(visible=False, value=""),
95
- ]
96
-
97
- def cleanup_old_files():
98
- current_time = time.time()
99
- for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
100
- for file_path in Path(folder).glob('*'):
101
- if current_time - file_path.stat().st_mtime > 3600: # 1 hour
102
- file_path.unlink()
103
-
104
- title_html = """
105
- <h2> <span class="gradient-text" id="text">General OCR Theory</span><span class="plain-text">: Towards OCR-2.0 via a Unified End-to-end Model</span></h2>
106
- <a href="https://huggingface.co/ucaslcl/GOT-OCR2_0">[๐Ÿ˜Š Hugging Face]</a>
107
- <a href="https://arxiv.org/abs/2409.01704">[๐Ÿ“œ Paper]</a>
108
- <a href="https://github.com/Ucas-HaoranWei/GOT-OCR2.0/">[๐ŸŒŸ GitHub]</a>
109
- """
110
-
111
- with gr.Blocks() as demo:
112
- gr.HTML(title_html)
113
- gr.Markdown("""\
114
- "๐Ÿ”ฅ๐Ÿ”ฅ๐Ÿ”ฅThis is the official online demo of GOT-OCR-2.0 model!!!"
115
-
116
- ### Demo Guidelines
117
- You need to upload your image below and choose one mode of GOT, then click "Submit" to run GOT model. More characters will result in longer wait times.
118
- - **plain texts OCR & format texts OCR**: The two modes are for the image-level OCR.
119
- - **plain multi-crop OCR & format multi-crop OCR**: For images with more complex content, you can achieve higher-quality results with these modes.
120
- - **plain fine-grained OCR & format fine-grained OCR**: In these modes, you can specify fine-grained regions on the input image for more flexible OCR. Fine-grained regions can be coordinates of the box, red color, blue color, or green color.
121
- """)
122
-
123
- with gr.Row():
124
- with gr.Column():
125
- image_input = gr.Image(type="filepath", label="upload your image")
126
- task_dropdown = gr.Dropdown(
127
- choices=[
128
- "plain texts OCR",
129
- "format texts OCR",
130
- "plain multi-crop OCR",
131
- "format multi-crop OCR",
132
- "plain fine-grained OCR",
133
- "format fine-grained OCR",
134
- ],
135
- label="Choose one mode of GOT",
136
- value="plain texts OCR"
137
- )
138
- fine_grained_dropdown = gr.Dropdown(
139
- choices=["box", "color"],
140
- label="fine-grained type",
141
- visible=False
142
- )
143
- color_dropdown = gr.Dropdown(
144
- choices=["red", "green", "blue"],
145
- label="color list",
146
- visible=False
147
- )
148
- box_input = gr.Textbox(
149
- label="input box: [x1,y1,x2,y2]",
150
- placeholder="e.g., [0,0,100,100]",
151
- visible=False
152
- )
153
- submit_button = gr.Button("Submit")
154
-
155
- with gr.Column():
156
- ocr_result = gr.Textbox(label="GOT output")
157
-
158
- with gr.Column():
159
- gr.Markdown("**If you choose the mode with format, the mathpix result will be automatically rendered as follows:**")
160
- html_result = gr.HTML(label="rendered html", show_label=True)
161
-
162
- # Removed examples section
163
- """
164
- gr.Examples(
165
- examples=[
166
- ["assets/coco.jpg", "plain texts OCR", "", "", ""],
167
- ["assets/en_30.png", "plain texts OCR", "", "", ""],
168
- ["assets/table.jpg", "format texts OCR", "", "", ""],
169
- ["assets/eq.jpg", "format texts OCR", "", "", ""],
170
- ["assets/exam.jpg", "format texts OCR", "", "", ""],
171
- ["assets/giga.jpg", "format multi-crop OCR", "", "", ""],
172
- ["assets/aff2.png", "plain fine-grained OCR", "box", "", "[409,763,756,891]"],
173
- ["assets/color.png", "plain fine-grained OCR", "color", "red", ""],
174
- ],
175
- inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
176
- outputs=[ocr_result, html_result],
177
- fn=run_GOT,
178
- label="examples",
179
- )
180
- """
181
-
182
- task_dropdown.change(
183
- task_update,
184
- inputs=[task_dropdown],
185
- outputs=[fine_grained_dropdown, color_dropdown, box_input]
186
- )
187
- fine_grained_dropdown.change(
188
- fine_grained_update,
189
- inputs=[fine_grained_dropdown],
190
- outputs=[color_dropdown, box_input]
191
- )
192
-
193
- submit_button.click(
194
- run_GOT,
195
- inputs=[image_input, task_dropdown, fine_grained_dropdown, color_dropdown, box_input],
196
- outputs=[ocr_result, html_result]
197
- )
198
-
199
- if __name__ == "__main__":
200
- cleanup_old_files()
201
- demo.launch()
 
 
 
 
1
  import os
2
+ import copy
3
+ import tempfile
4
+ import requests
5
+ import re
6
+ from argparse import ArgumentParser
7
  from pathlib import Path
8
+ from byaldi import RAGMultiModalModel
9
 
10
+ API_KEY = os.environ['API_KEY']
11
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
12
+
13
+ def _get_args():
14
+ parser = ArgumentParser()
15
+ parser.add_argument("--share", action="store_true", default=False)
16
+ args = parser.parse_args()
17
+ return args
18
+
19
+ def _parse_text(text):
20
+ lines = text.split("\n")
21
+ lines = [line for line in lines if line]
22
+ count = 0
23
+ for i, line in enumerate(lines):
24
+ if "```" in line:
25
+ count += 1
26
+ items = line.split("`")
27
+ lines[i] = f'<pre><code class="language-{items[-1]}">' if count % 2 == 1 else "<br></code></pre>"
28
+ elif count % 2 == 1:
29
+ lines[i] = "<br>" + re.sub(r'[<>\*_\-.\!\(\)\$]', lambda x: f'&{x.group(0)};', line.replace(" ", "&nbsp;"))
30
+ return "".join(lines)
31
+
32
+ def _remove_image_special(text):
33
+ text = text.replace('<ref>', '').replace('</ref>', '')
34
+ return re.sub(r'<box>.*?(</box>|$)', '', text)
35
+
36
+ def _launch_demo(args):
37
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(Path(tempfile.gettempdir()) / "gradio")
38
+
39
+ def predict(_chatbot, task_history):
40
+ chat_query = _chatbot[-1][0]
41
+ query = task_history[-1][0]
42
+ if not chat_query:
43
+ _chatbot.pop()
44
+ task_history.pop()
45
+ return _chatbot
46
+
47
+ history_cp = copy.deepcopy(task_history)
48
+ messages = []
49
+ content = []
50
+
51
+ for q, a in history_cp:
52
+ content.append({'image': f'file://{q[0]}'})
53
+ messages.append({'role': 'user', 'content': content})
54
+ messages.append({'role': 'assistant', 'content': [{'text': a}]})
55
+ content = []
56
+ messages.pop()
57
+
58
+ responses = RAG.call(model='qwen-vl-max-0809', messages=messages, stream=True)
59
+ for response in responses:
60
+ response_content = response['output']['choices'][0]['message']['content']
61
+ response_text = ''.join(ele.get('text', ele.get('box', '')) for ele in response_content)
62
+ _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(response_text))
63
+ yield _chatbot
64
+
65
+ response_text = response_content[0]['text']
66
+ _chatbot[-1] = (_parse_text(chat_query), response_text)
67
+ task_history[-1] = (query, _parse_text(response_text))
68
+
69
+ def regenerate(_chatbot, task_history):
70
+ if not task_history:
71
+ return _chatbot
72
+ item = task_history[-1]
73
+ if item[1] is None:
74
+ return _chatbot
75
+ task_history[-1] = (item[0], None)
76
+ chatbot_item = _chatbot.pop(-1)
77
+ _chatbot.append((chatbot_item[0], None) if chatbot_item[0] is not None else (_chatbot[-1][0], None))
78
+ return predict(_chatbot, task_history)
79
+
80
+ def add_text(history, task_history, text):
81
+ task_text = text
82
+ history = history if history is not None else []
83
+ task_history = task_history if task_history is not None else []
84
+ history.append((_parse_text(text), None))
85
+ task_history.append((task_text, None))
86
+ return history, task_history, ""
87
+
88
+ def add_file(history, task_history, file):
89
+ history = history if history is not None else []
90
+ task_history = task_history if task_history is not None else []
91
+ history.append(((file.name,), None))
92
+ task_history.append(((file.name,), None))
93
+ return history, task_history
94
+
95
+ def reset_state(task_history):
96
+ task_history.clear()
97
+ return []
98
+
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("""<p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 80px"/><p>""")
101
+ gr.Markdown("<center><font size=8>Qwen2-VL-Max</center>")
102
+
103
+ chatbot = gr.Chatbot(label='Qwen2-VL-Max', height=500)
104
+ query = gr.Textbox(lines=2, label='Input')
105
+ task_history = gr.State([])
106
+
107
+ with gr.Row():
108
+ addfile_btn = gr.UploadButton("๐Ÿ“ Upload", file_types=["image"])
109
+ submit_btn = gr.Button("๐Ÿš€ Submit")
110
+ regen_btn = gr.Button("๐Ÿค”๏ธ Regenerate")
111
+ empty_bin = gr.Button("๐Ÿงน Clear History")
112
+
113
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True)
114
+ submit_btn.click(lambda: gr.update(value=""), [], [query])
115
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
116
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
117
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
118
+
119
+ demo.queue().launch(share=args.share)
120
+
121
+ def main():
122
+ args = _get_args()
123
+ _launch_demo(args)
124
+
125
+ if __name__ == '__main__':
126
+ main()