XCLiu commited on
Commit
7dab0a6
·
1 Parent(s): b8d57c1

Delete rf_models.py

Browse files
Files changed (1) hide show
  1. rf_models.py +0 -249
rf_models.py DELETED
@@ -1,249 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
-
16
- import argparse
17
- import logging
18
- import math
19
- import os
20
- import random
21
- from pathlib import Path
22
- from typing import Optional, Union, List, Callable
23
-
24
- import datasets
25
- import numpy as np
26
- import torch
27
- import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
- import transformers
30
- from datasets import load_dataset
31
- from huggingface_hub import HfFolder, Repository, create_repo, whoami
32
- from packaging import version
33
- from torchvision import transforms
34
- from tqdm.auto import tqdm
35
- from transformers import CLIPTextModel, CLIPTokenizer
36
-
37
- import diffusers
38
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
39
- from diffusers.optimization import get_scheduler
40
- from diffusers.training_utils import EMAModel
41
- from diffusers.utils import check_min_version, deprecate
42
- from diffusers.utils.import_utils import is_xformers_available
43
-
44
- import time
45
-
46
- from torch.distributions import Normal, Categorical
47
- from torch.distributions.multivariate_normal import MultivariateNormal
48
- from torch.distributions.mixture_same_family import MixtureSameFamily
49
-
50
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
51
- import torchvision
52
-
53
- import cv2
54
- import copy
55
-
56
- @torch.no_grad()
57
- def inference_latent_euler(
58
- pipeline,
59
- prompt: Union[str, List[str]],
60
- height: Optional[int] = None,
61
- width: Optional[int] = None,
62
- num_inference_steps: int = 50,
63
- guidance_scale: float = 7.5,
64
- negative_prompt: Optional[Union[str, List[str]]] = None,
65
- num_images_per_prompt: Optional[int] = 1,
66
- eta: float = 0.0,
67
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
68
- latents: Optional[torch.FloatTensor] = None,
69
- output_type: Optional[str] = "pil",
70
- return_dict: bool = True,
71
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
72
- callback_steps: Optional[int] = 1,
73
- ):
74
- # 0. Default height and width to unet
75
- height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
76
- width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
77
-
78
- # 1. Check inputs. Raise error if not correct
79
- pipeline.check_inputs(prompt, height, width, callback_steps)
80
-
81
- # 2. Define call parameters
82
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
83
- device = pipeline._execution_device
84
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
85
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
86
- # corresponds to doing no classifier free guidance.
87
- do_classifier_free_guidance = guidance_scale > 1.0
88
-
89
- # 3. Encode input prompt
90
- t_s = time.time()
91
- text_embeddings = pipeline._encode_prompt(
92
- prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
93
- )
94
- t_e = time.time()
95
- print('Text Embedding Time:', t_e - t_s)
96
-
97
- # 5. Prepare latent variables
98
- num_channels_latents = pipeline.unet.in_channels
99
- latents = pipeline.prepare_latents(
100
- batch_size * num_images_per_prompt,
101
- num_channels_latents,
102
- height,
103
- width,
104
- text_embeddings.dtype,
105
- device,
106
- generator,
107
- latents,
108
- )
109
-
110
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
111
- extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
112
-
113
- # 7. Denoising loop
114
- dt = 1./ num_inference_steps
115
- init_latents = latents.detach().clone()
116
-
117
- for i in range(num_inference_steps):
118
- # expand the latents if we are doing classifier free guidance
119
- latent_model_input = torch.cat(
120
- [latents] * 2) if do_classifier_free_guidance else latents
121
-
122
- vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * (i / num_inference_steps * 1.0)
123
-
124
-
125
- v_pred = pipeline.unet(
126
- latent_model_input, (1.-vec_t) * 1000., encoder_hidden_states=text_embeddings).sample
127
-
128
- # perform guidance
129
- if do_classifier_free_guidance:
130
- v_pred_uncond, v_pred_text = v_pred.chunk(2)
131
- v_pred = v_pred_uncond + guidance_scale * \
132
- (v_pred_text - v_pred_uncond)
133
-
134
- latents = latents + dt * v_pred
135
-
136
- example = {
137
- 'latent': latents.detach(),
138
- 'init_latent': init_latents.detach().clone(),
139
- 'text_embeddings': text_embeddings.chunk(2)[1].detach() if do_classifier_free_guidance else text_embeddings.detach(),
140
- }
141
-
142
- return example
143
-
144
- def setup_seed(seed):
145
- import random
146
- torch.manual_seed(seed)
147
- torch.cuda.manual_seed_all(seed)
148
- np.random.seed(seed)
149
- random.seed(seed)
150
- torch.backends.cudnn.benchmark = False
151
- torch.backends.cudnn.deterministic = True
152
- torch.cuda.empty_cache()
153
-
154
-
155
- class RF_model():
156
-
157
- def __init__(self, model_id):
158
- pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
159
- self.pretrained_model_name_or_path = pretrained_model_name_or_path
160
-
161
- # Load scheduler, tokenizer and models.
162
- noise_scheduler = DDPMScheduler.from_pretrained(self.pretrained_model_name_or_path, subfolder="scheduler")
163
- tokenizer = CLIPTokenizer.from_pretrained(
164
- self.pretrained_model_name_or_path, subfolder="tokenizer"#, revision=args.revision
165
- )
166
- text_encoder = CLIPTextModel.from_pretrained(
167
- self.pretrained_model_name_or_path, subfolder="text_encoder"#, revision=args.revision
168
- )
169
- vae = AutoencoderKL.from_pretrained(
170
- self.pretrained_model_name_or_path, subfolder="vae"#, revision=args.revision
171
- )
172
- unet = UNet2DConditionModel.from_pretrained(
173
- self.pretrained_model_name_or_path, subfolder="unet"#, revision=args.non_ema_revision
174
- )
175
-
176
- print('Loading: Stacked U-Net 0.9B')
177
- unet = UNet2DConditionModel.from_config(unet.config)
178
- unet.load_state_dict(torch.load(model_id, map_location='cpu'))
179
-
180
- unet.eval()
181
- vae.eval()
182
- text_encoder.eval()
183
-
184
- # Freeze vae and text_encoder
185
- vae.requires_grad_(False)
186
- text_encoder.requires_grad_(False)
187
- unet.requires_grad_(False)
188
-
189
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
190
- # as these models are only used for inference, keeping weights in full precision is not required.
191
- weight_dtype = torch.float16
192
- self.weight_dtype = weight_dtype
193
- device = 'cuda'
194
- self.device = device
195
-
196
- # Move text_encode and vae to gpu and cast to weight_dtype
197
- text_encoder.to(device, dtype=weight_dtype)
198
- vae.to(device, dtype=weight_dtype)
199
- unet.to(device, dtype=weight_dtype)
200
-
201
- # Create the pipeline using the trained modules and save it.
202
- pipeline = StableDiffusionPipeline.from_pretrained(
203
- self.pretrained_model_name_or_path,
204
- text_encoder=text_encoder,
205
- vae=vae,
206
- unet=unet,
207
- torch_dtype=weight_dtype,
208
- )
209
- self.pipeline = pipeline.to(device)
210
-
211
- def set_new_latent_and_generate_new_image(self, seed=None, prompt=None, negative_prompt="", num_inference_steps=50, guidance_scale=4.0, verbose=True):
212
- if seed is None:
213
- assert False, "Must have a pre-defined random seed"
214
-
215
- if prompt is None:
216
- assert False, "Must have a user-specified text prompt"
217
-
218
- setup_seed(seed)
219
- self.latents = torch.randn((1, 4, 64, 64), device=self.device).to(dtype=self.weight_dtype)
220
- self.prompt = prompt
221
- self.negative_prompt = negative_prompt
222
- self.guidance_scale = guidance_scale
223
- self.num_inference_steps = num_inference_steps
224
-
225
- prompts = [prompt]
226
- negative_prompts = [negative_prompt]
227
- if verbose:
228
- print(prompts)
229
- print(negative_prompts)
230
-
231
- output = inference_latent_euler(
232
- self.pipeline,
233
- prompt=prompts,
234
- negative_prompt=negative_prompts,
235
- num_inference_steps=num_inference_steps,
236
- guidance_scale=self.guidance_scale,
237
- latents=self.latents.detach().clone(),
238
- )
239
-
240
- t_s = time.time()
241
- image = self.pipeline.decode_latents(output['latent'])
242
- t_e = time.time()
243
- print('Decoding Time:', t_e - t_s)
244
-
245
- self.org_image = image
246
-
247
- return image
248
-
249
-