cp524 commited on
Commit
204cd3a
·
1 Parent(s): 547bd75

update app

Browse files
Files changed (1) hide show
  1. app.py +91 -147
app.py CHANGED
@@ -21,39 +21,13 @@ from typing import List
21
  # Adjust the import path if your file is located elsewhere
22
  from src.smc.inference import infer_pretrained, infer_smc_grad, PretrainedInferenceConfig, SMCGradInferenceConfig
23
 
24
- # Global constants (adjust if needed)
25
- MAX_SEED = 2 ** 32 - 1
26
- MAX_IMAGE_SIZE = 1024
27
- DEVICE = "cpu"
28
-
29
-
30
- # Sensible defaults (change to match your model constraints)
31
- DEFAULTS = {
32
- "resolution": 512,
33
- "pretrained_steps": 20,
34
- "pretrained_CFG": 7.5,
35
- "pretrained_num_batches": 1,
36
-
37
- "smc_steps": 20,
38
- "smc_CFG": 7.5,
39
- "smc_num_batches": 1,
40
- "smc_num_particles": 4,
41
- "smc_ess_threshold": 0.5,
42
- "smc_partial_resampling": True,
43
- "smc_resample_frequency": 5,
44
- "smc_kl_weight": 0.1,
45
- "smc_lambda_tempering": False,
46
- "smc_lambda_one_at": 0.5,
47
- "smc_phi": 1,
48
- "smc_tau": 0.1,
49
- }
50
 
51
  examples = [
52
  "A dreamy Monet-style landscape with soft brush strokes",
53
  "Vibrant city street at dawn in impressionist style",
54
  ]
55
 
56
-
57
  def _format_inference_output(out) -> str:
58
  """Return a short summary string for the UI"""
59
  if out is None:
@@ -71,34 +45,28 @@ def run_inference_all(
71
 
72
  # Pretrained method controls
73
  pretrained_negative_prompt,
74
- pretrained_resolution,
75
  pretrained_CFG,
76
  pretrained_steps,
77
- pretrained_num_batches,
78
- pretrained_device,
79
 
80
  # SMC-grad method controls
81
- smc_negative_prompt,
82
- smc_resolution,
83
- smc_CFG,
84
- smc_steps,
85
- smc_num_batches,
86
- smc_num_particles,
87
- smc_ess_threshold,
88
- smc_partial_resampling,
89
- smc_resample_frequency,
90
- smc_kl_weight,
91
- smc_lambda_tempering,
92
- smc_lambda_one_at,
93
- smc_use_continuous_formulation,
94
- smc_phi,
95
- smc_tau,
96
- smc_proposal_type,
97
  ):
98
  """Wrapper that runs both inference methods and returns UI-friendly outputs.
99
 
100
  Returns:
101
- pretrained_images, pretrained_info, smc_images, smc_info
102
  """
103
  # --- Pretrained ---
104
  pretrained_output = None
