prithivMLmods commited on
Commit
f534b3e
·
verified ·
1 Parent(s): 9d04faa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -359
app.py CHANGED
@@ -5,9 +5,6 @@ import json
5
  import time
6
  import asyncio
7
  from threading import Thread
8
- from pathlib import Path
9
- from io import BytesIO
10
- from typing import Optional, Tuple, Dict, Any, Iterable
11
 
12
  import gradio as gr
13
  import spaces
@@ -16,148 +13,24 @@ import numpy as np
16
  from PIL import Image
17
  import cv2
18
  import requests
19
- import fitz
20
 
21
  from transformers import (
22
- Qwen3VLMoeForConditionalGeneration,
 
23
  AutoProcessor,
24
  TextIteratorStreamer,
 
 
25
  )
26
  from transformers.image_utils import load_image
27
 
28
- from gradio.themes import Soft
29
- from gradio.themes.utils import colors, fonts, sizes
30
-
31
- colors.thistle = colors.Color(
32
- name="thistle",
33
- c50="#F9F5F9", c100="#F0E8F1", c200="#E7DBE8", c300="#DECEE0",
34
- c400="#D2BFD8", c500="#D8BFD8", c600="#B59CB7", c700="#927996",
35
- c800="#6F5675", c900="#4C3454", c950="#291233",
36
- )
37
-
38
- colors.red_gray = colors.Color(
39
- name="red_gray",
40
- c50="#f7eded", c100="#f5dcdc", c200="#efb4b4", c300="#e78f8f",
41
- c400="#d96a6a", c500="#c65353", c600="#b24444", c700="#8f3434",
42
- c800="#732d2d", c900="#5f2626", c950="#4d2020",
43
- )
44
-
45
- class ThistleTheme(Soft):
46
- def __init__(
47
- self,
48
- *,
49
- primary_hue: colors.Color | str = colors.gray,
50
- secondary_hue: colors.Color | str = colors.thistle,
51
- neutral_hue: colors.Color | str = colors.slate,
52
- text_size: sizes.Size | str = sizes.text_lg,
53
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
54
- fonts.GoogleFont("Inconsolata"), "Arial", "sans-serif",
55
- ),
56
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
57
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
58
- ),
59
- ):
60
- super().__init__(
61
- primary_hue=primary_hue,
62
- secondary_hue=secondary_hue,
63
- neutral_hue=neutral_hue,
64
- text_size=text_size,
65
- font=font,
66
- font_mono=font_mono,
67
- )
68
- super().set(
69
- background_fill_primary="*primary_50",
70
- background_fill_primary_dark="*primary_900",
71
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
72
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
73
- button_primary_text_color="black",
74
- button_primary_text_color_hover="white",
75
- button_primary_background_fill="linear-gradient(90deg, *secondary_400, *secondary_400)",
76
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_600)",
77
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
78
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
79
- button_secondary_text_color="black",
80
- button_secondary_text_color_hover="white",
81
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
82
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
83
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
84
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
85
- button_cancel_background_fill=f"linear-gradient(90deg, {colors.red_gray.c400}, {colors.red_gray.c500})",
86
- button_cancel_background_fill_dark=f"linear-gradient(90deg, {colors.red_gray.c700}, {colors.red_gray.c800})",
87
- button_cancel_background_fill_hover=f"linear-gradient(90deg, {colors.red_gray.c500}, {colors.red_gray.c600})",
88
- button_cancel_background_fill_hover_dark=f"linear-gradient(90deg, {colors.red_gray.c800}, {colors.red_gray.c900})",
89
- button_cancel_text_color="white",
90
- button_cancel_text_color_dark="white",
91
- button_cancel_text_color_hover="white",
92
- button_cancel_text_color_hover_dark="white",
93
- slider_color="*secondary_300",
94
- slider_color_dark="*secondary_600",
95
- block_title_text_weight="600",
96
- block_border_width="3px",
97
- block_shadow="*shadow_drop_lg",
98
- button_primary_shadow="*shadow_drop_lg",
99
- button_large_padding="11px",
100
- color_accent_soft="*primary_100",
101
- block_label_background_fill="*primary_200",
102
- )
103
-
104
- thistle_theme = ThistleTheme()
105
-
106
- css = """
107
- #main-title h1 {
108
- font-size: 2.3em !important;
109
- }
110
- #output-title h2 {
111
- font-size: 2.1em !important;
112
- }
113
- :root {
114
- --color-grey-50: #f9fafb;
115
- --banner-background: var(--secondary-400);
116
- --banner-text-color: var(--primary-100);
117
- --banner-background-dark: var(--secondary-800);
118
- --banner-text-color-dark: var(--primary-100);
119
- --banner-chrome-height: calc(16px + 43px);
120
- --chat-chrome-height-wide-no-banner: 320px;
121
- --chat-chrome-height-narrow-no-banner: 450px;
122
- --chat-chrome-height-wide: calc(var(--chat-chrome-height-wide-no-banner) + var(--banner-chrome-height));
123
- --chat-chrome-height-narrow: calc(var(--chat-chrome-height-narrow-no-banner) + var(--banner-chrome-height));
124
- }
125
- .banner-message { background-color: var(--banner-background); padding: 5px; margin: 0; border-radius: 5px; border: none; }
126
- .banner-message-text { font-size: 13px; font-weight: bolder; color: var(--banner-text-color) !important; }
127
- body.dark .banner-message { background-color: var(--banner-background-dark) !important; }
128
- body.dark .gradio-container .contain .banner-message .banner-message-text { color: var(--banner-text-color-dark) !important; }
129
- .toast-body { background-color: var(--color-grey-50); }
130
- .html-container:has(.css-styles) { padding: 0; margin: 0; }
131
- .css-styles { height: 0; }
132
- .model-message { text-align: end; }
133
- .model-dropdown-container { display: flex; align-items: center; gap: 10px; padding: 0; }
134
- .user-input-container .multimodal-textbox{ border: none !important; }
135
- .control-button { height: 51px; }
136
- button.cancel { border: var(--button-border-width) solid var(--button-cancel-border-color); background: var(--button-cancel-background-fill); color: var(--button-cancel-text-color); box-shadow: var(--button-cancel-shadow); }
137
- button.cancel:hover, .cancel[disabled] { background: var(--button-cancel-background-fill-hover); color: var(--button-cancel-text-color-hover); }
138
- .opt-out-message { top: 8px; }
139
- .opt-out-message .html-container, .opt-out-checkbox label { font-size: 14px !important; padding: 0 !important; margin: 0 !important; color: var(--neutral-400) !important; }
140
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; max-height: 900px !important; }
141
- div.no-padding { padding: 0 !important; }
142
- @media (max-width: 1280px) { div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-wide)) !important; } }
143
- @media (max-width: 1024px) {
144
- .responsive-row { flex-direction: column; }
145
- .model-message { text-align: start; font-size: 10px !important; }
146
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
147
- div.block.chatbot { height: calc(100svh - var(--chat-chrome-height-narrow)) !important; }
148
- }
149
- @media (max-width: 400px) {
150
- .responsive-row { flex-direction: column; }
151
- .model-message { text-align: start; font-size: 10px !important; }
152
- .model-dropdown-container { flex-direction: column; align-items: flex-start; }
153
- div.block.chatbot { max-height: 360px !important; }
154
- }
155
- @media (max-height: 932px) { .chatbot { max-height: 500px !important; } }
156
- @media (max-height: 1280px) { div.block.chatbot { max-height: 800px !important; } }
157
- """
158
-
159
  MAX_MAX_NEW_TOKENS = 4096
