prithivMLmods commited on
Commit
d5ea556
·
verified ·
1 Parent(s): a94e539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -137
app.py CHANGED
@@ -8,86 +8,82 @@ import torch
8
  import cv2
9
  import tempfile
10
  import shutil
11
- import traceback
12
  from PIL import Image
13
  from typing import Iterable
14
  from gradio.themes import Soft
15
  from gradio.themes.utils import colors, fonts, sizes
16
- from transformers import Sam3Processor, Sam3Model
17
 
18
  # ---------------------------------------------------------
19
- # 0. HEADLESS RENDERING SETUP
20
  # ---------------------------------------------------------
21
- # Essential for running 3D visualizers (pyrender) in cloud environments
22
- os.environ["PYOPENGL_PLATFORM"] = "egl"
23
-
24
- # ---------------------------------------------------------
25
- # 1. SETUP & DYNAMIC IMPORTS
26
- # ---------------------------------------------------------
27
-
28
  REPO_URL = "https://github.com/facebookresearch/sam-3d-body.git"
29
  REPO_DIR = "sam-3d-body"
30
 
31
- def setup_environment():
32
- """Clones repo and sets up Python paths."""
33
- print("--- Checking Environment for SAM 3D Body ---")
34
-
35
- # 1. Clone Repository
36
  if not os.path.exists(REPO_DIR):
37
- print(f"Cloning {REPO_URL}...")
38
  try:
39
  subprocess.run(["git", "clone", REPO_URL], check=True)
40
- print("Repository cloned.")
41
- # Install in editable mode to ensure package discovery works
42
- subprocess.run([sys.executable, "-m", "pip", "install", "-e", f"./{REPO_DIR}"], check=True)
43
- except Exception as e:
44
- print(f"Git clone/Install failed: {e}")
 
45
 
46
  # 2. Add paths to sys.path
47
- repo_abs = os.path.abspath(REPO_DIR)
48
- notebook_abs = os.path.abspath(os.path.join(REPO_DIR, "notebook"))
49
 
50
- if repo_abs not in sys.path:
51
- sys.path.insert(0, repo_abs)
52
- if notebook_abs not in sys.path:
53
- sys.path.insert(0, notebook_abs)
 
 
 
54
 
55
- print(f"Python Paths: {sys.path[:2]}...")
56
 
57
- setup_environment()
 
58
 
59
- # Global variables for models and error tracking
60
- sam3d_estimator = None
61
- sam3d_visualizer = None
62
- sam3d_load_error = None
63
- SAM3D_AVAILABLE = False
64
 
 
65
  try:
66
- # Try importing from the utils file in the cloned repo
67
- # This expects 'notebook/utils.py' to exist
68
- from utils import (
69
- setup_sam_3d_body,
70
- setup_visualizer,
71
- visualize_2d_results,
72
- visualize_3d_mesh,
73
- save_mesh_results
74
- )
75
-
76
- print("Loading SAM 3D Body Estimator (this may take time)...")
77
- # Initialize the model immediately to catch errors early
78
- sam3d_estimator = setup_sam_3d_body(hf_repo_id="facebook/sam-3d-body-dinov3")
79
- sam3d_visualizer = setup_visualizer()
80
- SAM3D_AVAILABLE = True
81
- print("SAM 3D Body Model loaded successfully.")
82
-
83
- except Exception as e:
84
- # Capture the exact error (e.g., missing mmhuman3d)
85
- sam3d_load_error = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
86
- print(f"CRITICAL ERROR loading SAM 3D Body:\n{sam3d_load_error}")
87
- SAM3D_AVAILABLE = False
 
 
88
 
89
  # ---------------------------------------------------------
90
- # 2. THEME DEFINITION
91
  # ---------------------------------------------------------
