prithivMLmods commited on
Commit
3396a9a
·
verified ·
1 Parent(s): da84e63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -147
app.py CHANGED
@@ -3,6 +3,7 @@ import random
3
  import uuid
4
  import json
5
  import time
 
6
  from threading import Thread
7
  from typing import Iterable
8
 
@@ -10,22 +11,30 @@ import gradio as gr
10
  import spaces
11
  import torch
12
  import numpy as np
13
- from PIL import Image
14
  import cv2
 
15
 
16
  from transformers import (
17
- Qwen2_5_VLForConditionalGeneration,
18
- AutoModelForCausalLM, # Added for PaddleOCR-VL
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
22
  from transformers.image_utils import load_image
 
 
 
23
  from gradio.themes import Soft
24
  from gradio.themes.utils import colors, fonts, sizes
25
 
 
 
 
 
 
 
26
  # --- Theme and CSS Definition ---
27
 
28
- # Define the SteelBlue color palette
29
  colors.steel_blue = colors.Color(
30
  name="steel_blue",
31
  c50="#EBF3F8",
@@ -73,14 +82,8 @@ class SteelBlueTheme(Soft):
73
  button_primary_text_color_hover="white",
74
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
75
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
76
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
77
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
78
- button_secondary_text_color="black",
79
- button_secondary_text_color_hover="white",
80
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
81
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
82
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
83
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
84
  slider_color="*secondary_500",
85
  slider_color_dark="*secondary_600",
86
  block_title_text_weight="600",
@@ -92,7 +95,6 @@ class SteelBlueTheme(Soft):
92
  block_label_background_fill="*primary_200",
93
  )
94
 
95
- # Instantiate the new theme
96
  steel_blue_theme = SteelBlueTheme()
97
 
98
  css = """
@@ -105,179 +107,186 @@ css = """
105
  """
106
 
107
  # Constants for text generation
108
- MAX_MAX_NEW_TOKENS = 4096
109
- DEFAULT_MAX_NEW_TOKENS = 1024
110
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
111
 
112
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
113
-
114
- print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
115
- print("torch.__version__ =", torch.__version__)
116
- print("torch.version.cuda =", torch.version.cuda)
117
- print("cuda available:", torch.cuda.is_available())
118
- print("cuda device count:", torch.cuda.device_count())
119
- if torch.cuda.is_available():
120
- print("current device:", torch.cuda.current_device())
121
- print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
122
-
123
- print("Using device:", device)
124
 
125
- # --- Model Loading ---
126
  # Load Nanonets-OCR2-3B
127
- MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
128
- processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
129
- model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
130
- MODEL_ID_V,
 
 
131
  trust_remote_code=True,
132
- torch_dtype=torch.float16
133
- ).to(device).eval()
134
 
135
- # Load PaddleOCR-VL
136
- # Using the corrected model path from your previous attempt
137
- MODEL_ID_P = "strangervisionhf/paddle"
138
- processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
139
- model_p = AutoModelForCausalLM.from_pretrained(
140
- MODEL_ID_P,
 
141
  trust_remote_code=True,
142
- torch_dtype=torch.float16,
143
- ).to(device).eval()
144
 
145
- # --- Task Prompts for PaddleOCR-VL ---
146
- PROMPTS = {
147
- "ocr": "OCR:",
148
- "table": "Table Recognition:",
149
- "chart": "Chart Recognition:",
150
- "formula": "Formula Recognition:",
151
- }
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  @spaces.GPU
154
- def generate_image(model_name: str, text: str, image: Image.Image,
155
- max_new_tokens: int, temperature: float, top_p: float,
156
- top_k: int, repetition_penalty: float):
157
- """
158
- Generates responses using the selected model for image input.
159
- Yields raw text and Markdown-formatted text.
160
- """
161
- if image is None:
162
- yield "Please upload an image.", "Please upload an image."
 
 
 
 
163
  return
164
 
165
- if model_name == "Nanonets-OCR2-3B":
166
- processor = processor_v
167
- model = model_v
168
 
169
- messages = [{
 
 
 
 
 
 
 
 
 
 
170
  "role": "user",
171
- "content": [
172
- {"type": "image"},
173
- {"type": "text", "text": text},
174
  ]
175
- }]
176
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
177
-
178
- inputs = processor(
179
- text=[prompt_full],
180
- images=[image],
181
- return_tensors="pt",
182
- padding=True).to(device)
183
-
184
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
185
- generation_kwargs = {
186
- **inputs,
187
- "streamer": streamer,
188
- "max_new_tokens": max_new_tokens,
189
- "do_sample": True,
190
- "temperature": temperature,
191
- "top_p": top_p,
192
- "top_k": top_k,
193
- "repetition_penalty": repetition_penalty,
194
  }
195
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
196
- thread.start()
197
- buffer = ""
198
- for new_text in streamer:
199
- buffer += new_text
200
- buffer = buffer.replace("<|im_end|>", "")
201
- time.sleep(0.01)
202
- yield buffer, buffer
203
 
204
- elif model_name == "PaddleOCR-VL":
205
- processor = processor_p
206
- model = model_p
 
 
 
 
 
 
 
 
 
207
 
208
- # --- CORRECTED LOGIC FOR PADDLEOCR-VL ---
209
- # It expects a simple string content, not a list of dicts.
210
- # The user's input `text` should be one of the specific prompts.
211
- messages = [{"role": "user", "content": text}]
212
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
213
 
214
- inputs = processor(text=[prompt_full], images=[image], return_tensors="pt").to(device)
 
 
215
 
216
- generation_kwargs = {
217
- **inputs,
218
- "max_new_tokens": max_new_tokens,
219
- "do_sample": False, # As per the reference script for best results
220
- "use_cache": True,
221
- }
222
-
223
- with torch.inference_mode():
224
- generated_ids = model.generate(**generation_kwargs)
225
 
226
- resp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
227
- # Extract only the model's answer, excluding the prompt
228
- answer = resp.split(prompt_full)[-1].strip()
229
- yield answer, answer
230
-
231
- else:
232
- yield "Invalid model selected.", "Invalid model selected."
233
- return
234
 
235
- # Define examples for image inference, updated for both models
236
  image_examples = [
237
- ["OCR:", "images/ocr.png"],
238
- ["Table Recognition:", "images/4.png"],
239
- ["Extract the content of this invoice.", "images/0.png"]
 
 
 
 
 
 
240
  ]
241
 
 
 
 
 
242
 
243
  # Create the Gradio Interface
244
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
245
- gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
246
  with gr.Row():
247
  with gr.Column(scale=2):
248
- image_query = gr.Textbox(label="Query Input", placeholder="Enter query. For PaddleOCR, use 'OCR:', 'Table Recognition:', etc.")
249
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
250
-
251
- image_submit = gr.Button("Submit", variant="primary")
252
- gr.Examples(
253
- examples=image_examples,
254
- inputs=[image_query, image_upload]
255
- )
256
-
 
 
257
  with gr.Accordion("Advanced options", open=False):
258
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
259
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
260
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
261
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
262
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
263
 
264
  with gr.Column(scale=3):
265
- gr.Markdown("## Output", elem_id="output-title")
266
- output = gr.Textbox(label="Raw Output", interactive=False, lines=11, show_copy_button=True)
267
- with gr.Accordion("(Result.md)", open=False):
268
- markdown_output = gr.Markdown(label="(Result.Md)")
269
-
270
- model_choice = gr.Radio(
271
- choices=["Nanonets-OCR2-3B", "PaddleOCR-VL"],
272
- label="Select Model",
273
- value="Nanonets-OCR2-3B"
274
- )
275
-
276
  image_submit.click(
277
  fn=generate_image,
278
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
279
- outputs=[output, markdown_output]
 
 
 
 
 
280
  )
281
 
282
  if __name__ == "__main__":
283
- demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
 
3
  import uuid
4
  import json
5
  import time
6
+ import asyncio
7
  from threading import Thread
8
  from typing import Iterable
9
 
 
11
  import spaces
12
  import torch
13
  import numpy as np
14
+ from PIL import Image, ImageOps
15
  import cv2
16
+ import requests
17
 
18
  from transformers import (
19
+ AutoTokenizer,
 
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
23
  from transformers.image_utils import load_image
24
+ # The custom model class is imported via trust_remote_code=True
25
+ from transformers import AutoModelForImageTextToText
26
+
27
  from gradio.themes import Soft
28
  from gradio.themes.utils import colors, fonts, sizes
29
 
30
+ from docling_core.types.doc import DoclingDocument, DocTagsDocument
31
+
32
+ import re
33
+ import ast
34
+ import html
35
+
36
  # --- Theme and CSS Definition ---
37
 
 
38
  colors.steel_blue = colors.Color(
39
  name="steel_blue",
40
  c50="#EBF3F8",
 
82
  button_primary_text_color_hover="white",
83
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
84
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
85
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
86
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
 
 
 
 
 
 
87
  slider_color="*secondary_500",
88
  slider_color_dark="*secondary_600",
89
  block_title_text_weight="600",
 
95
  block_label_background_fill="*primary_200",
96
  )
97
 
 
98
  steel_blue_theme = SteelBlueTheme()
99
 
100
  css = """
 
107
  """
108
 
109
  # Constants for text generation
110
+ MAX_MAX_NEW_TOKENS = 5120
111
+ DEFAULT_MAX_NEW_TOKENS = 3072
112
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
113
 
114
+ # Check for CUDA availability
115
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
116
 
 
117
  # Load Nanonets-OCR2-3B
118
+ MODEL_ID_3B = "nanonets/Nanonets-OCR2-3B"
119
+ processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
120
+ model_3b = AutoModelForImageTextToText.from_pretrained(
121
+ MODEL_ID_3B,
122
+ torch_dtype="auto",
123
+ device_map="auto",
124
  trust_remote_code=True,
125
+ attn_implementation="flash_attention_2"
126
+ ).eval()
127
 
128
+ # Load Nanonets-OCR2-1.5B-exp
129
+ MODEL_ID_1_5B = "nanonets/Nanonets-OCR2-1.5B-exp"
130
+ processor_1_5b = AutoProcessor.from_pretrained(MODEL_ID_1_5B, trust_remote_code=True)
131
+ model_1_5b = AutoModelForImageTextToText.from_pretrained(
132
+ MODEL_ID_1_5B,
133
+ torch_dtype="auto",
134
+ device_map="auto",
135
  trust_remote_code=True,
136
+ attn_implementation="flash_attention_2"
137
+ ).eval()
138
 
139
+
140
+ def downsample_video(video_path):
141
+ """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
142
+ vidcap = cv2.VideoCapture(video_path)
143
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
144
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
145
+ frames = []
146
+ # Use a smaller number of frames for video to avoid overwhelming the model
147
+ frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
148
+ for i in frame_indices:
149
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
150
+ success, image = vidcap.read()
151
+ if success:
152
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
153
+ pil_image = Image.fromarray(image)
154
+ timestamp = round(i / fps, 2)
155
+ frames.append((pil_image, timestamp))
156
+ vidcap.release()
157
+ return frames
158
 
159
  @spaces.GPU
160
+ def generate(model_name: str, text: str, media_input, media_type: str,
161
+ max_new_tokens: int = 1024,
162
+ temperature: float = 0.6,
163
+ top_p: float = 0.9,
164
+ top_k: int = 50,
165
+ repetition_penalty: float = 1.2):
166
+ """Generic generation function for both image and video."""
167
+ if model_name == "Nanonets-OCR2-3B":
168
+ processor, model = processor_3b, model_3b
169
+ elif model_name == "Nanonets-OCR2-1.5B-exp":
170
+ processor, model = processor_1_5b, model_1_5b
171
+ else:
172
+ yield "Invalid model selected.", "Invalid model selected."
173
  return
174
 
175
+ if media_input is None:
176
+ yield f"Please upload an {media_type}.", f"Please upload an {media_type}."
177
+ return
178
 
179
+ if media_type == "image":
180
+ images = [media_input]
181
+ elif media_type == "video":
182
+ frames = downsample_video(media_input)
183
+ images = [frame for frame, _ in frames]
184
+ else:
185
+ yield "Invalid media type.", "Invalid media type."
186
+ return
187
+
188
+ messages = [
189
+ {
190
  "role": "user",
191
+ "content": [{"type": "image"} for _ in images] + [
192
+ {"type": "text", "text": text}
 
193
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  }
195
+ ]
196
+
197
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
198
+ # Since device_map="auto" is used, we don't need .to(device)
199
+ inputs = processor(text=prompt, images=images, return_tensors="pt")
 
 
 
200
 
201
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
202
+ generation_kwargs = {
203
+ **inputs,
204
+ "streamer": streamer,
205
+ "max_new_tokens": max_new_tokens,
206
+ "temperature": temperature,
207
+ "top_p": top_p,
208
+ "top_k": top_k,
209
+ "repetition_penalty": repetition_penalty,
210
+ }
211
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
212
+ thread.start()
213
 
214
+ buffer = ""
215
+ for new_text in streamer:
216
+ buffer += new_text.replace("<|im_end|>", "")
217
+ yield buffer, buffer
 
218
 
219
+ # Wrapper functions for Gradio clarity
220
+ def generate_image(*args):
221
+ yield from generate(*args[:3], media_input=args[2], media_type="image", *args[3:])
222
 
223
+ def generate_video(*args):
224
+ yield from generate(*args[:3], media_input=args[2], media_type="video", *args[3:])
 
 
 
 
 
 
 
225
 
 
 
 
 
 
 
 
 
226
 
227
+ # Define examples for image and video inference
228
  image_examples = [
229
+ ["Reconstruct the doc [table] as it is.", "images/0.png"],
230
+ ["Describe the image!", "images/8.png"],
231
+ ["OCR the image", "images/2.jpg"],
232
+ ["Convert this page to docling", "images/1.png"],
233
+ ["Convert this page to docling", "images/3.png"],
234
+ ["Convert chart to OTSL.", "images/4.png"],
235
+ ["Convert code to text", "images/5.jpg"],
236
+ ["Convert this table to OTSL.", "images/6.jpg"],
237
+ ["Convert formula to late.", "images/7.jpg"],
238
  ]
239
 
240
+ video_examples = [
241
+ ["Explain the video in detail.", "videos/1.mp4"],
242
+ ["Explain the video in detail.", "videos/2.mp4"]
243
+ ]
244
 
245
  # Create the Gradio Interface
246
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
247
+ gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
248
  with gr.Row():
249
  with gr.Column(scale=2):
250
+ with gr.Tabs():
251
+ with gr.TabItem("Image Inference"):
252
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
253
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
254
+ image_submit = gr.Button("Submit", variant="primary")
255
+ gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
256
+ with gr.TabItem("Video Inference"):
257
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
258
+ video_upload = gr.Video(label="Upload Video (<= 30s)", height=290)
259
+ video_submit = gr.Button("Submit", variant="primary")
260
+ gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
261
  with gr.Accordion("Advanced options", open=False):
262
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
263
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
264
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
265
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
266
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
267
 
268
  with gr.Column(scale=3):
269
+ gr.Markdown("## Output", elem_id="output-title")
270
+ raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
271
+ with gr.Accordion("(Result.md)", open=True):
272
+ formatted_output = gr.Markdown(label="(Result.md)")
273
+
274
+ model_choice = gr.Radio(
275
+ choices=["Nanonets-OCR2-3B", "Nanonets-OCR2-1.5B-exp"],
276
+ label="Select Model",
277
+ value="Nanonets-OCR2-3B"
278
+ )
279
+
280
  image_submit.click(
281
  fn=generate_image,
282
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
283
+ outputs=[raw_output, formatted_output]
284
+ )
285
+ video_submit.click(
286
+ fn=generate_video,
287
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
288
+ outputs=[raw_output, formatted_output]
289
  )
290
 
291
  if __name__ == "__main__":
292
+ demo.queue(max_size=50).launch(ssr_mode=False, show_error=True)