prithivMLmods commited on
Commit
05bb57b
·
verified ·
1 Parent(s): 51342e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -57
app.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  import cv2
9
  import tempfile
10
  import shutil
 
11
  from PIL import Image
12
  from typing import Iterable
13
  from gradio.themes import Soft
@@ -16,46 +17,53 @@ from gradio.themes.utils import colors, fonts, sizes
16
  # ---------------------------------------------------------
17
  # 1. ENVIRONMENT SETUP & REPO CLONING
18
  # ---------------------------------------------------------
19
- # Define the repository path
20
  REPO_URL = "https://github.com/facebookresearch/sam-3d-body.git"
21
  REPO_DIR = "sam-3d-body"
22
 
23
  def setup_sam_3d_env():
24
- """Clones the repo and sets up paths."""
 
 
 
25
  # 1. Clone if not exists
26
  if not os.path.exists(REPO_DIR):
27
  print(f"Cloning SAM 3D Body repository from {REPO_URL}...")
28
  try:
29
  subprocess.run(["git", "clone", REPO_URL], check=True)
30
- # Install the package in editable mode to handle internal imports
31
- print("Installing sam-3d-body package...")
32
  subprocess.run([sys.executable, "-m", "pip", "install", "-e", REPO_DIR], check=True)
 
 
 
33
  except subprocess.CalledProcessError as e:
34
  print(f"Error during setup: {e}")
35
  return False
36
 
37
- # 2. Add paths to sys.path
38
  repo_abs_path = os.path.abspath(REPO_DIR)
39
  notebook_path = os.path.join(repo_abs_path, "notebook")
40
 
41
- # Add repo root (for sam_3d_body package)
42
  if repo_abs_path not in sys.path:
43
  sys.path.insert(0, repo_abs_path)
 
44
 
45
- # Add notebook folder (for utils.py)
46
  if notebook_path not in sys.path:
47
  sys.path.insert(0, notebook_path)
 
48
 
49
  return True
50
 
51
- # Run setup
52
  env_ready = setup_sam_3d_env()
53
 
54
  # ---------------------------------------------------------
55
  # 2. IMPORTS
56
  # ---------------------------------------------------------
57
 
58
- # Import SAM3 (Transformers)
59
  try:
60
  from transformers import Sam3Processor, Sam3Model
61
  SAM3_AVAILABLE = True
@@ -63,24 +71,24 @@ except ImportError:
63
  print("Warning: transformers library not found or outdated. SAM3 will be disabled.")
64
  SAM3_AVAILABLE = False
65
 
66
- # Import SAM 3D Body Utils
 
 
67
  SAM3D_AVAILABLE = False
 
68
  if env_ready:
69
  try:
70
- # Import specific functions from the notebook/utils.py
71
- # Note: We rely on the path insertion above to find 'utils'
72
- from utils import (
73
- setup_sam_3d_body,
74
- setup_visualizer,
75
- visualize_2d_results,
76
- visualize_3d_mesh,
77
- save_mesh_results
78
- )
79
  SAM3D_AVAILABLE = True
80
  print("SAM 3D Body utils imported successfully.")
81
  except ImportError as e:
82
  print(f"Error importing SAM 3D Body utils: {e}")
83
- print("Ensure requirements are installed (pytorch3d, opencv, etc.)")
 
 
84
 
85
  # ---------------------------------------------------------
86
  # 3. THEME DEFINITION
@@ -154,7 +162,7 @@ print(f"Using device: {device}")
154
  # 4. LOAD MODELS
155
  # ---------------------------------------------------------
156
 
157
- # --- Load SAM3 ---
158
  sam3_model = None
159
  sam3_processor = None
160
  if SAM3_AVAILABLE:
@@ -166,21 +174,27 @@ if SAM3_AVAILABLE:
166
  except Exception as e:
167
  print(f"Error loading SAM3: {e}")
168
 
169
- # --- Load SAM 3D Body ---
170
  sam3d_estimator = None
171
  sam3d_visualizer = None
172
 
173
  if SAM3D_AVAILABLE:
174
  try:
