prithivMLmods commited on
Commit
fb3d061
Β·
verified Β·
1 Parent(s): 636fdab

update app [..]

Browse files
Files changed (1) hide show
  1. app.py +398 -55
app.py CHANGED
@@ -1,67 +1,410 @@
1
  import gradio as gr
2
- import requests
3
- import spaces
4
- from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForSeq2SeqLM
6
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
8
 
9
- # Constants
10
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
11
 
12
- # --- Model Loading ---
13
- # Loading all three T5Gemma-2 variants
14
- print("Loading google/t5gemma-2-4b-4b...")
15
- MODEL_ID_4B = "google/t5gemma-2-4b-4b"
16
- processor_4b = AutoProcessor.from_pretrained(MODEL_ID_4B, trust_remote_code=True)
17
- model_4b = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID_4B, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval()
18
 
19
- print("Loading google/t5gemma-2-1b-1b...")
20
- MODEL_ID_1B = "google/t5gemma-2-1b-1b"
21
- processor_1b = AutoProcessor.from_pretrained(MODEL_ID_1B, trust_remote_code=True)
22
- model_1b = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID_1B, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval()
 
 
 
23
 
24
- print("Loading google/t5gemma-2-270m-270m...")
25
- MODEL_ID_270M = "google/t5gemma-2-270m-270m"
26
- processor_270m = AutoProcessor.from_pretrained(MODEL_ID_270M, trust_remote_code=True)
27
- model_270m = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID_270M, torch_dtype=torch.float16, trust_remote_code=True).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  @spaces.GPU
30
- def process_image(model_choice, image):
31
- # Determine which model and processor to use based on the model_choice
32
- if model_choice == "t5gemma-2-4b-4b":
33
- processor = processor_4b
34
- model = model_4b
35
- elif model_choice == "t5gemma-2-1b-1b":
36
- processor = processor_1b
37
- model = model_1b
38
- elif model_choice == "t5gemma-2-270m-270m":
39
- processor = processor_270m
40
- model = model_270m
41
-
42
- # Define the prompt
43
- prompt = "<start_of_image> in this image, there is"
44
-
45
- # Process the image and generate the description
46
- model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
47
- generation = model.generate(**model_inputs, max_new_tokens=20, do_sample=False)
48
- return processor.decode(generation[0])
49
 
