lower ZeroGPU usage
Browse files
demo.py
CHANGED
|
@@ -242,8 +242,8 @@ def preprocess_mask(mask):
|
|
| 242 |
return np.array(mask_pil)
|
| 243 |
|
| 244 |
|
| 245 |
-
@spaces.GPU
|
| 246 |
-
@torch.no_grad(
|
| 247 |
def generate_latent_image(mask, class_selection, sampling_steps=50):
|
| 248 |
"""Generate a latent image based on mask, class selection, and sampling steps"""
|
| 249 |
|
|
@@ -306,8 +306,8 @@ def generate_latent_image(mask, class_selection, sampling_steps=50):
|
|
| 306 |
return latent_image # B x C x H x W
|
| 307 |
|
| 308 |
|
| 309 |
-
@spaces.GPU
|
| 310 |
-
@torch.no_grad(
|
| 311 |
def decode_images(latents):
|
| 312 |
"""Decode latent representations to pixel space using a VAE.
|
| 313 |
|
|
@@ -385,8 +385,8 @@ def decode_latent_to_pixel(latent_image):
|
|
| 385 |
return decoded_image
|
| 386 |
|
| 387 |
|
| 388 |
-
@spaces.GPU
|
| 389 |
-
@torch.no_grad(
|
| 390 |
def check_privacy(latent_image_numpy, class_selection):
|
| 391 |
"""Check if the latent image is too similar to database images"""
|
| 392 |
latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
|
|
@@ -412,8 +412,8 @@ def check_privacy(latent_image_numpy, class_selection):
|
|
| 412 |
)
|
| 413 |
|
| 414 |
|
| 415 |
-
@spaces.GPU
|
| 416 |
-
@torch.no_grad(
|
| 417 |
def generate_animation(
|
| 418 |
latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
|
| 419 |
):
|
|
@@ -501,8 +501,8 @@ def generate_animation(
|
|
| 501 |
return synthetic_video.detach().cpu() # B x C x T x H x W
|
| 502 |
|
| 503 |
|
| 504 |
-
@spaces.GPU
|
| 505 |
-
@torch.no_grad(
|
| 506 |
def decode_animation(latent_animation):
|
| 507 |
"""Decode a latent animation to pixel space"""
|
| 508 |
if latent_animation is None:
|
|
@@ -577,8 +577,8 @@ def convert_latent_to_display(latent_image):
|
|
| 577 |
return display_image
|
| 578 |
|
| 579 |
|
| 580 |
-
@spaces.GPU
|
| 581 |
-
@torch.no_grad(
|
| 582 |
def latent_animation_to_grayscale(latent_animation):
|
| 583 |
"""Convert multi-channel latent animation to grayscale for display"""
|
| 584 |
if latent_animation is None:
|
|
|
|
| 242 |
return np.array(mask_pil)
|
| 243 |
|
| 244 |
|
| 245 |
+
@spaces.GPU(duration=3)
|
| 246 |
+
@torch.no_grad()
|
| 247 |
def generate_latent_image(mask, class_selection, sampling_steps=50):
|
| 248 |
"""Generate a latent image based on mask, class selection, and sampling steps"""
|
| 249 |
|
|
|
|
| 306 |
return latent_image # B x C x H x W
|
| 307 |
|
| 308 |
|
| 309 |
+
@spaces.GPU(duration=3)
|
| 310 |
+
@torch.no_grad()
|
| 311 |
def decode_images(latents):
|
| 312 |
"""Decode latent representations to pixel space using a VAE.
|
| 313 |
|
|
|
|
| 385 |
return decoded_image
|
| 386 |
|
| 387 |
|
| 388 |
+
@spaces.GPU(duration=3)
|
| 389 |
+
@torch.no_grad()
|
| 390 |
def check_privacy(latent_image_numpy, class_selection):
|
| 391 |
"""Check if the latent image is too similar to database images"""
|
| 392 |
latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
|
|
|
|
| 412 |
)
|
| 413 |
|
| 414 |
|
| 415 |
+
@spaces.GPU(duration=3)
|
| 416 |
+
@torch.no_grad()
|
| 417 |
def generate_animation(
|
| 418 |
latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
|
| 419 |
):
|
|
|
|
| 501 |
return synthetic_video.detach().cpu() # B x C x T x H x W
|
| 502 |
|
| 503 |
|
| 504 |
+
@spaces.GPU(duration=3)
|
| 505 |
+
@torch.no_grad()
|
| 506 |
def decode_animation(latent_animation):
|
| 507 |
"""Decode a latent animation to pixel space"""
|
| 508 |
if latent_animation is None:
|
|
|
|
| 577 |
return display_image
|
| 578 |
|
| 579 |
|
| 580 |
+
@spaces.GPU(duration=3)
|
| 581 |
+
@torch.no_grad()
|
| 582 |
def latent_animation_to_grayscale(latent_animation):
|
| 583 |
"""Convert multi-channel latent animation to grayscale for display"""
|
| 584 |
if latent_animation is None:
|