Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from safetensors.torch import load_file | |
| import requests | |
| import os | |
| def download_lora(url, save_path="lora.safetensors"): | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(save_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return save_path | |
| return None | |
| def load_lora(model, lora_path): | |
| lora_weights = load_file(lora_path) | |
| model.load_lora_weights(lora_weights) | |
| return model | |
| def generate_image(lora_url, prompt): | |
| base_model = "runwayml/stable-diffusion-v1-5" | |
| pipe = StableDiffusionPipeline.from_pretrained(base_model).to("cpu") | |
| lora_path = download_lora(lora_url) | |
| if lora_path: | |
| pipe = load_lora(pipe, lora_path) | |
| image = pipe(prompt).images[0] | |
| return image | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# LoRA-Enhanced Image Generator") | |
| lora_url = gr.Textbox(label="LoRA Model URL") | |
| prompt = gr.Textbox(label="Prompt") | |
| generate_btn = gr.Button("Generate") | |
| output = gr.Image() | |
| generate_btn.click(generate_image, inputs=[lora_url, prompt], outputs=output) | |
| demo.launch() | |