GGPENG commited on
Commit
7e1dd12
·
verified ·
1 Parent(s): 57fc54c

Upload 2 files

Browse files
Files changed (2) hide show
  1. Custom_Diffusion.py +243 -0
  2. requirements.txt +0 -0
Custom_Diffusion.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import safetensors #
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import transformers
10
+ from accelerate import Accelerator
11
+ from accelerate.logging import get_logger
12
+ from accelerate.utils import ProjectConfiguration, set_seed
13
+ from PIL import Image
14
+ from torch.utils.data import Dataset
15
+ from torchvision import transforms
16
+ from tqdm.auto import tqdm
17
+ from transformers import AutoTokenizer, CLIPTextModel
18
+ from safetensors.torch import load_file
19
+
20
+ import diffusers
21
+ # from diffusers.pipelines import BlipDiffusionPipeline
22
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DiffusionPipeline
23
+ from diffusers.loaders import AttnProcsLayers
24
+ from diffusers.models.attention_processor import CustomDiffusionAttnProcessor
25
+ from diffusers.optimization import get_scheduler
26
+ from diffusers.utils import load_image
27
+ import streamlit as st
28
+
29
+ import io
30
+
31
+ import streamlit as st # 用于创建交互式网页UI
32
+ import io # 处理文件流(后面用来生成下载按钮)
33
+
34
+ # 设置页面标题和布局
35
+ st.set_page_config(page_title="Fine-tuning style diffusion", layout="wide")
36
+
37
+ st.title("Fine-tuning style diffusion 推理 Demo")
38
+
39
+ st.write("支持 **A <new1> reference.(风格) + 文本*")
40
+
41
+ st.write("只是训练了一个提示词 'A <new1> reference.'")
42
+
43
+ st.write("即使用该提示词时以十二生肖为主要元素进行新年图片风格的生成,例如使用一下提示词")
44
+
45
+ st.write("A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background")
46
+
47
+
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ dtype = torch.float16
50
+
51
+
52
+ # ==========================
53
+ # 模型加载(缓存)
54
+ # ==========================
55
+
56
+ @st.cache_resource
57
+ def load_models():
58
+
59
+ model_path = "./stable-diffusion-v1-5"
60
+
61
+ tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
62
+
63
+ text_encoder = CLIPTextModel.from_pretrained(
64
+ model_path,
65
+ subfolder="text_encoder",
66
+ torch_dtype=torch.float16
67
+ ).to(device)
68
+
69
+ vae = AutoencoderKL.from_pretrained(
70
+ model_path,
71
+ subfolder="vae",
72
+ torch_dtype=torch.float16
73
+ ).to(device)
74
+
75
+ unet = UNet2DConditionModel.from_pretrained(
76
+ model_path,
77
+ subfolder="unet",
78
+ torch_dtype=torch.float16
79
+ ).to(device)
80
+
81
+ attn_path = "output/pytorch_custom_diffusion_weights.bin"
82
+
83
+ state_dict = torch.load(attn_path, map_location="cpu")
84
+ unet.load_attn_procs(state_dict)
85
+
86
+ token_path = "output/learned_embeds.safetensors"
87
+
88
+
89
+ try:
90
+
91
+ new_embed = torch.load(token_path)
92
+
93
+ token_id = tokenizer.convert_tokens_to_ids("<new1>")
94
+
95
+ text_encoder.get_input_embeddings().weight.data[token_id] = new_embed
96
+
97
+ print("Loaded <new1> token embedding")
98
+
99
+ except:
100
+ print("No trained <new1> token found")
101
+
102
+ scheduler = DDPMScheduler.from_pretrained(
103
+ model_path,
104
+ subfolder="scheduler"
105
+ )
106
+
107
+ unet.enable_xformers_memory_efficient_attention()
108
+
109
+ return tokenizer, text_encoder, vae, unet, scheduler
110
+
111
+ tokenizer, text_encoder, vae, unet, scheduler = load_models()
112
+
113
+
114
+ prompt = st.text_input(
115
+ "Prompt",
116
+ "A <new1> reference."
117
+ )
118
+
119
+ # 调整参数
120
+ steps = st.slider("Steps", 10, 320, 100)
121
+
122
+ guidance = st.slider("Guidance", 1.0, 18.0, 6.0)
123
+
124
+
125
+ # ==========================
126
+ # 图像预处理
127
+ # ==========================
128
+
129
+ def preprocess(image):
130
+ # 调整图像,转换为tensor(张量)并归一化到[-1,1]
131
+ transform = transforms.Compose([
132
+ transforms.Resize((512,512)),
133
+ transforms.ToTensor(),
134
+ transforms.Normalize([0.5],[0.5])
135
+ ])
136
+ # 增加batch维度
137
+ return transform(image).unsqueeze(0)
138
+
139
+
140
+ # ==========================
141
+ # diffusion 推理
142
+ # ==========================
143
+
144
+ def generate(prompt):
145
+
146
+ with torch.no_grad():
147
+ # 文本向量化
148
+ text_input = tokenizer(
149
+ prompt,
150
+ padding="max_length",
151
+ max_length=tokenizer.model_max_length,
152
+ truncation=True,
153
+ return_tensors="pt"
154
+ ).to(device)
155
+
156
+ text_emb = text_encoder(text_input.input_ids)[0]
157
+
158
+ # 无条件 embedding;
159
+ uncond_input = tokenizer(
160
+ "",
161
+ padding="max_length",
162
+ max_length=tokenizer.model_max_length,
163
+ return_tensors="pt"
164
+ ).to(device)
165
+
166
+ uncond_emb = text_encoder(uncond_input.input_ids)[0]
167
+
168
+
169
+ text_emb = torch.cat([uncond_emb, text_emb], dim=0)
170
+
171
+ # 初始化噪声潜变量
172
+ latents = torch.randn(
173
+ (1,4,64,64),
174
+ device=device,
175
+ dtype=torch.float16
176
+ )
177
+
178
+ # 设置diffusion时间步
179
+ scheduler.set_timesteps(steps)
180
+
181
+ # ----------------
182
+ # diffusion loop
183
+ # ----------------
184
+ # 采用
185
+ for t in scheduler.timesteps:
186
+ # 为什么要拼接两份
187
+ latent_model_input = torch.cat([latents]*2)
188
+
189
+ noise_pred = unet(
190
+ latent_model_input,
191
+ t,
192
+ encoder_hidden_states=text_emb
193
+ ).sample
194
+
195
+
196
+ noise_uncond, noise_text = noise_pred.chunk(2)
197
+
198
+ noise_pred = noise_uncond + guidance * (
199
+ noise_text - noise_uncond
200
+ )
201
+ # 调度程序/潜在的
202
+ latents = scheduler.step(
203
+ noise_pred,
204
+ t,
205
+ latents
206
+ ).prev_sample
207
+
208
+
209
+ # ----------------
210
+ # decode image;解码图像
211
+ # ----------------
212
+ # 解码生成图像;将latent解码成[0,1]的RGB图像
213
+ latents = latents / vae.config.scaling_factor
214
+
215
+ image = vae.decode(latents).sample
216
+
217
+ image = (image/2 + 0.5).clamp(0,1)
218
+ # 转成numpy数组,再用PIL转成可展示的图像
219
+ image = image.cpu().permute(0,2,3,1).numpy()[0]
220
+
221
+ image = (image*255).astype(np.uint8)
222
+
223
+ return Image.fromarray(image)
224
+
225
+
226
+ if st.button("Generate"):
227
+
228
+ with st.spinner("Generating..."):
229
+
230
+ image = generate(prompt)
231
+
232
+ st.image(image,caption="Result",width=512)
233
+
234
+ buf = io.BytesIO()
235
+
236
+ image.save(buf,format="PNG")
237
+
238
+ st.download_button(
239
+ "Download",
240
+ buf.getvalue(),
241
+ "result.png"
242
+ )
243
+
requirements.txt ADDED
Binary file (6.14 kB). View file