prithivMLmods commited on
Commit
5a34082
·
verified ·
1 Parent(s): f17e9df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -75
app.py CHANGED
@@ -1,15 +1,43 @@
1
  import os
 
2
  import spaces
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
- import random
7
- from PIL import Image, ImageDraw
 
 
8
  from typing import Iterable
9
  from gradio.themes import Soft
10
  from gradio.themes.utils import colors, fonts, sizes
11
  from transformers import Sam3Processor, Sam3Model
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  colors.steel_blue = colors.Color(
14
  name="steel_blue",
15
  c50="#EBF3F8",
@@ -75,119 +103,208 @@ steel_blue_theme = SteelBlueTheme()
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
  print(f"Using device: {device}")
77
 
 
 
 
 
 
78
  try:
79
  print("Loading SAM3 Model and Processor...")
80
- model = Sam3Model.from_pretrained("facebook/sam3").to(device)
81
- processor = Sam3Processor.from_pretrained("facebook/sam3")
82
- print("Model loaded successfully.")
83
-
84
  except Exception as e:
85
- print(f"Error loading model: {e}")
86
- print("Ensure you have the correct libraries installed and access to the model.")
87
- # Fallback/Placeholder for demonstration if model doesn't exist in environment yet
88
- model = None
89
- processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  @spaces.GPU
92
  def segment_image(input_image, text_prompt, threshold=0.5):
 
93
  if input_image is None:
94
  raise gr.Error("Please upload an image.")
95
  if not text_prompt:
96
- raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
97
-
98
- if model is None or processor is None:
99
- raise gr.Error("Model not loaded correctly.")
100
 
101
- # Convert image to RGB
102
  image_pil = input_image.convert("RGB")
 
103
 
104
- # Preprocess
105
- inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
106
-
107
- # Inference
108
  with torch.no_grad():
109
- outputs = model(**inputs)
110
 
111
- # Post-process results
112
- results = processor.post_process_instance_segmentation(
113
  outputs,
114
  threshold=threshold,
115
  mask_threshold=0.5,
116
  target_sizes=inputs.get("original_sizes").tolist()
117
  )[0]
118
 
119
- masks = results['masks'] # Boolean tensor [N, H, W]
120
- scores = results['scores']
121
-
122
- # Prepare for Gradio AnnotatedImage
123
- # Gradio expects (image, [(mask, label), ...])
124
 
125
  annotations = []
126
- masks_np = masks.cpu().numpy()
127
- scores_np = scores.cpu().numpy()
128
-
129
- for i, mask in enumerate(masks_np):
130
- # mask is a boolean array (True/False).
131
- # AnnotatedImage handles the coloring automatically.
132
- # We just pass the mask and a label.
133
- score_val = scores_np[i]
134
- label = f"{text_prompt} ({score_val:.2f})"
135
  annotations.append((mask, label))
136
 
137
- # Return tuple format for AnnotatedImage
138
  return (image_pil, annotations)
139
 
140
- css="""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  #col-container {
142
  margin: 0 auto;
143
- max-width: 980px;
144
  }
145
- #main-title h1 {font-size: 2.1em !important;}
146
  """
147
 
148
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
149
  with gr.Column(elem_id="col-container"):
150
- gr.Markdown(
151
- "# **SAM3 Image Segmentation**",
152
- elem_id="main-title"
153
- )
154
 
155
- gr.Markdown("Segment objects in images using **SAM3** (Segment Anything Model 3) with text prompts.")
156
-
157
- with gr.Row():
158
- with gr.Column(scale=1):
159
- input_image = gr.Image(label="Input Image", type="pil", height=300)
160
- text_prompt = gr.Textbox(
161
- label="Text Prompt",
162
- placeholder="e.g., cat, ear, car wheel...",
 
 
 
 
 
 
 
 
 
 
 
163
  )
164
 
165
- run_button = gr.Button("Segment", variant="primary")
 
 
 
 
 
 
 
166
 
167
- with gr.Column(scale=1.5):
168
- output_image = gr.AnnotatedImage(label="Segmented Output", height=380)
 
169
 
170
  with gr.Row():
171
- threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
172
-
173
- gr.Examples(
174
- examples=[
175
- ["examples/player.jpg", "player in white", 0.5],
176
- ["examples/goldencat.webp", "black cat", 0.4],
177
- ["examples/taxi.jpg", "blue taxi", 0.5],
178
- ],
179
- inputs=[input_image, text_prompt, threshold],
180
- outputs=[output_image],
181
- fn=segment_image,
182
- cache_examples="lazy",
183
- label="Examples"
184
- )
185
 
186
- run_button.click(
187
- fn=segment_image,
188
- inputs=[input_image, text_prompt, threshold],
189
- outputs=[output_image]
190
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  if __name__ == "__main__":
193
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)
 
1
  import os
2
+ import sys
3
  import spaces
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
+ import cv2
8
+ import tempfile
9
+ import shutil
10
+ from PIL import Image
11
  from typing import Iterable
12
  from gradio.themes import Soft
13
  from gradio.themes.utils import colors, fonts, sizes
14
  from transformers import Sam3Processor, Sam3Model
15
 
16
+ # ---------------------------------------------------------
17
+ # 1. SETUP PATHS & CUSTOM IMPORTS
18
+ # ---------------------------------------------------------
19
+ # Attempt to import the specific utils provided in your snippet
20
+ try:
21
+ # Adjust path to find utils.py (assuming it's in parent dir based on your snippet)
22
+ parent_dir = os.path.dirname(os.getcwd())
23
+ if parent_dir not in sys.path:
24
+ sys.path.insert(0, parent_dir)
25
+
26
+ from utils import (
27
+ setup_sam_3d_body,
28
+ setup_visualizer,
29
+ visualize_2d_results,
30
+ visualize_3d_mesh,
31
+ save_mesh_results
32
+ )
33
+ SAM3D_AVAILABLE = True
34
+ except ImportError as e:
35
+ print(f"Warning: SAM 3D Body utils not found ({e}). The 3D Body tab will use placeholder logic.")
36
+ SAM3D_AVAILABLE = False
37
+
38
+ # ---------------------------------------------------------
39
+ # 2. THEME DEFINITION
40
+ # ---------------------------------------------------------
41
  colors.steel_blue = colors.Color(
42
  name="steel_blue",
43
  c50="#EBF3F8",
 
103
  device = "cuda" if torch.cuda.is_available() else "cpu"
104
  print(f"Using device: {device}")
105
 
106
+ # ---------------------------------------------------------
107
+ # 3. MODEL LOADING
108
+ # ---------------------------------------------------------
109
+
110
+ # --- Load SAM3 (Segmentation) ---
111
  try:
112
  print("Loading SAM3 Model and Processor...")
113
+ sam3_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
114
+ sam3_processor = Sam3Processor.from_pretrained("facebook/sam3")
115
+ print("SAM3 Model loaded successfully.")
 
116
  except Exception as e:
117
+ print(f"Error loading SAM3 model: {e}")
118
+ sam3_model = None
119
+ sam3_processor = None
120
+
121
+ # --- Load SAM 3D Body ---
122
+ sam3d_estimator = None
123
+ sam3d_visualizer = None
124
+
125
+ if SAM3D_AVAILABLE:
126
+ try:
127
+ print("Loading SAM 3D Body Estimator...")
128
+ sam3d_estimator = setup_sam_3d_body(hf_repo_id="facebook/sam-3d-body-dinov3")
129
+ sam3d_visualizer = setup_visualizer()
130
+ print("SAM 3D Body Model loaded successfully.")
131
+ except Exception as e:
132
+ print(f"Error loading SAM 3D Body model: {e}")
133
+
134
+ # ---------------------------------------------------------
135
+ # 4. INFERENCE FUNCTIONS
136
+ # ---------------------------------------------------------
137
 
138
  @spaces.GPU
139
  def segment_image(input_image, text_prompt, threshold=0.5):
140
+ """Function for Tab 1: SAM3 Segmentation"""
141
  if input_image is None:
142
  raise gr.Error("Please upload an image.")
143
  if not text_prompt:
144
+ raise gr.Error("Please enter a text prompt.")
145
+ if sam3_model is None or sam3_processor is None:
146
+ raise gr.Error("SAM3 Model not loaded correctly.")
 
147
 
 
148
  image_pil = input_image.convert("RGB")
149
+ inputs = sam3_processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
150
 
 
 
 
 
151
  with torch.no_grad():
152
+ outputs = sam3_model(**inputs)
153
 
154
+ results = sam3_processor.post_process_instance_segmentation(
 
155
  outputs,
156
  threshold=threshold,
157
  mask_threshold=0.5,
158
  target_sizes=inputs.get("original_sizes").tolist()
159
  )[0]
160
 
161
+ masks = results['masks'].cpu().numpy()
162
+ scores = results['scores'].cpu().numpy()
 
 
 
163
 
164
  annotations = []
165
+ for i, mask in enumerate(masks):
166
+ label = f"{text_prompt} ({scores[i]:.2f})"
 
 
 
 
 
 
 
167
  annotations.append((mask, label))
168
 
 
169
  return (image_pil, annotations)
170
 
171
+ @spaces.GPU
172
+ def process_3d_body(input_image):
173
+ """Function for Tab 2: SAM 3D Body"""
174
+ if input_image is None:
175
+ raise gr.Error("Please upload an image.")
176
+ if not SAM3D_AVAILABLE or sam3d_estimator is None:
177
+ raise gr.Error("SAM 3D Body libraries or model not available.")
178
+
179
+ # Convert PIL to CV2 BGR
180
+ img_np = np.array(input_image.convert("RGB"))
181
+ img_cv2 = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
182
+
183
+ # Save temp image for the process_one_image function if it requires a path
184
+ # (Checking the snippet provided: outputs = estimator.process_one_image(image_path))
185
+ # We need a physical path.
186
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
187
+ tmp_path = tmp_file.name
188
+ cv2.imwrite(tmp_path, img_cv2)
189
+
190
+ try:
191
+ # Run Inference
192
+ outputs = sam3d_estimator.process_one_image(tmp_path)
193
+
194
+ if not outputs:
195
+ return None, None, None, "No people detected."
196
+
197
+ # 1. Generate 2D Visualization
198
+ vis_results_2d = visualize_2d_results(img_cv2, outputs, sam3d_visualizer)
199
+ # Taking the first result if multiple people, or combine them
200
+ # Converting the first result to RGB for display
201
+ res_2d_rgb = cv2.cvtColor(vis_results_2d[0], cv2.COLOR_BGR2RGB) if vis_results_2d else img_np
202
+
203
+ # 2. Generate 3D Visualization (Overlay Image)
204
+ mesh_results_img = visualize_3d_mesh(img_cv2, outputs, sam3d_estimator.faces)
205
+ res_3d_overlay_rgb = cv2.cvtColor(mesh_results_img[0], cv2.COLOR_BGR2RGB) if mesh_results_img else img_np
206
+
207
+ # 3. Save PLY Mesh to temp directory for Gradio Model3D
208
+ # Create a unique temp dir
209
+ output_dir = tempfile.mkdtemp()
210
+ image_name = "person_mesh"
211
+
212
+ # This function saves .ply files
213
+ ply_files = save_mesh_results(img_cv2, outputs, sam3d_estimator.faces, output_dir, image_name)
214
+
215
+ ply_path = None
216
+ if ply_files:
217
+ ply_path = ply_files[0] # Return the first person's mesh
218
+
219
+ status = f"Detected {len(outputs)} person(s)."
220
+
221
+ return res_2d_rgb, res_3d_overlay_rgb, ply_path, status
222
+
223
+ finally:
224
+ # Cleanup input temp file
225
+ if os.path.exists(tmp_path):
226
+ os.remove(tmp_path)
227
+
228
+
229
+ # ---------------------------------------------------------
230
+ # 5. GRADIO UI LAYOUT
231
+ # ---------------------------------------------------------
232
+
233
+ css = """
234
  #col-container {
235
  margin: 0 auto;
236
+ max-width: 1200px;
237
  }
238
+ #main-title h1 {font-size: 2.1em !important; text-align: center;}
239
  """
240
 
241
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
242
  with gr.Column(elem_id="col-container"):
243
+ gr.Markdown("# **SAM Integrated Vision Suite**", elem_id="main-title")
 
 
 
244
 
245
+ with gr.Tabs():
246
+ # ================= TAB 1: SEGMENTATION =================
247
+ with gr.Tab("SAM3 Segmentation"):
248
+ gr.Markdown("Segment objects using **SAM3** with text prompts.")
249
+
250
+ with gr.Row():
251
+ with gr.Column(scale=1):
252
+ t1_input_image = gr.Image(label="Input Image", type="pil", height=350)
253
+ t1_text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, ear, car wheel...")
254
+ t1_threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
255
+ t1_run_btn = gr.Button("Segment Image", variant="primary")
256
+
257
+ with gr.Column(scale=1.5):
258
+ t1_output_image = gr.AnnotatedImage(label="Segmented Output", height=450)
259
+
260
+ t1_run_btn.click(
261
+ fn=segment_image,
262
+ inputs=[t1_input_image, t1_text_prompt, t1_threshold],
263
+ outputs=[t1_output_image]
264
  )
265
 
266
+ gr.Examples(
267
+ examples=[
268
+ ["examples/player.jpg", "player", 0.5],
269
+ ["examples/goldencat.webp", "cat", 0.4],
270
+ ],
271
+ inputs=[t1_input_image, t1_text_prompt, t1_threshold],
272
+ label="Segmentation Examples"
273
+ )
274
 
275
+ # ================= TAB 2: 3D BODY =================
276
+ with gr.Tab("SAM 3D Body"):
277
+ gr.Markdown("Detect human bodies and reconstruct **3D Meshes**.")
278
 
279
  with gr.Row():
280
+ with gr.Column(scale=1):
281
+ t2_input_image = gr.Image(label="Input Image", type="pil", height=350)
282
+ t2_run_btn = gr.Button("Generate 3D Body", variant="primary")
283
+ t2_status = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
284
 
285
+ with gr.Column(scale=2):
286
+ with gr.Row():
287
+ t2_output_2d = gr.Image(label="2D Keypoints", type="numpy")
288
+ t2_output_overlay = gr.Image(label="Mesh Overlay", type="numpy")
289
+
290
+ t2_output_3d = gr.Model3D(
291
+ label="Interactive 3D Mesh",
292
+ clear_color=[0.0, 0.0, 0.0, 0.0],
293
+ camera_position=[0, 0, 3]
294
+ )
295
+
296
+ t2_run_btn.click(
297
+ fn=process_3d_body,
298
+ inputs=[t2_input_image],
299
+ outputs=[t2_output_2d, t2_output_overlay, t2_output_3d, t2_status]
300
+ )
301
+
302
+ # Assuming examples exist in the folder
303
+ gr.Examples(
304
+ examples=[["examples/player.jpg"], ["examples/dancing.jpg"]],
305
+ inputs=[t2_input_image],
306
+ label="3D Body Examples"
307
+ )
308
 
309
  if __name__ == "__main__":
310
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)