Spaces:
Sleeping
Sleeping
Fix decoding batch size bug
Browse files- app.py +3 -15
- src/smc/pipeline.py +9 -7
app.py
CHANGED
|
@@ -1,18 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
Gradio app to compare multiple inference methods for Monetico model.
|
| 3 |
-
|
| 4 |
-
This file wires your existing inference functions (infer_pretrained, infer_smc_grad)
|
| 5 |
-
into a single UI with one shared prompt and per-method collapsed setting panels.
|
| 6 |
-
|
| 7 |
-
Place this file at repository root (next to src/) and run:
|
| 8 |
-
python app.py
|
| 9 |
-
|
| 10 |
-
Notes:
|
| 11 |
-
- The code assumes your module that contains infer_pretrained and infer_smc_grad
|
| 12 |
-
is importable (e.g. package root with src/ on PYTHONPATH). Adjust imports if needed.
|
| 13 |
-
- Defaults provided are reasonable starting points; tweak as you like.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
import gradio as gr
|
| 17 |
|
| 18 |
# Import your inference functions and dataclasses
|
|
@@ -79,6 +65,7 @@ def run_inference_all(
|
|
| 79 |
pretrained_output = infer_pretrained(pretrained_cfg, device=DEVICE)
|
| 80 |
pretrained_images = pretrained_output.images
|
| 81 |
except Exception as e:
|
|
|
|
| 82 |
pretrained_images = []
|
| 83 |
pretrained_output = None
|
| 84 |
pretrained_error = f"Pretrained inference error: {e}"
|
|
@@ -107,6 +94,7 @@ def run_inference_all(
|
|
| 107 |
smc_grad_output = infer_smc_grad(smc_grad_cfg, device=DEVICE)
|
| 108 |
# The above line is defensive; simpler: pass smc_grad_device value used by gradio - will be provided.
|
| 109 |
except Exception as e:
|
|
|
|
| 110 |
smc_grad_images = []
|
| 111 |
smc_grad_output = None
|
| 112 |
smc_grad_error = f"SMC inference error: {e}"
|
|
|
|
| 1 |
+
import traceback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
|
| 4 |
# Import your inference functions and dataclasses
|
|
|
|
| 65 |
pretrained_output = infer_pretrained(pretrained_cfg, device=DEVICE)
|
| 66 |
pretrained_images = pretrained_output.images
|
| 67 |
except Exception as e:
|
| 68 |
+
traceback.print_exc()
|
| 69 |
pretrained_images = []
|
| 70 |
pretrained_output = None
|
| 71 |
pretrained_error = f"Pretrained inference error: {e}"
|
|
|
|
| 94 |
smc_grad_output = infer_smc_grad(smc_grad_cfg, device=DEVICE)
|
| 95 |
# The above line is defensive; simpler: pass smc_grad_device value used by gradio - will be provided.
|
| 96 |
except Exception as e:
|
| 97 |
+
traceback.print_exc()
|
| 98 |
smc_grad_images = []
|
| 99 |
smc_grad_output = None
|
| 100 |
smc_grad_error = f"SMC inference error: {e}"
|
src/smc/pipeline.py
CHANGED
|
@@ -242,7 +242,7 @@ class Pipeline(
|
|
| 242 |
sample = F.gumbel_softmax(tmp_logits, tau=tau, hard=True)
|
| 243 |
if use_continuous_formulation:
|
| 244 |
sample = gamma * sample + (ONE - MASK) * latents_one_hot
|
| 245 |
-
sample = self._decode_one_hot_latents(sample,
|
| 246 |
tmp_rewards[:, phi_i] = reward_fn(sample)
|
| 247 |
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
|
| 248 |
|
|
@@ -332,7 +332,7 @@ class Pipeline(
|
|
| 332 |
tmp_logp_x0 = self.model._subs_parameterization(tmp_logits, latents_batch)
|
| 333 |
for phi_i in range(phi):
|
| 334 |
sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
|
| 335 |
-
sample = self._decode_latents(sample,
|
| 336 |
tmp_rewards[:, phi_i] = reward_fn(sample)
|
| 337 |
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
|
| 338 |
|
|
@@ -427,7 +427,7 @@ class Pipeline(
|
|
| 427 |
tmp_logp_x0 = self._subs_parameterization(tmp_logits, latents_batch)
|
| 428 |
for phi_i in range(phi):
|
| 429 |
sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
|
| 430 |
-
sample = self._decode_latents(sample,
|
| 431 |
tmp_rewards[:, phi_i] = reward_fn(sample)
|
| 432 |
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
|
| 433 |
|
|
@@ -517,7 +517,7 @@ class Pipeline(
|
|
| 517 |
for j in range(0, total_particles, batch_p):
|
| 518 |
latents_batch = latents[j:j+batch_p]
|
| 519 |
with torch.no_grad():
|
| 520 |
-
sample = self._decode_latents(latents_batch,
|
| 521 |
tmp_rewards = reward_fn(sample)
|
| 522 |
rewards[j:j+batch_p] = tmp_rewards
|
| 523 |
log_twist[j:j+batch_p] = tmp_rewards * scale_cur
|
|
@@ -563,7 +563,7 @@ class Pipeline(
|
|
| 563 |
for j in range(0, batches, batch_p):
|
| 564 |
latents_batch = latents[j:j+batch_p]
|
| 565 |
outputs.extend(
|
| 566 |
-
self._decode_latents(latents_batch,
|
| 567 |
)
|
| 568 |
if output_type == "pt":
|
| 569 |
outputs = torch.stack(outputs, dim=0)
|
|
@@ -602,7 +602,8 @@ class Pipeline(
|
|
| 602 |
tmp_logits = torch.cat([tmp_logits, pad_logits], dim=-1)
|
| 603 |
return tmp_logits
|
| 604 |
|
| 605 |
-
def _decode_latents(self, latents,
|
|
|
|
| 606 |
if output_type == "latent":
|
| 607 |
output = latents
|
| 608 |
else:
|
|
@@ -624,7 +625,8 @@ class Pipeline(
|
|
| 624 |
self.vqvae.half()
|
| 625 |
return output
|
| 626 |
|
| 627 |
-
def _decode_one_hot_latents(self, latents_one_hot,
|
|
|
|
| 628 |
shape = (
|
| 629 |
batch_size,
|
| 630 |
height // self.vae_scale_factor,
|
|
|
|
| 242 |
sample = F.gumbel_softmax(tmp_logits, tau=tau, hard=True)
|
| 243 |
if use_continuous_formulation:
|
| 244 |
sample = gamma * sample + (ONE - MASK) * latents_one_hot
|
| 245 |
+
sample = self._decode_one_hot_latents(sample, height, width, "pt")
|
| 246 |
tmp_rewards[:, phi_i] = reward_fn(sample)
|
| 247 |
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
|
| 248 |
|
|
|
|
| 332 |
tmp_logp_x0 = self.model._subs_parameterization(tmp_logits, latents_batch)
|
| 333 |
for phi_i in range(phi):
|
| 334 |
sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
|
| 335 |
+
sample = self._decode_latents(sample, height, width, "pt")
|
| 336 |
tmp_rewards[:, phi_i] = reward_fn(sample)
|
| 337 |
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
|
| 338 |
|
|
|
|
| 427 |
tmp_logp_x0 = self._subs_parameterization(tmp_logits, latents_batch)
|
| 428 |
for phi_i in range(phi):
|
| 429 |
sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
|
| 430 |
+
sample = self._decode_latents(sample, height, width, "pt")
|
| 431 |
tmp_rewards[:, phi_i] = reward_fn(sample)
|
| 432 |
tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
|
| 433 |
|
|
|
|
| 517 |
for j in range(0, total_particles, batch_p):
|
| 518 |
latents_batch = latents[j:j+batch_p]
|
| 519 |
with torch.no_grad():
|
| 520 |
+
sample = self._decode_latents(latents_batch, height, width, "pt")
|
| 521 |
tmp_rewards = reward_fn(sample)
|
| 522 |
rewards[j:j+batch_p] = tmp_rewards
|
| 523 |
log_twist[j:j+batch_p] = tmp_rewards * scale_cur
|
|
|
|
| 563 |
for j in range(0, batches, batch_p):
|
| 564 |
latents_batch = latents[j:j+batch_p]
|
| 565 |
outputs.extend(
|
| 566 |
+
self._decode_latents(latents_batch, height, width, output_type) # type: ignore
|
| 567 |
)
|
| 568 |
if output_type == "pt":
|
| 569 |
outputs = torch.stack(outputs, dim=0)
|
|
|
|
| 602 |
tmp_logits = torch.cat([tmp_logits, pad_logits], dim=-1)
|
| 603 |
return tmp_logits
|
| 604 |
|
| 605 |
+
def _decode_latents(self, latents, height, width, output_type):
|
| 606 |
+
batch_size = latents.shape[0]
|
| 607 |
if output_type == "latent":
|
| 608 |
output = latents
|
| 609 |
else:
|
|
|
|
| 625 |
self.vqvae.half()
|
| 626 |
return output
|
| 627 |
|
| 628 |
+
def _decode_one_hot_latents(self, latents_one_hot, height, width, output_type):
|
| 629 |
+
batch_size = latents_one_hot.shape[0]
|
| 630 |
shape = (
|
| 631 |
batch_size,
|
| 632 |
height // self.vae_scale_factor,
|