akhaliq HF Staff commited on
Commit
51b079f
Β·
verified Β·
1 Parent(s): 5179233

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -26
app.py CHANGED
@@ -7,16 +7,46 @@ from PIL import Image
7
  import io
8
  import tempfile
9
  from pathlib import Path
 
 
10
 
11
  # Add notebook directory to path for inference code
12
  NOTEBOOK_PATH = "notebook"
13
- if os.path.exists(NOTEBOOK_PATH):
14
- sys.path.append(NOTEBOOK_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Import inference code with error handling
17
  try:
18
  from inference import Inference, load_image, load_single_mask
19
  INFERENCE_AVAILABLE = True
 
20
  except ImportError as e:
21
  print(f"Warning: Could not import inference module: {e}")
22
  print("Running in demo mode with mock functionality")
@@ -66,13 +96,32 @@ def process_image_to_3d(image, mask=None, seed=42, model_tag="hf"):
66
  demo_content = create_demo_3d_output()
67
  return {
68
  "status": "demo",
69
- "message": "Demo mode - inference module not available",
70
  "file_content": demo_content,
71
  "filename": "demo_splat.ply"
72
  }
73
 
74
- # Initialize inference if not already done
75
- config_path = f"checkpoints/{model_tag}/pipeline.yaml"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Create temporary files for the uploaded image and mask
78
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_temp:
@@ -183,6 +232,18 @@ def create_interface():
183
  color: #721c24;
184
  border: 1px solid #f5c6cb;
185
  }
 
 
 
 
 
 
 
 
 
 
 
 
186
  """
187
 
188
  with gr.Blocks(css=css, title="Image to 3D Converter") as demo:
@@ -211,16 +272,25 @@ def create_interface():
211
  label="Input Image",
212
  type="numpy",
213
  image_mode="RGB",
214
- elem_classes=["upload-area"]
 
215
  )
216
 
 
 
 
 
 
 
 
217
  with gr.Row():
218
  mask_upload = gr.Image(
219
- label="Optional Mask",
220
  type="numpy",
221
  image_mode="L",
222
  image_edit=True,
223
- elem_classes=["upload-area"]
 
224
  )
225
 
226
  mask_status = gr.Textbox(
@@ -232,8 +302,9 @@ def create_interface():
232
 
233
  with gr.Column(scale=1):
234
  gr.HTML("""
235
- <div style="background: #f0f8ff; padding: 20px; border-radius: 10px; margin-bottom: 20px;">
236
  <h3>βš™οΈ Configuration</h3>
 
237
  </div>
238
  """)
239
 
@@ -244,16 +315,23 @@ def create_interface():
244
  value=42,
245
  step=1,
246
  label="Random Seed",
247
- info="Controls the randomness in generation"
248
  )
249
 
250
  model_tag = gr.Dropdown(
251
  choices=["hf"],
252
  value="hf",
253
- label="Model Tag",
254
- info="Select the model configuration"
255
  )
256
 
 
 
 
 
 
 
 
257
  run_button = gr.Button(
258
  "πŸš€ Generate 3D Model",
259
  variant="primary",
@@ -263,19 +341,21 @@ def create_interface():
263
  with gr.Row():
264
  with gr.Column():
265
  status_output = gr.Textbox(
266
- label="Status",
267
  max_lines=5,
268
  interactive=False,
269
- elem_classes=["status-message"]
 
270
  )
271
 
272
  with gr.Row():
273
  with gr.Column():
274
  output_file = gr.File(
275
- label="Download 3D Model",
276
  file_types=[".ply"],
277
  visible=False,
278
- elem_classes=["download-section"]
 
279
  )
280
 
281
  # Wire up the interface
@@ -291,24 +371,46 @@ def create_interface():
291
  outputs=[status_output, output_file, gr.File()]
292
  )
293
 
294
- # Examples section
295
  gr.HTML("""
296
- <div style="margin-top: 40px; text-align: center;">
297
  <h3>πŸ“– How to Use</h3>
298
- <div style="display: flex; justify-content: space-around; margin-top: 20px; flex-wrap: wrap;">
299
- <div style="max-width: 300px; padding: 20px; background: #f8f9fa; border-radius: 10px; margin: 10px;">
300
  <h4>1. Upload Image</h4>
301
- <p>Choose a clear, well-lit image for best results</p>
302
  </div>
303
- <div style="max-width: 300px; padding: 20px; background: #f8f9fa; border-radius: 10px; margin: 10px;">
304
  <h4>2. Add Mask (Optional)</h4>
305
- <p>Upload a mask to focus on specific areas</p>
306
  </div>