160
  DEFAULT_MAX_NEW_TOKENS = 2048
 
 
 
 
161
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
162
 
163
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
@@ -168,32 +41,83 @@ print("cuda device count:", torch.cuda.device_count())
168
  if torch.cuda.is_available():
169
  print("current device:", torch.cuda.current_device())
170
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
 
171
  print("Using device:", device)
 
 
 
 
172
 
173
- MODEL_ID_Q3VL = "Qwen/Qwen3-VL-30B-A3B-Instruct"
174
- processor_q3vl = AutoProcessor.from_pretrained(MODEL_ID_Q3VL, trust_remote_code=True, use_fast=False)
175
- model_q3vl = Qwen3VLMoeForConditionalGeneration.from_pretrained(
176
- MODEL_ID_Q3VL,
 
177
  trust_remote_code=True,
178
- dtype=torch.float16
179
  ).to(device).eval()
180
 
181
- def extract_gif_frames(gif_path: str):
182
- if not gif_path:
183
- return []
184
- with Image.open(gif_path) as gif:
185
- total_frames = gif.n_frames
186
- frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
187
- frames = []
188
- for i in frame_indices:
189
- gif.seek(i)
190
- frames.append(gif.convert("RGB").copy())
191
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  def downsample_video(video_path):
 
 
 
 
194
  vidcap = cv2.VideoCapture(video_path)
