rahul7star commited on
Commit
8310b47
·
verified ·
1 Parent(s): e68deed

Update app_lora1.py

Browse files
Files changed (1) hide show
  1. app_lora1.py +111 -132
app_lora1.py CHANGED
@@ -2,11 +2,9 @@ import spaces
2
  import os
3
  import io
4
  import torch
5
- from PIL import Image
6
-
7
  import gradio as gr
8
  import requests
9
- from diffusers import DiffusionPipeline
10
 
11
  # =========================================================
12
  # CONFIG
@@ -21,10 +19,10 @@ MODEL_ID = "Tongyi-MAI/Z-Image-Turbo"
21
  os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True)
22
 
23
  # =========================================================
24
- # GLOBAL STATE
25
  # =========================================================
26
- SCRIPT_CODE = {}
27
- PIPELINES = {}
28
  log_buffer = io.StringIO()
29
 
30
 
@@ -39,37 +37,49 @@ def log(msg):
39
  def pipeline_technology_info(pipe):
40
  tech = []
41
 
 
42
  if hasattr(pipe, "hf_device_map"):
43
  tech.append("Device map: enabled")
44
  else:
45
  tech.append(f"Device: {pipe.device}")
46
 
 
47
  if hasattr(pipe, "transformer"):
48
  try:
49
  tech.append(f"Transformer dtype: {pipe.transformer.dtype}")
50
  except Exception:
51
  pass
 
 
52
  if hasattr(pipe.transformer, "layerwise_casting"):
53
  lw = pipe.transformer.layerwise_casting
54
- tech.append(f"Layerwise casting: storage={lw.storage_dtype}, compute={lw.compute_dtype}")
 
 
55
 
 
56
  if hasattr(pipe, "vae"):
57
  try:
58
  tech.append(f"VAE dtype: {pipe.vae.dtype}")
59
  except Exception:
60
  pass
61
 
 
62
  if hasattr(pipe, "quantization_config"):
63
  tech.append(f"Quantization: {pipe.quantization_config}")
64
 
 
65
  if hasattr(pipe, "config"):
66
- attn = pipe.config.get("attn_implementation", None)
67
  if attn:
68
  tech.append(f"Attention: {attn}")
69
 
70
  return "\n".join(f"• {t}" for t in tech)
71
 
72
 
 
 
 
73
  def pipeline_debug_info(pipe):
