Spaces:
Sleeping
Sleeping
File size: 2,103 Bytes
55e61c2 6cc9580 a97d94f 6cc9580 a97d94f e4c36ac 6cc9580 4a9d074 2d1eccd 55e61c2 435643f 2d1eccd e4c36ac 435643f e4c36ac 55e61c2 2d1eccd 1b886a8 2d1eccd 7c1adba 590639f 6cc9580 590639f 6cc9580 55e61c2 9597fb4 c9dcd2c 6cc9580 435643f 9597fb4 c9dcd2c 6cc9580 c9dcd2c 6cc9580 55e61c2 c3737ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | # app.py
import torch
import gradio as gr
from diffusers import DiffusionPipeline
print("Loading pipeline...")
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16 if device=="cuda" else torch.float32,
cache_dir="/tmp/huggingface",
use_safetensors=True,
safety_checker=None
)
# pipe.safety_checker = None
# device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
pipe.enable_attention_slicing()
# pipe.enable_model_cpu_offload()
# pipe.unet.load_attn_procs(
# "./pytorch_custom_diffusion_weights.bin"
# )
# attn_path = "./pytorch_custom_diffusion_weights.bin"
# state_dict = torch.load(attn_path, map_location="cpu")
# pipe.unet.load_attn_procs(
# state_dict
# )
import os
print(os.path.getsize("pytorch_custom_diffusion_weights.bin"))
pipe.unet.load_attn_procs(
"./pytorch_custom_diffusion_weights.bin",
weight_name="pytorch_custom_diffusion_weights.bin"
)
print("Pipeline loaded")
# def generate(prompt, steps, guidance):
# print("Generating...")
# image = pipe(
# prompt,
# num_inference_steps=steps,
# guidance_scale=guidance,
# eta=1
# ).images[0]
# print("Done")
# return image
def generate(prompt, steps, guidance):
print("Generating...")
result = pipe(
prompt,
num_inference_steps=int(steps),
guidance_scale=float(guidance)
)
print("RESULT:", result)
image = result.images[0]
return image
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(
label="Prompt",
value="A <new1> reference. New Year image with a rabbit as the main element"
),
gr.Slider(10, 320, value=100, label="Steps"),
gr.Slider(1, 18, value=6, label="Guidance"),
],
outputs=gr.Image(),
title="Fine-tuning style diffusion Demo"
)
demo.launch(
server_name="0.0.0.0",
server_port=7860
) |