195
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
196
  frames = []
 
197
  frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
198
  for i in frame_indices:
199
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
@@ -201,70 +125,54 @@ def downsample_video(video_path):
201
  if success:
202
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
203
  pil_image = Image.fromarray(image)
204
- frames.append(pil_image)
 
205
  vidcap.release()
206
  return frames
207
 
208
- def convert_pdf_to_images(file_path: str, dpi: int = 200):
209
- if not file_path:
210
- return []
211
- images = []
212
- pdf_document = fitz.open(file_path)
213
- zoom = dpi / 72.0
214
- mat = fitz.Matrix(zoom, zoom)
215
- for page_num in range(len(pdf_document)):
216
- page = pdf_document.load_page(page_num)
217
- pix = page.get_pixmap(matrix=mat)
218
- img_data = pix.tobytes("png")
219
- images.append(Image.open(BytesIO(img_data)))
220
- pdf_document.close()
221
- return images
222
-
223
- def get_initial_pdf_state() -> Dict[str, Any]:
224
- return {"pages": [], "total_pages": 0, "current_page_index": 0}
225
-
226
- def load_and_preview_pdf(file_path: Optional[str]) -> Tuple[Optional[Image.Image], Dict[str, Any], str]:
227
- state = get_initial_pdf_state()
228
- if not file_path:
229
- return None, state, '<div style="text-align:center;">No file loaded</div>'
230
- try:
231
- pages = convert_pdf_to_images(file_path)
232
- if not pages:
233
- return None, state, '<div style="text-align:center;">Could not load file</div>'
234
- state["pages"] = pages
235
- state["total_pages"] = len(pages)
236
- page_info_html = f'<div style="text-align:center;">Page 1 / {state["total_pages"]}</div>'
237
- return pages[0], state, page_info_html
238
- except Exception as e:
239
- return None, state, f'<div style="text-align:center;">Failed to load preview: {e}</div>'
240
-
241
- def navigate_pdf_page(direction: str, state: Dict[str, Any]):
242
- if not state or not state["pages"]:
243
- return None, state, '<div style="text-align:center;">No file loaded</div>'
244
- current_index = state["current_page_index"]
245
- total_pages = state["total_pages"]
246
- if direction == "prev":
247
- new_index = max(0, current_index - 1)
248
- elif direction == "next":
249
- new_index = min(total_pages - 1, current_index + 1)
250
- else:
251
- new_index = current_index
252
- state["current_page_index"] = new_index
253
- image_preview = state["pages"][new_index]
254
- page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
255
- return image_preview, state, page_info_html
256
-
257
  @spaces.GPU