50
- with gr.Blocks() as demo:
51
- gr.Markdown("# **t5gemma-2-Demo**")
52
- with gr.Row():
53
- model_choice = gr.Radio(
54
- choices=["t5gemma-2-4b-4b", "t5gemma-2-1b-1b", "t5gemma-2-270m-270m"],
55
- label="Select Model",
56
- value="t5gemma-2-4b-4b"
57
- )
58
- image_input = gr.Image(type="pil", label="Upload Image")
59
- text_output = gr.Textbox(label="Output")
60
- submit_btn = gr.Button("Submit")
61
- submit_btn.click(
62
- fn=process_image,
63
- inputs=[model_choice, image_input],
64
- outputs=text_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- demo.launch()
 
 
1
  import gradio as gr
 
 
 
 
2
  import torch
3
+ from transformers import AutoProcessor, AutoModelForImageTextToText
4
+ from PIL import Image, ImageDraw
5
+ import numpy as np
6
+ import spaces
7
+ import cv2
8
+ import re
9
+ import os
10
+ from molmo_utils import process_vision_info
11
+
12
+ from typing import Iterable
13
+ from gradio.themes import Soft
14
+ from gradio.themes.utils import colors, fonts, sizes
15
+
16
+ colors.orange_red = colors.Color(
17
+ name="orange_red",
18
+ c50="#FFF0E5",
19
+ c100="#FFE0CC",
20
+ c200="#FFC299",
21
+ c300="#FFA366",
22
+ c400="#FF8533",
23
+ c500="#FF4500",
24
+ c600="#E63E00",
25
+ c700="#CC3700",
26
+ c800="#B33000",
27
+ c900="#992900",
28
+ c950="#802200",
29
+ )
30
+
31
+ class OrangeRedTheme(Soft):
32
+ def __init__(
33
+ self,
34
+ *,
35
+ primary_hue: colors.Color | str = colors.gray,
36
+ secondary_hue: colors.Color | str = colors.orange_red,
37
+ neutral_hue: colors.Color | str = colors.slate,
38
+ text_size: sizes.Size | str = sizes.text_lg,
39
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
40
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
41
+ ),
42
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
43
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
44
+ ),
45
+ ):
46
+ super().__init__(
47
+ primary_hue=primary_hue,
48
+ secondary_hue=secondary_hue,
49
+ neutral_hue=neutral_hue,
50
+ text_size=text_size,
51
+ font=font,
52
+ font_mono=font_mono,
53
+ )
54
+ super().set(
55
+ background_fill_primary="*primary_50",
56
+ background_fill_primary_dark="*primary_900",
57
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
58
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
59
+ button_primary_text_color="white",
60
+ button_primary_text_color_hover="white",
61
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
62
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
63
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
64
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
65
+ button_secondary_text_color="black",
66
+ button_secondary_text_color_hover="white",
67
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
68
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
69
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
70
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
71
+ slider_color="*secondary_500",
72
+ slider_color_dark="*secondary_600",
73
+ block_title_text_weight="600",
74
+ block_border_width="3px",
75
+ block_shadow="*shadow_drop_lg",
76
+ button_primary_shadow="*shadow_drop_lg",
77
+ button_large_padding="11px",
78
+ color_accent_soft="*primary_100",
79
+ block_label_background_fill="*primary_200",
80
+ )
81
+
82
+ orange_red_theme = OrangeRedTheme()
83
+
84
+ MODEL_ID = "prithivMLmods/Qwen3-VL-4B-Instruct-abliterated-v1"
85
 
86
+ print(f"Loading {MODEL_ID}...")
87
+ processor = AutoProcessor.from_pretrained(
88
+ MODEL_ID,
89
+ trust_remote_code=True,
90
+ dtype="auto",
91
+ device_map="auto"
92
+ )
93
 
94
+ model = AutoModelForImageTextToText.from_pretrained(
95
+ MODEL_ID,
96
+ trust_remote_code=True,
97
+ dtype="auto",
98
+ device_map="auto"
99
+ )
100
+ print("Model loaded successfully.")
101
 
102
+ COORD_REGEX = re.compile(rf"<(?:points|tracks).*? coords=\"([0-9\t:;, .]+)\"/?>")
103
+ FRAME_REGEX = re.compile(rf"(?:^|\t|:|,|;)([0-9\.]+) ([0-9\. ]+)")
104
+ POINTS_REGEX = re.compile(r"([0-9]+) ([0-9]{3,4}) ([0-9]{3,4})")
 
 
 
105
 
106
+ def _points_from_num_str(text, image_w, image_h):
107
+ for points in POINTS_REGEX.finditer(text):
108
+ ix, x, y = points.group(1), points.group(2), points.group(3)
109
+ # our points format assume coordinates are scaled by 1000
110
+ x, y = float(x)/1000*image_w, float(y)/1000*image_h
111
+ if 0 <= x <= image_w and 0 <= y <= image_h:
112
+ yield ix, x, y
113
 