74
  return f"""
75
  Pipeline Info
@@ -87,40 +97,7 @@ def latent_shape_info(height, width, pipe):
87
 
88
 
89
  # =========================================================
90
- # PIPELINE FEATURE REGISTRATION HELPER
91
- # =========================================================
92
- def register_pipeline_feature(pipe, text: str):
93
- if not hasattr(pipe, "_enabled_features"):
94
- pipe._enabled_features = []
95
- pipe._enabled_features.append(text)
96
-
97
-
98
- # =========================================================
99
- # WRAPPER TO LOG ANY METHOD CALL ON PIPE OR TRANSFORMER
100
- # =========================================================
101
- def log_pipe_calls(obj, obj_name="pipe"):
102
- for attr_name in dir(obj):
103
- if attr_name.startswith("_"):
104
- continue
105
- attr = getattr(obj, attr_name)
106
-
107
- # Skip non-callables or torch modules
108
- if not callable(attr) or isinstance(attr, torch.nn.Module):
109
- continue
110
-
111
- def make_wrapper(f, name):
112
- def wrapper(*args, **kwargs):
113
- log(f"• {obj_name}.{name} called with args={args}, kwargs={kwargs}")
114
- return f(*args, **kwargs)
115
- return wrapper
116
-
117
- setattr(obj, attr_name, make_wrapper(attr, attr_name))
118
- return obj
119
-
120
-
121
-
122
- # =========================================================
123
- # DOWNLOAD SCRIPTS
124
  # =========================================================
125
  def download_scripts():
126
  resp = requests.get(SCRIPTS_REPO_API)
@@ -143,55 +120,70 @@ SCRIPT_NAMES = download_scripts()
143
 
144
 
145
  # =========================================================
146
- # REGISTER SCRIPTS
147
  # =========================================================
148
  def register_scripts(selected_scripts):
149
  SCRIPT_CODE.clear()
 
150
  for name in selected_scripts:
151
  path = os.path.join(LOCAL_SCRIPTS_DIR, name)
152
  with open(path, "r") as f:
153
  SCRIPT_CODE[name] = f.read()
 
154
  return f"{len(SCRIPT_CODE)} script(s) registered ✅"
155
 
156
 
157
  # =========================================================
158
- # BUILD PIPELINE (GPU)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # =========================================================
160
  def get_pipeline(script_name):
161
  if script_name in PIPELINES:
162
  return PIPELINES[script_name]
163
 
164
  log(f"🔧 Building pipeline from {script_name}")
 
165
 
166
- namespace = {
167
- "__file__": script_name,
168
- "__name__": "__main__",
169
- "torch": torch,
170
- "register_pipeline_feature": register_pipeline_feature,
171
- "log_pipe_calls": log_pipe_calls,
172
- }
173
 
174
- try:
175
- exec(SCRIPT_CODE[script_name], namespace)
176
- except Exception as e:
177
- log(f"❌ Script failed: {script_name}")
178
- raise RuntimeError(f"Pipeline build failed for {script_name}") from e
179
-
180
- if "pipe" not in namespace:
181
- raise RuntimeError(f"{script_name} did not define `pipe`.")
182
 
183
- pipe = namespace["pipe"]
 
 
 
 
 
184
 
185
- # Wrap transformer and pipe to log method calls (post-pretrained modifications)
186
- if hasattr(pipe, "transformer"):
187
- pipe.transformer = log_pipe_calls(pipe.transformer, "pipe.transformer")
188
- pipe = log_pipe_calls(pipe, "pipe")
189
 
190
-
 
 
191
 
192
  PIPELINES[script_name] = pipe
193
- log(f"✅ Pipeline ready: {script_name}")
194
-
195
  return pipe
196
 
197
 
@@ -216,20 +208,9 @@ def generate_image(
216
  raise RuntimeError("Pipeline not registered")
217
 
218
  pipe = get_pipeline(pipeline_name)
219
-
220
- if not hasattr(pipe, "hf_device_map"):
221
- pipe = pipe.to("cuda")
222
 
223
  log("=== PIPELINE TECHNOLOGY ===")
224
  log(pipeline_technology_info(pipe))
225
-
226
- log("=== PIPELINE FEATURES ===")
227
- if hasattr(pipe, "_enabled_features"):
228
- for f in pipe._enabled_features:
229
- log(f"✔ {f}")
230
- else:
231
- log("✔ No explicit pipeline features registered")
232
-
233
  log("=== NEW GENERATION REQUEST ===")
234
  log(f"Pipeline: {pipeline_name}")
235
  log(f"Prompt: {prompt}")
@@ -246,6 +227,7 @@ def generate_image(
246
  num_images = min(max(1, int(num_images)), 3)
247
  generator = torch.Generator("cuda").manual_seed(int(seed))
248
 
 
249
  result = pipe(
250
  prompt=prompt,
251
  height=int(height),
@@ -258,82 +240,79 @@ def generate_image(
258
  output_type="pil",
259
  )
260
 
261
- # Optional: scale down very large images for UI display
262
- max_display_size = 1024
263
- fixed_images = []
264
- for img in result.images:
265
- if isinstance(img, Image.Image):
266
- w, h = img.size
267
- scale = min(max_display_size / max(w, h), 1.0)
268
- if scale < 1.0:
269
- img = img.resize((int(w * scale), int(h * scale)), Image.BICUBIC)
270
- fixed_images.append(img)
271
-
272
  try:
273
  log(pipeline_debug_info(pipe))
274
  log(latent_shape_info(height, width, pipe))
275
  except Exception as e:
276
  log(f"Diagnostics error: {e}")
277
 
278
- log("Generation complete")
279
-
280
- return fixed_images, seed, log_buffer.getvalue()
281
 
282
 
283
  # =========================================================
284
- # GRADIO UI
285
  # =========================================================
286
- with gr.Blocks(title="Z-Image Turbo ZeroGPU") as demo:
287
- gr.Markdown("## Z-Image Turbo (Script-Driven · ZeroGPU Safe)")
288
-
289
- # ------------------ Scripts selection ------------------
290
- script_selector = gr.CheckboxGroup(
291
- choices=SCRIPT_NAMES,
292
- label="Select pipeline scripts",
293
- )
294
- register_btn = gr.Button("Register Scripts")
295
- status = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  register_btn.click(
298
  register_scripts,
299
  inputs=[script_selector],
300
- outputs=[status],
301
- )
302
-
303
- # ------------------ Pipeline dropdown ------------------
304
- pipeline_picker = gr.Dropdown(
305
- choices=[],
306
- label="Active Pipeline",
307
  )
308
 
309
  register_btn.click(
310
  lambda s: gr.update(choices=s, value=s[0] if s else None),
311
  inputs=[script_selector],
312
- outputs=[pipeline_picker],
313
  )
314
 
315
- gr.Markdown("---")
316
-
317
- # ------------------ Prompt + sliders ------------------
318
- prompt = gr.Textbox(label="Prompt", lines=3)
319
- height = gr.Slider(256, 1024, 512, step=64, label="Height")
320
- width = gr.Slider(256, 1024, 512, step=64, label="Width")
321
- steps = gr.Slider(1, 8, 4, step=1, label="Inference Steps")
322
- images = gr.Slider(1, 3, 1, step=1, label="Images")
323
- seed = gr.Number(value=0, label="Seed")
324
- random_seed = gr.Checkbox(value=True, label="Randomize Seed")
325
-
326
- run_btn = gr.Button("Generate")
327
-
328
- # ------------------ Outputs ------------------
329
- gallery = gr.Gallery(columns=3, height=512, object_fit="contain")
330
- used_seed = gr.Number(label="Used Seed")
331
- logs = gr.Textbox(lines=25, label="Logs")
332
-
333
- run_btn.click(
334
  generate_image,
335
- inputs=[prompt, height, width, steps, seed, random_seed, images, pipeline_picker],
336
- outputs=[gallery, used_seed, logs],
337
  )
338
 
339
  demo.queue()
 
2
  import os
3
  import io
4
  import torch
 
 
5
  import gradio as gr
6
  import requests
7
+ from diffusers import DiffusionPipeline, ZImagePipeline
8
 
9
  # =========================================================
10
  # CONFIG
 
19
  os.makedirs(LOCAL_SCRIPTS_DIR, exist_ok=True)
20
 
21
  # =========================================================
22
+ # GLOBAL STATE (CPU SAFE)
23
  # =========================================================
24
+ SCRIPT_CODE = {} # script_name -> code (CPU only)
25
+ PIPELINES = {} # script_name -> pipeline (GPU only, lazy)
26
  log_buffer = io.StringIO()
27
 
28
 
 
37
  def pipeline_technology_info(pipe):
38
  tech = []
39
 
40
+ # Device map
41
  if hasattr(pipe, "hf_device_map"):
42
  tech.append("Device map: enabled")
43
  else:
44
  tech.append(f"Device: {pipe.device}")
45
 
46
+ # Transformer dtype
47
  if hasattr(pipe, "transformer"):
48
  try:
49
  tech.append(f"Transformer dtype: {pipe.transformer.dtype}")
50
  except Exception:
51
  pass
52
+
53
+ # Layerwise casting (Z-Image specific)
54
  if hasattr(pipe.transformer, "layerwise_casting"):
55
  lw = pipe.transformer.layerwise_casting
56
+ tech.append(
57
+ f"Layerwise casting: storage={lw.storage_dtype}, compute={lw.compute_dtype}"
58
+ )
59
 
60
+ # VAE dtype
61
  if hasattr(pipe, "vae"):
62
  try:
63
  tech.append(f"VAE dtype: {pipe.vae.dtype}")
64
  except Exception:
65
  pass
66
 
67
+ # Quantization / GGUF
68
  if hasattr(pipe, "quantization_config"):
69
  tech.append(f"Quantization: {pipe.quantization_config}")
70
 
71
+ # Attention backend
72
  if hasattr(pipe, "config"):
73
+ attn = getattr(pipe.config, "attn_implementation", None)
74
  if attn:
75
  tech.append(f"Attention: {attn}")
76
 
77
  return "\n".join(f"• {t}" for t in tech)
78
 
79
 
80
+ # =========================================================
81
+ # LATENT INFO
82
+ # =========================================================
83
  def pipeline_debug_info(pipe):
84
  return f"""
