XCLiu commited on
Commit
8c3d9a8
·
1 Parent(s): 7dab0a6

Delete sd_models.py

Browse files
Files changed (1) hide show
  1. sd_models.py +0 -239
sd_models.py DELETED
@@ -1,239 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
-
16
- import argparse
17
- import logging
18
- import math
19
- import os
20
- import random
21
- from pathlib import Path
22
- from typing import Optional, Union, List, Callable
23
-
24
- import datasets
25
- import numpy as np
26
- import torch
27
- import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
- import transformers
30
- from datasets import load_dataset
31
- from huggingface_hub import HfFolder, Repository, create_repo, whoami
32
- from packaging import version
33
- from torchvision import transforms
34
- from tqdm.auto import tqdm
35
- from transformers import CLIPTextModel, CLIPTokenizer
36
-
37
- import diffusers
38
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel#, StackUNet2DConditionModel
39
- from diffusers.optimization import get_scheduler
40
- from diffusers.training_utils import EMAModel
41
- from diffusers.utils import check_min_version, deprecate
42
- from diffusers.utils.import_utils import is_xformers_available
43
-
44
- import time
45
-
46
- from torch.distributions import Normal, Categorical
47
- from torch.distributions.multivariate_normal import MultivariateNormal
48
- from torch.distributions.mixture_same_family import MixtureSameFamily
49
-
50
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
51
- import torchvision
52
-
53
- import cv2
54
-
55
- def inference_latent(
56
- pipeline,
57
- prompt: Union[str, List[str]],
58
- height: Optional[int] = None,
59
- width: Optional[int] = None,
60
- num_inference_steps: int = 50,
61
- guidance_scale: float = 7.5,
62
- negative_prompt: Optional[Union[str, List[str]]] = None,
63
- num_images_per_prompt: Optional[int] = 1,
64
- eta: float = 0.0,
65
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
66
- latents: Optional[torch.FloatTensor] = None,
67
- output_type: Optional[str] = "pil",
68
- return_dict: bool = True,
69
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
70
- callback_steps: Optional[int] = 1,
71
- ):
72
-
73
- # 0. Default height and width to unet
74
- height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
75
- width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
76
-
77
- # 1. Check inputs. Raise error if not correct
78
- #pipeline.check_inputs(prompt, height, width, callback_steps)
79
-
80
- # 2. Define call parameters
81
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
82
- device = pipeline._execution_device
83
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
84
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
85
- # corresponds to doing no classifier free guidance.
86
- do_classifier_free_guidance = guidance_scale > 1.0
87
-
88
- # 3. Encode input prompt
89
- #setup_seed(0)
90
- text_embeddings = pipeline._encode_prompt(
91
- prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
92
- )
93
-
94
- # 4. Prepare timesteps
95
- pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
96
- timesteps = pipeline.scheduler.timesteps
97
-
98
- # 5. Prepare latent variables
99
- num_channels_latents = pipeline.unet.in_channels
100
- latents = latents.reshape(1, num_channels_latents, 64, 64)
101
-
102
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
103
- extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
104
-
105
- # 7. Denoising loop
106
- num_warmup_steps = len(timesteps) - \
107
- num_inference_steps * pipeline.scheduler.order
108
-
109
- latents_cllt = [latents.detach().clone()]
110
- with torch.no_grad():
111
- for i, t in enumerate(timesteps):
112
- # expand the latents if we are doing classifier free guidance
113
- latent_model_input = torch.cat(
114
- [latents] * 2) if do_classifier_free_guidance else latents
115
- latent_model_input = pipeline.scheduler.scale_model_input(
116
- latent_model_input, t)
117
-
118
- noise_pred = pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
119
-
120
- # perform guidance
121
- if do_classifier_free_guidance:
122
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
123
- noise_pred = noise_pred_uncond + guidance_scale * \
124
- (noise_pred_text - noise_pred_uncond)
125
-
126
- # compute the previous noisy sample x_t -> x_t-1
127
- outputs = pipeline.scheduler.step(
128
- noise_pred, t, latents, **extra_step_kwargs)
129
-
130
- latents = outputs.prev_sample
131
-
132
-
133
- example = {
134
- 'latent': latents.detach().clone(),
135
- 'text_embeddings': text_embeddings.chunk(2)[1].detach() if do_classifier_free_guidance else text_embeddings.detach(),
136
- }
137
- return example
138
-
139
-
140
-
141
- def setup_seed(seed):
142
- import random
143
- torch.manual_seed(seed)
144
- torch.cuda.manual_seed_all(seed)
145
- np.random.seed(seed)
146
- random.seed(seed)
147
- torch.backends.cudnn.benchmark = False
148
- torch.backends.cudnn.deterministic = True
149
- torch.cuda.empty_cache()
150
-
151
-
152
- class SD_model():
153
-
154
- def __init__(self, pretrained_model_name_or_path):
155
- self.pretrained_model_name_or_path = pretrained_model_name_or_path
156
-
157
- # Load scheduler, tokenizer and models.
158
- noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
159
- tokenizer = CLIPTokenizer.from_pretrained(
160
- self.pretrained_model_name_or_path, subfolder="tokenizer"#, revision=args.revision
161
- )
162
- text_encoder = CLIPTextModel.from_pretrained(
163
- self.pretrained_model_name_or_path, subfolder="text_encoder"#, revision=args.revision
164
- )
165
- vae = AutoencoderKL.from_pretrained(
166
- self.pretrained_model_name_or_path, subfolder="vae"#, revision=args.revision
167
- )
168
- unet = UNet2DConditionModel.from_pretrained(
169
- self.pretrained_model_name_or_path, subfolder="unet"#, revision=args.non_ema_revision
170
- )
171
-
172
-
173
- unet.eval()
174
- vae.eval()
175
- text_encoder.eval()
176
-
177
- # Freeze vae and text_encoder
178
- vae.requires_grad_(False)
179
- text_encoder.requires_grad_(False)
180
- unet.requires_grad_(False)
181
-
182
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
183
- # as these models are only used for inference, keeping weights in full precision is not required.
184
- weight_dtype = torch.float16
185
- self.weight_dtype = weight_dtype
186
- device = 'cuda'
187
- self.device = device
188
-
189
- # Move text_encode and vae to gpu and cast to weight_dtype
190
- text_encoder.to(device, dtype=weight_dtype)
191
- vae.to(device, dtype=weight_dtype)
192
- unet.to(device, dtype=weight_dtype)
193
-
194
- # Create the pipeline using the trained modules and save it.
195
- pipeline = StableDiffusionPipeline.from_pretrained(
196
- self.pretrained_model_name_or_path,
197
- text_encoder=text_encoder,
198
- vae=vae,
199
- unet=unet,
200
- torch_dtype=weight_dtype,
201
- )
202
- pipeline = pipeline.to(device)
203
- from diffusers import DPMSolverMultistepScheduler
204
-
205
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
206
- self.pipeline = pipeline
207
-
208
- def set_new_latent_and_generate_new_image(self, seed=None, prompt=None, negative_prompt="", num_inference_steps=25, guidance_scale=5.0):
209
- if seed is None:
210
- assert False, "Must have a pre-defined random seed"
211
-
212
- if prompt is None:
213
- assert False, "Must have a user-specified text prompt"
214
-
215
- setup_seed(seed)
216
- self.latents = torch.randn((1, 4*64*64), device=self.device).to(dtype=self.weight_dtype)
217
- self.prompt = prompt
218
- self.negative_prompt = negative_prompt
219
- self.guidance_scale = guidance_scale
220
- self.num_inference_steps = num_inference_steps
221
-
222
- prompts = [prompt]
223
- negative_prompts = [negative_prompt]
224
-
225
- output = inference_latent(
226
- self.pipeline,
227
- prompt=prompts,
228
- negative_prompt=negative_prompts,
229
- num_inference_steps=num_inference_steps,
230
- guidance_scale=self.guidance_scale,
231
- latents=self.latents.detach().clone(),
232
- )
233
-
234
- image = self.pipeline.decode_latents(output['latent'])
235
-
236
- self.org_image = image
237
-
238
- return image
239
-