@@ -107,12 +75,10 @@ def run_inference_all(
107
  pretrained_cfg = PretrainedInferenceConfig(
108
  prompt=prompt,
109
  negative_prompt=pretrained_negative_prompt or "",
110
- resolution=int(pretrained_resolution),
111
  CFG=float(pretrained_CFG),
112
  steps=int(pretrained_steps),
113
- num_batches=int(pretrained_num_batches),
114
  )
115
- pretrained_output = infer_pretrained(pretrained_cfg, device=pretrained_device)
116
  pretrained_images = pretrained_output.images
117
  except Exception as e:
118
  pretrained_images = []
@@ -121,44 +87,41 @@ def run_inference_all(
121
  pretrained_images = [pretrained_error]
122
 
123
  # --- SMC-grad ---
124
- smc_output = None
125
- smc_images = []
126
  try:
127
- smc_cfg = SMCGradInferenceConfig(
128
  prompt=prompt,
129
- negative_prompt=smc_negative_prompt or "",
130
- ess_threshold=float(smc_ess_threshold),
131
- partial_resampling=bool(smc_partial_resampling),
132
- resample_frequency=int(smc_resample_frequency),
133
- resolution=int(smc_resolution),
134
- CFG=float(smc_CFG),
135
- steps=int(smc_steps),
136
- kl_weight=float(smc_kl_weight),
137
- lambda_tempering=bool(smc_lambda_tempering),
138
- lambda_one_at=float(smc_lambda_one_at),
139
- num_batches=int(smc_num_batches),
140
- num_particles=int(smc_num_particles),
141
- proposal_type=str(smc_proposal_type),
142
- use_continuous_formulation=bool(smc_use_continuous_formulation),
143
- phi=int(smc_phi),
144
- tau=float(smc_tau),
145
  )
146
- smc_output = infer_smc_grad(smc_cfg, device=DEVICE)
147
- # The above line is defensive; simpler: pass smc_device value used by gradio - will be provided.
148
  except Exception as e:
149
- smc_images = []
150
- smc_output = None
151
- smc_error = f"SMC inference error: {e}"
152
- smc_images = [smc_error]
153
 
154
  # If outputs are dataclasses with PIL images, gr.Gallery accepts lists of PIL images.
155
  pretrained_gallery = pretrained_images if isinstance(pretrained_images, list) else [pretrained_images]
156
- smc_gallery = smc_output.images if smc_output is not None else smc_images
157
 
158
  pretrained_info = _format_inference_output(pretrained_output)
159
- smc_info = _format_inference_output(smc_output)
160
 
161
- return pretrained_gallery, pretrained_info, smc_gallery, smc_info
162
 
163
 
164
  with gr.Blocks() as demo:
@@ -174,42 +137,35 @@ with gr.Blocks() as demo:
174
  with gr.Row():
175
  with gr.Column(scale=1, min_width=280):
176
  with gr.Accordion("Pretrained method — settings", open=False):
177
- pretrained_negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=1)
178
- pretrained_resolution = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=DEFAULTS["resolution"], label="Resolution")
179
- pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=DEFAULTS["pretrained_CFG"], label="CFG")
180
- pretrained_steps = gr.Slider(1, 200, step=1, value=DEFAULTS["pretrained_steps"], label="Steps")
181
- pretrained_num_batches = gr.Slider(1, 8, step=1, value=DEFAULTS["pretrained_num_batches"], label="Batches")
182
- pretrained_device = gr.Dropdown(choices=["cpu", "cuda"], value=("cuda" if torch.cuda.is_available() else "cpu"), label="Device")
183
 
184
  with gr.Column(scale=2):
185
- pretrained_gallery = gr.Gallery(label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="auto", columns=4)
186
  pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
187
 
188
  # --- SMC-grad method row ---
189
  with gr.Row():
190
  with gr.Column(scale=1, min_width=280):
191
  with gr.Accordion("SMC-grad method — settings", open=False):
192
- smc_negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=1)
193
- smc_resolution = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=DEFAULTS["resolution"], label="Resolution")
194
- smc_CFG = gr.Slider(0.0, 30.0, step=0.1, value=DEFAULTS["smc_CFG"], label="CFG")
195
- smc_steps = gr.Slider(1, 200, step=1, value=DEFAULTS["smc_steps"], label="Steps")
196
- smc_num_batches = gr.Slider(1, 8, step=1, value=DEFAULTS["smc_num_batches"], label="Batches")
197
- smc_num_particles = gr.Slider(1, 64, step=1, value=DEFAULTS["smc_num_particles"], label="Num particles")
198
- smc_ess_threshold = gr.Slider(0.0, 1.0, step=0.01, value=DEFAULTS["smc_ess_threshold"], label="ESS threshold")
199
- smc_partial_resampling = gr.Checkbox(label="Partial resampling", value=DEFAULTS["smc_partial_resampling"])
200
- smc_resample_frequency = gr.Slider(1, 50, step=1, value=DEFAULTS["smc_resample_frequency"], label="Resample frequency")
201
- smc_kl_weight = gr.Slider(0.0, 10.0, step=0.01, value=DEFAULTS["smc_kl_weight"], label="KL weight")
202
- smc_lambda_tempering = gr.Checkbox(label="Lambda tempering", value=DEFAULTS["smc_lambda_tempering"])
203
- smc_lambda_one_at = gr.Slider(0.0, 1.0, step=0.01, value=DEFAULTS["smc_lambda_one_at"], label="Lambda one at (fraction of steps)")
204
- smc_use_continuous_formulation = gr.Checkbox(label="Use continuous formulation", value=True)
205
- smc_phi = gr.Slider(1, 8, step=1, value=DEFAULTS["smc_phi"], label="Phi")
206
- smc_tau = gr.Slider(0.0, 1.0, step=0.001, value=DEFAULTS["smc_tau"], label="Tau")
207
- smc_proposal_type = gr.Dropdown(choices=["locally_optimal", "without_SMC", "other"], value="locally_optimal", label="Proposal type")
208
- smc_device = gr.Dropdown(choices=["cpu", "cuda"], value=("cuda" if torch.cuda.is_available() else "cpu"), label="Device")
209
 