114
+ def extract_multi_image_points(text, image_w, image_h, extract_ids=False):
115
+ """Extract pointing coordinates for images."""
116
+ all_points = []
117
+ # Handle list of dimensions for multi-image
118
+ if isinstance(image_w, (list, tuple)) and isinstance(image_h, (list, tuple)):
119
+ assert len(image_w) == len(image_h)
120
+ diff_res = True
121
+ else:
122
+ diff_res = False
123
+
124
+ for coord in COORD_REGEX.finditer(text):
125
+ for point_grp in FRAME_REGEX.finditer(coord.group(1)):
126
+ # For images, frame_id corresponds to the image index (1-based in text usually, but we need to check)
127
+ frame_id = int(point_grp.group(1)) if diff_res else float(point_grp.group(1))
128
+
129
+ if diff_res:
130
+ # bounds check
131
+ idx = int(frame_id) - 1
132
+ if 0 <= idx < len(image_w):
133
+ w, h = (image_w[idx], image_h[idx])
134
+ else:
135
+ continue
136
+ else:
137
+ w, h = (image_w, image_h)
138
+
139
+ for idx, x, y in _points_from_num_str(point_grp.group(2), w, h):
140
+ if extract_ids:
141
+ all_points.append((frame_id, idx, x, y))
142
+ else:
143
+ all_points.append((frame_id, x, y))
144
+ return all_points
145
+
146
+ def extract_video_points(text, image_w, image_h, extract_ids=False):
147
+ """Extract video pointing coordinates (t, x, y)."""
148
+ all_points = []
149
+ for coord in COORD_REGEX.finditer(text):
150
+ for point_grp in FRAME_REGEX.finditer(coord.group(1)):
151
+ frame_id = float(point_grp.group(1)) # This is usually timestamp in seconds or frame index
152
+ w, h = (image_w, image_h)
153
+ for idx, x, y in _points_from_num_str(point_grp.group(2), w, h):
154
+ if extract_ids:
155
+ all_points.append((frame_id, idx, x, y))
156
+ else:
157
+ all_points.append((frame_id, x, y))
158
+ return all_points
159
+
160
+ def draw_points_on_images(images, points):
161
+ """Draws points on a list of PIL Images."""
162
+ annotated_images = [img.copy() for img in images]
163
+
164
+ # Points format: [(image_index_1_based, x, y), ...]
165
+ for p in points:
166
+ img_idx = int(p[0]) - 1 # Convert 1-based index to 0-based
167
+ x, y = p[1], p[2]
168
+
169
+ if 0 <= img_idx < len(annotated_images):
170
+ draw = ImageDraw.Draw(annotated_images[img_idx])
171
+ r = 10 # radius
172
+ # Draw a red circle with outline
173
+ draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=3)
174
+ draw.text((x+r, y), "target", fill="red")
175
+
176
+ return annotated_images
177
+
178
+ def draw_points_on_video(video_path, points, original_width, original_height):
179
+ """
180
+ Draws points on video.
181
+ points format: [(timestamp_seconds, x, y), ...]
182
+ """
183
+ cap = cv2.VideoCapture(video_path)
184
+ fps = cap.get(cv2.CAP_PROP_FPS)
185
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
186
+ vid_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
187
+ vid_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
188
+
189
+ # Scale factor if Molmo processed a resized version vs original video file
190
+ # Note: Molmo points are usually scaled to the dimensions passed in metadata.
191
+ # If the video metadata passed to Molmo matches the file, x/y are correct for the file.
192
+ scale_x = vid_w / original_width
193
+ scale_y = vid_h / original_height
194
+
195
+ # Organize points by frame index for faster lookup
196
+ # Molmo outputs timestamps. frame_idx = timestamp * fps
197
+ points_by_frame = {}
198
+ for t, x, y in points:
199
+ f_idx = int(round(t * fps))
200
+ if f_idx not in points_by_frame:
201
+ points_by_frame[f_idx] = []
202
+ points_by_frame[f_idx].append((x * scale_x, y * scale_y))
203
+
204
+ # Output setup
205
+ output_path = "annotated_video.mp4"
206
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
207
+ out = cv2.VideoWriter(output_path, fourcc, fps, (vid_w, vid_h))
208
+
209
+ current_frame = 0
210
+ while cap.isOpened():
211
+ ret, frame = cap.read()
212
+ if not ret:
213
+ break
214
+
215
+ # Draw points if they exist for this frame (or nearby frames to persist visualization slightly)
216
+ # Simple approach: Exact frame match
217
+ if current_frame in points_by_frame:
218
+ for px, py in points_by_frame[current_frame]:
219
+ cv2.circle(frame, (int(px), int(py)), 10, (0, 0, 255), -1)
220
+ cv2.circle(frame, (int(px), int(py)), 12, (255, 255, 255), 2)
221
+
222
+ out.write(frame)
223
+ current_frame += 1
224
+
225
+ cap.release()
226
+ out.release()
227
+ return output_path
228
 
