update app
Browse files
app.py
CHANGED
|
@@ -363,44 +363,48 @@ def update_selection(evt: gr.SelectData, width, height):
|
|
| 363 |
|
| 364 |
@spaces.GPU
|
| 365 |
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
selected_lora = loras[selected_index]
|
| 370 |
-
lora_path = selected_lora["repo"]
|
| 371 |
-
trigger_word = selected_lora["trigger_word"]
|
| 372 |
|
| 373 |
-
if
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
else:
|
| 378 |
-
prompt_mash = f"{
|
| 379 |
else:
|
| 380 |
-
prompt_mash =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
else:
|
|
|
|
|
|
|
| 382 |
prompt_mash = prompt
|
| 383 |
-
|
| 384 |
-
# Unload previous LoRAs to start fresh
|
| 385 |
-
with calculateDuration("Unloading LoRA"):
|
| 386 |
-
pipe.unload_lora_weights()
|
| 387 |
|
| 388 |
-
# LoRA weights flow
|
| 389 |
-
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
| 390 |
-
weight_name = selected_lora.get("weights", None)
|
| 391 |
-
try:
|
| 392 |
-
pipe.load_lora_weights(
|
| 393 |
-
lora_path,
|
| 394 |
-
weight_name=weight_name,
|
| 395 |
-
adapter_name="default",
|
| 396 |
-
low_cpu_mem_usage=True
|
| 397 |
-
)
|
| 398 |
-
# Set adapter scale
|
| 399 |
-
pipe.set_adapters(["default"], adapter_weights=[lora_scale])
|
| 400 |
-
except Exception as e:
|
| 401 |
-
print(f"Error loading LoRA: {e}")
|
| 402 |
-
gr.Warning("Failed to load LoRA weights. Generating with base model.")
|
| 403 |
-
|
| 404 |
with calculateDuration("Randomizing seed"):
|
| 405 |
if randomize_seed:
|
| 406 |
seed = random.randint(0, MAX_SEED)
|
|
@@ -412,10 +416,6 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
| 412 |
|
| 413 |
with calculateDuration("Generating image"):
|
| 414 |
# For Turbo models, guidance_scale is typically 0.0
|
| 415 |
-
# The user interface passes cfg_scale, but we override or warn if needed.
|
| 416 |
-
# However, for flexibility, if the user explicitly sets it, we might check,
|
| 417 |
-
# but the reference strongly suggests 0.0 for Turbo.
|
| 418 |
-
|
| 419 |
forced_guidance = 0.0 # Turbo mode
|
| 420 |
|
| 421 |
final_image = pipe(
|
|
@@ -536,12 +536,12 @@ with gr.Blocks(delete_cache=(60, 60)) as demo:
|
|
| 536 |
selected_index = gr.State(None)
|
| 537 |
with gr.Row():
|
| 538 |
with gr.Column(scale=3):
|
| 539 |
-
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt ")
|
| 540 |
with gr.Column(scale=1, elem_id="gen_column"):
|
| 541 |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
| 542 |
with gr.Row():
|
| 543 |
with gr.Column():
|
| 544 |
-
selected_info = gr.Markdown("")
|
| 545 |
gallery = gr.Gallery(
|
| 546 |
[(item["image"], item["title"]) for item in loras],
|
| 547 |
label="Z-Image LoRAs",
|
|
@@ -556,7 +556,7 @@ with gr.Blocks(delete_cache=(60, 60)) as demo:
|
|
| 556 |
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
|
| 557 |
with gr.Column():
|
| 558 |
progress_bar = gr.Markdown(elem_id="progress",visible=False)
|
| 559 |
-
result = gr.Image(label="Generated Image", format="png"
|
| 560 |
|
| 561 |
with gr.Row():
|
| 562 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
| 363 |
|
| 364 |
@spaces.GPU
|
| 365 |
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
| 366 |
+
# Clean up previous LoRAs in both cases
|
| 367 |
+
with calculateDuration("Unloading LoRA"):
|
| 368 |
+
pipe.unload_lora_weights()
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
+
# Check if a LoRA is selected
|
| 371 |
+
if selected_index is not None and selected_index < len(loras):
|
| 372 |
+
selected_lora = loras[selected_index]
|
| 373 |
+
lora_path = selected_lora["repo"]
|
| 374 |
+
trigger_word = selected_lora["trigger_word"]
|
| 375 |
+
|
| 376 |
+
# Prepare Prompt with Trigger Word
|
| 377 |
+
if(trigger_word):
|
| 378 |
+
if "trigger_position" in selected_lora:
|
| 379 |
+
if selected_lora["trigger_position"] == "prepend":
|
| 380 |
+
prompt_mash = f"{trigger_word} {prompt}"
|
| 381 |
+
else:
|
| 382 |
+
prompt_mash = f"{prompt} {trigger_word}"
|
| 383 |
else:
|
| 384 |
+
prompt_mash = f"{trigger_word} {prompt}"
|
| 385 |
else:
|
| 386 |
+
prompt_mash = prompt
|
| 387 |
+
|
| 388 |
+
# Load LoRA
|
| 389 |
+
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
| 390 |
+
weight_name = selected_lora.get("weights", None)
|
| 391 |
+
try:
|
| 392 |
+
pipe.load_lora_weights(
|
| 393 |
+
lora_path,
|
| 394 |
+
weight_name=weight_name,
|
| 395 |
+
adapter_name="default",
|
| 396 |
+
low_cpu_mem_usage=True
|
| 397 |
+
)
|
| 398 |
+
# Set adapter scale
|
| 399 |
+
pipe.set_adapters(["default"], adapter_weights=[lora_scale])
|
| 400 |
+
except Exception as e:
|
| 401 |
+
print(f"Error loading LoRA: {e}")
|
| 402 |
+
gr.Warning("Failed to load LoRA weights. Generating with base model.")
|
| 403 |
else:
|
| 404 |
+
# Base Model Case
|
| 405 |
+
print("No LoRA selected. Running with Base Model.")
|
| 406 |
prompt_mash = prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
with calculateDuration("Randomizing seed"):
|
| 409 |
if randomize_seed:
|
| 410 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
| 416 |
|
| 417 |
with calculateDuration("Generating image"):
|
| 418 |
# For Turbo models, guidance_scale is typically 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
forced_guidance = 0.0 # Turbo mode
|
| 420 |
|
| 421 |
final_image = pipe(
|
|
|
|
| 536 |
selected_index = gr.State(None)
|
| 537 |
with gr.Row():
|
| 538 |
with gr.Column(scale=3):
|
| 539 |
+
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt (or leave LoRA unselected for Base Model)")
|
| 540 |
with gr.Column(scale=1, elem_id="gen_column"):
|
| 541 |
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
|
| 542 |
with gr.Row():
|
| 543 |
with gr.Column():
|
| 544 |
+
selected_info = gr.Markdown("### No LoRA Selected (Base Model)")
|
| 545 |
gallery = gr.Gallery(
|
| 546 |
[(item["image"], item["title"]) for item in loras],
|
| 547 |
label="Z-Image LoRAs",
|
|
|
|
| 556 |
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
|
| 557 |
with gr.Column():
|
| 558 |
progress_bar = gr.Markdown(elem_id="progress",visible=False)
|
| 559 |
+
result = gr.Image(label="Generated Image", format="png")
|
| 560 |
|
| 561 |
with gr.Row():
|
| 562 |
with gr.Accordion("Advanced Settings", open=False):
|