Update app.py
Browse files
app.py
CHANGED
|
@@ -57,7 +57,6 @@ def load_models(device):
|
|
| 57 |
unet = UNet2DConditionModel.from_pretrained(
|
| 58 |
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
| 59 |
)
|
| 60 |
-
|
| 61 |
unet.requires_grad_(False)
|
| 62 |
unet.to(device, dtype=weight_dtype)
|
| 63 |
vae.requires_grad_(False)
|
|
@@ -124,7 +123,7 @@ class main():
|
|
| 124 |
self.vae.to(device, dtype=weight_dtype)
|
| 125 |
self.text_encoder.to(device, dtype=weight_dtype)
|
| 126 |
print("")
|
| 127 |
-
|
| 128 |
|
| 129 |
self.network = None
|
| 130 |
|
|
@@ -171,7 +170,8 @@ class main():
|
|
| 171 |
self.thick = thick
|
| 172 |
|
| 173 |
|
| 174 |
-
|
|
|
|
| 175 |
def sample_model(self):
|
| 176 |
self.unet, _, _, _, _ = load_models(self.device)
|
| 177 |
self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
|
@@ -181,6 +181,13 @@ class main():
|
|
| 181 |
@spaces.GPU
|
| 182 |
def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
| 183 |
device = self.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 185 |
latents = torch.randn(
|
| 186 |
(1, self.unet.in_channels, 512 // 8, 512 // 8),
|
|
|
|
| 57 |
unet = UNet2DConditionModel.from_pretrained(
|
| 58 |
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
| 59 |
)
|
|
|
|
| 60 |
unet.requires_grad_(False)
|
| 61 |
unet.to(device, dtype=weight_dtype)
|
| 62 |
vae.requires_grad_(False)
|
|
|
|
| 123 |
self.vae.to(device, dtype=weight_dtype)
|
| 124 |
self.text_encoder.to(device, dtype=weight_dtype)
|
| 125 |
print("")
|
| 126 |
+
|
| 127 |
|
| 128 |
self.network = None
|
| 129 |
|
|
|
|
| 170 |
self.thick = thick
|
| 171 |
|
| 172 |
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
@spaces.GPU
|
| 175 |
def sample_model(self):
|
| 176 |
self.unet, _, _, _, _ = load_models(self.device)
|
| 177 |
self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
|
|
|
| 181 |
@spaces.GPU
|
| 182 |
def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
| 183 |
device = self.device
|
| 184 |
+
self.unet.to(device)
|
| 185 |
+
self.text_encoder.to(device)
|
| 186 |
+
self.vae.to(device)
|
| 187 |
+
self.tokenizer.to(device)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 192 |
latents = torch.randn(
|
| 193 |
(1, self.unet.in_channels, 512 // 8, 512 // 8),
|