229
  @spaces.GPU
230
+ def process_images(user_text, input_images):
231
+ if not input_images:
232
+ return "Please upload at least one image.", None
233
+
234
+ # input_images from Gradio Gallery is a list of (path, caption) tuples
235
+ # OR a list of paths depending on type. We requested 'filepath' type in Gradio.
236
+ pil_images = []
237
+ for img_path in input_images:
238
+ # If type='filepath' in Gallery, img_path is just the string path
239
+ # If using old gradio versions it might be a tuple.
240
+ if isinstance(img_path, tuple):
241
+ img_path = img_path[0]
242
+ pil_images.append(Image.open(img_path).convert("RGB"))
 
 
 
 
 
 
243
 
244
+ # Construct messages
245
+ content = [dict(type="text", text=user_text)]
246
+ for img in pil_images:
247
+ content.append(dict(type="image", image=img))
248
+
249
+ messages = [{"role": "user", "content": content}]
250
+
251
+ # Process inputs
252
+ inputs = processor.apply_chat_template(
253
+ messages,
254
+ tokenize=True,
255
+ add_generation_prompt=True,
256
+ return_tensors="pt",
257
+ return_dict=True,
258
+ )
259
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
260
+
261
+ # Generate
262
+ with torch.inference_mode():
263
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
264
+
265
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
266
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
267
+
268
+ # Check for points
269
+ widths = [img.width for img in pil_images]
270
+ heights = [img.height for img in pil_images]
271
+
272
+ points = extract_multi_image_points(generated_text, widths, heights)
273
+
274
+ output_gallery = pil_images
275
+ if points:
276
+ output_gallery = draw_points_on_images(pil_images, points)
277
+
278
+ return generated_text, output_gallery
279
+
280
+ @spaces.GPU
281
+ def process_video(user_text, video_path):
282
+ if not video_path:
283
+ return "Please upload a video.", None
284
+
285
+ # Construct messages
286
+ # Note: Molmo expects a URL or a path it can read.
287
+ messages = [
288
+ {
289
+ "role": "user",
290
+ "content": [
291
+ dict(type="text", text=user_text),
292
+ dict(type="video", video=video_path),
293
+ ],
294
+ }
295
+ ]
296
+
297
+ # Process Vision Info (Molmo Utils)
298
+ # This samples the video and prepares tensors
299
+ _, videos, video_kwargs = process_vision_info(messages)
300
+ videos, video_metadatas = zip(*videos)
301
+ videos, video_metadatas = list(videos), list(video_metadatas)
302
+
303
+ # Chat Template
304
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
305
+
306
+ # Inputs
307
+ inputs = processor(
308
+ videos=videos,
309
+ video_metadata=video_metadatas,
310
+ text=text,
311
+ padding=True,
312
+ return_tensors="pt",
313
+ **video_kwargs,
314
  )