85
  Pipeline Info
 
97
 
98
 
99
  # =========================================================
100
+ # DOWNLOAD SCRIPTS (CPU ONLY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # =========================================================
102
  def download_scripts():
103
  resp = requests.get(SCRIPTS_REPO_API)
 
120
 
121
 
122
  # =========================================================
123
+ # REGISTER SCRIPTS (CPU ONLY)
124
  # =========================================================
125
  def register_scripts(selected_scripts):
126
  SCRIPT_CODE.clear()
127
+
128
  for name in selected_scripts:
129
  path = os.path.join(LOCAL_SCRIPTS_DIR, name)
130
  with open(path, "r") as f:
131
  SCRIPT_CODE[name] = f.read()
132
+
133
  return f"{len(SCRIPT_CODE)} script(s) registered ✅"
134
 
135
 
136
  # =========================================================
137
+ # EXTRACT LINES AFTER FROM_PRETRAINED
138
+ # =========================================================
139
+ def extract_pipe_lines(script_code: str):
140
+ lines = script_code.splitlines()
141
+ pipe_lines = []
142
+ found = False
143
+
144
+ for line in lines:
145
+ stripped = line.strip()
146
+ if not found and stripped.startswith("pipe = ZImagePipeline.from_pretrained"):
147
+ found = True
148
+ pipe_lines.append(line)
149
+ elif found:
150
+ if "pipe" in stripped:
151
+ pipe_lines.append(line)
152
+ return pipe_lines
153
+
154
+
155
+ # =========================================================
156
+ # GPU-ONLY PIPELINE BUILDER
157
  # =========================================================
