AbstractPhil commited on
Commit
eb8e393
·
verified ·
1 Parent(s): 5404f3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +596 -130
app.py CHANGED
@@ -1,154 +1,620 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import random
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
1
+ """
2
+ Lyra/Lune Flow-Matching Inference Space
3
+ Author: AbstractPhil
4
+ License: MIT
5
+
6
+ SD1.5-based flow matching with geometric crystalline architectures.
7
+ """
8
+
9
+ import os
10
+ import torch
11
  import gradio as gr
12
  import numpy as np
13
+ from PIL import Image
14
+ from typing import Optional, Dict
15
+ import spaces
16
 
17
+ from diffusers import (
18
+ UNet2DConditionModel,
19
+ AutoencoderKL,
20
+ DPMSolverMultistepScheduler,
21
+ EulerDiscreteScheduler
22
+ )
23
+ from transformers import CLIPTextModel, CLIPTokenizer
24
+ from huggingface_hub import hf_hub_download
25
 
 
 
26
 
27
+ # ============================================================================
28
+ # MODEL LOADING
29
+ # ============================================================================
 
30
 
31
+ class FlowMatchingPipeline:
32
+ """Custom pipeline for flow-matching inference."""
33
+
34
+ def __init__(
35
+ self,
36
+ vae: AutoencoderKL,
37
+ text_encoder: CLIPTextModel,
38
+ tokenizer: CLIPTokenizer,
39
+ unet: UNet2DConditionModel,
40
+ scheduler,
41
+ device: str = "cuda"
42
+ ):
43
+ self.vae = vae
44
+ self.text_encoder = text_encoder
45
+ self.tokenizer = tokenizer
46
+ self.unet = unet
47
+ self.scheduler = scheduler
48
+ self.device = device
49
+
50
+ # VAE scaling factor
51
+ self.vae_scale_factor = 0.18215
52
+
53
+ def encode_prompt(self, prompt: str, negative_prompt: str = ""):
54
+ """Encode text prompts to embeddings."""
55
+ # Positive prompt
56
+ text_inputs = self.tokenizer(
57
+ prompt,
58
+ padding="max_length",
59
+ max_length=self.tokenizer.model_max_length,
60
+ truncation=True,
61
+ return_tensors="pt",
62
+ )
63
+ text_input_ids = text_inputs.input_ids.to(self.device)
64
+
65
+ with torch.no_grad():
66
+ prompt_embeds = self.text_encoder(text_input_ids)[0]
67
+
68
+ # Negative prompt
69
+ if negative_prompt:
70
+ uncond_inputs = self.tokenizer(
71
+ negative_prompt,
72
+ padding="max_length",
73
+ max_length=self.tokenizer.model_max_length,
74
+ truncation=True,
75
+ return_tensors="pt",
76
+ )
77
+ uncond_input_ids = uncond_inputs.input_ids.to(self.device)
78
+
79
+ with torch.no_grad():
80
+ negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0]
81
+ else:
82
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
83
+
84
+ return prompt_embeds, negative_prompt_embeds
85
+
86
+ @torch.no_grad()
87
+ def __call__(
88
+ self,
89
+ prompt: str,
90
+ negative_prompt: str = "",
91
+ height: int = 512,
92
+ width: int = 512,
93
+ num_inference_steps: int = 20,
94
+ guidance_scale: float = 7.5,
95
+ shift: float = 2.5,
96
+ use_flow_matching: bool = True,
97
+ prediction_type: str = "epsilon",
98
+ seed: Optional[int] = None,
99
+ progress_callback=None
100
+ ):
101
+ """Generate image using flow matching or standard diffusion."""
102
+
103
+ # Set seed
104
+ if seed is not None:
105
+ generator = torch.Generator(device=self.device).manual_seed(seed)
106
+ else:
107
+ generator = None
108
+
109
+ # Encode prompts
110
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
111
+ prompt, negative_prompt
112
+ )
113
+
114
+ # Prepare latents
115
+ latent_channels = 4
116
+ latent_height = height // 8
117
+ latent_width = width // 8
118
+
119
+ latents = torch.randn(
120
+ (1, latent_channels, latent_height, latent_width),
121
+ generator=generator,
122
+ device=self.device,
123
+ dtype=torch.float32
124
+ )
125
+
126
+ # Set timesteps
127
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
128
+ timesteps = self.scheduler.timesteps
129
+
130
+ # Denoising loop
131
+ for i, t in enumerate(timesteps):
132
+ if progress_callback:
133
+ progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
134
+
135
+ # Expand latents for classifier-free guidance
136
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
137
+
138
+ # Apply shift for flow matching
139
+ if use_flow_matching and shift > 0:
140
+ # Compute sigma from timestep with shift
141
+ sigma = t.float() / 1000.0
142
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
143
+
144
+ # Scale latent input
145
+ scaling = torch.sqrt(1 + sigma_shifted ** 2)
146
+ latent_model_input = latent_model_input / scaling
147
+
148
+ # Prepare timestep
149
+ timestep = t.expand(latent_model_input.shape[0])
150
+
151
+ # Predict noise/velocity
152
+ text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds
153
+
154
+ noise_pred = self.unet(
155
+ latent_model_input,
156
+ timestep,
157
+ encoder_hidden_states=text_embeds,
158
+ return_dict=False
159
+ )[0]
160
+
161
+ # Classifier-free guidance
162
+ if guidance_scale > 1.0:
163
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
164
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
165
+
166
+ # Flow matching step
167
+ if use_flow_matching:
168
+ # Manual flow matching update
169
+ sigma = t.float() / 1000.0
170
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
171
+
172
+ if prediction_type == "v_prediction":
173
+ # Convert v-prediction to epsilon
174
+ v_pred = noise_pred
175
+ alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
176
+ sigma_t = sigma_shifted
177
+ noise_pred = alpha_t * v_pred + sigma_t * latents
178
+
179
+ # Compute next latent
180
+ dt = -1.0 / num_inference_steps
181
+ latents = latents + dt * noise_pred
182
+ else:
183
+ # Standard scheduler step
184
+ latents = self.scheduler.step(
185
+ noise_pred, t, latents, return_dict=False
186
+ )[0]
187
+
188
+ # Decode latents
189
+ latents = latents / self.vae_scale_factor
190
+
191
+ with torch.no_grad():
192
+ image = self.vae.decode(latents).sample
193
+
194
+ # Convert to PIL
195
+ image = (image / 2 + 0.5).clamp(0, 1)
196
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
197
+ image = (image * 255).round().astype("uint8")
198
+ image = Image.fromarray(image[0])
199
+
200
+ return image
201
 
 
 
