cp524 commited on
Commit
cbf192b
·
1 Parent(s): b113524

Init app.py

Browse files
Files changed (1) hide show
  1. app.py +278 -19
app.py CHANGED
@@ -1,23 +1,282 @@
1
- import os
2
- import sys
3
- import time
4
- import json
5
- from dataclasses import dataclass, asdict
6
- from typing import Optional, Dict, Any, List
7
- import torch
8
- from PIL import Image
 
 
 
 
 
 
 
9
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- from src.smc.transformer import Transformer2DModel
12
- from src.smc.pipeline import Pipeline
13
- from src.meissonic.scheduler import Scheduler
14
- from src.smc.scheduler import ReMDMScheduler, MeissonicScheduler
15
- from transformers import CLIPTextModelWithProjection, CLIPTokenizer
16
- from diffusers.models.autoencoders.vq_model import VQModel
17
- import src.smc.rewards as rewards
18
- from src.smc.resampling import resample
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- device = "cuda"
21
- dtype = torch.bfloat16
22
- model_path = "Collov-Labs/Monetico"
23
 
 
 
 
1
+ """
2
+ Gradio app to compare multiple inference methods for Monetico model.
3
+
4
+ This file wires your existing inference functions (infer_pretrained, infer_smc_grad)
5
+ into a single UI with one shared prompt and per-method collapsed setting panels.
6
+
7
+ Place this file at repository root (next to src/) and run:
8
+ python app.py
9
+
10
+ Notes:
11
+ - The code assumes your module that contains infer_pretrained and infer_smc_grad
12
+ is importable (e.g. package root with src/ on PYTHONPATH). Adjust imports if needed.
13
+ - Defaults provided are reasonable starting points; tweak as you like.
14
+ """
15
+
16
  import gradio as gr