158
  def get_pipeline(script_name):
159
  if script_name in PIPELINES:
160
  return PIPELINES[script_name]
161
 
162
  log(f"🔧 Building pipeline from {script_name}")
163
+ code = SCRIPT_CODE[script_name]
164
 
165
+ # Extract lines after from_pretrained for logging
166
+ pipe_lines = extract_pipe_lines(code)
 
 
 
 
 
167
 
168
+ # Safe namespace for exec
169
+ namespace = {"torch": torch, "ZImagePipeline": ZImagePipeline}
 
 
 
 
 
 
170
 
171
+ pipe = None
172
+ for line in pipe_lines:
173
+ if line.strip():
174
+ log(f"• {line.strip()}")
175
+ exec(line, namespace)
176
+ pipe = namespace.get("pipe", pipe)
177
 
178
+ if pipe is None:
179
+ raise RuntimeError(f"{script_name} did not define `pipe`.")
 
 
180
 
181
+ # ZeroGPU safe
182
+ if not hasattr(pipe, "hf_device_map"):
183
+ pipe = pipe.to("cuda")
184
 
185
  PIPELINES[script_name] = pipe
186
+ log("✅ Pipeline ready")
 
187
  return pipe
188
 
189
 
 
208
  raise RuntimeError("Pipeline not registered")
