bennyguo commited on
Commit
db54c72
·
1 Parent(s): 0798bf7

Run full pipeline in a single GPU call

Browse files

Fold preprocessing into generate() so one button click triggers one
@spaces.GPU acquisition instead of two (one on image change, one on
generate). The preprocessed preview now updates after generation.

Files changed (1) hide show
  1. app.py +9 -21
app.py CHANGED
@@ -86,23 +86,17 @@ def _viewer_iframe(ply_path: Path) -> str:
86
  # ----------------------------------------------------------------------------
87
 
88
  @spaces.GPU
89
- def on_image_change(image):
90
- """Run preprocessing as soon as the input changes — gives the user instant
91
- feedback on the matte/crop without waiting for the full generation."""
92
- if image is None:
93
- return None
94
- return PIPE.preprocess_image(image)
95
-
96
-
97
- @spaces.GPU
98
- def generate(prepared, seed: int, steps: int, guidance_scale: float,
99
  num_gaussians: int, output_format: str,
100
  progress=gr.Progress(track_tqdm=True)):
101
- if prepared is None:
102
- raise gr.Error("Please upload an image and wait for preprocessing to finish.")
 
 
103
 
104
  progress(0, desc="Generating...")
105
  t0 = time.time()
 
106
  gen = torch.Generator(device=PIPE._device).manual_seed(int(seed))
107
  cond = PIPE.encode_image(prepared, generator=gen)
108
  out = PIPE.sample_latent(cond, steps=int(steps),
@@ -127,7 +121,7 @@ def generate(prepared, seed: int, steps: int, guidance_scale: float,
127
 
128
  info = (f"{gaussian.get_xyz.shape[0]:,} gaussians · "
129
  f"generation: {gen_dt:.1f}s · saved: {download_path.name}")
130
- return _viewer_iframe(ply_path), gr.update(value=str(download_path), interactive=True), info
131
 
132
 
133
  # ----------------------------------------------------------------------------
@@ -176,16 +170,10 @@ with gr.Blocks(title="TripoSplat") as demo:
176
  viewer_out = gr.HTML(value=PLACEHOLDER_HTML, label="Spark.js viewer")
177
  file_out = gr.DownloadButton(label="Download", value=None, interactive=False)
178
 
179
- image_in.change(
180
- fn=on_image_change,
181
- inputs=[image_in],
182
- outputs=[prepared_out],
183
- )
184
-
185
  run_btn.click(
186
  fn=generate,
187
- inputs=[prepared_out, seed_in, steps_in, cfg_in, num_g_in, fmt_in],
188
- outputs=[viewer_out, file_out, info_out],
189
  )
190
 
191
 
 
86
  # ----------------------------------------------------------------------------
87
 
88
  @spaces.GPU
89
+ def generate(image, seed: int, steps: int, guidance_scale: float,
 
 
 
 
 
 
 
 
 
90
  num_gaussians: int, output_format: str,
91
  progress=gr.Progress(track_tqdm=True)):
92
+ """Run the full pipeline (preprocess + encode + sample + decode) in a
93
+ single GPU acquisition."""
94
+ if image is None:
95
+ raise gr.Error("Please upload an image first.")
96
 
97
  progress(0, desc="Generating...")
98
  t0 = time.time()
99
+ prepared = PIPE.preprocess_image(image)
100
  gen = torch.Generator(device=PIPE._device).manual_seed(int(seed))
101
  cond = PIPE.encode_image(prepared, generator=gen)
102
  out = PIPE.sample_latent(cond, steps=int(steps),
 
121
 
122
  info = (f"{gaussian.get_xyz.shape[0]:,} gaussians · "
123
  f"generation: {gen_dt:.1f}s · saved: {download_path.name}")
124
+ return prepared, _viewer_iframe(ply_path), gr.update(value=str(download_path), interactive=True), info
125
 
126
 
127
  # ----------------------------------------------------------------------------
 
170
  viewer_out = gr.HTML(value=PLACEHOLDER_HTML, label="Spark.js viewer")
171
  file_out = gr.DownloadButton(label="Download", value=None, interactive=False)
172
 
 
 
 
 
 
 
173
  run_btn.click(
174
  fn=generate,
175
+ inputs=[image_in, seed_in, steps_in, cfg_in, num_g_in, fmt_in],
176
+ outputs=[prepared_out, viewer_out, file_out, info_out],
177
  )
178
 
179