rahul7star commited on
Commit
3e2d9a0
·
verified ·
1 Parent(s): ba4f365

Update app_lora1.py

Browse files
Files changed (1) hide show
  1. app_lora1.py +173 -271
app_lora1.py CHANGED
@@ -1,316 +1,218 @@
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
- import torch
4
- import os
5
- import traceback
6
- from diffusers import ZImagePipeline
7
- from huggingface_hub import list_repo_files
8
- from PIL import Image
9
-
10
 
11
- # ============================================================
12
  # CONFIG
13
- # ============================================================
14
-
 
 
 
 
15
  MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
16
- DEFAULT_LORA_REPO = "rahul7star/ZImageLora"
17
 
18
- DTYPE = torch.bfloat16
19
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- # ============================================================
22
- # GLOBAL STATE
23
- # ============================================================
24
 
25
- pipe = None
26
- CURRENT_LORA_REPO = None
27
- CURRENT_LORA_FILE = None
28
 
29
- # ============================================================
30
  # LOGGING
31
- # ============================================================
32
-
33
  def log(msg):
34
  print(msg)
35
- return msg
36
 
37
- # ============================================================
38
- # PIPELINE BUILD (ONCE)
39
- # ============================================================
40
 
41
- try:
42
- pipe = ZImagePipeline.from_pretrained(
43
- MODEL_ID,
44
- torch_dtype=DTYPE,
45
- )
 
 
 
 
46
 
47
-
48
 
49
- pipe.to(DEVICE)
50
- log("✅ Pipeline built successfully")
 
 
51
 
52
- except Exception as e:
53
- log("❌ Pipeline build failed")
54
- log(traceback.format_exc())
55
- pipe = None
56
 
 
 
 
 
 
 
57
 
 
 
 
 
 
 
 
 
 
58
 
59
- # ============================================================
60
- # HELPERS
61
- # ============================================================
62
 
63
- def list_loras_from_repo(repo_id: str):
64
- try:
65
- files = list_repo_files(repo_id)
66
- return [f for f in files if f.endswith(".safetensors")]
67
- except Exception as e:
68
- log(f"❌ Failed to list LoRAs: {e}")
69
- return []
70
 
71
- # ============================================================
72
- # IMAGE GENERATION (SAFE LORA LOGIC)
73
- # ============================================================
74
- @spaces.GPU()
75
- def generate_image(prompt, height, width, steps, seed, guidance_scale):
76
- LOGS = []
77
- print(prompt)
78
 
79
- if pipe is None:
80
- return None, [], "❌ Pipeline not initialized"
81
 
82
- generator = torch.Generator().manual_seed(int(seed))
83
- placeholder = Image.new("RGB", (width, height), (255, 255, 255))
84
- previews = []
 
 
85
 
86
- # ---- Always start clean ----
87
- try:
88
- pipe.unload_lora_weights()
89
- except Exception:
90
- pass
91
-
92
- # ---- Load LoRA for this run only ----
93
- if CURRENT_LORA_FILE:
94
- try:
95
- pipe.load_lora_weights(
96
- CURRENT_LORA_REPO,
97
- weight_name=CURRENT_LORA_FILE
98
- )
99
- LOGS.append(f"🧩 LoRA loaded: {CURRENT_LORA_FILE}")
100
- except Exception as e:
101
- LOGS.append(f"❌ LoRA load failed: {e}")
102
-
103
- # ---- Preview steps (lightweight) ----
104
- try:
105
- num_previews = min(5, steps)
106
- for i in range(num_previews):
107
- out = pipe(
108
- prompt=prompt,
109
- height=height // 4,
110
- width=width // 4,
111
- num_inference_steps=i + 1,
112
- guidance_scale=guidance_scale,
113
- generator=generator,
114
- )
115
- img = out.images[0].resize((width, height))
116
- previews.append(img)
117
- yield None, previews, "\n".join(LOGS)
118
- except Exception as e:
119
- LOGS.append(f"⚠️ Preview failed: {e}")
120
 
121
- # ---- Final image ----
122
- try:
123
- out = pipe(
124
- prompt=prompt,
125
- height=height,
126
- width=width,
127
- num_inference_steps=steps,
128
- guidance_scale=guidance_scale,
129
- generator=generator,
130
- )
131
- final_img = out.images[0]
132
- previews.append(final_img)
133
- LOGS.append("✅ Image generated")
134
 
135
- yield final_img, previews, "\n".join(LOGS)
 
136
 