258
- def generate_image(text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
 
 
 
 
 
259
  if image is None:
260
  yield "Please upload an image.", "Please upload an image."
261
  return
262
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
263
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
264
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
265
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
267
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
268
  thread.start()
269
  buffer = ""
270
  for new_text in streamer:
@@ -273,98 +181,67 @@ def generate_image(text: str, image: Image.Image, max_new_tokens: int = 1024, te
273
  yield buffer, buffer
274
 
275
  @spaces.GPU
276
- def generate_video(text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
 
 
 
 
 
277
  if video_path is None:
278
  yield "Please upload a video.", "Please upload a video."
279
  return
280
- frames = downsample_video(video_path)
281
- if not frames:
 
282
  yield "Could not process video.", "Could not process video."
283
  return
284
- messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
285
- for frame in frames:
286
- messages[0]["content"].insert(0, {"type": "image"})
287
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
288
- inputs = processor_q3vl(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
289
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
290
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
291
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
292
- thread.start()
293
- buffer = ""
294
- for new_text in streamer:
295
- buffer += new_text
296
- buffer = buffer.replace("<|im_end|>", "")
297
- time.sleep(0.01)
298
- yield buffer, buffer
299
 
300
- @spaces.GPU
301
- def generate_pdf(text: str, state: Dict[str, Any], max_new_tokens: int = 2048, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
302
- if not state or not state["pages"]:
303
- yield "Please upload a PDF file first.", "Please upload a PDF file first."
 
 
 
 
 
 
 
 
 
 
 
 
304
  return
305
- page_images = state["pages"]
306
- full_response = ""
307
- for i, image in enumerate(page_images):
308
- page_header = f"--- Page {i+1}/{len(page_images)} ---\n"
309
- yield full_response + page_header, full_response + page_header
310
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text}]}]
311
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
312
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
313
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
314
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
315
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
316
- thread.start()
317
- page_buffer = ""
318
- for new_text in streamer:
319
- page_buffer += new_text
320
- yield full_response + page_header + page_buffer, full_response + page_header + page_buffer
321
- time.sleep(0.01)
322
- full_response += page_header + page_buffer + "\n\n"
323
 
324
- @spaces.GPU
325
- def generate_caption(image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
326
- if image is None:
327
- yield "Please upload an image to caption.", "Please upload an image to caption."
328
  return
329
- system_prompt = (
330
- "You are an AI assistant that rigorously follows this response protocol: For every input image, your primary "
331
- "task is to write a precise caption that captures the essence of the image in clear, concise, and contextually "
332
- "accurate language. Along with the caption, provide a structured set of attributes describing the visual "
333
- "elements, including details such as objects, people, actions, colors, environment, mood, and other notable "
334
- "characteristics. Ensure captions are precise, neutral, and descriptive, avoiding unnecessary elaboration or "
335
- "subjective interpretation unless explicitly required. Do not reference the rules or instructions in the output; "
336
- "only return the formatted caption, attributes, and class_name."
337
- )
338
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": system_prompt}]}]
339
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
340
- inputs = processor_q3vl(text=[prompt_full], images=[image], return_tensors="pt", padding=True).to(device)
341
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
342
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
343
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
344
- thread.start()
345
- buffer = ""
346
- for new_text in streamer:
347
- buffer += new_text
348
- time.sleep(0.01)
349
- yield buffer, buffer
350
 
351
- @spaces.GPU
352
- def generate_gif(text: str, gif_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
353
- if gif_path is None:
354
- yield "Please upload a GIF.", "Please upload a GIF."
355
- return
356
- frames = extract_gif_frames(gif_path)
357
- if not frames:
358
- yield "Could not process GIF.", "Could not process GIF."
359
- return
360
  messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
361
- for frame in frames:
362
- messages[0]["content"].insert(0, {"type": "image"})
363
- prompt_full = processor_q3vl.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
364
- inputs = processor_q3vl(text=[prompt_full], images=frames, return_tensors="pt", padding=True).to(device)
365
- streamer = TextIteratorStreamer(processor_q3vl, skip_prompt=True, skip_special_tokens=True)
366
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty}
367
- thread = Thread(target=model_q3vl.generate, kwargs=generation_kwargs)
 
 
 
 
 
 
 
 
 
 
368
  thread.start()
369
  buffer = ""
370
  for new_text in streamer:
@@ -372,62 +249,46 @@ def generate_gif(text: str, gif_path: str, max_new_tokens: int = 1024, temperatu
372
  buffer = buffer.replace("<|im_end|>", "")
373
  time.sleep(0.01)
374
  yield buffer, buffer
375
-
376
- image_examples = [["Perform OCR on the image precisely and reconstruct it correctly...", "examples/images/1.jpg"],
377
- ["Caption the image. Describe the safety measures shown in the image. Conclude whether the situation is (safe or unsafe)...", "examples/images/2.jpg"],
378
- ["Solve the problem...", "examples/images/3.png"]]
379
- video_examples = [["Explain the Ad video in detail.", "examples/videos/1.mp4"],
380
- ["Explain the video in detail.", "examples/videos/2.mp4"]]
381
- pdf_examples = [["Extract the content precisely.", "examples/pdfs/doc1.pdf"],
382
- ["Analyze and provide a short report.", "examples/pdfs/doc2.pdf"]]
383
- gif_examples = [["Describe this GIF.", "examples/gifs/1.gif"],
384
- ["Describe this GIF.", "examples/gifs/2.gif"]]
385
- caption_examples = [["https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG"],
386
- ["examples/captions/2.png"], ["examples/captions/3.png"]]
387
-
388
- with gr.Blocks(theme=thistle_theme, css=css) as demo:
389
- pdf_state = gr.State(value=get_initial_pdf_state())
390
- gr.Markdown("# **Qwen-3VL:Multimodal**", elem_id="main-title")
 
 
 
 
 
 
 
 
 
 
391
  with gr.Row():
392
- with gr.Column(scale=2):
393
  with gr.Tabs():
394
  with gr.TabItem("Image Inference"):
395
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
396
  image_upload = gr.Image(type="pil", label="Image", height=290)
397
- image_submit = gr.Button("Submit", variant="primary")
398
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
399
-
400
  with gr.TabItem("Video Inference"):
401
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
402
  video_upload = gr.Video(label="Video", height=290)
403
- video_submit = gr.Button("Submit", variant="primary")
404
  gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
405
 
406
- with gr.TabItem("PDF Inference"):
407
- with gr.Row():
408
- with gr.Column(scale=1):
409
- pdf_query = gr.Textbox(label="Query Input", placeholder="e.g., 'Summarize this document'")
410
- pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
411
- pdf_submit = gr.Button("Submit", variant="primary")
412
- with gr.Column(scale=1):
413
- pdf_preview_img = gr.Image(label="PDF Preview", height=290)
414
- with gr.Row():
415
- prev_page_btn = gr.Button("◀ Previous")
416
- page_info = gr.HTML('<div style="text-align:center;">No file loaded</div>')
417
- next_page_btn = gr.Button("Next ▶")
418
- gr.Examples(examples=pdf_examples, inputs=[pdf_query, pdf_upload])
419
-
420
- with gr.TabItem("Gif Inference"):
421
- gif_query = gr.Textbox(label="Query Input", placeholder="e.g., 'What is happening in this gif?'")
422
- gif_upload = gr.Image(type="filepath", label="Upload GIF", height=290)
423
- gif_submit = gr.Button("Submit", variant="primary")
424
- gr.Examples(examples=gif_examples, inputs=[gif_query, gif_upload])
425
-
426
- with gr.TabItem("Caption"):
427
- caption_image_upload = gr.Image(type="pil", label="Image to Caption", height=290)
428
- caption_submit = gr.Button("Generate Caption", variant="primary")
429
- gr.Examples(examples=caption_examples, inputs=[caption_image_upload])
430
-
431
  with gr.Accordion("Advanced options", open=False):
432
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
433
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
@@ -435,31 +296,33 @@ with gr.Blocks(theme=thistle_theme, css=css) as demo:
435
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
436
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
437
 
438
- with gr.Column(scale=3):
439
- gr.Markdown("## Output", elem_id="output-title")
440
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=14, show_copy_button=True)
441
- with gr.Accordion("(Result.md)", open=False):
442
- markdown_output = gr.Markdown(label="(Result.Md)")
443
-
444
- image_submit.click(fn=generate_image,
445
- inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
446
- outputs=[output, markdown_output])
447
- video_submit.click(fn=generate_video,
448
- inputs=[video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
449
- outputs=[output, markdown_output])
450
- pdf_submit.click(fn=generate_pdf,
451
- inputs=[pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
452
- outputs=[output, markdown_output])
453
- gif_submit.click(fn=generate_gif,
454
- inputs=[gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
455
- outputs=[output, markdown_output])
456
- caption_submit.click(fn=generate_caption,
457
- inputs=[caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
458
- outputs=[output, markdown_output])
459
-
460
- pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])
461
- prev_page_btn.click(fn=lambda s: navigate_pdf_page("prev", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
462
- next_page_btn.click(fn=lambda s: navigate_pdf_page("next", s), inputs=[pdf_state], outputs=[pdf_preview_img, pdf_state, page_info])
 
 
463
 
464
  if __name__ == "__main__":
465
  demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
 
5
  import time
6
  import asyncio
7
  from threading import Thread
 
 
 
8
 
9
  import gradio as gr
10
  import spaces
 
13
  from PIL import Image
14
  import cv2
15
  import requests
 
16
 
17
  from transformers import (
18
+ Qwen2VLForConditionalGeneration,
19
+ Qwen2_5_VLForConditionalGeneration,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
+ AutoModel,
23
+ AutoTokenizer,
24
  )
25
  from transformers.image_utils import load_image
26
 
27
+ # Constants for text generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  MAX_MAX_NEW_TOKENS = 4096
29
  DEFAULT_MAX_NEW_TOKENS = 2048
30
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
+
32
+ # Let the environment (e.g., Hugging Face Spaces) determine the device.
33
+ # This avoids conflicts with the CUDA environment setup by the platform.
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
  print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
 
41
  if torch.cuda.is_available():
42
  print("current device:", torch.cuda.current_device())
43
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
44
+
45
  print("Using device:", device)
46
+ # --- Model Loading ---
47
+
48
+ # To address the warnings, we add `use_fast=False` to ensure we use the
49
+ # processor version the model was originally saved with.
50
 
51
+ # Load DREX-062225-exp
52
+ MODEL_ID_X = "prithivMLmods/DREX-062225-exp"
53
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True, use_fast=False)
54
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
55
+ MODEL_ID_X,
56
  trust_remote_code=True,
57
+ torch_dtype=torch.float16
58
  ).to(device).eval()
59
 
60
+ # Load typhoon-ocr-3b
61
+ MODEL_ID_T = "scb10x/typhoon-ocr-3b"
62
+ processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True, use_fast=False)
63
+ model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
+ MODEL_ID_T,
65
+ trust_remote_code=True,
66
+ torch_dtype=torch.float16
67
+ ).to(device).eval()
68
+
69
+ # Load olmOCR-7B-0225-preview
70
+ MODEL_ID_O = "allenai/olmOCR-7B-0225-preview"
71
+ processor_o = AutoProcessor.from_pretrained(MODEL_ID_O, trust_remote_code=True, use_fast=False)
72
+ model_o = Qwen2VLForConditionalGeneration.from_pretrained(
73
+ MODEL_ID_O,
74
+ trust_remote_code=True,
75
+ torch_dtype=torch.float16
76
+ ).to(device).eval()
77
+
78
+ # Load Lumian-VLR-7B-Thinking
79
+ MODEL_ID_J = "prithivMLmods/Lumian-VLR-7B-Thinking"
80
+ SUBFOLDER = "think-preview"
81
+ processor_j = AutoProcessor.from_pretrained(MODEL_ID_J, trust_remote_code=True, subfolder=SUBFOLDER, use_fast=False)
82
+ model_j = Qwen2_5_VLForConditionalGeneration.from_pretrained(
83
+ MODEL_ID_J,
84
+ trust_remote_code=True,
85
+ subfolder=SUBFOLDER,
86
+ torch_dtype=torch.float16
87
+ ).to(device).eval()
88
+
89
+ # Load openbmb/MiniCPM-V-4
90
+ MODEL_ID_V4 = 'openbmb/MiniCPM-V-4'
91
+ model_v4 = AutoModel.from_pretrained(
92
+ MODEL_ID_V4,
93
+ trust_remote_code=True,
94
+ torch_dtype=torch.bfloat16,
95
+ # Using 'sdpa' can sometimes cause issues in certain environments,
96
+ # letting transformers choose the default is safer.
97
+ # attn_implementation='sdpa'
98
+ ).eval().to(device)
99
+ tokenizer_v4 = AutoTokenizer.from_pretrained(MODEL_ID_V4, trust_remote_code=True, use_fast=False)
100
+
101
+ # --- Refactored Model Dictionary ---
102
+ # This simplifies model selection in the generation functions.
103
+ MODELS = {
104
+ "DREX-062225-7B-exp": (processor_x, model_x),
105
+ "Typhoon-OCR-3B": (processor_t, model_t),
106
+ "olmOCR-7B-0225-preview": (processor_o, model_o),
107
+ "Lumian-VLR-7B-Thinking": (processor_j, model_j),
108
+ }
109
+
110
 
111
  def downsample_video(video_path):
112
+ """
113
+ Downsamples the video to evenly spaced frames.
114
+ Each frame is returned as a PIL image along with its timestamp.
115
+ """
116
  vidcap = cv2.VideoCapture(video_path)
117
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
118
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
119
  frames = []
120
+ # Use a maximum of 10 frames to avoid excessive memory usage
121
  frame_indices = np.linspace(0, total_frames - 1, min(total_frames, 10), dtype=int)
122
  for i in frame_indices:
123
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
 
125
  if success:
126
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
127
  pil_image = Image.fromarray(image)
128
+ timestamp = round(i / fps, 2)
129
+ frames.append((pil_image, timestamp))
130
  vidcap.release()
131
  return frames
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  @spaces.GPU
134
+ def generate_image(model_name: str, text: str, image: Image.Image,
135
+ max_new_tokens: int = 1024,
136
+ temperature: float = 0.6,
137
+ top_p: float = 0.9,
138
+ top_k: int = 50,
139
+ repetition_penalty: float = 1.2):
140
+ """
141
+ Generates responses using the selected model for image input.
142
+ """
143
  if image is None:
144
  yield "Please upload an image.", "Please upload an image."
145
  return
146
+
147
+ # Handle MiniCPM-V-4 separately due to its different API
148
+ if model_name == "openbmb/MiniCPM-V-4":
149
+ msgs = [{'role': 'user', 'content': [image, text]}]
150
+ try:
151
+ answer = model_v4.chat(
152
+ image=image.convert('RGB'), msgs=msgs, tokenizer=tokenizer_v4,
153
+ max_new_tokens=max_new_tokens, temperature=temperature,
154
+ top_p=top_p, repetition_penalty=repetition_penalty,
155
+ )
156
+ yield answer, answer
157
+ except Exception as e:
158
+ yield f"Error: {e}", f"Error: {e}"
159
+ return
160
+
161
+ # Use the dictionary for other models
162
+ if model_name not in MODELS:
163
+ yield "Invalid model selected.", "Invalid model selected."
164
+ return
165
+ processor, model = MODELS[model_name]
166
+
167
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text}]}]
168
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
169
+ inputs = processor(
170
+ text=[prompt_full], images=[image], return_tensors="pt", padding=True,
171
+ truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
172
+ ).to(device)
173
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
174
  generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