175
- print("Loading SAM 3D Body Estimator...")
176
- # Note: This might require huggingface_hub login if the repo is gated,
177
- # but facebook/sam-3d-body-dinov3 is usually public.
178
- sam3d_estimator = setup_sam_3d_body(hf_repo_id="facebook/sam-3d-body-dinov3")
179
- sam3d_visualizer = setup_visualizer()
180
- print("SAM 3D Body Loaded.")
 
 
 
181
  except Exception as e:
182
  print(f"Error loading SAM 3D Body model: {e}")
 
183
  SAM3D_AVAILABLE = False
 
 
184
 
185
  # ---------------------------------------------------------
186
  # 5. INFERENCE FUNCTIONS
@@ -188,6 +202,7 @@ if SAM3D_AVAILABLE:
188
 
189
  @spaces.GPU
190
  def segment_image(input_image, text_prompt, threshold=0.5):
 
191
  if input_image is None:
192
  raise gr.Error("Please upload an image.")
193
  if not text_prompt:
@@ -221,66 +236,80 @@ def segment_image(input_image, text_prompt, threshold=0.5):
221
 
222
  @spaces.GPU
223
  def process_3d_body(input_image):
 
224
  if input_image is None:
225
  raise gr.Error("Please upload an image.")
 
226
  if not SAM3D_AVAILABLE or sam3d_estimator is None:
227
- raise gr.Error("SAM 3D Body libraries or model not available (Check logs for import errors).")
228
 
229
- # Prepare Image
230
  img_np = np.array(input_image.convert("RGB"))
231
  img_cv2 = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
232
 
233
- # The utils/estimator usually requires a file path
234
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
235
  tmp_path = tmp_file.name
236
  cv2.imwrite(tmp_path, img_cv2)
237
 
238
  try:
239
- # Run Inference
240
  print(f"Processing 3D Body for {tmp_path}...")
 
 
 
241
  outputs = sam3d_estimator.process_one_image(tmp_path)
242
 
243
  if not outputs:
244
  return None, None, None, "No people detected."
245
 
246
- # 1. 2D Visuals
247
- vis_results_2d = visualize_2d_results(img_cv2, outputs, sam3d_visualizer)
248
- # Handle case if visualize_2d_results returns list of images (one per person)
249
- if isinstance(vis_results_2d, list) and len(vis_results_2d) > 0:
250
- # Just take the first one or combine them?
251
- # Usually it returns cropped visuals. Let's assume list of images.
 
252
  res_2d_rgb = cv2.cvtColor(vis_results_2d[0], cv2.COLOR_BGR2RGB)
253
  else:
254
  res_2d_rgb = img_np
255
 
256
- # 2. 3D Overlay Visuals
257
- mesh_results_img = visualize_3d_mesh(img_cv2, outputs, sam3d_estimator.faces)
258
- if isinstance(mesh_results_img, list) and len(mesh_results_img) > 0:
259
- res_3d_overlay_rgb = cv2.cvtColor(mesh_results_img[0], cv2.COLOR_BGR2RGB)
 
260
  else:
261
  res_3d_overlay_rgb = img_np
262
 
263
- # 3. Save PLY for Model3D
 
264
  output_dir = tempfile.mkdtemp()
265
  image_name = "gradio_mesh"
266
 
267
  # save_mesh_results returns list of paths to .ply files
268
- ply_files = save_mesh_results(img_cv2, outputs, sam3d_estimator.faces, output_dir, image_name)
 
 
 
 
 
 
269
 
270
  ply_path = None
271
  if ply_files and len(ply_files) > 0:
272
  ply_path = ply_files[0] # Return the first mesh found
273
 
274
- status = f"Detected {len(outputs)} person(s). Showing result for Person 0."
275
 
276
- return res_2d_rgb, res_3d_overlay_rgb, ply_path, status
277
 
278
  except Exception as e:
279
  import traceback
280
  traceback.print_exc()
281
- raise gr.Error(f"Inference failed: {e}")
282
 
283
  finally:
 
284
  if os.path.exists(tmp_path):
285
  os.remove(tmp_path)
286
 
@@ -294,6 +323,7 @@ css = """
294
  max-width: 1200px;
295
  }
296
  #main-title h1 {font-size: 2.1em !important; text-align: center;}
 
297
  """
298
 
