Update app.py
Browse files
app.py
CHANGED
|
@@ -125,7 +125,7 @@ class main():
|
|
| 125 |
print("")
|
| 126 |
|
| 127 |
|
| 128 |
-
self.
|
| 129 |
|
| 130 |
young = get_direction(df, "Young", pinverse, 1000, device)
|
| 131 |
young = debias(young, "Male", df, pinverse, device)
|
|
@@ -170,11 +170,7 @@ class main():
|
|
| 170 |
self.thick = thick
|
| 171 |
|
| 172 |
|
| 173 |
-
|
| 174 |
-
@spaces.GPU(duration=1000)
|
| 175 |
-
def sample_model(self):
|
| 176 |
-
self.unet, _, _, _, _ = load_models(self.device)
|
| 177 |
-
self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
| 178 |
|
| 179 |
|
| 180 |
@torch.no_grad()
|
|
@@ -184,8 +180,19 @@ class main():
|
|
| 184 |
self.unet.to(device)
|
| 185 |
self.text_encoder.to(device)
|
| 186 |
self.vae.to(device)
|
| 187 |
-
self.
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
|
| 191 |
|
|
@@ -213,18 +220,9 @@ class main():
|
|
| 213 |
for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
|
| 214 |
latent_model_input = torch.cat([latents] * 2)
|
| 215 |
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
|
| 216 |
-
with
|
| 217 |
-
print(latent_model_input.device)
|
| 218 |
-
print(self.unet.device)
|
| 219 |
-
print(self.text_encoder.device)
|
| 220 |
-
print(self.vae.device)
|
| 221 |
-
print(self.network.proj.device)
|
| 222 |
-
print(self.network.mean.device)
|
| 223 |
-
print(self.network.std.device)
|
| 224 |
-
print(self.network.v.device)
|
| 225 |
-
print(text_embeddings.device)
|
| 226 |
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
|
| 227 |
-
|
| 228 |
#guidance
|
| 229 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 230 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
@@ -315,16 +313,22 @@ class main():
|
|
| 315 |
|
| 316 |
return image
|
| 317 |
|
| 318 |
-
@
|
|
|
|
| 319 |
def sample_then_run(self):
|
| 320 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
prompt = "sks person"
|
| 322 |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
|
| 323 |
seed = 5
|
| 324 |
cfg = 3.0
|
| 325 |
steps = 25
|
| 326 |
-
image = self.inference( prompt, negative_prompt, cfg, steps, seed)
|
| 327 |
-
torch.save(self.
|
| 328 |
return image, "model.pt"
|
| 329 |
|
| 330 |
|
|
|
|
| 125 |
print("")
|
| 126 |
|
| 127 |
|
| 128 |
+
self.weights = None
|
| 129 |
|
| 130 |
young = get_direction(df, "Young", pinverse, 1000, device)
|
| 131 |
young = debias(young, "Male", df, pinverse, device)
|
|
|
|
| 170 |
self.thick = thick
|
| 171 |
|
| 172 |
|
| 173 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
@torch.no_grad()
|
|
|
|
| 180 |
self.unet.to(device)
|
| 181 |
self.text_encoder.to(device)
|
| 182 |
self.vae.to(device)
|
| 183 |
+
self.mean.to(device)
|
| 184 |
+
self.std.to(device)
|
| 185 |
+
self.v.to(device)
|
| 186 |
+
self.proj.to(device)
|
| 187 |
+
self.weights.to(device)
|
| 188 |
+
|
| 189 |
+
network = LoRAw2w( self.weights, self.mean, self.std, self.v,
|
| 190 |
+
self.unet,
|
| 191 |
+
rank=1,
|
| 192 |
+
multiplier=1.0,
|
| 193 |
+
alpha=27.0,
|
| 194 |
+
train_method="xattn-strict"
|
| 195 |
+
).to(device, torch.bfloat16)
|
| 196 |
|
| 197 |
|
| 198 |
|
|
|
|
| 220 |
for i,t in enumerate(tqdm.tqdm(self.noise_scheduler.timesteps)):
|
| 221 |
latent_model_input = torch.cat([latents] * 2)
|
| 222 |
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
|
| 223 |
+
with network:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
|
| 225 |
+
|
| 226 |
#guidance
|
| 227 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 228 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
| 313 |
|
| 314 |
return image
|
| 315 |
|
| 316 |
+
@torch.no_grad()
|
| 317 |
+
@spaces.GPU(duration=1000)
|
| 318 |
def sample_then_run(self):
|
| 319 |
+
self.unet = UNet2DConditionModel.from_pretrained(
|
| 320 |
+
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
| 321 |
+
)
|
| 322 |
+
self.unet.to(self.device, dtype=torch.bfloat16)
|
| 323 |
+
self.weights = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
| 324 |
+
|
| 325 |
prompt = "sks person"
|
| 326 |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
|
| 327 |
seed = 5
|
| 328 |
cfg = 3.0
|
| 329 |
steps = 25
|
| 330 |
+
image = self.inference( weights, prompt, negative_prompt, cfg, steps, seed)
|
| 331 |
+
torch.save(self.weights.cpu().detach(), "model.pt" )
|
| 332 |
return image, "model.pt"
|
| 333 |
|
| 334 |
|