307
- <div style="max-width: 300px; padding: 20px; background: #f8f9fa; border-radius: 10px; margin: 10px;">
308
  <h4>3. Generate</h4>
309
- <p>Click generate and wait for your 3D model</p>
310
  </div>
311
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  </div>
313
  """)
314
 
@@ -336,4 +438,4 @@ if __name__ == "__main__":
336
  show_error=True,
337
  debug=True,
338
  inbrowser=True
339
- )
 
7
  import io
8
  import tempfile
9
  from pathlib import Path
10
+ import subprocess
11
+ import shutil
12
 
13
  # Add notebook directory to path for inference code
14
  NOTEBOOK_PATH = "notebook"
15
+ REPO_URL = "https://github.com/facebookresearch/sam-3d-objects"
16
+
17
+ def ensure_repository():
18
+ """Ensure the SAM 3D objects repository is cloned"""
19
+ repo_dir = "sam-3d-objects"
20
+
21
+ if not os.path.exists(repo_dir):
22
+ print(f"Cloning repository: {REPO_URL}")
23
+ try:
24
+ subprocess.run(["git", "clone", REPO_URL, repo_dir], check=True)
25
+ print("Repository cloned successfully")
26
+ except subprocess.CalledProcessError as e:
27
+ print(f"Failed to clone repository: {e}")
28
+ return False
29
+ except FileNotFoundError:
30
+ print("Git not found. Please install git to clone the repository")
31
+ return False
32
+
33
+ # Add to path
34
+ if os.path.exists(repo_dir):
35
+ sys.path.append(repo_dir)
36
+ if os.path.exists(NOTEBOOK_PATH):
37
+ sys.path.append(NOTEBOOK_PATH)
38
+
39
+ return True
40
+
41
+ # Ensure repository is available
42
+ if not ensure_repository():
43
+ print("Warning: Could not clone repository. Running in limited mode.")
44
 
45
  # Import inference code with error handling
46
  try:
47
  from inference import Inference, load_image, load_single_mask
48
  INFERENCE_AVAILABLE = True
49
+ print("Inference module loaded successfully")
50
  except ImportError as e:
51
  print(f"Warning: Could not import inference module: {e}")
52
  print("Running in demo mode with mock functionality")
 
96
  demo_content = create_demo_3d_output()
97
  return {
98
  "status": "demo",
99
+ "message": "Demo mode - inference module not available. This is a sample 3D model file.",
100
  "file_content": demo_content,
101
  "filename": "demo_splat.ply"
102
  }
103
 
104
+ # Check if model directory exists
105
+ model_dir = f"checkpoints/{model_tag}"
106
+ if not os.path.exists(model_dir):
107
+ return {
108
+ "status": "error",
109
+ "message": f"Model checkpoint not found at {model_dir}. Please ensure the model is downloaded.",
110
+ "file_content": None,
111
+ "filename": None
112
+ }
113
+
114
+ config_path = f"{model_dir}/pipeline.yaml"
115
+ if not os.path.exists(config_path):
116
+ config_path = f"{model_dir}/config.yaml" # Try alternative config name
117
+
118
+ if not os.path.exists(config_path):
119
+ return {
120
+ "status": "error",
121
+ "message": f"Pipeline configuration not found. Expected at {config_path}",
122
+ "file_content": None,
123
+ "filename": None
124
+ }
125
 
126
  # Create temporary files for the uploaded image and mask
127
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as img_temp:
 
232
  color: #721c24;
233
  border: 1px solid #f5c6cb;
234
  }
235
+ .upload-area {
236
+ border: 2px dashed #4CAF50 !important;
237
+ border-radius: 10px !important;
238
+ padding: 10px !important;
239
+ }
240
+ .mask-status {
241
+ background-color: #e3f2fd;
242
+ border: 1px solid #2196F3;
243
+ padding: 5px;
244
+ border-radius: 5px;
245
+ font-weight: bold;
246
+ }
247
  """
248
 
249
  with gr.Blocks(css=css, title="Image to 3D Converter") as demo:
 
272
  label="Input Image",
273
  type="numpy",
274
  image_mode="RGB",
275
+ elem_classes=["upload-area"],
276
+ info="Upload a clear, well-lit image for best results"
277
  )
278
 
279
+ gr.HTML("""
280
+ <div class="upload-section" style="margin-top: 15px;">
281
+ <h3>🎭 Optional Mask</h3>
282
+ <p>Upload a mask to focus on specific areas</p>
283
+ </div>
284
+ """)
285
+
286
  with gr.Row():
287
  mask_upload = gr.Image(
288
+ label="Segmentation Mask",
289
  type="numpy",
290
  image_mode="L",
291
  image_edit=True,
292
+ elem_classes=["upload-area"],
293
+ info="Optional: Upload a binary mask to segment the object"
294
  )
