cp524 commited on
Commit
05f9f55
·
1 Parent(s): 971c192

Fix decoding batch size bug

Browse files
Files changed (2) hide show
  1. app.py +3 -15
  2. src/smc/pipeline.py +9 -7
app.py CHANGED
@@ -1,18 +1,4 @@
1
- """
2
- Gradio app to compare multiple inference methods for Monetico model.
3
-
4
- This file wires your existing inference functions (infer_pretrained, infer_smc_grad)
5
- into a single UI with one shared prompt and per-method collapsed setting panels.
6
-
7
- Place this file at repository root (next to src/) and run:
8
- python app.py
9
-
10
- Notes:
11
- - The code assumes your module that contains infer_pretrained and infer_smc_grad
12
- is importable (e.g. package root with src/ on PYTHONPATH). Adjust imports if needed.
13
- - Defaults provided are reasonable starting points; tweak as you like.
14
- """
15
-
16
  import gradio as gr
17
 
18
  # Import your inference functions and dataclasses
@@ -79,6 +65,7 @@ def run_inference_all(
79
  pretrained_output = infer_pretrained(pretrained_cfg, device=DEVICE)
80
  pretrained_images = pretrained_output.images
81
  except Exception as e:
 
82
  pretrained_images = []
83
  pretrained_output = None
84
  pretrained_error = f"Pretrained inference error: {e}"
@@ -107,6 +94,7 @@ def run_inference_all(
107
  smc_grad_output = infer_smc_grad(smc_grad_cfg, device=DEVICE)
108
  # The above line is defensive; simpler: pass smc_grad_device value used by gradio - will be provided.
109
  except Exception as e:
 
110
  smc_grad_images = []
111
  smc_grad_output = None
112
  smc_grad_error = f"SMC inference error: {e}"
 
1
+ import traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
 
4
  # Import your inference functions and dataclasses
 
65
  pretrained_output = infer_pretrained(pretrained_cfg, device=DEVICE)
66
  pretrained_images = pretrained_output.images
67
  except Exception as e:
68
+ traceback.print_exc()
69
  pretrained_images = []
70
  pretrained_output = None
71
  pretrained_error = f"Pretrained inference error: {e}"
 
94
  smc_grad_output = infer_smc_grad(smc_grad_cfg, device=DEVICE)
95
  # The above line is defensive; simpler: pass smc_grad_device value used by gradio - will be provided.
96
  except Exception as e:
97
+ traceback.print_exc()
98
  smc_grad_images = []
99
  smc_grad_output = None
100
  smc_grad_error = f"SMC inference error: {e}"
src/smc/pipeline.py CHANGED
@@ -242,7 +242,7 @@ class Pipeline(
242
  sample = F.gumbel_softmax(tmp_logits, tau=tau, hard=True)
243
  if use_continuous_formulation:
244
  sample = gamma * sample + (ONE - MASK) * latents_one_hot
245
- sample = self._decode_one_hot_latents(sample, batch_p, height, width, "pt")
246
  tmp_rewards[:, phi_i] = reward_fn(sample)
247
  tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
248
 
@@ -332,7 +332,7 @@ class Pipeline(
332
  tmp_logp_x0 = self.model._subs_parameterization(tmp_logits, latents_batch)
333
  for phi_i in range(phi):
334
  sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
335
- sample = self._decode_latents(sample, batch_p, height, width, "pt")
336
  tmp_rewards[:, phi_i] = reward_fn(sample)
337
  tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
338
 
@@ -427,7 +427,7 @@ class Pipeline(
427
  tmp_logp_x0 = self._subs_parameterization(tmp_logits, latents_batch)
428
  for phi_i in range(phi):
429
  sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
430
- sample = self._decode_latents(sample, batch_p, height, width, "pt")
431
  tmp_rewards[:, phi_i] = reward_fn(sample)
432
  tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
433
 
@@ -517,7 +517,7 @@ class Pipeline(
517
  for j in range(0, total_particles, batch_p):
518
  latents_batch = latents[j:j+batch_p]
519
  with torch.no_grad():
520
- sample = self._decode_latents(latents_batch, batch_p, height, width, "pt")
521
  tmp_rewards = reward_fn(sample)
522
  rewards[j:j+batch_p] = tmp_rewards
523
  log_twist[j:j+batch_p] = tmp_rewards * scale_cur
@@ -563,7 +563,7 @@ class Pipeline(
563
  for j in range(0, batches, batch_p):
564
  latents_batch = latents[j:j+batch_p]
565
  outputs.extend(
566
- self._decode_latents(latents_batch, batch_p, height, width, output_type) # type: ignore
567
  )
568
  if output_type == "pt":
569
  outputs = torch.stack(outputs, dim=0)
@@ -602,7 +602,8 @@ class Pipeline(
602
  tmp_logits = torch.cat([tmp_logits, pad_logits], dim=-1)
603
  return tmp_logits
604
 
605
- def _decode_latents(self, latents, batch_size, height, width, output_type):
 
606
  if output_type == "latent":
607
  output = latents
608
  else:
@@ -624,7 +625,8 @@ class Pipeline(
624
  self.vqvae.half()
625
  return output
626
 
627
- def _decode_one_hot_latents(self, latents_one_hot, batch_size, height, width, output_type):
 
628
  shape = (
629
  batch_size,
630
  height // self.vae_scale_factor,
 
242
  sample = F.gumbel_softmax(tmp_logits, tau=tau, hard=True)
243
  if use_continuous_formulation:
244
  sample = gamma * sample + (ONE - MASK) * latents_one_hot
245
+ sample = self._decode_one_hot_latents(sample, height, width, "pt")
246
  tmp_rewards[:, phi_i] = reward_fn(sample)
247
  tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
248
 
 
332
  tmp_logp_x0 = self.model._subs_parameterization(tmp_logits, latents_batch)
333
  for phi_i in range(phi):
334
  sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
335
+ sample = self._decode_latents(sample, height, width, "pt")
336
  tmp_rewards[:, phi_i] = reward_fn(sample)
337
  tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
338
 
 
427
  tmp_logp_x0 = self._subs_parameterization(tmp_logits, latents_batch)
428
  for phi_i in range(phi):
429
  sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
430
+ sample = self._decode_latents(sample, height, width, "pt")
431
  tmp_rewards[:, phi_i] = reward_fn(sample)
432
  tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
433
 
 
517
  for j in range(0, total_particles, batch_p):
518
  latents_batch = latents[j:j+batch_p]
519
  with torch.no_grad():
520
+ sample = self._decode_latents(latents_batch, height, width, "pt")
521
  tmp_rewards = reward_fn(sample)
522
  rewards[j:j+batch_p] = tmp_rewards
523
  log_twist[j:j+batch_p] = tmp_rewards * scale_cur
 
563
  for j in range(0, batches, batch_p):
564
  latents_batch = latents[j:j+batch_p]
565
  outputs.extend(
566
+ self._decode_latents(latents_batch, height, width, output_type) # type: ignore
567
  )
568
  if output_type == "pt":
569
  outputs = torch.stack(outputs, dim=0)
 
602
  tmp_logits = torch.cat([tmp_logits, pad_logits], dim=-1)
603
  return tmp_logits
604
 
605
+ def _decode_latents(self, latents, height, width, output_type):
606
+ batch_size = latents.shape[0]
607
  if output_type == "latent":
608
  output = latents
609
  else:
 
625
  self.vqvae.half()
626
  return output
627
 
628
+ def _decode_one_hot_latents(self, latents_one_hot, height, width, output_type):
629
+ batch_size = latents_one_hot.shape[0]
630
  shape = (
631
  batch_size,
632
  height // self.vae_scale_factor,