210
  with gr.Column(scale=2):
211
- smc_gallery = gr.Gallery(label="SMC-grad outputs", show_label=True, elem_id="smc_gallery", height="auto", columns=4)
212
- smc_info = gr.Textbox(label="SMC-grad info", interactive=False)
213
 
214
  # Wire up the run button and prompt submit to the same runner
215
  run_button.click(
@@ -218,30 +174,24 @@ with gr.Blocks() as demo:
218
  prompt,
219
 
220
  pretrained_negative_prompt,
221
- pretrained_resolution,
222
  pretrained_CFG,
223
  pretrained_steps,
224
- pretrained_num_batches,
225
- pretrained_device,
226
-
227
- smc_negative_prompt,
228
- smc_resolution,
229
- smc_CFG,
230
- smc_steps,
231
- smc_num_batches,
232
- smc_num_particles,
233
- smc_ess_threshold,
234
- smc_partial_resampling,
235
- smc_resample_frequency,
236
- smc_kl_weight,
237
- smc_lambda_tempering,
238
- smc_lambda_one_at,
239
- smc_use_continuous_formulation,
240
- smc_phi,
241
- smc_tau,
242
- smc_proposal_type,
243
  ],
244
- outputs=[pretrained_gallery, pretrained_info, smc_gallery, smc_info],
245
  )
246
 
247
  # Also allow pressing Enter in the prompt to trigger
@@ -251,32 +201,26 @@ with gr.Blocks() as demo:
251
  prompt,
252
 
253
  pretrained_negative_prompt,
254
- pretrained_resolution,
255
  pretrained_CFG,
256
  pretrained_steps,
257
- pretrained_num_batches,
258
- pretrained_device,
259
-
260
- smc_negative_prompt,
261
- smc_resolution,
262
- smc_CFG,
263
- smc_steps,
264
- smc_num_batches,
265
- smc_num_particles,
266
- smc_ess_threshold,
267
- smc_partial_resampling,
268
- smc_resample_frequency,
269
- smc_kl_weight,
270
- smc_lambda_tempering,
271
- smc_lambda_one_at,
272
- smc_use_continuous_formulation,
273
- smc_phi,
274
- smc_tau,
275
- smc_proposal_type,
276
  ],
277
- outputs=[pretrained_gallery, pretrained_info, smc_gallery, smc_info],
278
  )
279
 
280
 
281
  if __name__ == "__main__":
282
- demo.launch()
 
21
  # Adjust the import path if your file is located elsewhere
22
  from src.smc.inference import infer_pretrained, infer_smc_grad, PretrainedInferenceConfig, SMCGradInferenceConfig
23
 
24
+ DEVICE = "cuda"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  examples = [
27
  "A dreamy Monet-style landscape with soft brush strokes",
28
  "Vibrant city street at dawn in impressionist style",
29
  ]
30
 
 
31
  def _format_inference_output(out) -> str:
32
  """Return a short summary string for the UI"""
33
  if out is None:
 
45
 
46
  # Pretrained method controls
47
  pretrained_negative_prompt,
 
48
  pretrained_CFG,
49
  pretrained_steps,
 
 
50
 
51
  # SMC-grad method controls
