ZeroGPU
Browse files
demo.py
CHANGED
|
@@ -306,13 +306,12 @@ def generate_latent_image(mask, class_selection, sampling_steps=50):
|
|
| 306 |
|
| 307 |
|
| 308 |
@spaces.GPU
|
| 309 |
-
def decode_images(latents
|
| 310 |
"""Decode latent representations to pixel space using a VAE.
|
| 311 |
|
| 312 |
Args:
|
| 313 |
latents: A numpy array of shape [B, C, H, W] for single image
|
| 314 |
or [B, C, T, H, W] for sequences/animations
|
| 315 |
-
vae: The VAE model for decoding
|
| 316 |
|
| 317 |
Returns:
|
| 318 |
numpy array of decoded images in [B, H, W, 3] format for single image
|
|
@@ -321,6 +320,9 @@ def decode_images(latents, vae):
|
|
| 321 |
if latents is None:
|
| 322 |
return None
|
| 323 |
|
|
|
|
|
|
|
|
|
|
| 324 |
# Convert to torch tensor if needed
|
| 325 |
if not isinstance(latents, torch.Tensor):
|
| 326 |
latents = torch.from_numpy(latents).to(device, dtype=dtype)
|
|
@@ -365,7 +367,6 @@ def decode_images(latents, vae):
|
|
| 365 |
|
| 366 |
def decode_latent_to_pixel(latent_image):
|
| 367 |
"""Decode a single latent image to pixel space"""
|
| 368 |
-
global vae
|
| 369 |
if latent_image is None:
|
| 370 |
return None
|
| 371 |
|
|
@@ -373,7 +374,7 @@ def decode_latent_to_pixel(latent_image):
|
|
| 373 |
if len(latent_image.shape) == 3:
|
| 374 |
latent_image = latent_image[None, ...]
|
| 375 |
|
| 376 |
-
decoded_image = decode_images(latent_image
|
| 377 |
decoded_image = cv2.resize(
|
| 378 |
decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST
|
| 379 |
)
|
|
@@ -493,7 +494,6 @@ def generate_animation(
|
|
| 493 |
|
| 494 |
def decode_animation(latent_animation):
|
| 495 |
"""Decode a latent animation to pixel space"""
|
| 496 |
-
global vae
|
| 497 |
if latent_animation is None:
|
| 498 |
return None
|
| 499 |
|
|
@@ -506,9 +506,7 @@ def decode_animation(latent_animation):
|
|
| 506 |
latent_animation = latent_animation[None, ...] # Add batch dimension
|
| 507 |
|
| 508 |
# Decode using VAE
|
| 509 |
-
decoded = decode_images(
|
| 510 |
-
latent_animation, vae
|
| 511 |
-
) # Returns B x C x T x H x W numpy array
|
| 512 |
|
| 513 |
# Remove batch dimension and transpose to T x H x W x C
|
| 514 |
decoded = np.transpose(decoded[0], (1, 2, 3, 0)) # [T, H, W, C]
|
|
|
|
| 306 |
|
| 307 |
|
| 308 |
@spaces.GPU
|
| 309 |
+
def decode_images(latents):
|
| 310 |
"""Decode latent representations to pixel space using a VAE.
|
| 311 |
|
| 312 |
Args:
|
| 313 |
latents: A numpy array of shape [B, C, H, W] for single image
|
| 314 |
or [B, C, T, H, W] for sequences/animations
|
|
|
|
| 315 |
|
| 316 |
Returns:
|
| 317 |
numpy array of decoded images in [B, H, W, 3] format for single image
|
|
|
|
| 320 |
if latents is None:
|
| 321 |
return None
|
| 322 |
|
| 323 |
+
vae = vae.to(device, dtype=dtype)
|
| 324 |
+
vae.eval()
|
| 325 |
+
|
| 326 |
# Convert to torch tensor if needed
|
| 327 |
if not isinstance(latents, torch.Tensor):
|
| 328 |
latents = torch.from_numpy(latents).to(device, dtype=dtype)
|
|
|
|
| 367 |
|
| 368 |
def decode_latent_to_pixel(latent_image):
|
| 369 |
"""Decode a single latent image to pixel space"""
|
|
|
|
| 370 |
if latent_image is None:
|
| 371 |
return None
|
| 372 |
|
|
|
|
| 374 |
if len(latent_image.shape) == 3:
|
| 375 |
latent_image = latent_image[None, ...]
|
| 376 |
|
| 377 |
+
decoded_image = decode_images(latent_image)
|
| 378 |
decoded_image = cv2.resize(
|
| 379 |
decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST
|
| 380 |
)
|
|
|
|
| 494 |
|
| 495 |
def decode_animation(latent_animation):
|
| 496 |
"""Decode a latent animation to pixel space"""
|
|
|
|
| 497 |
if latent_animation is None:
|
| 498 |
return None
|
| 499 |
|
|
|
|
| 506 |
latent_animation = latent_animation[None, ...] # Add batch dimension
|
| 507 |
|
| 508 |
# Decode using VAE
|
| 509 |
+
decoded = decode_images(latent_animation) # Returns B x C x T x H x W numpy array
|
|
|
|
|
|
|
| 510 |
|
| 511 |
# Remove batch dimension and transpose to T x H x W x C
|
| 512 |
decoded = np.transpose(decoded[0], (1, 2, 3, 0)) # [T, H, W, C]
|