299
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
@@ -301,7 +331,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
301
  gr.Markdown("# **SAM Integrated Vision Suite**", elem_id="main-title")
302
 
303
  with gr.Tabs():
304
- # TAB 1: SEGMENTATION
305
  with gr.Tab("SAM3 Segmentation"):
306
  gr.Markdown("Segment objects using **SAM3** with text prompts.")
307
  with gr.Row():
@@ -315,7 +345,10 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
315
 
316
  t1_btn.click(segment_image, [t1_input, t1_prompt, t1_thresh], [t1_output])
317
 
318
- # TAB 2: 3D BODY
 
 
 
319
  with gr.Tab("SAM 3D Body"):
320
  gr.Markdown("Detect human bodies and reconstruct **3D Meshes**.")
321
 
@@ -328,12 +361,12 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
328
  with gr.Column(scale=2):
329
  with gr.Row():
330
  t2_vis_2d = gr.Image(label="2D Detection", type="numpy")
331
- t2_vis_overlay = gr.Image(label="Mesh Overlay", type="numpy")
332
 
333
  t2_model_3d = gr.Model3D(
334
  label="Interactive 3D Mesh",
335
  clear_color=[0.0, 0.0, 0.0, 0.0],
336
- camera_position=[0, 0, 2.5]
337
  )
338
 
339
  t2_btn.click(
@@ -341,12 +374,6 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
341
  inputs=[t2_input],
342
  outputs=[t2_vis_2d, t2_vis_overlay, t2_model_3d, t2_status]
343
  )
344
-
345
- gr.Examples(
346
- examples=[["examples/player.jpg"], ["examples/dancing.jpg"]],
347
- inputs=[t2_input],
348
- label="3D Body Examples"
349
- )
350
 
351
  if __name__ == "__main__":
352
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)
 
8
  import cv2
9
  import tempfile
10
  import shutil
11
+ import glob
12
  from PIL import Image
13
  from typing import Iterable
14
  from gradio.themes import Soft
 
17
  # ---------------------------------------------------------
18
  # 1. ENVIRONMENT SETUP & REPO CLONING
19
  # ---------------------------------------------------------
 
20
  REPO_URL = "https://github.com/facebookresearch/sam-3d-body.git"
21
  REPO_DIR = "sam-3d-body"
22
 
23
  def setup_sam_3d_env():
24
+ """
25
+ Clones the repo, installs dependencies, and fixes sys.path
26
+ so that 'utils', 'tools', and 'sam_3d_body' can be imported.
27
+ """
28
  # 1. Clone if not exists
29
  if not os.path.exists(REPO_DIR):
30
  print(f"Cloning SAM 3D Body repository from {REPO_URL}...")
31
  try:
32
  subprocess.run(["git", "clone", REPO_URL], check=True)
33
+ print("Installing sam-3d-body package in editable mode...")
34
+ # We install using pip to resolve internal package dependencies
35
  subprocess.run([sys.executable, "-m", "pip", "install", "-e", REPO_DIR], check=True)
36
+
37
+ # Install other requirements usually needed
38
+ subprocess.run([sys.executable, "-m", "pip", "install", "trimesh", "opencv-python", "matplotlib"], check=True)
39
  except subprocess.CalledProcessError as e:
40
  print(f"Error during setup: {e}")
41
  return False
42
 
43
+ # 2. Add Critical Paths to sys.path
44
  repo_abs_path = os.path.abspath(REPO_DIR)
45
  notebook_path = os.path.join(repo_abs_path, "notebook")
46
 
47
+ # CRITICAL: Add repo root first so 'import tools' and 'import sam_3d_body' work inside utils.py
48
  if repo_abs_path not in sys.path:
49
  sys.path.insert(0, repo_abs_path)
50
+ print(f"Added to sys.path: {repo_abs_path}")
51
 
52
+ # Add notebook folder so we can 'import utils'
53
  if notebook_path not in sys.path:
54
  sys.path.insert(0, notebook_path)
55
+ print(f"Added to sys.path: {notebook_path}")
56
 
57
  return True
58
 
59
+ # Run setup immediately
60
  env_ready = setup_sam_3d_env()
61
 
62
  # ---------------------------------------------------------
63
  # 2. IMPORTS
64
  # ---------------------------------------------------------