315
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
316
+
317
+ # Generate
318
+ with torch.inference_mode():
319
+ generated_ids = model.generate(**inputs, max_new_tokens=2048)
320
+
321
+ generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
322
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
323
+
324
+ # Point/Track processing
325
+ vid_meta = video_metadatas[0] # Assuming single video
326
+ points = extract_video_points(generated_text, image_w=vid_meta["width"], image_h=vid_meta["height"])
327
+
328
+ annotated_video_path = None
329
+ if points:
330
+ print(f"Found {len(points)} points/track-coords. Annotating video...")
331
+ annotated_video_path = draw_points_on_video(
332
+ video_path,
333
+ points,
334
+ original_width=vid_meta["width"],
335
+ original_height=vid_meta["height"]
336
+ )
337
+
338
+ # Return original video if no points found, otherwise annotated
339
+ out_vid = annotated_video_path if annotated_video_path else video_path
340
+
341
+ return generated_text, out_vid
342
+
343
+ css="""
344
+ #col-container {
345
+ margin: 0 auto;
346
+ max-width: 960px;
347
+ }
348
+ #main-title h1 {font-size: 2.1em !important;}
349
+ """
350
+
351
+ with gr.Blocks() as demo:
352
+ gr.Markdown("# **Molmo2 HF DemoπŸ–₯️**", elem_id="main-title")
353
+ gr.Markdown("Perform multi-image QA, pointing, general video QA, and tracking using the [Molmo2](https://huggingface.co/allenai/Molmo2-8B) multimodal model.")
354
+
355
+ with gr.Tabs():
356
+ with gr.Tab("Images (QA & Pointing)"):
357
+ with gr.Row():
358
+ with gr.Column():
359
+ img_input = gr.Gallery(label="Input Images", type="filepath", height=400)
360
+ img_prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'Describe this' or 'Point to the boats'")
361
+ img_btn = gr.Button("Run Image Analysis", variant="primary")
362
+
363
+ with gr.Column():
364
+ img_text_out = gr.Textbox(label="Generated Text", interactive=True, lines=5)
365
+ img_out = gr.Gallery(label="Annotated Images (Pointing if applicable)", height=378)
366
+
367
+ gr.Examples(
368
+ examples=[
369
+ [["example-images/compare1.jpg", "example-images/compare2.jpeg"], "Compare these two images."],
370
+ [["example-images/cat1.jpg", "example-images/cat2.jpg", "example-images/dog1.jpg"], "Point to the cats."],
371
+ [["example-images/candy.JPG"], "Point to all the candies."],
372
+ [["example-images/premium_photo-1691752881339-d78da354ee7e.jpg"], "Point to the girls."],
373
+ ],
374
+ inputs=[img_input, img_prompt],
375
+ label="Image Examples"
376
+ )
377
+ img_btn.click(
378
+ fn=process_images,
379
+ inputs=[img_prompt, img_input],
380
+ outputs=[img_text_out, img_out]
381
+ )
382
+
383
+ with gr.Tab("Video (QA, Pointing & Tracking)"):
384
+ gr.Markdown("**Note:** Video processing takes longer as frames are sampled.")
385
+ with gr.Row():
386
+ with gr.Column():
387
+ vid_input = gr.Video(label="Input Video", format="mp4", height=400)
388
+ vid_prompt = gr.Textbox(label="Prompt", placeholder="e.g. 'What is happening?' or 'Track the player'")
389
+ vid_btn = gr.Button("Run Video Analysis", variant="primary")
390
+
391
+ with gr.Column():
392
+ vid_text_out = gr.Textbox(label="Generated Text", interactive=True, lines=5)
393
+ vid_out = gr.Video(label="Output Video (Annotated if applicable)", height=378)
394
+
395
+ gr.Examples(
396
+ examples=[
397
+ ["example-videos/sample_video.mp4", "Track the football."],
398
+ ["example-videos/drink.mp4", "Explain the video."],
399
+ ],
400
+ inputs=[vid_input, vid_prompt],
401
+ label="Video Examples"
402
+ )
403
+ vid_btn.click(
404
+ fn=process_video,
405
+ inputs=[vid_prompt, vid_input],
406
+ outputs=[vid_text_out, vid_out]
407
+ )
408
 
409
+ if __name__ == "__main__":
410
+ demo.launch(theme=orange_red_theme, css=css, mcp_server=True, ssr_mode=False, show_error=True)