202
 
203
+ def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
204
+ """Load Lune checkpoint from .pt file."""
205
+ print(f"📥 Downloading checkpoint: {repo_id}/{filename}")
206
+
207
+ checkpoint_path = hf_hub_download(
208
+ repo_id=repo_id,
209
+ filename=filename,
210
+ repo_type="model"
211
+ )
212
+
213
+ print(f"✓ Downloaded to: {checkpoint_path}")
214
+ print(f"📦 Loading checkpoint...")
215
+
216
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
217
+
218
+ # Initialize UNet with SD1.5 config
219
+ print(f"🏗️ Initializing SD1.5 UNet...")
220
+ unet = UNet2DConditionModel.from_pretrained(
221
+ "runwayml/stable-diffusion-v1-5",
222
+ subfolder="unet",
223
+ torch_dtype=torch.float32
224
+ )
225
+
226
+ # Load student weights
227
+ student_state_dict = checkpoint["student"]
228
+
229
+ # Strip "unet." prefix if present
230
+ cleaned_dict = {}
231
+ for key, value in student_state_dict.items():
232
+ if key.startswith("unet."):
233
+ cleaned_dict[key[5:]] = value
234
+ else:
235
+ cleaned_dict[key] = value
236
+
237
+ # Load weights
238
+ unet.load_state_dict(cleaned_dict, strict=False)
239
+
240
+ step = checkpoint.get("gstep", "unknown")
241
+ print(f"✅ Loaded checkpoint from step {step}")
242
+
243
+ return unet.to(device)
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ def initialize_pipeline(model_choice: str, device: str = "cuda"):
247
+ """Initialize the complete pipeline."""
248
+
249
+ print(f"🚀 Initializing {model_choice} pipeline...")
250
+
251
+ # Load base components
252
+ print("Loading VAE...")
253
+ vae = AutoencoderKL.from_pretrained(
254
+ "runwayml/stable-diffusion-v1-5",
255
+ subfolder="vae",
256
+ torch_dtype=torch.float32
257
+ ).to(device)
258
+
259
+ print("Loading text encoder...")
260
+ text_encoder = CLIPTextModel.from_pretrained(
261
+ "openai/clip-vit-large-patch14",
262
+ torch_dtype=torch.float32
263
+ ).to(device)
264
+
265
+ tokenizer = CLIPTokenizer.from_pretrained(
266
+ "openai/clip-vit-large-patch14"
267
+ )
268
+
269
+ # Load UNet based on model choice
270
+ if model_choice == "Flow-Lune (Latest)":
271
+ # Load latest checkpoint from repo
272
+ repo_id = "AbstractPhil/sd15-flow-lune"
273
+ # Find latest checkpoint - for now use a known one
274
+ filename = "sd15_flow_lune_e34_s34000.pt"
275
+ unet = load_lune_checkpoint(repo_id, filename, device)
276
+
277
+ elif model_choice == "SD1.5 Base":
278
+ print("Loading SD1.5 base UNet...")
279
+ unet = UNet2DConditionModel.from_pretrained(
280
+ "runwayml/stable-diffusion-v1-5",
281
+ subfolder="unet",
282
+ torch_dtype=torch.float32
283
+ ).to(device)
284
+
285
+ else:
286
+ raise ValueError(f"Unknown model: {model_choice}")
287
+
288
+ # Initialize scheduler
289
+ scheduler = EulerDiscreteScheduler.from_pretrained(
290
+ "runwayml/stable-diffusion-v1-5",
291
+ subfolder="scheduler"
292
+ )
293
+
294
+ print("✅ Pipeline initialized!")
295
+
296
+ return FlowMatchingPipeline(
297
+ vae=vae,
298
+ text_encoder=text_encoder,
299
+ tokenizer=tokenizer,
300
+ unet=unet,
301
+ scheduler=scheduler,
302
+ device=device
303
+ )
304
 
 
 
 
 
 
 
 
 