65
 
66
+ # --- Import SAM3 (Segmentation) ---
67
  try:
68
  from transformers import Sam3Processor, Sam3Model
69
  SAM3_AVAILABLE = True
 
71
  print("Warning: transformers library not found or outdated. SAM3 will be disabled.")
72
  SAM3_AVAILABLE = False
73
 
74
+ # --- Import SAM 3D Body Utils ---
75
+ # We use a specific alias to avoid confusion with standard python utils
76
+ sam3d_utils = None
77
  SAM3D_AVAILABLE = False
78
+
79
  if env_ready:
80
  try:
81
+ # Now that sys.path is fixed, this import should work
82
+ # and utils.py will successfully find 'tools' and 'sam_3d_body'
83
+ import utils as sam3d_utils_module
84
+ sam3d_utils = sam3d_utils_module
 
 
 
 
 
85
  SAM3D_AVAILABLE = True
86
  print("SAM 3D Body utils imported successfully.")
87
  except ImportError as e:
88
  print(f"Error importing SAM 3D Body utils: {e}")
89
+ print("This usually happens if 'tools' or 'sam_3d_body' cannot be found by utils.py")
90
+ import traceback
91
+ traceback.print_exc()
92
 
93
  # ---------------------------------------------------------
94
  # 3. THEME DEFINITION
 
162
  # 4. LOAD MODELS
163
  # ---------------------------------------------------------
164
 
165
+ # --- 1. Load SAM3 ---
166
  sam3_model = None
167
  sam3_processor = None
168
  if SAM3_AVAILABLE:
 
174
  except Exception as e:
175
  print(f"Error loading SAM3: {e}")
176
 
177
+ # --- 2. Load SAM 3D Body ---
178
  sam3d_estimator = None
179
  sam3d_visualizer = None
180
 
181
  if SAM3D_AVAILABLE:
182
  try:
183
+ print("Loading SAM 3D Body Estimator (this may take a moment)...")
184
+ # Initialize estimator using the utility function from the repo
185
+ # Note: detector_name="vitdet" is default, requiring 'tools' import to work
186
+ sam3d_estimator = sam3d_utils.setup_sam_3d_body(
187
+ hf_repo_id="facebook/sam-3d-body-dinov3",
188
+ device=device
189
+ )
190
+ sam3d_visualizer = sam3d_utils.setup_visualizer()
191
+ print("SAM 3D Body Loaded Successfully.")
192
  except Exception as e:
193
  print(f"Error loading SAM 3D Body model: {e}")
194
+ # If it fails, we set the flag to False so the UI handles it gracefully
195
  SAM3D_AVAILABLE = False
196
+ import traceback
197
+ traceback.print_exc()
198
 
199
  # ---------------------------------------------------------
200
  # 5. INFERENCE FUNCTIONS
 
202
 
203
  @spaces.GPU
204
  def segment_image(input_image, text_prompt, threshold=0.5):
205
+ """Handler for Tab 1: Segmentation"""
206
  if input_image is None:
207
  raise gr.Error("Please upload an image.")
208
  if not text_prompt:
 
236
 
237
  @spaces.GPU
238
  def process_3d_body(input_image):
239
+ """Handler for Tab 2: 3D Body Reconstruction"""
240
  if input_image is None:
241
  raise gr.Error("Please upload an image.")
242
+
243
  if not SAM3D_AVAILABLE or sam3d_estimator is None:
244
+ raise gr.Error("SAM 3D Body libraries or model failed to load. Check console logs.")
245
 
246
+ # Convert PIL to CV2 BGR for the estimator
247
  img_np = np.array(input_image.convert("RGB"))
248
  img_cv2 = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
249
 
250
+ # The estimator.process_one_image expects a file path
251
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
252
  tmp_path = tmp_file.name
253
  cv2.imwrite(tmp_path, img_cv2)
254
 
255
  try:
 
256
  print(f"Processing 3D Body for {tmp_path}...")
257
+
258
+ # 1. Run Inference
259
+ # process_one_image is a method of the estimator class inside sam-3d-body
260
  outputs = sam3d_estimator.process_one_image(tmp_path)
261
 
262
  if not outputs:
263
  return None, None, None, "No people detected."
264
 