175
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
176
  thread.start()
177
  buffer = ""
178
  for new_text in streamer:
 
181
  yield buffer, buffer
182
 
183
  @spaces.GPU
184
+ def generate_video(model_name: str, text: str, video_path: str,
185
+ max_new_tokens: int = 1024,
186
+ temperature: float = 0.6,
187
+ top_p: float = 0.9,
188
+ top_k: int = 50,
189
+ repetition_penalty: float = 1.2):
190
+ """
191
+ Generates responses using the selected model for video input.
192
+ """
193
  if video_path is None:
194
  yield "Please upload a video.", "Please upload a video."
195
  return
196
+
197
+ frames_with_ts = downsample_video(video_path)
198
+ if not frames_with_ts:
199
  yield "Could not process video.", "Could not process video."
200
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ # Handle MiniCPM-V-4 separately
203
+ if model_name == "openbmb/MiniCPM-V-4":
204
+ images = [frame for frame, ts in frames_with_ts]
205
+ # For video, the prompt includes the text and then all the image frames
206
+ content = [text] + images
207
+ msgs = [{'role': 'user', 'content': content}]
208
+ try:
209
+ # The .chat API still takes a single image argument, typically the first frame
210
+ answer = model_v4.chat(
211
+ image=images[0].convert('RGB'), msgs=msgs, tokenizer=tokenizer_v4,
212
+ max_new_tokens=max_new_tokens, temperature=temperature,
213
+ top_p=top_p, repetition_penalty=repetition_penalty,
214
+ )
215
+ yield answer, answer
216
+ except Exception as e:
217
+ yield f"Error: {e}", f"Error: {e}"
218
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
+ # Use the dictionary for other models
221
+ if model_name not in MODELS:
222
+ yield "Invalid model selected.", "Invalid model selected."
 
