soiz1's picture
Update app.py
872057a verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
import requests
import os
def download_lora(url, save_path="lora.safetensors"):
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return save_path
return None
def load_lora(model, lora_path):
lora_weights = load_file(lora_path)
model.load_lora_weights(lora_weights)
return model
def generate_image(lora_url, prompt):
base_model = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(base_model).to("cpu")
lora_path = download_lora(lora_url)
if lora_path:
pipe = load_lora(pipe, lora_path)
image = pipe(prompt).images[0]
return image
with gr.Blocks() as demo:
gr.Markdown("# LoRA-Enhanced Image Generator")
lora_url = gr.Textbox(label="LoRA Model URL")
prompt = gr.Textbox(label="Prompt")
generate_btn = gr.Button("Generate")
output = gr.Image()
generate_btn.click(generate_image, inputs=[lora_url, prompt], outputs=output)
demo.launch()