autodecoder
Browse files
app.py
CHANGED
|
@@ -135,21 +135,21 @@ clip = FrozenCLIPEmbedder()
|
|
| 135 |
clip.eval()
|
| 136 |
clip.to(device)
|
| 137 |
|
| 138 |
-
#
|
| 139 |
-
|
| 140 |
-
|
| 141 |
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
|
| 154 |
|
| 155 |
@spaces.GPU #[uncomment to use ZeroGPU]
|
|
|
|
| 135 |
clip.eval()
|
| 136 |
clip.to(device)
|
| 137 |
|
| 138 |
+
# Load autoencoder.
|
| 139 |
+
autoencoder = libs.autoencoder.get_model(**config_1.autoencoder)
|
| 140 |
+
autoencoder.to(device)
|
| 141 |
|
| 142 |
|
| 143 |
+
@torch.cuda.amp.autocast()
|
| 144 |
+
def encode(_batch: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
"""Encode a batch of images using the autoencoder."""
|
| 146 |
+
return autoencoder.encode(_batch)
|
| 147 |
|
| 148 |
|
| 149 |
+
@torch.cuda.amp.autocast()
|
| 150 |
+
def decode(_batch: torch.Tensor) -> torch.Tensor:
|
| 151 |
+
"""Decode a batch of latent vectors using the autoencoder."""
|
| 152 |
+
return autoencoder.decode(_batch)
|
| 153 |
|
| 154 |
|
| 155 |
@spaces.GPU #[uncomment to use ZeroGPU]
|