223
  return
224
+ processor, model = MODELS[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ # Prepare messages for Qwen-style models
 
 
 
 
 
 
 
 
227
  messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
228
+ images_for_processor = []
229
+ for frame, timestamp in frames_with_ts:
230
+ messages[0]["content"].append({"type": "image", "image": frame})
231
+ images_for_processor.append(frame)
232
+
233
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
234
+ inputs = processor(
235
+ text=[prompt_full], images=images_for_processor, return_tensors="pt", padding=True,
236
+ truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH
237
+ ).to(device)
238
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
239
+ generation_kwargs = {
240
+ **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens,
241
+ "do_sample": True, "temperature": temperature, "top_p": top_p,
242
+ "top_k": top_k, "repetition_penalty": repetition_penalty,
243
+ }
244
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
245
  thread.start()
246
  buffer = ""
247
  for new_text in streamer:
 
249
  buffer = buffer.replace("<|im_end|>", "")
250
  time.sleep(0.01)
251
  yield buffer, buffer
252
+
253
+
254
+ # Define examples for image and video inference
255
+ image_examples = [
256
+ ["Describe the safety measures in the image. Conclude (Safe / Unsafe)..", "images/5.jpg"],
257
+ ["Convert this page to doc [markdown] precisely.", "images/3.png"],
258
+ ["Convert this page to doc [markdown] precisely.", "images/4.png"],
259
+ ["Explain the creativity in the image.", "images/6.jpg"],
260
+ ["Convert this page to doc [markdown] precisely.", "images/1.png"],
261
+ ["Convert chart to OTSL.", "images/2.png"]
262
+ ]
263
+
264
+ video_examples = [
265
+ ["Explain the video in detail.", "videos/2.mp4"],
266
+ ["Explain the ad in detail.", "videos/1.mp4"]
267
+ ]
268
+
269
+ css = """
270
+ .submit-btn { background-color: #2980b9 !important; color: white !important; }
271
+ .submit-btn:hover { background-color: #3498db !important; }
272
+ .canvas-output { border: 2px solid #4682B4; border-radius: 10px; padding: 20px; }
273
+ """
274
+
275
+ # Create the Gradio Interface
276
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
277
+ gr.Markdown("# **[Multimodal VLM Thinking](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
278
  with gr.Row():
279
+ with gr.Column():
280
  with gr.Tabs():
281
  with gr.TabItem("Image Inference"):
282
  image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
283
  image_upload = gr.Image(type="pil", label="Image", height=290)
284
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
285
  gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
 
286
  with gr.TabItem("Video Inference"):
287
  video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
288
  video_upload = gr.Video(label="Video", height=290)
289
+ video_submit = gr.Button("Submit", elem_classes="submit-btn")
290
  gr.Examples(examples=video_examples, inputs=[video_query, video_upload])
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  with gr.Accordion("Advanced options", open=False):
293
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
294
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
 
296
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
297
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
298
 
299
+ with gr.Column():
300
+ with gr.Column(elem_classes="canvas-output"):
301
+ gr.Markdown("## Output")
302
+ output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=5, show_copy_button=True)
303
+ with gr.Accordion("(Result.md)", open=False):
304
+ markdown_output = gr.Markdown(label="(Result.Md)")
305
+ model_choice = gr.Radio(
306
+ choices=["Lumian-VLR-7B-Thinking", "openbmb/MiniCPM-V-4", "Typhoon-OCR-3B", "DREX-062225-7B-exp", "olmOCR-7B-0225-preview"],
307
+ label="Select Model",
308
+ value="Lumian-VLR-7B-Thinking"
309
+ )
310
+ gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Multimodal-VLM-Thinking/discussions)")
311
+ gr.Markdown("> [MiniCPM-V 4.0](https://huggingface.co/openbmb/MiniCPM-V-4) is the latest efficient model in the MiniCPM-V series. The model is built based on SigLIP2-400M and MiniCPM4-3B with a total of 4.1B parameters. It inherits the strong single-image, multi-image and video understanding performance of MiniCPM-V 2.6 with largely improved efficiency. [Lumian-VLR-7B-Thinking](https://huggingface.co/prithivMLmods/Lumian-VLR-7B-Thinking) is a high-fidelity vision-language reasoning model built on Qwen2.5-VL-7B-Instruct, designed for fine-grained multimodal understanding, video reasoning, and document comprehension through explicit grounded reasoning.")
312
+ gr.Markdown("> [olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview) is a 7B parameter open large model designed for OCR tasks with robust text extraction, especially in complex document layouts. [Typhoon-ocr-3b](https://huggingface.co/scb10x/typhoon-ocr-3b) is a 3B parameter OCR model optimized for efficient and accurate optical character recognition in challenging conditions.")
313
+ gr.Markdown("> [DREX-062225-exp](https://huggingface.co/prithivMLmods/DREX-062225-exp) is an experimental multimodal model emphasizing strong document reading and extraction capabilities combined with vision-language understanding to support detailed document parsing and reasoning tasks.")
314
+ gr.Markdown("> ⚠️ Note: Video inference performance can vary significantly between models.")
315
+
316
+ image_submit.click(
317
+ fn=generate_image,
318
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
319
+ outputs=[output, markdown_output]
320
+ )
321
+ video_submit.click(
322
+ fn=generate_video,
323
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
324
+ outputs=[output, markdown_output]
325
+ )
326
 
327
  if __name__ == "__main__":
328
  demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)