137
- except Exception as e:
138
- LOGS.append(f"❌ Generation failed: {e}")
139
- yield placeholder, previews, "\n".join(LOGS)
140
-
141
- finally:
142
- # ---- CRITICAL: unload after run ----
143
- try:
144
- pipe.unload_lora_weights()
145
- LOGS.append("🧹 LoRA unloaded")
146
- except Exception:
147
- pass
148
-
149
- # ============================================================
150
- # GRADIO UI
151
- # ============================================================
152
- css = """
153
- .gradio-container {
154
- max-width: 100% !important;
155
- padding: 16px 32px !important;
156
- }
157
-
158
- .section {
159
- margin-bottom: 12px;
160
- }
161
-
162
- .generate-btn {
163
- background: linear-gradient(90deg, #4b6cb7, #182848) !important;
164
- color: white !important;
165
- font-weight: 600;
166
- height: 46px;
167
- border-radius: 10px;
168
- }
169
-
170
- .secondary-btn {
171
- height: 42px;
172
- border-radius: 10px;
173
- }
174
-
175
- textarea, input {
176
- border-radius: 10px !important;
177
- }
178
- """
179
 
180
- with gr.Blocks(
181
- title="Z-Image-Turbo (Runtime LoRA)",
182
- css=css,
183
- ) as demo:
184
 
