prithivMLmods commited on
Commit
cb9a829
·
verified ·
1 Parent(s): 32b9fa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -72
app.py CHANGED
@@ -12,7 +12,6 @@ import spaces
12
  import torch
13
  import numpy as np
14
  from PIL import Image, ImageOps
15
- import cv2
16
  import requests
17
 
18
  from transformers import (
@@ -136,34 +135,14 @@ model_1_5b = AutoModelForImageTextToText.from_pretrained(
136
  attn_implementation="flash_attention_2"
137
  ).eval()
138
 
139
-
140
- def downsample_video(video_path):
141
- """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
142
- vidcap = cv2.VideoCapture(video_path)
143
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
144
- fps = vidcap.get(cv2.CAP_PROP_FPS)
145
- frames = []
146
- # Use a smaller number of frames for video to avoid overwhelming the model
147
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
148
- for i in frame_indices:
149
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
150
- success, image = vidcap.read()
151
- if success:
152
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
153
- pil_image = Image.fromarray(image)
154
- timestamp = round(i / fps, 2)
155
- frames.append((pil_image, timestamp))
156
- vidcap.release()
157
- return frames
158
-
159
  @spaces.GPU
160
- def generate(model_name: str, text: str, media_input, media_type: str,
161
- max_new_tokens: int = 1024,
162
- temperature: float = 0.6,
163
- top_p: float = 0.9,
164
- top_k: int = 50,
165
- repetition_penalty: float = 1.2):
166
- """Generic generation function for both image and video."""
167
  if model_name == "Nanonets-OCR2-3B":
168
  processor, model = processor_3b, model_3b
169
  elif model_name == "Nanonets-OCR2-1.5B-exp":
@@ -172,30 +151,20 @@ def generate(model_name: str, text: str, media_input, media_type: str,
172
  yield "Invalid model selected.", "Invalid model selected."
173
  return
174
 
175
- if media_input is None:
176
- yield f"Please upload an {media_type}.", f"Please upload an {media_type}."
177
  return
178
 
179
- if media_type == "image":
180
- images = [media_input]
181
- elif media_type == "video":
182
- frames = downsample_video(media_input)
183
- images = [frame for frame, _ in frames]
184
- else:
185
- yield "Invalid media type.", "Invalid media type."
186
- return
187
 
188
  messages = [
189
  {
190
  "role": "user",
191
- "content": [{"type": "image"} for _ in images] + [
192
- {"type": "text", "text": text}
193
- ]
194
  }
195
  ]
196
 
197
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
198
- # Since device_map="auto" is used, we don't need .to(device)
199
  inputs = processor(text=prompt, images=images, return_tensors="pt")
200
 
201
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
@@ -216,15 +185,7 @@ def generate(model_name: str, text: str, media_input, media_type: str,
216
  buffer += new_text.replace("<|im_end|>", "")
217
  yield buffer, buffer
218
 
219
- # Wrapper functions for Gradio clarity
220
- def generate_image(*args):
221
- yield from generate(*args[:3], media_input=args[2], media_type="image", *args[3:])
222
-
223
- def generate_video(*args):
224
- yield from generate(*args[:3], media_input=args[2], media_type="video", *args[3:])
225
-
226
-
227
- # Define examples for image and video inference
228
  image_examples = [
229
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
230
  ["Describe the image!", "images/8.png"],
@@ -237,27 +198,17 @@ image_examples = [
237
  ["Convert formula to late.", "images/7.jpg"],
238
  ]
239
 
240
- video_examples = [
241
- ["Explain the video in detail.", "videos/1.mp4"],
242
- ["Explain the video in detail.", "videos/2.mp4"]
243
- ]
244
-
245
  # Create the Gradio Interface
246
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
247
  gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
248
  with gr.Row():
249
  with gr.Column(scale=2):
250
- with gr.Tabs():
251
- with gr.TabItem("Image Inference"):
252
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
253
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
254
- image_submit = gr.Button("Submit", variant="primary")
255
- gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
256
- with gr.TabItem("Video Inference"):
257
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
258
- video_upload = gr.Video(label="Upload Video (<= 30s)", height=290)
259
- video_submit = gr.Button("Submit", variant="primary")
260
- gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
261
  with gr.Accordion("Advanced options", open=False):
262
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
263
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
@@ -282,11 +233,6 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
282
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
283
  outputs=[raw_output, formatted_output]
284
  )
285
- video_submit.click(
286
- fn=generate_video,
287
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
288
- outputs=[raw_output, formatted_output]
289
- )
290
 
291
  if __name__ == "__main__":
292
  demo.queue(max_size=50).launch(ssr_mode=False, show_error=True)
 
12
  import torch
13
  import numpy as np
14
  from PIL import Image, ImageOps
 
15
  import requests
16
 
17
  from transformers import (
 
135
  attn_implementation="flash_attention_2"
136
  ).eval()
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  @spaces.GPU
139
+ def generate_image(model_name: str, text: str, image: Image.Image,
140
+ max_new_tokens: int = 1024,
141
+ temperature: float = 0.6,
142
+ top_p: float = 0.9,
143
+ top_k: int = 50,
144
+ repetition_penalty: float = 1.2):
145
+ """Generation function for image input."""
146
  if model_name == "Nanonets-OCR2-3B":
147
  processor, model = processor_3b, model_3b
148
  elif model_name == "Nanonets-OCR2-1.5B-exp":
 
151
  yield "Invalid model selected.", "Invalid model selected."
152
  return
153
 
154
+ if image is None:
155
+ yield "Please upload an image.", "Please upload an image."
156
  return
157
 
158
+ images = [image]
 
 
 
 
 
 
 
159
 
160
  messages = [
161
  {
162
  "role": "user",
163
+ "content": [{"type": "image"}] + [{"type": "text", "text": text}]
 
 
164
  }
165
  ]
166
 
167
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
 
168
  inputs = processor(text=prompt, images=images, return_tensors="pt")
169
 
170
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
185
  buffer += new_text.replace("<|im_end|>", "")
186
  yield buffer, buffer
187
 
188
+ # Define examples for image inference
 
 
 
 
 
 
 
 
189
  image_examples = [
190
  ["Reconstruct the doc [table] as it is.", "images/0.png"],
191
  ["Describe the image!", "images/8.png"],
 
198
  ["Convert formula to late.", "images/7.jpg"],
199
  ]
200
 
 
 
 
 
 
201
  # Create the Gradio Interface
202
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
203
  gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
204
  with gr.Row():
205
  with gr.Column(scale=2):
206
+ # Image Inference Components
207
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
208
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
209
+ image_submit = gr.Button("Submit", variant="primary")
210
+ gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
211
+
 
 
 
 
 
212
  with gr.Accordion("Advanced options", open=False):
213
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
214
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
 
233
  inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
234
  outputs=[raw_output, formatted_output]
235
  )
 
 
 
 
 
236
 
237
  if __name__ == "__main__":
238
  demo.queue(max_size=50).launch(ssr_mode=False, show_error=True)