92
  colors.steel_blue = colors.Color(
93
  name="steel_blue",
@@ -155,31 +151,49 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
155
  print(f"Using device: {device}")
156
 
157
  # ---------------------------------------------------------
158
- # 3. MODEL LOADING (SAM3)
159
  # ---------------------------------------------------------
160
- try:
161
- print("Loading SAM3 Model...")
162
- sam3_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
163
- sam3_processor = Sam3Processor.from_pretrained("facebook/sam3")
164
- print("SAM3 Model loaded successfully.")
165
- except Exception as e:
166
- print(f"Error loading SAM3 model: {e}")
167
- sam3_model = None
168
- sam3_processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # ---------------------------------------------------------
171
- # 4. INFERENCE FUNCTIONS
172
  # ---------------------------------------------------------
173
 
174
  @spaces.GPU
175
  def segment_image(input_image, text_prompt, threshold=0.5):
176
- """Function for Tab 1: SAM3 Segmentation"""
177
  if input_image is None:
178
  raise gr.Error("Please upload an image.")
179
  if not text_prompt:
180
  raise gr.Error("Please enter a text prompt.")
181
- if sam3_model is None or sam3_processor is None:
182
- raise gr.Error("SAM3 Model not loaded correctly.")
183
 
184
  image_pil = input_image.convert("RGB")
185
  inputs = sam3_processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
@@ -204,62 +218,74 @@ def segment_image(input_image, text_prompt, threshold=0.5):
204
 
205
  return (image_pil, annotations)
206
 
 
207
  @spaces.GPU
208
  def process_3d_body(input_image):
209
- """Function for Tab 2: SAM 3D Body"""
210
  if input_image is None:
211
  raise gr.Error("Please upload an image.")
212
-
213
- # Check if initialization failed
214
  if not SAM3D_AVAILABLE or sam3d_estimator is None:
215
- # Raise the specific error captured during startup
216
- error_msg = sam3d_load_error if sam3d_load_error else "Unknown initialization error."
217
- raise gr.Error(f"Model Setup Failed. Logs:\n{error_msg}")
218
 
219
- # Convert PIL to CV2 BGR
220
  img_np = np.array(input_image.convert("RGB"))
221
  img_cv2 = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
222
 
223
- # Helper requires a physical file path
224
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
225
  tmp_path = tmp_file.name
226
  cv2.imwrite(tmp_path, img_cv2)
227
 
228
  try:
229
- print(f"Running inference on {tmp_path}...")
 
230
  outputs = sam3d_estimator.process_one_image(tmp_path)
231
 
232
  if not outputs:
233
  return None, None, None, "No people detected."
234
 
235
- # 1. 2D Vis
236
  vis_results_2d = visualize_2d_results(img_cv2, outputs, sam3d_visualizer)
237
- res_2d_rgb = cv2.cvtColor(vis_results_2d[0], cv2.COLOR_BGR2RGB) if vis_results_2d else img_np
238
-
239
- # 2. 3D Overlay
 
 
 
 
 
 
240
  mesh_results_img = visualize_3d_mesh(img_cv2, outputs, sam3d_estimator.faces)
241
- res_3d_overlay_rgb = cv2.cvtColor(mesh_results_img[0], cv2.COLOR_BGR2RGB) if mesh_results_img else img_np
 
 
 
242
 
243
- # 3. Save PLY
244
  output_dir = tempfile.mkdtemp()
245
- image_name = "gradio_mesh_result"
 
 
246
  ply_files = save_mesh_results(img_cv2, outputs, sam3d_estimator.faces, output_dir, image_name)
247
 
248
- ply_path = ply_files[0] if ply_files else None
 
 
249
 
250
- status = f"Success! Detected {len(outputs)} person(s)."
251
 
252
  return res_2d_rgb, res_3d_overlay_rgb, ply_path, status
253
 
254
  except Exception as e:
 
255
  traceback.print_exc()
256
- raise gr.Error(f"Inference Runtime Error: {str(e)}")
 
257
  finally:
258
  if os.path.exists(tmp_path):
259
  os.remove(tmp_path)
260
 
261
  # ---------------------------------------------------------
262
- # 5. GRADIO UI LAYOUT
263
  # ---------------------------------------------------------
264
 
265
  css = """
@@ -275,73 +301,52 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
275
  gr.Markdown("# **SAM Integrated Vision Suite**", elem_id="main-title")
276
 
277
  with gr.Tabs():
278
- # ================= TAB 1: SEGMENTATION =================
279
  with gr.Tab("SAM3 Segmentation"):
280
  gr.Markdown("Segment objects using **SAM3** with text prompts.")
281
-
282
  with gr.Row():
283
  with gr.Column(scale=1):
284
- t1_input_image = gr.Image(label="Input Image", type="pil", height=350)
285
- t1_text_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, ear, car wheel...")
286
- t1_threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
287
- t1_run_btn = gr.Button("Segment Image", variant="primary")
288
-
289
  with gr.Column(scale=1.5):
290
- t1_output_image = gr.AnnotatedImage(label="Segmented Output", height=450)
291
 
292
- t1_run_btn.click(
293
- fn=segment_image,
294
- inputs=[t1_input_image, t1_text_prompt, t1_threshold],
295
- outputs=[t1_output_image]
296
- )
297
 
298
- # ================= TAB 2: 3D BODY =================
299
  with gr.Tab("SAM 3D Body"):
300
  gr.Markdown("Detect human bodies and reconstruct **3D Meshes**.")
301
 
302
  with gr.Row():
303
  with gr.Column(scale=1):
304
- t2_input_image = gr.Image(label="Input Image", type="pil", height=350)
305
- t2_run_btn = gr.Button("Generate 3D Body", variant="primary")
306
  t2_status = gr.Textbox(label="Status", interactive=False)
307
-
308
- # Warning box if initialization failed
309
- if not SAM3D_AVAILABLE:
310
- gr.Markdown(
311
- "⚠️ **Warning: SAM 3D Body failed to load.**\n"
312
- f"Error: {sam3d_load_error}\n"
313
- "Please check `mmhuman3d` and `mmcv` dependencies.",
314
- elem_classes=["error-box"]
315
- )
316
-
317
  with gr.Column(scale=2):
318
  with gr.Row():
319
- t2_output_2d = gr.Image(label="2D Keypoints", type="numpy")
320
- t2_output_overlay = gr.Image(label="Mesh Overlay", type="numpy")
321
 
322
- t2_output_3d = gr.Model3D(
323
- label="Interactive 3D Mesh (PLY)",
324
  clear_color=[0.0, 0.0, 0.0, 0.0],
325
- camera_position=[0, 0, 3]
326
  )
327
 
328
- t2_run_btn.click(
329
- fn=process_3d_body,
330
- inputs=[t2_input_image],
331
- outputs=[t2_output_2d, t2_output_overlay, t2_output_3d, t2_status]
332
  )
333
 
334
- # Dynamic examples
335
- ex_files = []
336
- if os.path.exists("examples/player.jpg"): ex_files.append(["examples/player.jpg"])
337
- if os.path.exists("examples/dancing.jpg"): ex_files.append(["examples/dancing.jpg"])
338
-
339
- if ex_files:
340
- gr.Examples(
341
- examples=ex_files,
342
- inputs=[t2_input_image],
343
- label="3D Body Examples"
344
- )
345
 
346
  if __name__ == "__main__":
347
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)
 
8
  import cv2
9
  import tempfile
10
  import shutil
 
11
  from PIL import Image
12
  from typing import Iterable
13
  from gradio.themes import Soft
14
  from gradio.themes.utils import colors, fonts, sizes
 
15
 
16
  # ---------------------------------------------------------
17
+ # 1. ENVIRONMENT SETUP & REPO CLONING
18
  # ---------------------------------------------------------
19
+ # Define the repository path
 
 
 
 
 
 
20
  REPO_URL = "https://github.com/facebookresearch/sam-3d-body.git"
21
  REPO_DIR = "sam-3d-body"
22
 
23
+ def setup_sam_3d_env():
24
+ """Clones the repo and sets up paths."""
25
+ # 1. Clone if not exists
 
 
26
  if not os.path.exists(REPO_DIR):
27
+ print(f"Cloning SAM 3D Body repository from {REPO_URL}...")
28
  try:
29
  subprocess.run(["git", "clone", REPO_URL], check=True)
30
+ # Install the package in editable mode to handle internal imports
31
+ print("Installing sam-3d-body package...")
32
+ subprocess.run([sys.executable, "-m", "pip", "install", "-e", REPO_DIR], check=True)
33
+ except subprocess.CalledProcessError as e:
34
+ print(f"Error during setup: {e}")
35
+ return False
36
 
37
  # 2. Add paths to sys.path
38
+ repo_abs_path = os.path.abspath(REPO_DIR)
39
+ notebook_path = os.path.join(repo_abs_path, "notebook")
40
 
41
+ # Add repo root (for sam_3d_body package)
42
+ if repo_abs_path not in sys.path:
43
+ sys.path.insert(0, repo_abs_path)
44
+
45
+ # Add notebook folder (for utils.py)
46
+ if notebook_path not in sys.path:
47
+ sys.path.insert(0, notebook_path)
48
 
49
+ return True
50
 
51
+ # Run setup
52
+ env_ready = setup_sam_3d_env()
53
 
54
+ # ---------------------------------------------------------
55
+ # 2. IMPORTS
56
+ # ---------------------------------------------------------
 
 
57
 
58
+ # Import SAM3 (Transformers)
59
  try:
60
+ from transformers import Sam3Processor, Sam3Model
61
+ SAM3_AVAILABLE = True
62
+ except ImportError:
63
+ print("Warning: transformers library not found or outdated. SAM3 will be disabled.")
64
+ SAM3_AVAILABLE = False
65
+
66
+ # Import SAM 3D Body Utils
67
+ SAM3D_AVAILABLE = False
68
+ if env_ready:
69
+ try:
70
+ # Import specific functions from the notebook/utils.py
71
+ # Note: We rely on the path insertion above to find 'utils'
72
+ from utils import (
73
+ setup_sam_3d_body,
74
+ setup_visualizer,
75
+ visualize_2d_results,
76
+ visualize_3d_mesh,
77
+ save_mesh_results
78
+ )
79
+ SAM3D_AVAILABLE = True
80
+ print("SAM 3D Body utils imported successfully.")
81
+ except ImportError as e:
82
+ print(f"Error importing SAM 3D Body utils: {e}")
83
+ print("Ensure requirements are installed (pytorch3d, opencv, etc.)")
84
 
85
  # ---------------------------------------------------------
86
+ # 3. THEME DEFINITION
87
  # ---------------------------------------------------------
88
  colors.steel_blue = colors.Color(
89
  name="steel_blue",
 
151
  print(f"Using device: {device}")
152
 
153
  # ---------------------------------------------------------
154
+ # 4. LOAD MODELS
155
  # ---------------------------------------------------------
156
+
157
+ # --- Load SAM3 ---
158
+ sam3_model = None
159
+ sam3_processor = None
160
+ if SAM3_AVAILABLE:
161
+ try:
162
+ print("Loading SAM3 Model...")
163
+ sam3_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
164
+ sam3_processor = Sam3Processor.from_pretrained("facebook/sam3")
165
+ print("SAM3 Loaded.")
166
+ except Exception as e:
167
+ print(f"Error loading SAM3: {e}")
168
+
169
+ # --- Load SAM 3D Body ---
170
+ sam3d_estimator = None
171
+ sam3d_visualizer = None
172
+
173
+ if SAM3D_AVAILABLE:
174
+ try:
175
+ print("Loading SAM 3D Body Estimator...")
176
+ # Note: This might require huggingface_hub login if the repo is gated,
177
+ # but facebook/sam-3d-body-dinov3 is usually public.
178
+ sam3d_estimator = setup_sam_3d_body(hf_repo_id="facebook/sam-3d-body-dinov3")
179
+ sam3d_visualizer = setup_visualizer()
180
+ print("SAM 3D Body Loaded.")
181
+ except Exception as e:
182
+ print(f"Error loading SAM 3D Body model: {e}")
183
+ SAM3D_AVAILABLE = False
184
 
185
  # ---------------------------------------------------------
186
+ # 5. INFERENCE FUNCTIONS
187
  # ---------------------------------------------------------
188
 
189
  @spaces.GPU
190
  def segment_image(input_image, text_prompt, threshold=0.5):
 
191
  if input_image is None:
192
  raise gr.Error("Please upload an image.")
193
  if not text_prompt:
194
  raise gr.Error("Please enter a text prompt.")
195
+ if sam3_model is None:
196
+ raise gr.Error("SAM3 Model is not loaded.")
197
 
198
  image_pil = input_image.convert("RGB")
199
  inputs = sam3_processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
 
218
 
219
  return (image_pil, annotations)
220
 
221
+
222
  @spaces.GPU
223
  def process_3d_body(input_image):
 
224
  if input_image is None:
225
  raise gr.Error("Please upload an image.")
 
 
226
  if not SAM3D_AVAILABLE or sam3d_estimator is None:
227
+ raise gr.Error("SAM 3D Body libraries or model not available (Check logs for import errors).")
 
 
228
 
229
+ # Prepare Image
230
  img_np = np.array(input_image.convert("RGB"))
231
  img_cv2 = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
232
 
233
+ # The utils/estimator usually requires a file path
234
  with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
235
  tmp_path = tmp_file.name
236
  cv2.imwrite(tmp_path, img_cv2)
237
 
238
  try:
239
+ # Run Inference
240
+ print(f"Processing 3D Body for {tmp_path}...")
241
  outputs = sam3d_estimator.process_one_image(tmp_path)
242
 
243
  if not outputs:
244
  return None, None, None, "No people detected."
245
 
246
+ # 1. 2D Visuals
247
  vis_results_2d = visualize_2d_results(img_cv2, outputs, sam3d_visualizer)
248
+ # Handle case if visualize_2d_results returns list of images (one per person)
249
+ if isinstance(vis_results_2d, list) and len(vis_results_2d) > 0:
250
+ # Just take the first one or combine them?
251
+ # Usually it returns cropped visuals. Let's assume list of images.
252
+ res_2d_rgb = cv2.cvtColor(vis_results_2d[0], cv2.COLOR_BGR2RGB)
253
+ else:
254
+ res_2d_rgb = img_np
255
+
256
+ # 2. 3D Overlay Visuals
257
  mesh_results_img = visualize_3d_mesh(img_cv2, outputs, sam3d_estimator.faces)
258
+ if isinstance(mesh_results_img, list) and len(mesh_results_img) > 0:
259
+ res_3d_overlay_rgb = cv2.cvtColor(mesh_results_img[0], cv2.COLOR_BGR2RGB)
260
+ else:
261
+ res_3d_overlay_rgb = img_np
262
 
263
+ # 3. Save PLY for Model3D
264
  output_dir = tempfile.mkdtemp()
265
+ image_name = "gradio_mesh"
266
+
267
+ # save_mesh_results returns list of paths to .ply files
268
  ply_files = save_mesh_results(img_cv2, outputs, sam3d_estimator.faces, output_dir, image_name)
269
 
270
+ ply_path = None
271
+ if ply_files and len(ply_files) > 0:
272
+ ply_path = ply_files[0] # Return the first mesh found
273
 
274
+ status = f"Detected {len(outputs)} person(s). Showing result for Person 0."
275
 
276
  return res_2d_rgb, res_3d_overlay_rgb, ply_path, status
277
 
278
  except Exception as e:
279
+ import traceback
280
  traceback.print_exc()
281
+ raise gr.Error(f"Inference failed: {e}")
282
+
283
  finally:
284
  if os.path.exists(tmp_path):
285
  os.remove(tmp_path)
286
 
287
  # ---------------------------------------------------------
288
+ # 6. GUI
289
  # ---------------------------------------------------------
290
 
291
  css = """
 
301
  gr.Markdown("# **SAM Integrated Vision Suite**", elem_id="main-title")
302
 
303
  with gr.Tabs():
304
+ # TAB 1: SEGMENTATION
305
  with gr.Tab("SAM3 Segmentation"):
306
  gr.Markdown("Segment objects using **SAM3** with text prompts.")
 
307
  with gr.Row():
308
  with gr.Column(scale=1):
309
+ t1_input = gr.Image(label="Input Image", type="pil", height=350)
310
+ t1_prompt = gr.Textbox(label="Text Prompt", placeholder="e.g., cat, face...")
311
+ t1_thresh = gr.Slider(0.0, 1.0, 0.4, step=0.05, label="Threshold")
312
+ t1_btn = gr.Button("Segment", variant="primary")
 
313
  with gr.Column(scale=1.5):
314
+ t1_output = gr.AnnotatedImage(label="Segmented Output", height=450)
315
 
316
+ t1_btn.click(segment_image, [t1_input, t1_prompt, t1_thresh], [t1_output])
 
 
 
 
317
 
318
+ # TAB 2: 3D BODY
319
  with gr.Tab("SAM 3D Body"):
320
  gr.Markdown("Detect human bodies and reconstruct **3D Meshes**.")
321
 
322
  with gr.Row():
323
  with gr.Column(scale=1):
324
+ t2_input = gr.Image(label="Input Image", type="pil", height=350)
325
+ t2_btn = gr.Button("Generate 3D Body", variant="primary")
326
  t2_status = gr.Textbox(label="Status", interactive=False)
327
+
 
 
 
 
 
 
 
 
 
328
  with gr.Column(scale=2):
329
  with gr.Row():
330
+ t2_vis_2d = gr.Image(label="2D Detection", type="numpy")
331
+ t2_vis_overlay = gr.Image(label="Mesh Overlay", type="numpy")
332
 
333
+ t2_model_3d = gr.Model3D(
334
+ label="Interactive 3D Mesh",
335
  clear_color=[0.0, 0.0, 0.0, 0.0],
336
+ camera_position=[0, 0, 2.5]
337
  )
338
 
339
+ t2_btn.click(
340
+ process_3d_body,
341
+ inputs=[t2_input],
342
+ outputs=[t2_vis_2d, t2_vis_overlay, t2_model_3d, t2_status]
343
  )
344
 
345
+ gr.Examples(
346
+ examples=[["examples/player.jpg"], ["examples/dancing.jpg"]],
347
+ inputs=[t2_input],
348
+ label="3D Body Examples"
349
+ )
 
 
 
 
 
 
350
 
351
  if __name__ == "__main__":
352
  demo.launch(mcp_server=True, ssr_mode=False, show_error=True)