Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| import spaces | |
| from PIL import Image | |
| from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = "cpu" | |
| print(device) | |
| # Set seed for reproducibility | |
| #seed = 42 | |
| #generator = torch.Generator(device=device) | |
| #generator.manual_seed(seed) | |
| #torch.manual_seed(42) | |
| negative_prompt = "deformed face, extra limbs, extra fingers, out of frame" | |
| # Constants | |
| base = 'stabilityai/stable-diffusion-2' | |
| checkpoints = { | |
| "Mambo Museum" : "adapters/mambo", | |
| "Street Art Museum" : "adapters/street", | |
| "Colonial Museum" : "adapters/colonial", | |
| "Oro Museum" : "adapters/oro", | |
| } | |
| loaded = None | |
| # Ensure model and scheduler are initialized in GPU-enabled function | |
| if torch.cuda.is_available(): | |
| pipe = StableDiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda") | |
| print('initialized cuda model') | |
| else: | |
| pipe = StableDiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16") | |
| # Function | |
| def generate_image(prompt, ckpt): | |
| global loaded | |
| print(prompt, ckpt) | |
| checkpoint = checkpoints[ckpt] | |
| num_inference_steps = 50 | |
| if loaded != num_inference_steps: | |
| #pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon") | |
| pipe.unet.load_attn_procs(checkpoint) | |
| loaded = num_inference_steps | |
| results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=7.5, negative_prompt=negative_prompt)#, generator=generator) | |
| return results.images[0] | |
| # Gradio Interface | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.HTML("<h1><center>Mambo Museum Workshop ⚡</center></h1>") | |
| gr.HTML("<p><center>LoRAs to generate different aesthetics</center></p>") | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox(label='Enter your prompt (English)', scale=8) | |
| ckpt = gr.Dropdown(label='Select LoRA model',choices=['Mambo Museum', 'Street Art Museum', 'Colonial Museum', 'Oro Museum'], value='Mambo Museum', interactive=True) | |
| submit = gr.Button(scale=1, variant='primary') | |
| img = gr.Image(label='SGenerated Image') | |
| prompt.submit(fn=generate_image, | |
| inputs=[prompt, ckpt], | |
| outputs=img, | |
| ) | |
| submit.click(fn=generate_image, | |
| inputs=[prompt, ckpt], | |
| outputs=img, | |
| ) | |
| demo.queue().launch() | |