209
 
210
  pipe = get_pipeline(pipeline_name)
 
 
 
211
 
212
  log("=== PIPELINE TECHNOLOGY ===")
213
  log(pipeline_technology_info(pipe))
 
 
 
 
 
 
 
 
214
  log("=== NEW GENERATION REQUEST ===")
215
  log(f"Pipeline: {pipeline_name}")
216
  log(f"Prompt: {prompt}")
 
227
  num_images = min(max(1, int(num_images)), 3)
228
  generator = torch.Generator("cuda").manual_seed(int(seed))
229
 
230
+ # Run pipeline
231
  result = pipe(
232
  prompt=prompt,
233
  height=int(height),
 
240
  output_type="pil",
241
  )
242
 
 
 
 
 
 
 
 
 
 
 
 
243
  try:
244
  log(pipeline_debug_info(pipe))
245
  log(latent_shape_info(height, width, pipe))
246
  except Exception as e:
247
  log(f"Diagnostics error: {e}")
248
 
249
+ log("Generation complete")
250
+ return result.images, seed, log_buffer.getvalue()
 
251
 
252
 
253
  # =========================================================
254
+ # GRADIO UI (original layout)
255
  # =========================================================
256
+ with gr.Blocks(title="Z-Image-Turbo Multi Image Demo") as demo:
257
+ gr.Markdown("# 🎨 Z-Image-Turbo Multi Image ")
258
+
259
+ with gr.Row():
260
+ with gr.Column(scale=1):
261
+ script_selector = gr.CheckboxGroup(
262
+ choices=SCRIPT_NAMES,
263
+ label="Select pipeline scripts"
264
+ )
265
+ register_btn = gr.Button("Register Scripts")
266
+ status = gr.Textbox(label="Status", interactive=False)
267
+
268
+ prompt = gr.Textbox(label="Prompt", lines=4)
269
+
270
+ with gr.Row():
271
+ height = gr.Slider(512, 2048, 1024, step=64, label="Height")
272
+ width = gr.Slider(512, 2048, 1024, step=64, label="Width")
273
+
274
+ num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images")
275
+
276
+ num_inference_steps = gr.Slider(
277
+ 1, 20, 9, step=1, label="Inference Steps",
278
+ info="9 steps = 8 DiT forward passes"
279
+ )
280
+
281
+ with gr.Row():
282
+ seed = gr.Number(label="Seed", value=42, precision=0)
283
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
284
+
285
+ generate_btn = gr.Button("🚀 Generate", variant="primary")
286
+
287
+ with gr.Column(scale=1):
288
+ pipeline_picker = gr.Dropdown(
289
+ choices=[],
290
+ label="Active Pipeline",
291
+ )
292
+ output_images = gr.Gallery(label="Generated Images", elem_classes=["gr-gallery-image"], type="pil").style(grid=[2], height="512px")
293
+ used_seed = gr.Number(label="Seed Used", interactive=False)
294
+ debug_log = gr.Textbox(
295
+ label="Debug Log Output",
296
+ lines=25,
297
+ interactive=False
298
+ )
299
 
300
  register_btn.click(
301
  register_scripts,
302
  inputs=[script_selector],
303
+ outputs=[status]
 
 
 
 
 
 
304
  )
305
 
306
  register_btn.click(
307
  lambda s: gr.update(choices=s, value=s[0] if s else None),
308
  inputs=[script_selector],
309
+ outputs=[pipeline_picker]
310
  )
311
 
312
+ generate_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  generate_image,
314
+ inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images, pipeline_picker],
315
+ outputs=[output_images, used_seed, debug_log]
316
  )
317
 
318
  demo.queue()