305
 
306
+ # ============================================================================
307
+ # GLOBAL STATE
308
+ # ============================================================================
309
 
310
+ # Initialize with None, will load on first inference
311
+ CURRENT_PIPELINE = None
312
+ CURRENT_MODEL = None
313
 
 
 
 
 
 
 
 
314
 
315
+ def get_pipeline(model_choice: str):
316
+ """Get or create pipeline for selected model."""
317
+ global CURRENT_PIPELINE, CURRENT_MODEL
318
+
319
+ if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
320
+ CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda")
321
+ CURRENT_MODEL = model_choice
322
+
323
+ return CURRENT_PIPELINE
324
 
 
325
 
326
+ # ============================================================================
327
+ # INFERENCE
328
+ # ============================================================================
 
 
 
 
 
329
 
330
+ def estimate_duration(num_steps: int, width: int, height: int) -> int:
331
+ """Estimate GPU duration based on generation parameters."""
332
+ # Base time per step (seconds)
333
+ base_time_per_step = 0.3
334
+
335
+ # Resolution scaling
336
+ resolution_factor = (width * height) / (512 * 512)
337
+
338
+ # Total estimate
339
+ estimated = num_steps * base_time_per_step * resolution_factor
340
+
341
+ # Add 15 seconds for model loading overhead
342
+ return int(estimated + 15)
343
+
344
+
345
+ @spaces.GPU(duration=lambda *args: estimate_duration(args[3], args[5], args[6]))
346
+ def generate_image(
347
+ prompt: str,
348
+ negative_prompt: str,
349
+ model_choice: str,
350
+ num_steps: int,
351
+ cfg_scale: float,
352
+ width: int,
353
+ height: int,
354
+ shift: float,
355
+ use_flow_matching: bool,
356
+ prediction_type: str,
357
+ seed: int,
358
+ randomize_seed: bool,
359
+ progress=gr.Progress()
360
+ ):
361
+ """Generate image with ZeroGPU support."""
362
+
363
+ # Randomize seed if requested
364
+ if randomize_seed:
365
+ seed = np.random.randint(0, 2**32 - 1)
366
+
367
+ # Progress tracking
368
+ def progress_callback(step, total, desc):
369
+ progress((step + 1) / total, desc=desc)
370
+
371
+ try:
372
+ # Get pipeline
373
+ pipeline = get_pipeline(model_choice)
374
+
375
+ # Generate
376
+ progress(0.05, desc="Starting generation...")
377
+
378
+ image = pipeline(
379
+ prompt=prompt,
380
+ negative_prompt=negative_prompt,
381
+ height=height,
382
+ width=width,
383
+ num_inference_steps=num_steps,
384
+ guidance_scale=cfg_scale,
385
+ shift=shift,
386
+ use_flow_matching=use_flow_matching,
387
+ prediction_type=prediction_type,
388
+ seed=seed,
389
+ progress_callback=progress_callback
390
+ )
391
+
392
+ progress(1.0, desc="Complete!")
393
+
394
+ return image, seed
395
+
396
+ except Exception as e:
397
+ print(f"❌ Generation failed: {e}")
398
+ raise e
399
 
 
 
 
 
 
 
 
 
