jamepark3922 commited on
Commit
3e58654
Β·
0 Parent(s):

init: gui image demo

Browse files
.gitattributes ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
5
+ # example-images/* !text !filter !merge !diff
6
+ # example-videos/* !text !filter !merge !diff
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ .gradio/
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Molmo-Point Demo
3
+ emoji: πŸ‘†
4
+ colorFrom: indigo
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 6.3.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Molmo-Point - Image & Video Pointing & Tracking
12
+ ---
13
+
14
+ ## Acknowledgements
15
+
16
+ Parts of this demo were adapted from [Molmo2-HF-Demo](https://huggingface.co/spaces/prithivMLmods/Molmo2-HF-Demo) by prithivMLmods. Thank you for the great work!
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ from PIL import Image, ImageDraw, ImageFile
10
+ from transformers import AutoModelForImageTextToText, AutoProcessor
11
+
12
+ import gradio as gr
13
+ import spaces
14
+ from molmo_utils import process_vision_info
15
+
16
+ from typing import Iterable
17
+ from gradio.themes import Soft
18
+ from gradio.themes.utils import colors, fonts, sizes
19
+
20
+ Image.MAX_IMAGE_PIXELS = None
21
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
22
+
23
+ # ── Constants ──────────────────────────────────────────────────────────────────
24
+
25
+ MODEL_ID = "allenai/MolmoPoint-Img-8B"
26
+ MAX_IMAGE_SIZE = 512
27
+ POINT_SIZE = 0.01
28
+ MAX_NEW_TOKENS = 2048
29
+
30
+ COLORS = [
31
+ "rgb(255, 100, 180)",
32
+ "rgb(100, 180, 255)",
33
+ "rgb(180, 255, 100)",
34
+ "rgb(255, 180, 100)",
35
+ "rgb(100, 255, 180)",
36
+ "rgb(180, 100, 255)",
37
+ "rgb(255, 255, 100)",
38
+ "rgb(100, 255, 255)",
39
+ "rgb(255, 120, 120)",
40
+ "rgb(120, 255, 255)",
41
+ "rgb(255, 255, 120)",
42
+ "rgb(255, 120, 255)",
43
+ ]
44
+
45
+ # ── Model loading ──────────────────────────────────────────────────────────────
46
+
47
+ print(f"Loading {MODEL_ID}...")
48
+ processor = AutoProcessor.from_pretrained(
49
+ MODEL_ID,
50
+ trust_remote_code=True,
51
+ padding_side="left",
52
+ )
53
+
54
+ model = AutoModelForImageTextToText.from_pretrained(
55
+ MODEL_ID,
56
+ trust_remote_code=True,
57
+ dtype="bfloat16",
58
+ device_map="auto",
59
+ )
60
+ print("Model loaded successfully.")
61
+
62
+ # ── Helper functions ───────────────────────────────────────────────────────────
63
+
64
+
65
+ def cast_float_bf16(t: torch.Tensor):
66
+ if torch.is_floating_point(t):
67
+ t = t.to(torch.bfloat16)
68
+ return t
69
+
70
+
71
+ def draw_points(image, points):
72
+ if isinstance(image, np.ndarray):
73
+ annotation = PIL.Image.fromarray(image)
74
+ else:
75
+ annotation = image.copy()
76
+ draw = ImageDraw.Draw(annotation)
77
+ w, h = annotation.size
78
+ size = max(5, int(max(w, h) * POINT_SIZE))
79
+ for i, (x, y) in enumerate(points):
80
+ color = COLORS[0]
81
+ draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None)
82
+ return annotation
83
+
84
+
85
+ def format_points_list(points):
86
+ """Format extracted points as a flat Python list string."""
87
+ if not points:
88
+ return "[]"
89
+ rows = []
90
+ for object_id, ix, x, y in points:
91
+ rows.append(f"[{int(object_id)}, {int(ix)}, {float(x):.1f}, {float(y):.1f}]")
92
+ return "[" + ", ".join(rows) + "]"
93
+
94
+
95
+ # ── Inference functions ────────────────────────────────────────────────────────
96
+
97
+
98
+ @spaces.GPU
99
+ def process_images(user_text, input_images, max_tokens):
100
+ if not input_images:
101
+ return "Please upload at least one image.", [], "[]"
102
+
103
+ pil_images = []
104
+ for img_path in input_images:
105
+ if isinstance(img_path, tuple):
106
+ img_path = img_path[0]
107
+ pil_images.append(Image.open(img_path).convert("RGB"))
108
+
109
+ # Build messages
110
+ content = [dict(type="text", text=user_text)]
111
+ for img in pil_images:
112
+ content.append(dict(type="image", image=img))
113
+ messages = [{"role": "user", "content": content}]
114
+
115
+ # Process inputs
116
+ images, _, _ = process_vision_info(messages)
117
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
118
+ print(f"Prompt: {text}")
119
+
120
+ inputs = processor(
121
+ images=images,
122
+ text=text,
123
+ padding=True,
124
+ return_tensors="pt",
125
+ return_pointing_metadata=True,
126
+ )
127
+ metadata = inputs.pop("metadata")
128
+ inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()}
129
+
130
+ # Generate
131
+ with torch.inference_mode():
132
+ with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
133
+ output = model.generate(
134
+ **inputs,
135
+ logits_processor=model.build_logit_processor_from_inputs(inputs),
136
+ max_new_tokens=int(max_tokens),
137
+ temperature=0
138
+ )
139
+
140
+ generated_tokens = output[0, inputs["input_ids"].size(1):]
141
+ generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
142
+
143
+ # Extract points
144
+ points = model.extract_image_points(
145
+ generated_text,
146
+ metadata["token_pooling"],
147
+ metadata["subpatch_mapping"],
148
+ metadata["image_sizes"],
149
+ )
150
+
151
+ points_table = format_points_list(points)
152
+
153
+ print(f"Output text: {generated_text}")
154
+ print("Extracted points:", points_table)
155
+
156
+ if points:
157
+ group_by_index = defaultdict(list)
158
+ for object_id, ix, x, y in points:
159
+ group_by_index[ix].append((x, y))
160
+ annotated = []
161
+ for ix, pts in group_by_index.items():
162
+ annotated.append(draw_points(images[ix], pts))
163
+ return generated_text, annotated, points_table
164
+
165
+ return generated_text, pil_images, points_table
166
+
167
+
168
+ # ── Gradio UI ──────────────────────────────────────────────────────────────────
169
+
170
+ css = """
171
+ #col-container {
172
+ margin: 0 auto;
173
+ max-width: 960px;
174
+ }
175
+ #main-title h1 {font-size: 2.3em !important;}
176
+ #input_image image {
177
+ object-fit: contain !important;
178
+ }
179
+ .gallery-item img {
180
+ border: none !important;
181
+ outline: none !important;
182
+ }
183
+ """
184
+
185
+ with gr.Blocks() as demo:
186
+ gr.Markdown("# **MolmoPoint-Img-8B Demo (GUI-Specialized)**", elem_id="main-title")
187
+ gr.Markdown(
188
+ "Image pointing using the "
189
+ "[MolmoPoint-Img-8B](https://huggingface.co/allenai/MolmoPoint-Img-8B) model. Specialized for pointing in GUI images."
190
+ )
191
+
192
+ with gr.Row():
193
+ # ── LEFT COLUMN: Inputs ──
194
+ with gr.Column():
195
+ images_input = gr.Gallery(
196
+ label="Input Images", elem_id="input_image", type="filepath", height=MAX_IMAGE_SIZE,
197
+ )
198
+
199
+ input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text")
200
+
201
+ max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=MAX_NEW_TOKENS)
202
+
203
+ with gr.Row():
204
+ submit_button = gr.Button("Submit", variant="primary", scale=3)
205
+ clear_all_button = gr.ClearButton(
206
+ components=[images_input, input_text], value="Clear All", scale=1,
207
+ )
208
+
209
+ # ── RIGHT COLUMN: Outputs ──
210
+ with gr.Column():
211
+ with gr.Tabs():
212
+ with gr.TabItem("Output Text"):
213
+ output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10)
214
+ with gr.TabItem("Extracted Points"):
215
+ output_points = gr.Textbox(
216
+ label="Extracted Points ([[id, index, x, y]])", lines=15,
217
+ )
218
+
219
+ with gr.Group():
220
+ gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*")
221
+ output_annotations_img = gr.Gallery(label="Annotated Images", height=MAX_IMAGE_SIZE)
222
+
223
+ # ── Examples ──
224
+ with gr.Group():
225
+ gr.Markdown("### Image Examples")
226
+ gr.Examples(
227
+ examples=[
228
+ [["example-images/boat1.jpeg", "example-images/boat2.jpeg"], "Point to the boats."],
229
+ [["example-images/messy1.jpg", "example-images/messy2.jpg", "example-images/messy3.jpg", "example-images/messy4.jpg"], "Point to the scissors."],
230
+ ],
231
+ inputs=[images_input, input_text],
232
+ label="Image Pointing Examples",
233
+ )
234
+
235
+ submit_button.click(
236
+ fn=process_images,
237
+ inputs=[input_text, images_input, max_tok_slider],
238
+ outputs=[output_text, output_annotations_img, output_points],
239
+ )
240
+
241
+ if __name__ == "__main__":
242
+ demo.launch(css=css, mcp_server=True, ssr_mode=False, show_error=True)
example-images/boat1.jpeg ADDED

