prithivMLmods commited on
Commit
ad59edd
·
verified ·
1 Parent(s): 734cdb7

update app [.]

Browse files
Files changed (1) hide show
  1. app.py +83 -100
app.py CHANGED
@@ -15,8 +15,8 @@ import gradio as gr
15
 
16
  from diffusers import (
17
  DiffusionPipeline,
18
- AutoPipelineForImage2Image,
19
- FlowMatchEulerDiscreteScheduler
20
  )
21
 
22
  from huggingface_hub import (
@@ -30,7 +30,6 @@ from typing import Iterable
30
  from gradio.themes import Soft
31
  from gradio.themes.utils import colors, fonts, sizes
32
 
33
- # --- THEME DEFINITION ---
34
  colors.steel_blue = colors.Color(
35
  name="steel_blue",
36
  c50="#EBF3F8",
@@ -99,7 +98,6 @@ class SteelBlueTheme(Soft):
99
 
100
  steel_blue_theme = SteelBlueTheme()
101
 
102
- # --- LORA DEFINITIONS ---
103
  loras = [
104
  {
105
  "image": "https://huggingface.co/Shakker-Labs/AWPortrait-Z/resolve/main/images/example.png",
@@ -117,23 +115,30 @@ loras = [
117
  },
118
  ]
119
 
120
- # --- MODEL LOADING ---
121
  dtype = torch.bfloat16
122
  device = "cuda" if torch.cuda.is_available() else "cpu"
123
  base_model = "Tongyi-MAI/Z-Image-Turbo"
124
 
125
- print(f"Loading {base_model}...")
126
 
127
- pipe = DiffusionPipeline.from_pretrained(
 
128
  base_model,
129
  torch_dtype=dtype,
130
- )
131
- pipe.to(device)
132
 
133
- # Initialize Image-to-Image pipeline sharing components with the main pipeline
134
- pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe)
 
 
 
 
 
 
 
135
 
136
- MAX_SEED = 2**32-1
137
 
138
  class calculateDuration:
139
  def __init__(self, activity_name=""):
@@ -174,42 +179,6 @@ def update_selection(evt: gr.SelectData, width, height):
174
  height,
175
  )
176
 
177
- @spaces.GPU
178
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
179
- pipe.to("cuda")
180
- generator = torch.Generator(device="cuda").manual_seed(seed)
181
-
182
- with calculateDuration("Generating image"):
183
- image = pipe(
184
- prompt=prompt_mash,
185
- num_inference_steps=steps,
186
- guidance_scale=cfg_scale,
187
- width=width,
188
- height=height,
189
- generator=generator,
190
- joint_attention_kwargs={"scale": lora_scale},
191
- output_type="pil",
192
- ).images[0]
193
- yield image
194
-
195
- def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
196
- generator = torch.Generator(device="cuda").manual_seed(seed)
197
- pipe_i2i.to("cuda")
198
- image_input = load_image(image_input_path)
199
- final_image = pipe_i2i(
200
- prompt=prompt_mash,
201
- image=image_input,
202
- strength=image_strength,
203
- num_inference_steps=steps,
204
- guidance_scale=cfg_scale,
205
- width=width,
206
- height=height,
207
- generator=generator,
208
- joint_attention_kwargs={"scale": lora_scale},
209
- output_type="pil",
210
- ).images[0]
211
- return final_image
212
-
213
  @spaces.GPU
214
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
215
  if selected_index is None:
@@ -230,70 +199,84 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
230
  else:
231
  prompt_mash = prompt
232
 
 
233
  with calculateDuration("Unloading LoRA"):
234
  pipe.unload_lora_weights()
235
- pipe_i2i.unload_lora_weights()
236
 
237
  # LoRA weights flow
238
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
239
- pipe_to_use = pipe_i2i if image_input is not None else pipe
240
  weight_name = selected_lora.get("weights", None)
241
-
242
  try:
243
- pipe_to_use.load_lora_weights(
244
  lora_path,
245
  weight_name=weight_name,
 
246
  low_cpu_mem_usage=True
247
  )
 
 
248
  except Exception as e:
249
  print(f"Error loading LoRA: {e}")
250
- raise gr.Error(f"Failed to load LoRA: {str(e)}")
251
 
252
  with calculateDuration("Randomizing seed"):
253
  if randomize_seed:
254
  seed = random.randint(0, MAX_SEED)
255
-
256
- if(image_input is not None):
257
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
258
- yield final_image, seed, gr.update(visible=False)
259
- else:
260
- # Standard generation
261
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- final_image = None
264
- # We process the generator (even if it yields once)
265
- for image in image_generator:
266
- final_image = image
267
- yield image, seed, gr.update(visible=False)
268
 
269
  def get_huggingface_safetensors(link):
270
- split_link = link.split("/")
271
- if(len(split_link) == 2):
272
- model_card = ModelCard.load(link)
273
- base_model_meta = model_card.data.get("base_model")
274
- print(f"Base model metadata: {base_model_meta}")
275
-
276
- # Note: We relax the check here slightly to allow models compatible with Turbo/Flux
277
- # or we just rely on try/catch during loading.
278
-
279
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
280
- trigger_word = model_card.data.get("instance_prompt", "")
281
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
282
- fs = HfFileSystem()
283
- safetensors_name = None
284
- try:
285
- list_of_files = fs.ls(link, detail=False)
286
- for file in list_of_files:
287
- if(file.endswith(".safetensors")):
288
- safetensors_name = file.split("/")[-1]
289
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
290
- image_elements = file.split("/")
291
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
292
- except Exception as e:
293
- print(e)
294
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
295
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
296
- return split_link[1], link, safetensors_name, trigger_word, image_url
 
 
297
 
298
  def check_custom_model(link):
299
  if(link.startswith("https://")):
@@ -336,8 +319,8 @@ def add_custom_lora(custom_lora):
336
 
337
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
338
  except Exception as e:
339
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or incompatible LoRA")
340
- return gr.update(visible=True, value=f"Invalid LoRA: {str(e)}"), gr.update(visible=False), gr.update(), "", None, ""
341
  else:
342
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
343
 
@@ -365,7 +348,7 @@ css = '''
365
 
366
  with gr.Blocks(delete_cache=(60, 60)) as demo:
367
  title = gr.HTML(
368
- """<h1>Z-Image-Turbo LoRA Studio ⚡</h1>""",
369
  elem_id="title",
370
  )
371
  selected_index = gr.State(None)
@@ -379,14 +362,14 @@ with gr.Blocks(delete_cache=(60, 60)) as demo:
379
  selected_info = gr.Markdown("")
380
  gallery = gr.Gallery(
381
  [(item["image"], item["title"]) for item in loras],
382
- label="Z-Image-Turbo LoRAs",
383
  allow_preview=False,
384
  columns=3,
385
  elem_id="gallery",
386
  )
387
  with gr.Group():
388
  custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="Shakker-Labs/AWPortrait-Z")
389
- gr.Markdown("[Check the list of Z-Image-Turbo LoRA's](https://huggingface.co/models?other=base_model:adapter:Tongyi-MAI/Z-Image-Turbo)", elem_id="lora_list")
390
  custom_lora_info = gr.HTML(visible=False)
391
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
392
  with gr.Column():
@@ -396,11 +379,11 @@ with gr.Blocks(delete_cache=(60, 60)) as demo:
396
  with gr.Row():
397
  with gr.Accordion("Advanced Settings", open=False):
398
  with gr.Row():
399
- input_image = gr.Image(label="Input image", type="filepath")
400
- image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
401
  with gr.Column():
402
  with gr.Row():
403
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
404
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=9)
405
 
406
  with gr.Row():
 
15
 
16
  from diffusers import (
17
  DiffusionPipeline,
18
+ AutoencoderKL,
19
+ ZImagePipeline
20
  )
21
 
22
  from huggingface_hub import (
 
30
  from gradio.themes import Soft
31
  from gradio.themes.utils import colors, fonts, sizes
32
 
 
33
  colors.steel_blue = colors.Color(
34
  name="steel_blue",
35
  c50="#EBF3F8",
 
98
 
99
  steel_blue_theme = SteelBlueTheme()
100
 
 
101
  loras = [
102
  {
103
  "image": "https://huggingface.co/Shakker-Labs/AWPortrait-Z/resolve/main/images/example.png",
 
115
  },
116
  ]
117
 
 
118
  dtype = torch.bfloat16
119
  device = "cuda" if torch.cuda.is_available() else "cpu"
120
  base_model = "Tongyi-MAI/Z-Image-Turbo"
121
 
122
+ print(f"Loading {base_model} pipeline...")
123
 
124
+ # Initialize Pipeline
125
+ pipe = ZImagePipeline.from_pretrained(
126
  base_model,
127
  torch_dtype=dtype,
128
+ low_cpu_mem_usage=False,
129
+ ).to(device)
130
 
131
+ # ======== AoTI compilation + FA3 ========
132
+ # As per reference for optimization
133
+ try:
134
+ print("Applying AoTI compilation and FA3...")
135
+ pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
136
+ spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3")
137
+ print("Optimization applied successfully.")
138
+ except Exception as e:
139
+ print(f"Optimization warning: {e}. Continuing with standard pipeline.")
140
 
141
+ MAX_SEED = np.iinfo(np.int32).max
142
 
143
  class calculateDuration:
144
  def __init__(self, activity_name=""):
 
179
  height,
180
  )
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  @spaces.GPU
183
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
184
  if selected_index is None:
 
199
  else:
200
  prompt_mash = prompt
201
 
202
+ # Unload previous LoRAs to start fresh
203
  with calculateDuration("Unloading LoRA"):
204
  pipe.unload_lora_weights()
 
205
 
206
  # LoRA weights flow
207
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
 
208
  weight_name = selected_lora.get("weights", None)
 
209
  try:
210
+ pipe.load_lora_weights(
211
  lora_path,
212
  weight_name=weight_name,
213
+ adapter_name="default",
214
  low_cpu_mem_usage=True
215
  )
216
+ # Set adapter scale
217
+ pipe.set_adapters(["default"], adapter_weights=[lora_scale])
218
  except Exception as e:
219
  print(f"Error loading LoRA: {e}")
220
+ gr.Warning("Failed to load LoRA weights. Generating with base model.")
221
 
222
  with calculateDuration("Randomizing seed"):
223
  if randomize_seed:
224
  seed = random.randint(0, MAX_SEED)
225
+
226
+ generator = torch.Generator(device=device).manual_seed(seed)
227
+
228
+ # Note: Z-Image-Turbo is strictly T2I in this reference implementation.
229
+ # Img2Img via image_input is disabled/ignored for this pipeline update.
230
+
231
+ with calculateDuration("Generating image"):
232
+ # For Turbo models, guidance_scale is typically 0.0
233
+ # The user interface passes cfg_scale, but we override or warn if needed.
234
+ # However, for flexibility, if the user explicitly sets it, we might check,
235
+ # but the reference strongly suggests 0.0 for Turbo.
236
+
237
+ forced_guidance = 0.0 # Turbo mode
238
+
239
+ final_image = pipe(
240
+ prompt=prompt_mash,
241
+ height=int(height),
242
+ width=int(width),
243
+ num_inference_steps=int(steps),
244
+ guidance_scale=forced_guidance,
245
+ generator=generator,
246
+ ).images[0]
247
 
248
+ yield final_image, seed, gr.update(visible=False)
 
 
 
 
249
 
250
  def get_huggingface_safetensors(link):
251
+ split_link = link.split("/")
252
+ if(len(split_link) == 2):
253
+ model_card = ModelCard.load(link)
254
+ base_model = model_card.data.get("base_model")
255
+ print(base_model)
256
+
257
+ # Relaxed check to allow Z-Image or Flux or others, assuming user knows what they are doing
258
+ # or specifically check for Z-Image-Turbo
259
+ if base_model not in ["Tongyi-MAI/Z-Image-Turbo", "black-forest-labs/FLUX.1-dev"]:
260
+ # Just a warning instead of error to allow experimentation
261
+ print("Warning: Base model might not match.")
262
+
263
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
264
+ trigger_word = model_card.data.get("instance_prompt", "")
265
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
266
+ fs = HfFileSystem()
267
+ try:
268
+ list_of_files = fs.ls(link, detail=False)
269
+ for file in list_of_files:
270
+ if(file.endswith(".safetensors")):
271
+ safetensors_name = file.split("/")[-1]
272
+ if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
273
+ image_elements = file.split("/")
274
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
275
+ except Exception as e:
276
+ print(e)
277
+ gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
278
+ raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
279
+ return split_link[1], link, safetensors_name, trigger_word, image_url
280
 
281
  def check_custom_model(link):
282
  if(link.startswith("https://")):
 
319
 
320
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
321
  except Exception as e:
322
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-supported LoRA")
323
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-supported LoRA"), gr.update(visible=False), gr.update(), "", None, ""
324
  else:
325
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
326
 
 
348
 
349
  with gr.Blocks(delete_cache=(60, 60)) as demo:
350
  title = gr.HTML(
351
+ """<h1>Z-Image-Turbo LoRA DLC⚡</h1>""",
352
  elem_id="title",
353
  )
354
  selected_index = gr.State(None)
 
362
  selected_info = gr.Markdown("")
363
  gallery = gr.Gallery(
364
  [(item["image"], item["title"]) for item in loras],
365
+ label="Z-Image LoRAs",
366
  allow_preview=False,
367
  columns=3,
368
  elem_id="gallery",
369
  )
370
  with gr.Group():
371
  custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="Shakker-Labs/AWPortrait-Z")
372
+ gr.Markdown("[Check the list of Z-Image LoRA's](https://huggingface.co/models?other=base_model:adapter:Tongyi-MAI/Z-Image-Turbo)", elem_id="lora_list")
373
  custom_lora_info = gr.HTML(visible=False)
374
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
375
  with gr.Column():
 
379
  with gr.Row():
380
  with gr.Accordion("Advanced Settings", open=False):
381
  with gr.Row():
382
+ input_image = gr.Image(label="Input image (Ignored for Z-Image-Turbo)", type="filepath", visible=False)
383
+ image_strength = gr.Slider(label="Denoise Strength", info="Ignored for Z-Image-Turbo", minimum=0.1, maximum=1.0, step=0.01, value=0.75, visible=False)
384
  with gr.Column():
385
  with gr.Row():
386
+ cfg_scale = gr.Slider(label="CFG Scale", info="Forced to 0.0 for Turbo", minimum=0, maximum=20, step=0.5, value=0.0, interactive=False)
387
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=9)
388
 
389
  with gr.Row():