52
+ smc_grad_negative_prompt,
53
+ smc_grad_CFG,
54
+ smc_grad_steps,
55
+ smc_grad_num_particles,
56
+ smc_grad_ess_threshold,
57
+ smc_grad_partial_resampling,
58
+ smc_grad_resample_frequency,
59
+ smc_grad_kl_weight,
60
+ smc_grad_lambda_tempering,
61
+ smc_grad_lambda_one_at,
62
+ smc_grad_use_continuous_formulation,
63
+ smc_grad_phi,
64
+ smc_grad_tau,
 
 
 
65
  ):
66
  """Wrapper that runs both inference methods and returns UI-friendly outputs.
67
 
68
  Returns:
69
+ pretrained_images, pretrained_info, smc_grad_images, smc_grad_info
70
  """
71
  # --- Pretrained ---
72
  pretrained_output = None
 
75
  pretrained_cfg = PretrainedInferenceConfig(
76
  prompt=prompt,
77
  negative_prompt=pretrained_negative_prompt or "",
 
78
  CFG=float(pretrained_CFG),
79
  steps=int(pretrained_steps),
 
80
  )
81
+ pretrained_output = infer_pretrained(pretrained_cfg, device=DEVICE)
82
  pretrained_images = pretrained_output.images
83
  except Exception as e:
84
  pretrained_images = []
 
87
  pretrained_images = [pretrained_error]
88
 
89
  # --- SMC-grad ---
90
+ smc_grad_output = None
91
+ smc_grad_images = []
92
  try:
93
+ smc_grad_cfg = SMCGradInferenceConfig(
94
  prompt=prompt,
95
+ negative_prompt=smc_grad_negative_prompt or "",
96
+ ess_threshold=float(smc_grad_ess_threshold),
97
+ partial_resampling=bool(smc_grad_partial_resampling),
98
+ resample_frequency=int(smc_grad_resample_frequency),
99
+ CFG=float(smc_grad_CFG),
100
+ steps=int(smc_grad_steps),
101
+ kl_weight=float(smc_grad_kl_weight),
102
+ lambda_tempering=bool(smc_grad_lambda_tempering),
103
+ lambda_one_at=float(smc_grad_lambda_one_at),
104
+ num_particles=int(smc_grad_num_particles),
105
+ use_continuous_formulation=bool(smc_grad_use_continuous_formulation),
106
+ phi=int(smc_grad_phi),
107
+ tau=float(smc_grad_tau),
 
 
 
108
  )
109
+ smc_grad_output = infer_smc_grad(smc_grad_cfg, device=DEVICE)
110
+ # The above line is defensive; simpler: pass smc_grad_device value used by gradio - will be provided.
111
  except Exception as e:
112
+ smc_grad_images = []
113
+ smc_grad_output = None
114
+ smc_grad_error = f"SMC inference error: {e}"
115
+ smc_grad_images = [smc_grad_error]
116
 
117
  # If outputs are dataclasses with PIL images, gr.Gallery accepts lists of PIL images.
118
  pretrained_gallery = pretrained_images if isinstance(pretrained_images, list) else [pretrained_images]
119
+ smc_grad_gallery = smc_grad_output.images if smc_grad_output is not None else smc_grad_images
120
 
121
  pretrained_info = _format_inference_output(pretrained_output)
122
+ smc_grad_info = _format_inference_output(smc_grad_output)
123
 
124
+ return pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info
125
 
126
 
127
  with gr.Blocks() as demo:
 
137
  with gr.Row():
138
  with gr.Column(scale=1, min_width=280):
139
  with gr.Accordion("Pretrained method — settings", open=False):
140
+ pretrained_negative_prompt = gr.Textbox(label="Negative prompt", value=PretrainedInferenceConfig.negative_prompt, lines=1)
141
+ pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=PretrainedInferenceConfig.CFG, label="CFG")
142
+ pretrained_steps = gr.Slider(1, 200, step=1, value=PretrainedInferenceConfig.steps, label="Steps")
 
 
 
143
 
144
  with gr.Column(scale=2):
145
+ pretrained_gallery = gr.Gallery(label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="70vw", columns=4)
146
  pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
147
 
148
  # --- SMC-grad method row ---
149
  with gr.Row():
150
  with gr.Column(scale=1, min_width=280):
151
  with gr.Accordion("SMC-grad method — settings", open=False):
