prithivMLmods commited on
Commit
1359413
·
verified ·
1 Parent(s): 87ca0be

update app

Browse files
Files changed (1) hide show
  1. app.py +39 -39
app.py CHANGED
@@ -363,44 +363,48 @@ def update_selection(evt: gr.SelectData, width, height):
363
 
364
  @spaces.GPU
365
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
366
- if selected_index is None:
367
- raise gr.Error("You must select a LoRA before proceeding.🧨")
368
-
369
- selected_lora = loras[selected_index]
370
- lora_path = selected_lora["repo"]
371
- trigger_word = selected_lora["trigger_word"]
372
 
373
- if(trigger_word):
374
- if "trigger_position" in selected_lora:
375
- if selected_lora["trigger_position"] == "prepend":
376
- prompt_mash = f"{trigger_word} {prompt}"
 
 
 
 
 
 
 
 
 
377
  else:
378
- prompt_mash = f"{prompt} {trigger_word}"
379
  else:
380
- prompt_mash = f"{trigger_word} {prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  else:
 
 
382
  prompt_mash = prompt
383
-
384
- # Unload previous LoRAs to start fresh
385
- with calculateDuration("Unloading LoRA"):
386
- pipe.unload_lora_weights()
387
 
388
- # LoRA weights flow
389
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
390
- weight_name = selected_lora.get("weights", None)
391
- try:
392
- pipe.load_lora_weights(
393
- lora_path,
394
- weight_name=weight_name,
395
- adapter_name="default",
396
- low_cpu_mem_usage=True
397
- )
398
- # Set adapter scale
399
- pipe.set_adapters(["default"], adapter_weights=[lora_scale])
400
- except Exception as e:
401
- print(f"Error loading LoRA: {e}")
402
- gr.Warning("Failed to load LoRA weights. Generating with base model.")
403
-
404
  with calculateDuration("Randomizing seed"):
405
  if randomize_seed:
406
  seed = random.randint(0, MAX_SEED)
@@ -412,10 +416,6 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
412
 
413
  with calculateDuration("Generating image"):
414
  # For Turbo models, guidance_scale is typically 0.0
415
- # The user interface passes cfg_scale, but we override or warn if needed.
416
- # However, for flexibility, if the user explicitly sets it, we might check,
417
- # but the reference strongly suggests 0.0 for Turbo.
418
-
419
  forced_guidance = 0.0 # Turbo mode
420
 
421
  final_image = pipe(
@@ -536,12 +536,12 @@ with gr.Blocks(delete_cache=(60, 60)) as demo:
536
  selected_index = gr.State(None)
537
  with gr.Row():
538
  with gr.Column(scale=3):
539
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt ")
540
  with gr.Column(scale=1, elem_id="gen_column"):
541
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
542
  with gr.Row():
543
  with gr.Column():
544
- selected_info = gr.Markdown("")
545
  gallery = gr.Gallery(
546
  [(item["image"], item["title"]) for item in loras],
547
  label="Z-Image LoRAs",
@@ -556,7 +556,7 @@ with gr.Blocks(delete_cache=(60, 60)) as demo:
556
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
557
  with gr.Column():
558
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
559
- result = gr.Image(label="Generated Image", format="png", height=600)
560
 
561
  with gr.Row():
562
  with gr.Accordion("Advanced Settings", open=False):
 
363
 
364
  @spaces.GPU
365
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
366
+ # Clean up previous LoRAs in both cases
367
+ with calculateDuration("Unloading LoRA"):
368
+ pipe.unload_lora_weights()
 
 
 
369
 
370
+ # Check if a LoRA is selected
371
+ if selected_index is not None and selected_index < len(loras):
372
+ selected_lora = loras[selected_index]
373
+ lora_path = selected_lora["repo"]
374
+ trigger_word = selected_lora["trigger_word"]
375
+
376
+ # Prepare Prompt with Trigger Word
377
+ if(trigger_word):
378
+ if "trigger_position" in selected_lora:
379
+ if selected_lora["trigger_position"] == "prepend":
380
+ prompt_mash = f"{trigger_word} {prompt}"
381
+ else:
382
+ prompt_mash = f"{prompt} {trigger_word}"
383
  else:
384
+ prompt_mash = f"{trigger_word} {prompt}"
385
  else:
386
+ prompt_mash = prompt
387
+
388
+ # Load LoRA
389
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
390
+ weight_name = selected_lora.get("weights", None)
391
+ try:
392
+ pipe.load_lora_weights(
393
+ lora_path,
394
+ weight_name=weight_name,
395
+ adapter_name="default",
396
+ low_cpu_mem_usage=True
397
+ )
398
+ # Set adapter scale
399
+ pipe.set_adapters(["default"], adapter_weights=[lora_scale])
400
+ except Exception as e:
401
+ print(f"Error loading LoRA: {e}")
402
+ gr.Warning("Failed to load LoRA weights. Generating with base model.")
403
  else:
404
+ # Base Model Case
405
+ print("No LoRA selected. Running with Base Model.")
406
  prompt_mash = prompt
 
 
 
 
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  with calculateDuration("Randomizing seed"):
409
  if randomize_seed:
410
  seed = random.randint(0, MAX_SEED)
 
416
 
417
  with calculateDuration("Generating image"):
418
  # For Turbo models, guidance_scale is typically 0.0
 
 
 
 
419
  forced_guidance = 0.0 # Turbo mode
420
 
421
  final_image = pipe(
 
536
  selected_index = gr.State(None)
537
  with gr.Row():
538
  with gr.Column(scale=3):
539
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt (or leave LoRA unselected for Base Model)")
540
  with gr.Column(scale=1, elem_id="gen_column"):
541
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
542
  with gr.Row():
543
  with gr.Column():
544
+ selected_info = gr.Markdown("### No LoRA Selected (Base Model)")
545
  gallery = gr.Gallery(
546
  [(item["image"], item["title"]) for item in loras],
547
  label="Z-Image LoRAs",
 
556
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
557
  with gr.Column():
558
  progress_bar = gr.Markdown(elem_id="progress",visible=False)
559
+ result = gr.Image(label="Generated Image", format="png")
560
 
561
  with gr.Row():
562
  with gr.Accordion("Advanced Settings", open=False):