ZeroGPU
Browse files
demo.py
CHANGED
|
@@ -498,7 +498,7 @@ def generate_animation(
|
|
| 498 |
|
| 499 |
print("Animation generated")
|
| 500 |
|
| 501 |
-
return synthetic_video.detach() # B x C x T x H x W
|
| 502 |
|
| 503 |
|
| 504 |
@spaces.GPU
|
|
@@ -510,7 +510,8 @@ def decode_animation(latent_animation):
|
|
| 510 |
|
| 511 |
# Convert to torch tensor if needed
|
| 512 |
if not isinstance(latent_animation, torch.Tensor):
|
| 513 |
-
latent_animation = torch.from_numpy(latent_animation)
|
|
|
|
| 514 |
|
| 515 |
# Ensure shape is B x C x T x H x W
|
| 516 |
if len(latent_animation.shape) == 4: # [T, C, H, W]
|
|
|
|
| 498 |
|
| 499 |
print("Animation generated")
|
| 500 |
|
| 501 |
+
return synthetic_video.detach().cpu() # B x C x T x H x W
|
| 502 |
|
| 503 |
|
| 504 |
@spaces.GPU
|
|
|
|
| 510 |
|
| 511 |
# Convert to torch tensor if needed
|
| 512 |
if not isinstance(latent_animation, torch.Tensor):
|
| 513 |
+
latent_animation = torch.from_numpy(latent_animation)
|
| 514 |
+
latent_animation = latent_animation.to(device, dtype=dtype)
|
| 515 |
|
| 516 |
# Ensure shape is B x C x T x H x W
|
| 517 |
if len(latent_animation.shape) == 4: # [T, C, H, W]
|