295
 
296
  mask_status = gr.Textbox(
 
302
 
303
  with gr.Column(scale=1):
304
  gr.HTML("""
305
+ <div style="background: #f0f8ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; border: 1px solid #2196F3;">
306
  <h3>βš™οΈ Configuration</h3>
307
+ <p>Fine-tune the 3D generation parameters</p>
308
  </div>
309
  """)
310
 
 
315
  value=42,
316
  step=1,
317
  label="Random Seed",
318
+ info="Controls the randomness in generation. Use different seeds for variations."
319
  )
320
 
321
  model_tag = gr.Dropdown(
322
  choices=["hf"],
323
  value="hf",
324
+ label="Model Configuration",
325
+ info="Select the model configuration for 3D generation"
326
  )
327
 
328
+ gr.HTML("""
329
+ <div style="margin-top: 20px; text-align: center;">
330
+ <p><strong>Generation Status:</strong></p>
331
+ <p style="color: #666; font-size: 0.9em;">This process may take several minutes depending on image complexity</p>
332
+ </div>
333
+ """)
334
+
335
  run_button = gr.Button(
336
  "πŸš€ Generate 3D Model",
337
  variant="primary",
 
341
  with gr.Row():
342
  with gr.Column():
343
  status_output = gr.Textbox(
344
+ label="Generation Status",
345
  max_lines=5,
346
  interactive=False,
347
+ elem_classes=["status-message"],
348
+ info="Real-time updates on the 3D generation process"
349
  )
350
 
351
  with gr.Row():
352
  with gr.Column():
353
  output_file = gr.File(
354
+ label="πŸ“₯ Download 3D Model",
355
  file_types=[".ply"],
356
  visible=False,
357
+ elem_classes=["download-section"],
358
+ info="Your generated 3D model will appear here for download"
359
  )
360
 
361
  # Wire up the interface
 
371
  outputs=[status_output, output_file, gr.File()]
372
  )
373
 
374
+ # Examples and instructions section
375
  gr.HTML("""
376
+ <div style="margin-top: 40px; text-align: center; background: #f8f9fa; padding: 30px; border-radius: 15px;">
377
  <h3>πŸ“– How to Use</h3>
378
+ <div style="display: flex; justify-content: space-around; margin-top: 20px; flex-wrap: wrap; gap: 20px;">
379
+ <div style="max-width: 280px; padding: 20px; background: white; border-radius: 10px; margin: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
380
  <h4>1. Upload Image</h4>
381
+ <p>Choose a clear, well-lit image for best results. The image should show the object you want to convert to 3D.</p>
382
  </div>
383
+ <div style="max-width: 280px; padding: 20px; background: white; border-radius: 10px; margin: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
384
  <h4>2. Add Mask (Optional)</h4>
385
+ <p>Upload a mask to focus on specific areas. This helps the model understand what part of the image to convert.</p>
386
  </div>
387
+ <div style="max-width: 280px; padding: 20px; background: white; border-radius: 10px; margin: 10px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
388
  <h4>3. Generate</h4>
389
+ <p>Click generate and wait for your 3D model. The process typically takes 1-5 minutes.</p>
390
  </div>
391
  </div>
392
+
393
+ <div style="margin-top: 30px; padding: 20px; background: white; border-radius: 10px; text-align: left; max-width: 800px; margin-left: auto; margin-right: auto;">
394
+ <h4>πŸ’‘ Tips for Better Results:</h4>
395
+ <ul style="text-align: left;">
396
+ <li>Use high-quality, well-lit images</li>
397
+ <li>Ensure the object is clearly visible and not occluded</li>
398
+ <li>Use masks to isolate specific objects in complex scenes</li>
399
+ <li>Try different random seeds for variations</li>
400
+ <li>Complex objects may take longer to process</li>
401
+ </ul>
402
+ </div>
403
+ </div>
404
+ """)
405
+
406
+ # System information
407
+ gr.HTML(f"""
408
+ <div style="margin-top: 30px; padding: 15px; background: #e8f5e8; border-radius: 10px; text-align: center;">
409
+ <p><strong>System Status:</strong></p>
410
+ <p style="font-size: 0.9em; color: #666;">
411
+ Inference Module: {"βœ“ Available" if INFERENCE_AVAILABLE else "βœ— Demo Mode"} |
412
+ SAM 3D Repository: {"βœ“ Cloned" if os.path.exists("sam-3d-objects") else "βœ— Not Available"}
413
+ </p>
414
  </div>
415
  """)
416
 
 
438
  show_error=True,
439
  debug=True,
440
  inbrowser=True
441
+ )