hiendang7613vulcanlabs commited on
Commit
ca978a4
·
1 Parent(s): 4664520

feat: two-stage accordion UI (Stage 1 multiview gen + Stage 2 3D recon)

Browse files
Files changed (2) hide show
  1. app.py +396 -78
  2. src/fal_multiview.py +112 -0
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
 
2
  import tempfile
 
3
  from pathlib import Path
 
4
  import gradio as gr
5
 
6
  # Strip trailing whitespace/newlines from API keys (HF Secrets UI sometimes adds \n)
@@ -10,6 +13,7 @@ for _key in ("FAL_KEY", "HF_TOKEN"):
10
  os.environ[_key] = _val.strip()
11
 
12
  from src.checkpoints import ensure_sam3d_checkpoints, link_mv_sam3d_checkpoints
 
13
  from src.fal_sam3 import extract_main_object_alpha
14
  from src.mv_sam3d import run_mv_sam3d_inference
15
  from src.utils import (
@@ -23,6 +27,176 @@ from src.utils import (
23
 
24
  MV_ROOT = Path("/mv_sam3d")
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def build_mv_input_from_uploads(
28
  files, mask_prompt: str, prompt: str | None, pick_mode: str, max_masks: int,
@@ -37,7 +211,6 @@ def build_mv_input_from_uploads(
37
  if not files:
38
  raise gr.Error("Upload at least 2 images (multi-view).")
39
 
40
- # Sort by filename (natural sort: 1,2,10)
41
  files_sorted = sorted(files, key=lambda f: natural_key(Path(f).name))
42
 
43
  workdir = Path(tempfile.mkdtemp(prefix="mv_input_"))
@@ -52,11 +225,9 @@ def build_mv_input_from_uploads(
52
  img = load_rgb_image(fp)
53
 
54
  if skip_sam3:
55
- # Simple white background removal (no API call)
56
  progress((i + 1) / max(n, 1), desc=f"Removing white bg {i+1}/{n}")
57
  alpha = remove_white_background_alpha(img, threshold=white_threshold)
58
  else:
59
- # fal SAM-3 main object alpha mask (L, 0..255)
60
  progress((i + 1) / max(n, 1), desc=f"SAM-3 masking view {i+1}/{n}")
61
  alpha = extract_main_object_alpha(
62
  image_path=fp,
@@ -65,20 +236,19 @@ def build_mv_input_from_uploads(
65
  max_masks=max_masks,
66
  )
67
 
68
- # Save RGB image to images/{i}.png
69
  rgb_path = images_dir / f"{i}.png"
70
  save_rgb_png(img, rgb_path)
71
 
72
- # Save RGBA mask image to <mask_prompt>/{i}.png (alpha channel stores mask)
73
  rgba = make_rgba_with_alpha(img, alpha)
74
  mask_path = masks_dir / f"{i}.png"
75
  save_rgba_png(rgba, mask_path)
76
 
77
- return input_dir, [str((images_dir / f"{i}.png")) for i in range(n)]
78
 
79
 
80
- def run_pipeline(
81
  files,
 
82
  mask_prompt,
83
  sam3_prompt,
84
  pick_mode,
@@ -92,15 +262,23 @@ def run_pipeline(
92
  da3_npz,
93
  progress=gr.Progress(),
94
  ):
95
- # 1) ensure checkpoints in /data and link into MV repo
 
 
 
 
 
 
 
 
96
  progress(0.02, desc="Ensuring SAM-3D checkpoints...")
97
  ensure_sam3d_checkpoints()
98
  link_mv_sam3d_checkpoints(MV_ROOT)
99
 
100
- # 2) build input folder using fal SAM-3 masks (or white bg removal)
101
  progress(0.05, desc="Preparing multi-view input...")
102
  input_dir, preview_imgs = build_mv_input_from_uploads(
103
- files=files,
104
  mask_prompt=mask_prompt,
105
  prompt=sam3_prompt,
106
  pick_mode=pick_mode,
@@ -123,7 +301,6 @@ def run_pipeline(
123
  da3_npz_path=(Path(da3_npz) if da3_npz else None),
124
  )
125
 
126
- # 4) outputs
127
  progress(1.0, desc="Done!")
128
  return (
129
  out.viewer_path,
@@ -135,93 +312,234 @@ def run_pipeline(
135
  )
136
 
137
 
138
- with gr.Blocks() as demo:
 
 
 
 
 
139
  gr.Markdown(
140
  """
141
- # MV-SAM3D + fal SAM-3 (Main Object Mask)
142
- Upload multi-view images fal SAM-3 extracts main object mask per view → MV-SAM3D outputs GLB + PLY.
 
143
 
144
- **Secrets required:** `FAL_KEY` and (maybe) `HF_TOKEN`.
145
  """
146
  )
147
 
148
- with gr.Row():
149
- files = gr.Files(label="Multi-view images (PNG/JPG)", file_types=["image"])
150
- da3_npz = gr.File(label="(Optional) DA3 output (.npz) for visibility/mixed", file_types=[".npz"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- with gr.Row():
153
- mask_prompt = gr.Textbox(label="mask_prompt folder name", value="object")
154
- skip_sam3 = gr.Checkbox(
155
- label="Skip SAM-3 (remove white background instead)",
156
- value=False,
157
- info="Check to skip fal SAM-3 API and use simple white background removal",
 
 
 
 
 
 
 
158
  )
159
 
160
- with gr.Row():
161
- white_threshold = gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  label="White BG threshold (R,G,B ≥ threshold → background)",
163
  minimum=200, maximum=255, value=240, step=1,
164
  visible=False,
165
  )
166
 
167
- with gr.Row():
168
- sam3_prompt = gr.Textbox(label="SAM-3 prompt (optional, e.g. 'stuffed toy' / 'person')", value="")
169
- pick_mode = gr.Dropdown(label="Pick main object mode", choices=["largest", "best_score"], value="largest")
170
- max_masks = gr.Slider(label="SAM-3 max_masks", minimum=1, maximum=10, value=5, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- # Toggle visibility: show white threshold when skipping SAM-3, hide SAM-3 options
173
- skip_sam3.change(
174
- fn=lambda s: (gr.update(visible=s), gr.update(visible=not s), gr.update(visible=not s), gr.update(visible=not s)),
175
- inputs=[skip_sam3],
176
- outputs=[white_threshold, sam3_prompt, pick_mode, max_masks],
177
- )
178
 
179
- gr.Markdown("### MV-SAM3D params")
180
 
181
- with gr.Row():
182
- image_names = gr.Textbox(label="image_names (comma-separated, optional)", placeholder="0,1,2,3,4,5,6,7")
 
 
 
 
183
 
184
- with gr.Row():
185
- stage1_weighting = gr.Checkbox(label="Stage 1 weighting", value=False)
186
- stage2_weighting = gr.Checkbox(label="Stage 2 weighting", value=False)
187
- stage2_weight_source = gr.Dropdown(
188
- label="Stage 2 weight source",
189
- choices=["entropy", "visibility", "mixed"],
190
- value="entropy",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  )
192
 
193
- run_btn = gr.Button("Run", variant="primary")
194
-
195
- with gr.Row():
196
- viewer = gr.Model3D(label="Preview (GLB)")
197
- preview_gallery = gr.Gallery(label="Prepared inputs (images/*.png)", columns=4, height=240)
198
-
199
- with gr.Row():
200
- glb_dl = gr.File(label="result.glb")
201
- ply_dl = gr.File(label="result.ply")
202
- npz_dl = gr.File(label="params.npz")
203
-
204
- log_box = gr.Textbox(label="Log tail", lines=18)
205
-
206
- run_btn.click(
207
- fn=run_pipeline,
208
- inputs=[
209
- files,
210
- mask_prompt,
211
- sam3_prompt,
212
- pick_mode,
213
- max_masks,
214
- skip_sam3,
215
- white_threshold,
216
- image_names,
217
- stage1_weighting,
218
- stage2_weighting,
219
- stage2_weight_source,
220
- da3_npz,
221
- ],
222
- outputs=[viewer, glb_dl, ply_dl, npz_dl, log_box, preview_gallery],
223
- )
224
 
225
  if __name__ == "__main__":
226
  demo.launch()
227
-
 
1
  import os
2
+ import shutil
3
  import tempfile
4
+ from datetime import datetime
5
  from pathlib import Path
6
+
7
  import gradio as gr
8
 
9
  # Strip trailing whitespace/newlines from API keys (HF Secrets UI sometimes adds \n)
 
13
  os.environ[_key] = _val.strip()
14
 
15
  from src.checkpoints import ensure_sam3d_checkpoints, link_mv_sam3d_checkpoints
16
+ from src.fal_multiview import generate_multiview_parallel, DEFAULT_ANGLES
17
  from src.fal_sam3 import extract_main_object_alpha
18
  from src.mv_sam3d import run_mv_sam3d_inference
19
  from src.utils import (
 
27
 
28
  MV_ROOT = Path("/mv_sam3d")
29
 
30
+ # ── Sample directories ──────────────────────────────────────────────────────
31
+ # Stage 1 samples: single images for multi-view generation
32
+ SAMPLES_S1_DIR = Path("/data/samples/stage1")
33
+ SAMPLES_S1_DIR.mkdir(parents=True, exist_ok=True)
34
+
35
+ # Stage 2 samples: multi-view image sets
36
+ SAMPLES_S2_DIR = Path("/data/samples/stage2")
37
+ SAMPLES_S2_DIR.mkdir(parents=True, exist_ok=True)
38
+
39
+ # Built-in Stage 2 samples from MV-SAM3D/data
40
+ BUILTIN_S2_DIR = MV_ROOT / "data"
41
+
42
+
43
+ # ── Helpers ─────────────────────────────────────────────────────────────────
44
+
45
+ def _list_stage1_samples() -> list[str]:
46
+ """Return list of sample names (subfolders or images in SAMPLES_S1_DIR)."""
47
+ if not SAMPLES_S1_DIR.exists():
48
+ return []
49
+ samples = []
50
+ for p in sorted(SAMPLES_S1_DIR.iterdir()):
51
+ if p.is_file() and p.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp"):
52
+ samples.append(p.name)
53
+ return samples
54
+
55
+
56
+ def _list_stage2_samples() -> list[str]:
57
+ """Return list of sample set names for Stage 2."""
58
+ names = []
59
+ # Built-in samples from MV-SAM3D/data
60
+ if BUILTIN_S2_DIR.exists():
61
+ for d in sorted(BUILTIN_S2_DIR.iterdir()):
62
+ if d.is_dir() and (d / "images").is_dir():
63
+ names.append(f"[builtin] {d.name}")
64
+ # User / Stage-1 generated samples
65
+ if SAMPLES_S2_DIR.exists():
66
+ for d in sorted(SAMPLES_S2_DIR.iterdir()):
67
+ if d.is_dir():
68
+ names.append(d.name)
69
+ return names
70
+
71
+
72
+ def _load_stage2_sample(name: str) -> list[str]:
73
+ """Load image file paths from a named Stage 2 sample set."""
74
+ if name.startswith("[builtin] "):
75
+ folder = BUILTIN_S2_DIR / name.replace("[builtin] ", "")
76
+ images_dir = folder / "images"
77
+ else:
78
+ folder = SAMPLES_S2_DIR / name
79
+ images_dir = folder # flat folder of images
80
+ if not images_dir.exists():
81
+ return []
82
+ exts = {".png", ".jpg", ".jpeg", ".webp"}
83
+ files = sorted(
84
+ [str(p) for p in images_dir.iterdir() if p.suffix.lower() in exts],
85
+ key=lambda f: natural_key(Path(f).name),
86
+ )
87
+ return files
88
+
89
+
90
+ # ── Stage 1: Multi-View Generation ─────────────────────────────────────────
91
+
92
+ def run_stage1(
93
+ image_file,
94
+ angles_str,
95
+ vertical_angle,
96
+ zoom,
97
+ lora_scale,
98
+ guidance_scale,
99
+ num_steps,
100
+ seed_val,
101
+ progress=gr.Progress(),
102
+ ):
103
+ """Generate multi-view images from a single input image."""
104
+ if image_file is None:
105
+ raise gr.Error("Please upload or select a source image.")
106
+
107
+ # Parse image path
108
+ img_path = image_file if isinstance(image_file, str) else image_file.name
109
+
110
+ # Parse angles
111
+ try:
112
+ angles = [float(a.strip()) for a in angles_str.split(",") if a.strip()]
113
+ except ValueError:
114
+ raise gr.Error("Invalid angles format. Use comma-separated numbers like: 0,60,120,180,240,300")
115
+
116
+ if not angles:
117
+ angles = DEFAULT_ANGLES
118
+
119
+ seed = int(seed_val) if seed_val and int(seed_val) >= 0 else None
120
+
121
+ progress(0.05, desc="Uploading image & starting generation...")
122
+
123
+ results = generate_multiview_parallel(
124
+ image_path=img_path,
125
+ horizontal_angles=angles,
126
+ vertical_angle=vertical_angle,
127
+ zoom=zoom,
128
+ lora_scale=lora_scale,
129
+ guidance_scale=guidance_scale,
130
+ num_inference_steps=int(num_steps),
131
+ seed=seed,
132
+ on_progress=lambda frac, msg: progress(0.05 + frac * 0.9, desc=msg),
133
+ )
134
+
135
+ progress(0.98, desc="Saving generated views...")
136
+
137
+ # Save to temp dir for gallery display
138
+ out_paths = []
139
+ for angle, img in results:
140
+ tmp = tempfile.NamedTemporaryFile(suffix=f"_{int(angle)}deg.png", delete=False)
141
+ img.save(tmp.name, format="PNG")
142
+ out_paths.append(tmp.name)
143
+
144
+ progress(1.0, desc=f"Done! Generated {len(results)} views.")
145
+ return out_paths
146
+
147
+
148
+ def send_to_stage2(gallery_images, progress=gr.Progress()):
149
+ """Save Stage 1 gallery images as a new Stage 2 sample set."""
150
+ if not gallery_images:
151
+ raise gr.Error("No generated views to send. Run Stage 1 first.")
152
+
153
+ # Create a named folder
154
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
155
+ name = f"multiview_{ts}"
156
+ dst = SAMPLES_S2_DIR / name
157
+ dst.mkdir(parents=True, exist_ok=True)
158
+
159
+ for i, item in enumerate(gallery_images):
160
+ # gallery_images items can be (filepath, caption) tuples or just filepaths
161
+ src_path = item[0] if isinstance(item, (list, tuple)) else item
162
+ shutil.copy2(src_path, dst / f"{i}.png")
163
+
164
+ # Return updated dropdown choices and select the new one
165
+ choices = _list_stage2_samples()
166
+ return (
167
+ gr.update(choices=choices, value=name),
168
+ f"✅ Saved {len(gallery_images)} views as sample set '{name}'",
169
+ )
170
+
171
+
172
+ def save_stage1_sample(image_file):
173
+ """Save uploaded image as a Stage 1 sample."""
174
+ if image_file is None:
175
+ return gr.update()
176
+ src = image_file if isinstance(image_file, str) else image_file.name
177
+ dst = SAMPLES_S1_DIR / Path(src).name
178
+ shutil.copy2(src, dst)
179
+ choices = _list_stage1_samples()
180
+ return gr.update(choices=choices)
181
+
182
+
183
+ def load_stage1_sample(sample_name):
184
+ """Load a Stage 1 sample image."""
185
+ if not sample_name:
186
+ return None
187
+ path = SAMPLES_S1_DIR / sample_name
188
+ return str(path) if path.exists() else None
189
+
190
+
191
+ # ── Stage 2: 3D Reconstruction ─────────────────────────────────────────────
192
+
193
+ def load_s2_sample(sample_name):
194
+ """Load a Stage 2 sample set into the file upload."""
195
+ if not sample_name:
196
+ return None
197
+ files = _load_stage2_sample(sample_name)
198
+ return files if files else None
199
+
200
 
201
  def build_mv_input_from_uploads(
202
  files, mask_prompt: str, prompt: str | None, pick_mode: str, max_masks: int,
 
211
  if not files:
212
  raise gr.Error("Upload at least 2 images (multi-view).")
213
 
 
214
  files_sorted = sorted(files, key=lambda f: natural_key(Path(f).name))
215
 
216
  workdir = Path(tempfile.mkdtemp(prefix="mv_input_"))
 
225
  img = load_rgb_image(fp)
226
 
227
  if skip_sam3:
 
228
  progress((i + 1) / max(n, 1), desc=f"Removing white bg {i+1}/{n}")
229
  alpha = remove_white_background_alpha(img, threshold=white_threshold)
230
  else:
 
231
  progress((i + 1) / max(n, 1), desc=f"SAM-3 masking view {i+1}/{n}")
232
  alpha = extract_main_object_alpha(
233
  image_path=fp,
 
236
  max_masks=max_masks,
237
  )
238
 
 
239
  rgb_path = images_dir / f"{i}.png"
240
  save_rgb_png(img, rgb_path)
241
 
 
242
  rgba = make_rgba_with_alpha(img, alpha)
243
  mask_path = masks_dir / f"{i}.png"
244
  save_rgba_png(rgba, mask_path)
245
 
246
+ return input_dir, [str(images_dir / f"{i}.png") for i in range(n)]
247
 
248
 
249
+ def run_stage2_pipeline(
250
  files,
251
+ s2_sample_name,
252
  mask_prompt,
253
  sam3_prompt,
254
  pick_mode,
 
262
  da3_npz,
263
  progress=gr.Progress(),
264
  ):
265
+ # Resolve files: from upload or from sample
266
+ actual_files = files
267
+ if (not actual_files or len(actual_files) == 0) and s2_sample_name:
268
+ actual_files = _load_stage2_sample(s2_sample_name)
269
+
270
+ if not actual_files or len(actual_files) < 2:
271
+ raise gr.Error("Provide at least 2 multi-view images (upload or select a sample set).")
272
+
273
+ # 1) ensure checkpoints
274
  progress(0.02, desc="Ensuring SAM-3D checkpoints...")
275
  ensure_sam3d_checkpoints()
276
  link_mv_sam3d_checkpoints(MV_ROOT)
277
 
278
+ # 2) build input folder
279
  progress(0.05, desc="Preparing multi-view input...")
280
  input_dir, preview_imgs = build_mv_input_from_uploads(
281
+ files=actual_files,
282
  mask_prompt=mask_prompt,
283
  prompt=sam3_prompt,
284
  pick_mode=pick_mode,
 
301
  da3_npz_path=(Path(da3_npz) if da3_npz else None),
302
  )
303
 
 
304
  progress(1.0, desc="Done!")
305
  return (
306
  out.viewer_path,
 
312
  )
313
 
314
 
315
+ # ══════════════════════════════════════════════════════════════════════════════
316
+ # Gradio UI
317
+ # ════════════════════════��═════════════════════════════════════════════════════
318
+
319
+ with gr.Blocks(title="Multi-View 3D Reconstruction") as demo:
320
+
321
  gr.Markdown(
322
  """
323
+ # 🎯 Multi-View 3D Object Reconstruction
324
+ **Stage 1** — Generate multi-view images from a single photo (fal Qwen Multi-Angles)
325
+ **Stage 2** — Reconstruct 3D model from multi-view images (MV-SAM3D)
326
 
327
+ Each stage runs independently. Stage 1 results can be sent to Stage 2 as a sample set.
328
  """
329
  )
330
 
331
+ # ── STAGE 1 ─────────────────────────────────────────────────────────────
332
+ with gr.Accordion("🖼️ Stage 1: Single Image → Multi-View Generation", open=True):
333
+
334
+ with gr.Row():
335
+ with gr.Column(scale=2):
336
+ s1_image = gr.Image(
337
+ label="Source image (single object photo)",
338
+ type="filepath",
339
+ height=300,
340
+ )
341
+ with gr.Column(scale=1):
342
+ s1_samples_dd = gr.Dropdown(
343
+ label="Stage 1 Samples",
344
+ choices=_list_stage1_samples(),
345
+ value=None,
346
+ interactive=True,
347
+ info="Select a saved sample or upload a new image",
348
+ )
349
+ s1_save_btn = gr.Button("💾 Save current image as sample", size="sm")
350
+
351
+ # Advanced settings
352
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
353
+ with gr.Row():
354
+ s1_angles = gr.Textbox(
355
+ label="Horizontal angles (comma-separated degrees)",
356
+ value="0, 60, 120, 180, 240, 300",
357
+ info="0°=front, 90°=right, 180°=back, 270°=left",
358
+ )
359
+ s1_vertical = gr.Slider(
360
+ label="Vertical angle",
361
+ minimum=-30, maximum=90, value=0, step=5,
362
+ info="-30°=low angle, 0°=eye level, 90°=bird's eye",
363
+ )
364
+ with gr.Row():
365
+ s1_zoom = gr.Slider(label="Zoom", minimum=0, maximum=10, value=5, step=0.5)
366
+ s1_lora = gr.Slider(label="LoRA scale", minimum=0, maximum=2, value=1.0, step=0.1)
367
+ with gr.Row():
368
+ s1_guidance = gr.Slider(label="Guidance scale", minimum=1, maximum=20, value=4.5, step=0.5)
369
+ s1_steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, value=28, step=1)
370
+ s1_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
371
+
372
+ s1_run_btn = gr.Button("🚀 Generate Multi-Views", variant="primary", size="lg")
373
+
374
+ s1_gallery = gr.Gallery(
375
+ label="Generated multi-view images",
376
+ columns=6,
377
+ height=280,
378
+ object_fit="contain",
379
+ )
380
+
381
+ with gr.Row():
382
+ s1_send_btn = gr.Button("📤 Send to Stage 2 Sample Sets", variant="secondary", size="lg")
383
+ s1_status = gr.Textbox(label="Status", interactive=False, scale=2)
384
+
385
+ # ── Stage 1 event handlers
386
+ s1_samples_dd.change(
387
+ fn=load_stage1_sample,
388
+ inputs=[s1_samples_dd],
389
+ outputs=[s1_image],
390
+ )
391
+
392
+ s1_save_btn.click(
393
+ fn=save_stage1_sample,
394
+ inputs=[s1_image],
395
+ outputs=[s1_samples_dd],
396
+ )
397
 
398
+ s1_run_btn.click(
399
+ fn=run_stage1,
400
+ inputs=[
401
+ s1_image,
402
+ s1_angles,
403
+ s1_vertical,
404
+ s1_zoom,
405
+ s1_lora,
406
+ s1_guidance,
407
+ s1_steps,
408
+ s1_seed,
409
+ ],
410
+ outputs=[s1_gallery],
411
  )
412
 
413
+ # ── STAGE 2 ─────────────────────────────────────────────────────────────
414
+ with gr.Accordion("🧊 Stage 2: Multi-View → 3D Reconstruction (MV-SAM3D)", open=True):
415
+
416
+ with gr.Row():
417
+ s2_samples_dd = gr.Dropdown(
418
+ label="Sample sets",
419
+ choices=_list_stage2_samples(),
420
+ value=None,
421
+ interactive=True,
422
+ info="Select a built-in or generated sample set",
423
+ scale=1,
424
+ )
425
+ s2_load_btn = gr.Button("📂 Load Sample", size="sm", scale=0)
426
+
427
+ with gr.Row():
428
+ s2_files = gr.Files(label="Multi-view images (PNG/JPG)", file_types=["image"])
429
+ s2_da3_npz = gr.File(
430
+ label="(Optional) DA3 output (.npz)",
431
+ file_types=[".npz"],
432
+ )
433
+
434
+ with gr.Row():
435
+ s2_mask_prompt = gr.Textbox(label="mask_prompt folder name", value="object")
436
+ s2_skip_sam3 = gr.Checkbox(
437
+ label="Skip SAM-3 (remove white background instead)",
438
+ value=False,
439
+ info="Use simple white background removal instead of fal SAM-3 API",
440
+ )
441
+
442
+ s2_white_threshold = gr.Slider(
443
  label="White BG threshold (R,G,B ≥ threshold → background)",
444
  minimum=200, maximum=255, value=240, step=1,
445
  visible=False,
446
  )
447
 
448
+ with gr.Row():
449
+ s2_sam3_prompt = gr.Textbox(
450
+ label="SAM-3 prompt (optional, e.g. 'stuffed toy')",
451
+ value="",
452
+ )
453
+ s2_pick_mode = gr.Dropdown(
454
+ label="Pick main object mode",
455
+ choices=["largest", "best_score"],
456
+ value="largest",
457
+ )
458
+ s2_max_masks = gr.Slider(
459
+ label="SAM-3 max_masks",
460
+ minimum=1, maximum=10, value=5, step=1,
461
+ )
462
+
463
+ # Toggle SAM-3 vs white BG options
464
+ s2_skip_sam3.change(
465
+ fn=lambda s: (
466
+ gr.update(visible=s),
467
+ gr.update(visible=not s),
468
+ gr.update(visible=not s),
469
+ gr.update(visible=not s),
470
+ ),
471
+ inputs=[s2_skip_sam3],
472
+ outputs=[s2_white_threshold, s2_sam3_prompt, s2_pick_mode, s2_max_masks],
473
+ )
474
+
475
+ with gr.Accordion("⚙️ MV-SAM3D Parameters", open=False):
476
+ s2_image_names = gr.Textbox(
477
+ label="image_names (comma-separated, optional)",
478
+ placeholder="0,1,2,3,4,5",
479
+ )
480
+ with gr.Row():
481
+ s2_stage1_w = gr.Checkbox(label="Stage 1 weighting", value=False)
482
+ s2_stage2_w = gr.Checkbox(label="Stage 2 weighting", value=False)
483
+ s2_w_source = gr.Dropdown(
484
+ label="Stage 2 weight source",
485
+ choices=["entropy", "visibility", "mixed"],
486
+ value="entropy",
487
+ )
488
+
489
+ s2_run_btn = gr.Button("🚀 Run 3D Reconstruction", variant="primary", size="lg")
490
+
491
+ with gr.Row():
492
+ s2_viewer = gr.Model3D(label="3D Preview (GLB)")
493
+ s2_gallery = gr.Gallery(
494
+ label="Prepared inputs (images/*.png)",
495
+ columns=4,
496
+ height=240,
497
+ )
498
+
499
+ with gr.Row():
500
+ s2_glb = gr.File(label="result.glb")
501
+ s2_ply = gr.File(label="result.ply")
502
+ s2_npz = gr.File(label="params.npz")
503
 
504
+ s2_log = gr.Textbox(label="Log tail", lines=15)
 
 
 
 
 
505
 
506
+ # ── Stage 2 event handlers
507
 
508
+ # Wire "Send to Stage 2" button from Stage 1
509
+ s1_send_btn.click(
510
+ fn=send_to_stage2,
511
+ inputs=[s1_gallery],
512
+ outputs=[s2_samples_dd, s1_status],
513
+ )
514
 
515
+ # Load sample set into file upload
516
+ s2_load_btn.click(
517
+ fn=load_s2_sample,
518
+ inputs=[s2_samples_dd],
519
+ outputs=[s2_files],
520
+ )
521
+
522
+ # Run 3D reconstruction
523
+ s2_run_btn.click(
524
+ fn=run_stage2_pipeline,
525
+ inputs=[
526
+ s2_files,
527
+ s2_samples_dd,
528
+ s2_mask_prompt,
529
+ s2_sam3_prompt,
530
+ s2_pick_mode,
531
+ s2_max_masks,
532
+ s2_skip_sam3,
533
+ s2_white_threshold,
534
+ s2_image_names,
535
+ s2_stage1_w,
536
+ s2_stage2_w,
537
+ s2_w_source,
538
+ s2_da3_npz,
539
+ ],
540
+ outputs=[s2_viewer, s2_glb, s2_ply, s2_npz, s2_log, s2_gallery],
541
  )
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
  if __name__ == "__main__":
545
  demo.launch()
 
src/fal_multiview.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stage 1: Generate multi-view images from a single image using
3
+ fal-ai/qwen-image-edit-2511-multiple-angles (parallel calls).
4
+ """
5
+
6
+ import io
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import requests
12
+ from PIL import Image
13
+ import fal_client
14
+
15
+ DEFAULT_ANGLES = [0, 60, 120, 180, 240, 300]
16
+
17
+ FAL_ENDPOINT = "fal-ai/qwen-image-edit-2511-multiple-angles"
18
+
19
+
20
+ def _download_image(url: str, timeout: int = 120) -> Image.Image:
21
+ r = requests.get(url, timeout=timeout)
22
+ r.raise_for_status()
23
+ return Image.open(io.BytesIO(r.content)).convert("RGB")
24
+
25
+
26
+ def _generate_single_view(
27
+ image_url: str,
28
+ horizontal_angle: float,
29
+ vertical_angle: float = 0.0,
30
+ zoom: float = 5.0,
31
+ lora_scale: float = 1.0,
32
+ guidance_scale: float = 4.5,
33
+ num_inference_steps: int = 28,
34
+ seed: Optional[int] = None,
35
+ ) -> Image.Image:
36
+ """Call fal API for a single angle and return the generated PIL image."""
37
+ args = {
38
+ "image_urls": [image_url],
39
+ "horizontal_angle": horizontal_angle,
40
+ "vertical_angle": vertical_angle,
41
+ "zoom": zoom,
42
+ "lora_scale": lora_scale,
43
+ "guidance_scale": guidance_scale,
44
+ "num_inference_steps": num_inference_steps,
45
+ "output_format": "png",
46
+ "num_images": 1,
47
+ "enable_safety_checker": False,
48
+ }
49
+ if seed is not None:
50
+ args["seed"] = seed
51
+
52
+ result = fal_client.subscribe(FAL_ENDPOINT, arguments=args)
53
+
54
+ images = result.get("images", [])
55
+ if not images:
56
+ raise RuntimeError(f"fal multiview returned no images for angle {horizontal_angle}°")
57
+ return _download_image(images[0]["url"])
58
+
59
+
60
+ def generate_multiview_parallel(
61
+ image_path: str | Path,
62
+ horizontal_angles: list[float] | None = None,
63
+ vertical_angle: float = 0.0,
64
+ zoom: float = 5.0,
65
+ lora_scale: float = 1.0,
66
+ guidance_scale: float = 4.5,
67
+ num_inference_steps: int = 28,
68
+ seed: Optional[int] = None,
69
+ max_workers: int = 6,
70
+ on_progress=None,
71
+ ) -> list[tuple[float, Image.Image]]:
72
+ """
73
+ Generate multi-view images in parallel.
74
+ Returns list of (horizontal_angle, PIL.Image) sorted by angle.
75
+ """
76
+ if horizontal_angles is None:
77
+ horizontal_angles = DEFAULT_ANGLES
78
+
79
+ # Upload source image once
80
+ image_url = fal_client.upload_file(str(image_path))
81
+
82
+ results: dict[float, Image.Image] = {}
83
+ total = len(horizontal_angles)
84
+
85
+ with ThreadPoolExecutor(max_workers=min(max_workers, total)) as pool:
86
+ future_to_angle = {
87
+ pool.submit(
88
+ _generate_single_view,
89
+ image_url=image_url,
90
+ horizontal_angle=angle,
91
+ vertical_angle=vertical_angle,
92
+ zoom=zoom,
93
+ lora_scale=lora_scale,
94
+ guidance_scale=guidance_scale,
95
+ num_inference_steps=num_inference_steps,
96
+ seed=seed,
97
+ ): angle
98
+ for angle in horizontal_angles
99
+ }
100
+
101
+ done_count = 0
102
+ for future in as_completed(future_to_angle):
103
+ angle = future_to_angle[future]
104
+ done_count += 1
105
+ img = future.result() # raises on error
106
+ results[angle] = img
107
+ if on_progress:
108
+ on_progress(done_count / total, f"Generated view {done_count}/{total} ({angle}°)")
109
+
110
+ # Return sorted by angle
111
+ return [(a, results[a]) for a in sorted(results.keys())]
112
+