Update app.py
Browse files
app.py
CHANGED
|
@@ -27,15 +27,15 @@ device = "cuda:0"
|
|
| 27 |
generator = torch.Generator(device=device)
|
| 28 |
|
| 29 |
|
| 30 |
-
models_path = snapshot_download(repo_id="Snapchat/w2w
|
| 31 |
-
|
| 32 |
-
mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
|
| 33 |
-
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
|
| 34 |
-
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
|
| 35 |
-
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
|
| 36 |
-
df = torch.load(f"{models_path}/identity_df.pt")
|
| 37 |
-
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
|
| 38 |
-
pinverse = torch.load(f"{models_path}/pinverse_1000pc.pt").bfloat16().to(device)
|
| 39 |
|
| 40 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
| 41 |
global network
|
|
|
|
| 27 |
generator = torch.Generator(device=device)
|
| 28 |
|
| 29 |
|
| 30 |
+
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
| 31 |
+
|
| 32 |
+
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
|
| 33 |
+
std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device)
|
| 34 |
+
v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device)
|
| 35 |
+
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device)
|
| 36 |
+
df = torch.load(f"{models_path}/files/identity_df.pt")
|
| 37 |
+
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
| 38 |
+
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
|
| 39 |
|
| 40 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
| 41 |
global network
|