152
+ smc_grad_negative_prompt = gr.Textbox(label="Negative prompt", value=SMCGradInferenceConfig.negative_prompt, lines=1)
153
+ smc_grad_CFG = gr.Slider(0.0, 30.0, step=0.1, value=SMCGradInferenceConfig.CFG, label="CFG")
154
+ smc_grad_steps = gr.Slider(1, 200, step=1, value=SMCGradInferenceConfig.steps, label="Steps")
155
+ smc_grad_num_particles = gr.Slider(1, 64, step=1, value=SMCGradInferenceConfig.num_particles, label="SMC Num particles")
156
+ smc_grad_ess_threshold = gr.Slider(0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.ess_threshold, label="ESS threshold")
157
+ smc_grad_partial_resampling = gr.Checkbox(label="Partial resampling", value=SMCGradInferenceConfig.partial_resampling)
158
+ smc_grad_resample_frequency = gr.Slider(1, 50, step=1, value=SMCGradInferenceConfig.resample_frequency, label="Resample frequency")
159
+ smc_grad_kl_weight = gr.Slider(0.0, 10.0, step=0.01, value=SMCGradInferenceConfig.kl_weight, label="KL weight")
160
+ smc_grad_lambda_tempering = gr.Checkbox(label="Lambda tempering", value=SMCGradInferenceConfig.lambda_tempering)
161
+ smc_grad_lambda_one_at = gr.Slider(0.0, 1.0, step=0.01, value=SMCGradInferenceConfig.lambda_one_at, label="Lambda one at (fraction of steps)")
162
+ smc_grad_use_continuous_formulation = gr.Checkbox(label="Use continuous formulation", value=SMCGradInferenceConfig.use_continuous_formulation)
163
+ smc_grad_phi = gr.Slider(1, 8, step=1, value=SMCGradInferenceConfig.phi, label="Phi")
164
+ smc_grad_tau = gr.Slider(0.0, 1.0, step=0.001, value=SMCGradInferenceConfig.tau, label="Tau")
 
 
 
 
165
 
166
  with gr.Column(scale=2):
167
+ smc_grad_gallery = gr.Gallery(label="SMC-grad outputs", show_label=True, elem_id="smc_grad_gallery", height="70vw", columns=4)
168
+ smc_grad_info = gr.Textbox(label="SMC-grad info", interactive=False)
169
 
170
  # Wire up the run button and prompt submit to the same runner
171
  run_button.click(
 
174
  prompt,
175
 
176
  pretrained_negative_prompt,
 
177
  pretrained_CFG,
178
  pretrained_steps,
179
+
180
+ smc_grad_negative_prompt,
181
+ smc_grad_CFG,
182
+ smc_grad_steps,
183
+ smc_grad_num_particles,
184
+ smc_grad_ess_threshold,
185
+ smc_grad_partial_resampling,
186
+ smc_grad_resample_frequency,
187
+ smc_grad_kl_weight,
188
+ smc_grad_lambda_tempering,
189
+ smc_grad_lambda_one_at,
190
+ smc_grad_use_continuous_formulation,
191
+ smc_grad_phi,
192
+ smc_grad_tau,
 
 
 
 
 
193
  ],
194
+ outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
195
  )
196
 
197
  # Also allow pressing Enter in the prompt to trigger
 
201
  prompt,
202
 
203
  pretrained_negative_prompt,
 
204
  pretrained_CFG,
205
  pretrained_steps,
206
+
207
+ smc_grad_negative_prompt,
208
+ smc_grad_CFG,
209
+ smc_grad_steps,
210
+ smc_grad_num_particles,
211
+ smc_grad_ess_threshold,
212
+ smc_grad_partial_resampling,
213
+ smc_grad_resample_frequency,
214
+ smc_grad_kl_weight,
215
+ smc_grad_lambda_tempering,
216
+ smc_grad_lambda_one_at,
217
+ smc_grad_use_continuous_formulation,
218
+ smc_grad_phi,
219
+ smc_grad_tau,
 
 
 
 
 
220
  ],
221
+ outputs=[pretrained_gallery, pretrained_info, smc_grad_gallery, smc_grad_info],
222
  )
223
 
224
 
225
  if __name__ == "__main__":
226
+ demo.launch(share=True)