Update app.py
Browse files
app.py
CHANGED
|
@@ -20,6 +20,7 @@ global text_encoder
|
|
| 20 |
global tokenizer
|
| 21 |
global noise_scheduler
|
| 22 |
device = "cuda:0"
|
|
|
|
| 23 |
|
| 24 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
| 25 |
|
|
@@ -31,7 +32,6 @@ df = torch.load(f"{models_path}/identity_df.pt")
|
|
| 31 |
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
|
| 32 |
|
| 33 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
| 34 |
-
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
| 35 |
global network
|
| 36 |
|
| 37 |
def sample_model():
|
|
|
|
| 20 |
global tokenizer
|
| 21 |
global noise_scheduler
|
| 22 |
device = "cuda:0"
|
| 23 |
+
generator = torch.Generator(device=device)
|
| 24 |
|
| 25 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
| 26 |
|
|
|
|
| 32 |
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
|
| 33 |
|
| 34 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
|
|
|
| 35 |
global network
|
| 36 |
|
| 37 |
def sample_model():
|