Spaces:
Runtime error
Runtime error
Commit
·
b609195
1
Parent(s):
2460e4b
Update app.py
Browse files
app.py
CHANGED
|
@@ -75,9 +75,14 @@ def interpolate(model, save_dir='./lerp/', frames=100, rows=8, cols=8):
|
|
| 75 |
|
| 76 |
|
| 77 |
def predict(model_name, choice, seed):
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
torch.manual_seed(seed)
|
| 82 |
|
| 83 |
if choice == 'interpolation':
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
def predict(model_name, choice, seed):
|
| 78 |
+
try:
|
| 79 |
+
model = Generator(3)
|
| 80 |
+
weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
|
| 81 |
+
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
| 82 |
+
except:
|
| 83 |
+
model = Generator(4)
|
| 84 |
+
weights_path = hf_hub_download(f'huggingnft/{model_name}', 'pytorch_model.bin')
|
| 85 |
+
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
|
| 86 |
torch.manual_seed(seed)
|
| 87 |
|
| 88 |
if choice == 'interpolation':
|