400
 
401
+ # ============================================================================
402
+ # GRADIO UI
403
+ # ============================================================================
404
+
405
+ def create_demo():
406
+ """Create Gradio interface."""
407
+
408
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
409
+ gr.Markdown("""
410
+ # 🌙 Lyra/Lune Flow-Matching Image Generation
411
+
412
+ **Geometric crystalline diffusion with flow matching** by [AbstractPhil](https://huggingface.co/AbstractPhil)
413
+
414
+ Generate images using SD1.5-based flow matching with pentachoron geometric structures.
415
+ Achieves high quality with dramatically reduced step counts through geometric efficiency.
416
+ """)
417
+
418
+ with gr.Row():
419
+ with gr.Column(scale=1):
420
+ # Prompt
421
+ prompt = gr.TextArea(
422
+ label="Prompt",
423
+ placeholder="A beautiful landscape with mountains and a lake at sunset...",
424
+ lines=3
425
+ )
426
+
427
+ negative_prompt = gr.TextArea(
428
+ label="Negative Prompt",
429
+ placeholder="blurry, low quality, distorted...",
430
+ lines=2
431
+ )
432
+
433
+ # Model selection
434
+ model_choice = gr.Dropdown(
435
+ label="Model",
436
+ choices=[
437
+ "Flow-Lune (Latest)",
438
+ "SD1.5 Base"
439
+ ],
440
+ value="Flow-Lune (Latest)"
441
+ )
442
+
443
+ # Flow matching settings
444
+ with gr.Accordion("Flow Matching Settings", open=True):
445
+ use_flow_matching = gr.Checkbox(
446
+ label="Enable Flow Matching",
447
+ value=True,
448
+ info="Use flow matching ODE integration"
449
+ )
450
+
451
+ shift = gr.Slider(
452
+ label="Shift",
453
+ minimum=0.0,
454
+ maximum=5.0,
455
+ value=2.5,
456
+ step=0.1,
457
+ info="Flow matching shift parameter (0=disabled, 1-3 typical)"
458
+ )
459
+
460
+ prediction_type = gr.Radio(
461
+ label="Prediction Type",
462
+ choices=["epsilon", "v_prediction"],
463
+ value="epsilon",
464
+ info="Type of model prediction"
465
+ )
466
+
467
+ # Generation settings
468
+ with gr.Accordion("Generation Settings", open=True):
469
+ num_steps = gr.Slider(
470
+ label="Steps",
471
+ minimum=1,
472
+ maximum=50,
473
+ value=20,
474
+ step=1,
475
+ info="Flow matching typically needs fewer steps (15-25)"
476
+ )
477
+
478
+ cfg_scale = gr.Slider(
479
+ label="CFG Scale",
480
+ minimum=1.0,
481
+ maximum=20.0,
482
+ value=7.5,
483
+ step=0.5
484
+ )
485
+
486
+ with gr.Row():
487
+ width = gr.Slider(
488
+ label="Width",
489
+ minimum=256,
490
+ maximum=1024,
491
+ value=512,
492
+ step=64
493
+ )
494
+
495
+ height = gr.Slider(
496
+ label="Height",
497
+ minimum=256,
498
+ maximum=1024,
499
+ value=512,
500
+ step=64
501
+ )
502
+
503
+ seed = gr.Slider(
504
+ label="Seed",
505
+ minimum=0,
506
+ maximum=2**32 - 1,
507
+ value=42,
508
+ step=1
509
+ )
510
+
511
+ randomize_seed = gr.Checkbox(
512
+ label="Randomize Seed",
513
+ value=True
514
+ )
515
+
516
+ generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
517
+
518
+ with gr.Column(scale=1):
519
+ output_image = gr.Image(
520
+ label="Generated Image",
521
+ type="pil"
522
  )