Git LFS Details

  • SHA256: 1652fe0880f870989ac390c704ece359acef36ecde5b423fcc32f9181e7c374f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
example-images/boat2.jpeg ADDED

Git LFS Details

  • SHA256: 2a974235aed23bef201d15f32de63c25098c020d61e3625663fb515af8acbe3c
  • Pointer size: 132 Bytes
  • Size of remote file: 3.12 MB
example-images/messy1.jpg ADDED

Git LFS Details

  • SHA256: 0810f7ed899a9a90e49923241cf577f3712eac9e8e5d52360d54a9a0d11b079b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
example-images/messy2.jpg ADDED

Git LFS Details

  • SHA256: 12b2e240b935d23644e311b7249410afb8025a842d38d06ebee014195d6da6a9
  • Pointer size: 131 Bytes
  • Size of remote file: 255 kB
example-images/messy3.jpg ADDED

Git LFS Details

  • SHA256: 7ef063cf698a947efd4224f4d95b36e23c7605a27b6a21bbd0d496cbd95afbfc
  • Pointer size: 131 Bytes
  • Size of remote file: 234 kB
example-images/messy4.jpg ADDED

Git LFS Details

  • SHA256: b3776aadcc39769a233bf185797b4984aa0fffdcef087fec32e4e1fb75712dec
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=23.0.0
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git@v4.57.1
2
+ git+https://github.com/huggingface/accelerate.git
3
+ torch==2.8.0
4
+ torchvision
5
+ pillow
6
+ einops
7
+ decord2
8
+ molmo_utils
9
+ opencv-python
10
+ numpy
11
+ gradio
12
+ spaces
13
+ kernels
14
+ hf_xet