17
+ import torch
18
+ from typing import List
19
+
20
+ # Import your inference functions and dataclasses
21
+ # Adjust the import path if your file is located elsewhere
22
+ from src.smc.inference import infer_pretrained, infer_smc_grad, PretrainedInferenceConfig, SMCGradInferenceConfig
23
+
24
+ # Global constants (adjust if needed)
25
+ MAX_SEED = 2 ** 32 - 1
26
+ MAX_IMAGE_SIZE = 1024
27
+ DEVICE = "cpu"
28
+
29
+
30
+ # Sensible defaults (change to match your model constraints)
31
+ DEFAULTS = {
32
+ "resolution": 512,
33
+ "pretrained_steps": 20,
34
+ "pretrained_CFG": 7.5,
35
+ "pretrained_num_batches": 1,
36
+
37
+ "smc_steps": 20,
38
+ "smc_CFG": 7.5,
39
+ "smc_num_batches": 1,
40
+ "smc_num_particles": 4,
41
+ "smc_ess_threshold": 0.5,
42
+ "smc_partial_resampling": True,
43
+ "smc_resample_frequency": 5,
44
+ "smc_kl_weight": 0.1,
45
+ "smc_lambda_tempering": False,
46
+ "smc_lambda_one_at": 0.5,
47
+ "smc_phi": 1,
48
+ "smc_tau": 0.1,
49
+ }
50
+
51
+ examples = [
52
+ "A dreamy Monet-style landscape with soft brush strokes",
53
+ "Vibrant city street at dawn in impressionist style",
54
+ ]
55
+
56
+
57
+ def _format_inference_output(out) -> str:
58
+ """Return a short summary string for the UI"""
59
+ if out is None:
60
+ return "No output"
61
+ try:
62
+ rewards = out.image_rewards
63
+ mem = out.gpu_mem_used
64
+ return f"Rewards: {rewards} | GPU mem (GB): {mem:.3f}"
65
+ except Exception:
66
+ return "Could not parse inference output"
67
+
68
+
69
+ def run_inference_all(
70
+ prompt,
71
+
72
+ # Pretrained method controls
73
+ pretrained_negative_prompt,
74
+ pretrained_resolution,
75
+ pretrained_CFG,
76
+ pretrained_steps,
77
+ pretrained_num_batches,
78
+ pretrained_device,
79
+
80
+ # SMC-grad method controls
81
+ smc_negative_prompt,
82
+ smc_resolution,
83
+ smc_CFG,
84
+ smc_steps,
85
+ smc_num_batches,
86
+ smc_num_particles,
87
+ smc_ess_threshold,
88
+ smc_partial_resampling,
89
+ smc_resample_frequency,
90
+ smc_kl_weight,
91
+ smc_lambda_tempering,
92
+ smc_lambda_one_at,
93
+ smc_use_continuous_formulation,
94
+ smc_phi,
95
+ smc_tau,
96
+ smc_proposal_type,
97
+ ):
98
+ """Wrapper that runs both inference methods and returns UI-friendly outputs.
99
+
100
+ Returns:
101
+ pretrained_images, pretrained_info, smc_images, smc_info
102
+ """
103
+ # --- Pretrained ---
104
+ pretrained_output = None
105
+ pretrained_images = []
106
+ try:
107
+ pretrained_cfg = PretrainedInferenceConfig(
108
+ prompt=prompt,
109
+ negative_prompt=pretrained_negative_prompt or "",
110
+ resolution=int(pretrained_resolution),
111
+ CFG=float(pretrained_CFG),
112
+ steps=int(pretrained_steps),
113
+ num_batches=int(pretrained_num_batches),
114
+ )
115
+ pretrained_output = infer_pretrained(pretrained_cfg, device=pretrained_device)
116
+ pretrained_images = pretrained_output.images
117
+ except Exception as e:
118
+ pretrained_images = []
119
+ pretrained_output = None
120
+ pretrained_error = f"Pretrained inference error: {e}"
121
+ pretrained_images = [pretrained_error]
122
+
123
+ # --- SMC-grad ---
124
+ smc_output = None
125
+ smc_images = []
126
+ try:
127
+ smc_cfg = SMCGradInferenceConfig(
128
+ prompt=prompt,
129
+ negative_prompt=smc_negative_prompt or "",
130
+ ess_threshold=float(smc_ess_threshold),
131
+ partial_resampling=bool(smc_partial_resampling),
132
+ resample_frequency=int(smc_resample_frequency),
133
+ resolution=int(smc_resolution),
134
+ CFG=float(smc_CFG),
135
+ steps=int(smc_steps),
136
+ kl_weight=float(smc_kl_weight),
137
+ lambda_tempering=bool(smc_lambda_tempering),
138
+ lambda_one_at=float(smc_lambda_one_at),
139
+ num_batches=int(smc_num_batches),
140
+ num_particles=int(smc_num_particles),
141
+ proposal_type=str(smc_proposal_type),
142
+ use_continuous_formulation=bool(smc_use_continuous_formulation),
143
+ phi=int(smc_phi),
144
+ tau=float(smc_tau),
145
+ )
146
+ smc_output = infer_smc_grad(smc_cfg, device=DEVICE)
147
+ # The above line is defensive; simpler: pass smc_device value used by gradio - will be provided.
148
+ except Exception as e:
149
+ smc_images = []
150
+ smc_output = None
151
+ smc_error = f"SMC inference error: {e}"
152
+ smc_images = [smc_error]
153
+
154
+ # If outputs are dataclasses with PIL images, gr.Gallery accepts lists of PIL images.
155
+ pretrained_gallery = pretrained_images if isinstance(pretrained_images, list) else [pretrained_images]
156
+ smc_gallery = smc_output.images if smc_output is not None else smc_images
157
+
158
+ pretrained_info = _format_inference_output(pretrained_output)
159
+ smc_info = _format_inference_output(smc_output)
160
+
161
+ return pretrained_gallery, pretrained_info, smc_gallery, smc_info
162
+
163
+
164
+ with gr.Blocks() as demo:
165
+ gr.Markdown("# Monetico — Multi-method Inference Playground")
166
+
167
+ with gr.Row():
168
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt here", value=examples[0], lines=1)
169
+ run_button = gr.Button("Run", variant="primary")
170
+
171
+ gr.Examples(examples=examples, inputs=prompt)
172
+
173
+ # --- Pretrained method row ---
174
+ with gr.Row():
175
+ with gr.Column(scale=1, min_width=280):
176
+ with gr.Accordion("Pretrained method — settings", open=False):
177
+ pretrained_negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=1)
178
+ pretrained_resolution = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=DEFAULTS["resolution"], label="Resolution")
179
+ pretrained_CFG = gr.Slider(0.0, 30.0, step=0.1, value=DEFAULTS["pretrained_CFG"], label="CFG")
180
+ pretrained_steps = gr.Slider(1, 200, step=1, value=DEFAULTS["pretrained_steps"], label="Steps")
181
+ pretrained_num_batches = gr.Slider(1, 8, step=1, value=DEFAULTS["pretrained_num_batches"], label="Batches")
182
+ pretrained_device = gr.Dropdown(choices=["cpu", "cuda"], value=("cuda" if torch.cuda.is_available() else "cpu"), label="Device")
183
+
184
+ with gr.Column(scale=2):
185
+ pretrained_gallery = gr.Gallery(label="Pretrained outputs", show_label=True, elem_id="pretrained_gallery", height="auto", columns=4)
186
+ pretrained_info = gr.Textbox(label="Pretrained info", interactive=False)
187
+
188
+ # --- SMC-grad method row ---
189
+ with gr.Row():
190
+ with gr.Column(scale=1, min_width=280):
191
+ with gr.Accordion("SMC-grad method — settings", open=False):
192
+ smc_negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=1)
193
+ smc_resolution = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=DEFAULTS["resolution"], label="Resolution")
194
+ smc_CFG = gr.Slider(0.0, 30.0, step=0.1, value=DEFAULTS["smc_CFG"], label="CFG")
195
+ smc_steps = gr.Slider(1, 200, step=1, value=DEFAULTS["smc_steps"], label="Steps")
196
+ smc_num_batches = gr.Slider(1, 8, step=1, value=DEFAULTS["smc_num_batches"], label="Batches")
197
+ smc_num_particles = gr.Slider(1, 64, step=1, value=DEFAULTS["smc_num_particles"], label="Num particles")
198
+ smc_ess_threshold = gr.Slider(0.0, 1.0, step=0.01, value=DEFAULTS["smc_ess_threshold"], label="ESS threshold")
199
+ smc_partial_resampling = gr.Checkbox(label="Partial resampling", value=DEFAULTS["smc_partial_resampling"])
200
+ smc_resample_frequency = gr.Slider(1, 50, step=1, value=DEFAULTS["smc_resample_frequency"], label="Resample frequency")
201
+ smc_kl_weight = gr.Slider(0.0, 10.0, step=0.01, value=DEFAULTS["smc_kl_weight"], label="KL weight")
202
+ smc_lambda_tempering = gr.Checkbox(label="Lambda tempering", value=DEFAULTS["smc_lambda_tempering"])
203
+ smc_lambda_one_at = gr.Slider(0.0, 1.0, step=0.01, value=DEFAULTS["smc_lambda_one_at"], label="Lambda one at (fraction of steps)")
204
+ smc_use_continuous_formulation = gr.Checkbox(label="Use continuous formulation", value=True)
205
+ smc_phi = gr.Slider(1, 8, step=1, value=DEFAULTS["smc_phi"], label="Phi")
206
+ smc_tau = gr.Slider(0.0, 1.0, step=0.001, value=DEFAULTS["smc_tau"], label="Tau")
207
+ smc_proposal_type = gr.Dropdown(choices=["locally_optimal", "without_SMC", "other"], value="locally_optimal", label="Proposal type")
208
+ smc_device = gr.Dropdown(choices=["cpu", "cuda"], value=("cuda" if torch.cuda.is_available() else "cpu"), label="Device")
209
+
210
+ with gr.Column(scale=2):
211
+ smc_gallery = gr.Gallery(label="SMC-grad outputs", show_label=True, elem_id="smc_gallery", height="auto", columns=4)
212
+ smc_info = gr.Textbox(label="SMC-grad info", interactive=False)
213
+
214
+ # Wire up the run button and prompt submit to the same runner
215
+ run_button.click(
216
+ fn=run_inference_all,
217
+ inputs=[
218
+ prompt,
219
+
220
+ pretrained_negative_prompt,
221
+ pretrained_resolution,
222
+ pretrained_CFG,
223
+ pretrained_steps,
224
+ pretrained_num_batches,
225
+ pretrained_device,
226
+
227
+ smc_negative_prompt,
228
+ smc_resolution,
229
+ smc_CFG,
230
+ smc_steps,
231
+ smc_num_batches,
232
+ smc_num_particles,
233
+ smc_ess_threshold,
234
+ smc_partial_resampling,
235
+ smc_resample_frequency,
236
+ smc_kl_weight,
237
+ smc_lambda_tempering,
238
+ smc_lambda_one_at,
239
+ smc_use_continuous_formulation,
240
+ smc_phi,
241
+ smc_tau,
242
+ smc_proposal_type,
243
+ ],
244
+ outputs=[pretrained_gallery, pretrained_info, smc_gallery, smc_info],
245
+ )
246
+
247
+ # Also allow pressing Enter in the prompt to trigger
248
+ prompt.submit(
249
+ fn=run_inference_all,
250
+ inputs=[
251
+ prompt,
252
+
253
+ pretrained_negative_prompt,
254
+ pretrained_resolution,
255
+ pretrained_CFG,
256
+ pretrained_steps,
257
+ pretrained_num_batches,
258
+ pretrained_device,
259
 
260
+ smc_negative_prompt,
261
+ smc_resolution,
262
+ smc_CFG,
263
+ smc_steps,
264
+ smc_num_batches,
265
+ smc_num_particles,
266
+ smc_ess_threshold,
267
+ smc_partial_resampling,
268
+ smc_resample_frequency,
269
+ smc_kl_weight,
270
+ smc_lambda_tempering,
271
+ smc_lambda_one_at,
272
+ smc_use_continuous_formulation,
273
+ smc_phi,
274
+ smc_tau,
275
+ smc_proposal_type,
276
+ ],
277
+ outputs=[pretrained_gallery, pretrained_info, smc_gallery, smc_info],
278
+ )
279
 
 
 
 
280
 
281
+ if __name__ == "__main__":
282
+ demo.launch()