523
+
524
+ output_seed = gr.Number(
525
+ label="Used Seed",
526
+ precision=0
527
+ )
528
+
529
+ gr.Markdown("""
530
+ ### Tips:
531
+ - **Flow matching** works best with 15-25 steps (vs 50+ for standard diffusion)
532
+ - **Shift** controls the flow trajectory (2.0-2.5 recommended for Lune)
533
+ - Lower shift = more direct path, higher shift = more exploration
534
+ - Try **v_prediction** mode if epsilon gives unstable results
535
+
536
+ ### Model Info:
537
+ - **Flow-Lune**: Trained with flow matching on 500k SD1.5 distillation pairs
538
+ - **SD1.5 Base**: Standard Stable Diffusion 1.5 for comparison
539
+
540
+ [📚 Learn more about geometric deep learning](https://github.com/AbstractEyes/lattice_vocabulary)
541
+ """)
542
+
543
+ # Examples
544
+ gr.Examples(
545
+ examples=[
546
+ [
547
+ "A serene mountain landscape at golden hour, crystal clear lake reflecting snow-capped peaks, photorealistic, 8k",
548
+ "blurry, low quality",
549
+ "Flow-Lune (Latest)",
550
+ 20,
551
+ 7.5,
552
+ 512,
553
+ 512,
554
+ 2.5,
555
+ True,
556
+ "epsilon",
557
+ 42,
558
+ False
559
+ ],
560
+ [
561
+ "A futuristic cyberpunk city at night, neon lights, rain-slicked streets, highly detailed",
562
+ "low quality, blurry",
563
+ "Flow-Lune (Latest)",
564
+ 22,
565
+ 8.0,
566
+ 512,
567
+ 512,
568
+ 2.5,
569
+ True,
570
+ "epsilon",
571
+ 123,
572
+ False
573
+ ],
574
+ [
575
+ "Portrait of a majestic lion, golden mane, dramatic lighting, wildlife photography",
576
+ "cartoon, painting",
577
+ "Flow-Lune (Latest)",
578
+ 18,
579
+ 7.0,
580
+ 512,
581
+ 512,
582
+ 2.0,
583
+ True,
584
+ "epsilon",
585
+ 456,
586
+ False
587
+ ]
588
+ ],
589
+ inputs=[
590
+ prompt, negative_prompt, model_choice, num_steps, cfg_scale,
591
+ width, height, shift, use_flow_matching, prediction_type,
592
+ seed, randomize_seed
593
+ ],
594
+ outputs=[output_image, output_seed],
595
+ fn=generate_image,
596
+ cache_examples=False
597
+ )
598
+
599
+ # Event handlers
600
+ generate_btn.click(
601
+ fn=generate_image,
602
+ inputs=[
603
+ prompt, negative_prompt, model_choice, num_steps, cfg_scale,
604
+ width, height, shift, use_flow_matching, prediction_type,
605
+ seed, randomize_seed
606
+ ],
607
+ outputs=[output_image, output_seed]
608
+ )
609
+
610
+ return demo
611
 
612
+
613
+ # ============================================================================
614
+ # LAUNCH
615
+ # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
  if __name__ == "__main__":
618
+ demo = create_demo()
619
+ demo.queue(max_size=20)
620
+ demo.launch(show_api=False)