265
+ # 2. 2D Keypoints Visualization
266
+ vis_results_2d = sam3d_utils.visualize_2d_results(img_cv2, outputs, sam3d_visualizer)
267
+ # Combine if multiple, or just take first for display simplicity.
268
+ # Usually vis_results_2d is a list of full images with drawings.
269
+ if vis_results_2d:
270
+ # For simplicity, if multiple people, the last one overrides or we assume 1 main person
271
+ # Ideally we'd grid them, but for Gradio output, let's take the first result's image
272
  res_2d_rgb = cv2.cvtColor(vis_results_2d[0], cv2.COLOR_BGR2RGB)
273
  else:
274
  res_2d_rgb = img_np
275
 
276
+ # 3. 3D Overlay Visualization
277
+ # visualize_3d_mesh returns a wide image (Original | Overlay | White | Side)
278
+ mesh_results_wide = sam3d_utils.visualize_3d_mesh(img_cv2, outputs, sam3d_estimator.faces)
279
+ if mesh_results_wide:
280
+ res_3d_overlay_rgb = cv2.cvtColor(mesh_results_wide[0], cv2.COLOR_BGR2RGB)
281
  else:
282
  res_3d_overlay_rgb = img_np
283
 
284
+ # 4. Save PLY for Model3D
285
+ # Create a unique directory for this run
286
  output_dir = tempfile.mkdtemp()
287
  image_name = "gradio_mesh"
288
 
289
  # save_mesh_results returns list of paths to .ply files
290
+ ply_files = sam3d_utils.save_mesh_results(
291
+ img_cv2,
292
+ outputs,
293
+ sam3d_estimator.faces,
294
+ output_dir,
295
+ image_name
296
+ )
297
 
298
  ply_path = None
299
  if ply_files and len(ply_files) > 0:
300
  ply_path = ply_files[0] # Return the first mesh found
301
 
302
+ status_msg = f"Detected {len(outputs)} person(s). Displaying Person 0."
303
 
304
+ return res_2d_rgb, res_3d_overlay_rgb, ply_path, status_msg
305
 
306
  except Exception as e:
307
  import traceback
308
  traceback.print_exc()
309
+ raise gr.Error(f"Inference failed: {str(e)}")
310
 
311
  finally:
312
+ # Cleanup input temp file
313
  if os.path.exists(tmp_path):
314
  os.remove(tmp_path)
315
 
 
323
  max-width: 1200px;
324
  }
325
  #main-title h1 {font-size: 2.1em !important; text-align: center;}
326
+ .gradio-container {min-height: 0px !important;}
327
  """
328
 
329
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
 
331
  gr.Markdown("# **SAM Integrated Vision Suite**", elem_id="main-title")
332
 
333
  with gr.Tabs():
334
+ # ================= TAB 1: SEGMENTATION =================
335
  with gr.Tab("SAM3 Segmentation"):
336
  gr.Markdown("Segment objects using **SAM3** with text prompts.")
337
  with gr.Row():
 
345
 
346
  t1_btn.click(segment_image, [t1_input, t1_prompt, t1_thresh], [t1_output])
347
 
348
+ # Optional examples if files exist
349
+ # gr.Examples(...)
350
+
351
+ # ================= TAB 2: 3D BODY =================
352
  with gr.Tab("SAM 3D Body"):
353
  gr.Markdown("Detect human bodies and reconstruct **3D Meshes**.")
354
 
 
361
  with gr.Column(scale=2):
362
  with gr.Row():
363
  t2_vis_2d = gr.Image(label="2D Detection", type="numpy")
364
+ t2_vis_overlay = gr.Image(label="3D Visualization (Original | Overlay | White | Side)", type="numpy")
365
 
366
  t2_model_3d = gr.Model3D(
367
  label="Interactive 3D Mesh",
368
  clear_color=[0.0, 0.0, 0.0, 0.0],
369
+ camera_position=[0, 0, 4.0]
370
  )
371
 
372
  t2_btn.click(
 
374
  inputs=[t2_input],
375
  outputs=[t2_vis_2d, t2_vis_overlay, t2_model_3d, t2_status]
376
  )
 
 
 
 
 
 
377
 
378
  if __name__ == "__main__":
379
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)