File size: 2,103 Bytes
55e61c2
6cc9580
 
 
a97d94f
6cc9580
a97d94f
e4c36ac
 
 
 
6cc9580
4a9d074
2d1eccd
 
 
 
55e61c2
435643f
2d1eccd
 
e4c36ac
435643f
e4c36ac
55e61c2
2d1eccd
1b886a8
2d1eccd
 
7c1adba
 
 
 
 
 
 
 
 
 
590639f
 
 
6cc9580
590639f
 
6cc9580
55e61c2
9597fb4
 
c9dcd2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cc9580
435643f
9597fb4
c9dcd2c
 
6cc9580
c9dcd2c
 
 
 
 
 
 
6cc9580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55e61c2
c3737ef
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# app.py
import torch
import gradio as gr
from diffusers import DiffusionPipeline

print("Loading pipeline...")

device = "cuda" if torch.cuda.is_available() else "cpu"

dtype = torch.float16 if device == "cuda" else torch.float32

pipe = DiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if device=="cuda" else torch.float32,
    cache_dir="/tmp/huggingface",
    use_safetensors=True,
    safety_checker=None
)

# pipe.safety_checker = None

# device = "cuda" if torch.cuda.is_available() else "cpu"

pipe.to(device)

pipe.enable_attention_slicing()
# pipe.enable_model_cpu_offload()


# pipe.unet.load_attn_procs(
# "./pytorch_custom_diffusion_weights.bin"
# )

# attn_path = "./pytorch_custom_diffusion_weights.bin"
# state_dict = torch.load(attn_path, map_location="cpu")
# pipe.unet.load_attn_procs(
# state_dict
# ) 

import os
print(os.path.getsize("pytorch_custom_diffusion_weights.bin"))

pipe.unet.load_attn_procs(
    "./pytorch_custom_diffusion_weights.bin",
    weight_name="pytorch_custom_diffusion_weights.bin"
)

print("Pipeline loaded")

# def generate(prompt, steps, guidance):

#     print("Generating...")
    
#     image = pipe(
#         prompt,
#         num_inference_steps=steps,
#         guidance_scale=guidance,
#         eta=1
#     ).images[0]

#     print("Done")
    
#     return image

def generate(prompt, steps, guidance):

    print("Generating...")

    result = pipe(
        prompt,
        num_inference_steps=int(steps),
        guidance_scale=float(guidance)
    )

    print("RESULT:", result)

    image = result.images[0]

    return image


demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(
            label="Prompt",
            value="A <new1> reference. New Year image with a rabbit as the main element"
        ),
        gr.Slider(10, 320, value=100, label="Steps"),
        gr.Slider(1, 18, value=6, label="Guidance"),
    ],
    outputs=gr.Image(),
    title="Fine-tuning style diffusion Demo"
)

demo.launch(
    server_name="0.0.0.0",
    server_port=7860
)