Spaces:
Running
Running
Lorenzo Adacher
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,7 +22,7 @@ class SpriteGenerator(nn.Module):
|
|
| 22 |
nn.Linear(latent_dim, latent_dim)
|
| 23 |
)
|
| 24 |
|
| 25 |
-
# Generator modificato per corrispondere ai pesi salvati
|
| 26 |
self.generator = nn.Sequential(
|
| 27 |
# Input: latent_dim x 1 x 1
|
| 28 |
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), # -> 512 x 4 x 4
|
|
@@ -49,12 +49,16 @@ class SpriteGenerator(nn.Module):
|
|
| 49 |
nn.BatchNorm2d(16),
|
| 50 |
nn.ReLU(True),
|
| 51 |
|
| 52 |
-
# Layer finale modificato per corrispondere ai pesi
|
| 53 |
nn.ConvTranspose2d(16, 16, 4, 2, 1, bias=False), # -> 16 x 256 x 256
|
| 54 |
nn.BatchNorm2d(16),
|
| 55 |
nn.ReLU(True),
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
# Frame interpolator
|
|
@@ -100,7 +104,7 @@ class SpriteGenerator(nn.Module):
|
|
| 100 |
sprites = torch.stack(all_frames, dim=1)
|
| 101 |
|
| 102 |
return sprites
|
| 103 |
-
|
| 104 |
# Costanti
|
| 105 |
MODEL_ID = "Lod34/Animator2D-v2"
|
| 106 |
CACHE_DIR = "model_cache"
|
|
|
|
| 22 |
nn.Linear(latent_dim, latent_dim)
|
| 23 |
)
|
| 24 |
|
| 25 |
+
# Generator modificato per corrispondere esattamente ai pesi salvati
|
| 26 |
self.generator = nn.Sequential(
|
| 27 |
# Input: latent_dim x 1 x 1
|
| 28 |
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), # -> 512 x 4 x 4
|
|
|
|
| 49 |
nn.BatchNorm2d(16),
|
| 50 |
nn.ReLU(True),
|
| 51 |
|
|
|
|
| 52 |
nn.ConvTranspose2d(16, 16, 4, 2, 1, bias=False), # -> 16 x 256 x 256
|
| 53 |
nn.BatchNorm2d(16),
|
| 54 |
nn.ReLU(True),
|
| 55 |
|
| 56 |
+
# Layer finale modificato per corrispondere esattamente ai pesi
|
| 57 |
+
nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False), # Output layer per RGB
|
| 58 |
+
nn.BatchNorm2d(3), # Aggiunto BatchNorm
|
| 59 |
+
nn.ReLU(True), # Aggiunto ReLU
|
| 60 |
+
|
| 61 |
+
nn.Conv2d(3, 3, 3, 1, 1) # Layer di output finale
|
| 62 |
)
|
| 63 |
|
| 64 |
# Frame interpolator
|
|
|
|
| 104 |
sprites = torch.stack(all_frames, dim=1)
|
| 105 |
|
| 106 |
return sprites
|
| 107 |
+
|
| 108 |
# Costanti
|
| 109 |
MODEL_ID = "Lod34/Animator2D-v2"
|
| 110 |
CACHE_DIR = "model_cache"
|