185
- gr.Markdown(
186
- """
187
- # 🎨 Z-Image-Turbo LORA
188
- **Runtime LoRA · Safe Mode · Full-Width UI**
189
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
191
 
192
- # ======================================================
193
- # MAIN LAYOUT
194
- # ======================================================
195
- with gr.Row():
 
 
 
196
 
197
- # ================= LEFT PANEL =================
198
- with gr.Column(scale=5):
199
-
200
- # -------- Prompt --------
201
- prompt = gr.Textbox(
202
- label="Prompt",
203
- value="boat in ocean",
204
- lines=4,
205
- placeholder="Describe the image you want to generate…",
206
- )
207
-
208
- # -------- LoRA Controls (NEXT TO PROMPT) --------
209
- gr.Markdown("### 🧩 LoRA Controls")
210
-
211
- lora_repo = gr.Textbox(
212
- label="LoRA Repository",
213
- value=DEFAULT_LORA_REPO,
214
- lines=2,
215
- placeholder="username/repo (e.g. rahul7star/ZImageLora)",
216
- )
217
-
218
- lora_dropdown = gr.Dropdown(
219
- label="LoRA File",
220
- choices=[],
221
- interactive=True,
222
- )
223
-
224
- with gr.Row():
225
- refresh_btn = gr.Button("🔄 Refresh LoRA List", elem_classes="secondary-btn")
226
- clear_lora_btn = gr.Button("❌ Clear LoRA", elem_classes="secondary-btn")
227
-
228
- # -------- Generation Controls --------
229
- gr.Markdown("### ⚙️ Generation Settings")
230
-
231
- with gr.Row():
232
- width = gr.Slider(256, 2048, value=1024, step=8, label="Width")
233
- height = gr.Slider(256, 2048, value=1024, step=8, label="Height")
234
-
235
- with gr.Row():
236
- steps = gr.Slider(1, 50, value=20, step=1, label="Steps")
237
- guidance = gr.Slider(0, 10, value=0.0, step=0.5, label="Guidance")
238
- seed = gr.Number(value=42, label="Seed", precision=0)
239
-
240
- run_btn = gr.Button("🚀 Generate Image", elem_classes="generate-btn")
241
-
242
- logs_box = gr.Textbox(
243
- label="Logs",
244
- lines=10,
245
- interactive=False,
246
- )
247
-
248
- # ================= RIGHT PANEL =================
249
- with gr.Column(scale=7):
250
-
251
- final_image = gr.Image(
252
- label="Final Image",
253
- height=520,
254
- )
255
-
256
- gallery = gr.Gallery(
257
- label="Generation Steps",
258
- columns=4,
259
- height=260,
260
- )
261
-
262
- # ======================================================
263
- # CALLBACKS
264
- # ======================================================
265
- def refresh_loras(repo):
266
- files = list_loras_from_repo(repo)
267
- return gr.update(
268
- choices=files,
269
- value=files[0] if files else None,
270
  )
271
 
272
- refresh_btn.click(
273
- refresh_loras,
274
- inputs=[lora_repo],
275
- outputs=[lora_dropdown],
 
 
 
276
  )
277
 
278
- def select_lora(lora_file, repo):
279
- global CURRENT_LORA_FILE, CURRENT_LORA_REPO
280
- CURRENT_LORA_FILE = lora_file
281
- CURRENT_LORA_REPO = repo
282
- return f"🧩 Selected LoRA: {lora_file}"
283
 
284
- lora_dropdown.change(
285
- select_lora,
286
- inputs=[lora_dropdown, lora_repo],
287
- outputs=[logs_box],
288
- )
289
 
290
- def clear_lora():
291
- global CURRENT_LORA_FILE, CURRENT_LORA_REPO
292
- CURRENT_LORA_FILE = None
293
- CURRENT_LORA_REPO = None
294
- try:
295
- pipe.unload_lora_weights()
296
- except Exception:
297
- pass
298
- return (
299
- gr.update(value=None),
300
- "🧹 LoRA cleared — base model will be used."
301
- )
302
 
303
- clear_lora_btn.click(
304
- clear_lora,
305
- outputs=[lora_dropdown, logs_box],
306
- )
307
 
308
- run_btn.click(
309
- generate_image,
310
- inputs=[prompt, height, width, steps, seed, guidance],
311
- outputs=[final_image, gallery, logs_box],
312
- )
313
 
 
 
 
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
 
316
  demo.launch()
 
1
+ import os
2
+ import io
3
+ import sys
4
+ import json
5
+ import torch
6
  import spaces
7
  import gradio as gr
8
+ import requests
9
+ from diffusers import DiffusionPipeline
 
 
 
 
 
10
 
11
+ # =========================
12
  # CONFIG
13
+ # =========================
14
+ SCRIPTS_REPO_API = (
15
+ "https://api.github.com/repos/asomoza/diffusers-recipes/contents/"
16
+ "models/z-image/scripts"
17
+ )
18
+ LOCAL_SCRIPTS_DIR = "z_image_scripts"
19
  MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
 
20
 
21
+ os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True)
 
22
 
23
+ pipe = None # lazy-loaded pipeline
24
+ log_buffer = io.StringIO()
 
25
 
 
 
 
26
 
27
+ # =========================
28
  # LOGGING
29
+ # =========================
 
30
  def log(msg):
31
  print(msg)
32
+ log_buffer.write(msg + "\n")
33
 
 
 
 
34
 
35
+ def pipeline_debug_info(pipe):
36
+ return f"""
37
+ Pipeline Info
38
+ -------------
39
+ Device: {pipe.device}
40
+ DType: {pipe.unet.dtype if hasattr(pipe, "unet") else "n/a"}
41
+ Transformer: {pipe.transformer.__class__.__name__}
42
+ VAE: {pipe.vae.__class__.__name__}
43
+ """
44
 
 
45
 
46
+ def latent_shape_info(height, width, pipe):
47
+ h = height // pipe.vae_scale_factor
48
+ w = width // pipe.vae_scale_factor
49
+ return f"Expected latent size: ({h}, {w})"
50
 
 
 
 
 
51
 
52
+ # =========================
53
+ # SCRIPT DOWNLOADER
54
+ # =========================
55
+ def download_scripts():
56
+ resp = requests.get(SCRIPTS_REPO_API)
57
+ resp.raise_for_status()
58
 
59
+ scripts = []
60
+ for item in resp.json():
61
+ if item["name"].endswith(".py"):
62
+ scripts.append(item["name"])
63
+ path = os.path.join(LOCAL_SCRIPTS_DIR, item["name"])
64
+ if not os.path.exists(path):
65
+ content = requests.get(item["download_url"]).text
66
+ with open(path, "w") as f:
67
+ f.write(content)
68
 
69
+ return sorted(scripts)
 
 
70
 
 
 
 
 
 
 
 
71
 
72
+ SCRIPT_NAMES = download_scripts()
 
 
 
 
 
 
73
 
 
 
74
 
75
+ # =========================
76
+ # PIPELINE BUILDER
77
+ # =========================
78
+ def build_pipeline(selected_scripts):
79
+ global pipe
80
 
81
+ log("Building Z-Image Turbo pipeline...")
82
+ log(f"Selected scripts: {selected_scripts}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ pipe = DiffusionPipeline.from_pretrained(
85
+ MODEL_ID,
86
+ torch_dtype=torch.bfloat16,
87
+ low_cpu_mem_usage=False,
88
+ attn_implementation="kernels-community/vllm-flash-attn3",
89
+ )
 
 
 
 
 
 
 
90
 
91
+ pipe.to("cuda")
92
+ log("Pipeline loaded and moved to CUDA")
93
 
94
+ return "Pipeline ready ✅"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
 
 
 
96
 
97
+ # =========================
98
+ # IMAGE GENERATION
99
+ # =========================
100
+ @spaces.GPU
101
+ def generate_image(
102
+ prompt,
103
+ height,
104
+ width,
105
+ num_inference_steps,
106
+ seed,
107
+ randomize_seed,
108
+ num_images,
109
+ ):
110
+ global pipe
111
+
112
+ log_buffer.truncate(0)
113
+ log_buffer.seek(0)
114
+
115
+ if pipe is None:
116
+ raise RuntimeError("Pipeline not built yet")
117
+
118
+ log("=== NEW GENERATION REQUEST ===")
119
+ log(f"Prompt: {prompt}")
120
+ log(f"Height: {height}, Width: {width}")
121
+ log(f"Inference Steps: {num_inference_steps}")
122
+ log(f"Num Images: {num_images}")
123
+
124
+ if randomize_seed:
125
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
126
+ log(f"Randomized Seed → {seed}")
127
+ else:
128
+ log(f"Seed: {seed}")
129
+
130
+ num_images = min(max(1, int(num_images)), 3)
131
+
132
+ log(pipeline_debug_info(pipe))
133
+
134
+ generator = torch.Generator("cuda").manual_seed(int(seed))
135
+
136
+ log("Running pipeline forward()...")
137
+ result = pipe(
138
+ prompt=prompt,
139
+ height=int(height),
140
+ width=int(width),
141
+ num_inference_steps=int(num_inference_steps),
142
+ guidance_scale=0.0,
143
+ generator=generator,
144
+ max_sequence_length=1024,
145
+ num_images_per_prompt=num_images,
146
+ output_type="pil",
147
  )
148
 
149
+ try:
150
+ log(f"VAE latent channels: {pipe.vae.config.latent_channels}")
151
+ log(f"VAE scaling factor: {pipe.vae.config.scaling_factor}")
152
+ log(f"Transformer latent size: {pipe.transformer.config.sample_size}")
153
+ log(latent_shape_info(height, width, pipe))
154
+ except Exception as e:
155
+ log(f"Latent diagnostics error: {e}")
156
 
157
+ log("Pipeline finished.")
158
+ return result.images, seed, log_buffer.getvalue()
159
+
160
+
161
+ # =========================
162
+ # GRADIO UI
163
+ # =========================
164
+ with gr.Blocks(title="Z-Image Turbo (ZeroGPU)") as demo:
165
+ gr.Markdown("## Z-Image Turbo ZeroGPU Space")
166
+
167
+ with gr.Row():
168
+ script_selector = gr.CheckboxGroup(
169
+ choices=SCRIPT_NAMES,
170
+ label="Select Z-Image Scripts",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  )
172
 
173
+ build_btn = gr.Button("Build Pipeline")
174
+ build_status = gr.Textbox(label="Status", interactive=False)
175
+
176
+ build_btn.click(
177
+ fn=build_pipeline,
178
+ inputs=[script_selector],
179
+ outputs=[build_status],
180
  )
181
 
182
+ gr.Markdown("---")
 
 
 
 
183
 
184
+ prompt = gr.Textbox(label="Prompt", lines=3)
185
+ with gr.Row():
186
+ height = gr.Slider(256, 1024, value=512, step=64, label="Height")
187
+ width = gr.Slider(256, 1024, value=512, step=64, label="Width")
 
188
 
189
+ with gr.Row():
190
+ steps = gr.Slider(1, 8, value=4, step=1, label="Inference Steps")
191
+ images = gr.Slider(1, 3, value=1, step=1, label="Images")
 
 
 
 
 
 
 
 
 
192
 
193
+ with gr.Row():
194
+ seed = gr.Number(value=0, label="Seed")
195
+ random_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
196
 
197
+ run_btn = gr.Button("Generate")
 
 
 
 
198
 
199
+ gallery = gr.Gallery(label="Results", columns=3)
200
+ seed_out = gr.Number(label="Used Seed")
201
+ logs = gr.Textbox(label="Logs", lines=12)
202
 
203
+ run_btn.click(
204
+ fn=generate_image,
205
+ inputs=[
206
+ prompt,
207
+ height,
208
+ width,
209
+ steps,
210
+ seed,
211
+ random_seed,
212
+ images,
213
+ ],
214
+ outputs=[gallery, seed_out, logs],
215
+ )
216
 
217
+ demo.queue()
218
  demo.launch()