ali2367fdhfe's picture
Update app.py
7083db1 verified
Raw
History Blame Contribute Delete
3.09 kB
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
import os
print("--- FINAL VERSION: ADJUSTING LORA SCALE ---")
# --- Step 1: Define constants ---
BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"
LORA_FILENAME = "MyStickmanProject-10.safetensors"
print(f"Base model ID: {BASE_MODEL_ID}")
print(f"LoRA file: {LORA_FILENAME}")
# --- Step 2: Load the full base model pipeline ---
print("Loading base model pipeline...")
pipe = StableDiffusionPipeline.from_pretrained(
BASE_MODEL_ID,
use_safetensors=True
)
print("Base model loaded successfully.")
# --- Step 3: Load and apply the LoRA weights ---
# This is the standard way to load LoRA weights.
print("Loading LoRA weights...")
pipe.load_lora_weights(LORA_FILENAME)
print("LoRA weights loaded.")
# --- Step 4: Move to CPU and optimize ---
print("Moving pipeline to CPU...")
pipe.to("cpu")
print("Enabling attention slicing...")
pipe.enable_attention_slicing()
print("--- MODEL SETUP COMPLETE ---")
# --- Step 5: Define the core generation function ---
def generate_image(prompt, negative_prompt, guidance_scale, num_steps, lora_scale):
print(f"Generating image with prompt: {prompt}")
# This is the KEY CHANGE: We are now controlling the LoRA strength.
# lora_scale=1.0 means full strength. lora_scale=0.7 means 70% strength.
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=int(num_steps),
guidance_scale=float(guidance_scale),
cross_attention_kwargs={"scale": lora_scale}
).images[0]
print("Image generation complete.")
return image
# --- Step 6: Create the Gradio user interface ---
with gr.Blocks() as iface:
gr.Markdown("# My Custom Stickman LoRA Demo")
gr.Markdown("Enter a prompt to generate an image in my unique style.")
with gr.Row():
with gr.Column(scale=70):
prompt_input = gr.Textbox(label="Prompt", placeholder="A stickman wearing a fedora hat, holding a magnifying glass")
negative_prompt_input = gr.Textbox(label="Negative Prompt", value="(worst quality, low quality:1.4), blurry, noisy, grainy, 3d, realistic, photo")
# We add a new slider to control the LoRA strength!
lora_slider = gr.Slider(minimum=0, maximum=1.5, step=0.05, value=0.75, label="LoRA Strength (Style Intensity)")
with gr.Row():
guidance_scale_slider = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale")
steps_slider = gr.Slider(minimum=10, maximum=100, step=1, value=25, label="Inference Steps")
with gr.Column(scale=30):
image_output = gr.Image(label="Generated Image")
submit_button = gr.Button("Generate", variant="primary")
submit_button.click(
fn=generate_image,
inputs=[prompt_input, negative_prompt_input, guidance_scale_slider, steps_slider, lora_slider],
outputs=image_output
)
# --- Step 7: Launch the application ---
print("Launching Gradio app...")
iface.launch()