Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -172,21 +172,21 @@ ASPECT_RATIOS = {
|
|
| 172 |
def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
|
| 173 |
"""
|
| 174 |
Create VAE cache with appropriate dimensions for the given aspect ratio.
|
|
|
|
| 175 |
"""
|
| 176 |
ar_config = ASPECT_RATIOS[aspect_ratio]
|
| 177 |
latent_h = ar_config["latent_h"]
|
| 178 |
latent_w = ar_config["latent_w"]
|
| 179 |
|
| 180 |
# Create new cache tensors with correct dimensions
|
| 181 |
-
#
|
| 182 |
cache = []
|
| 183 |
|
| 184 |
-
# The
|
| 185 |
-
|
| 186 |
-
cache.append(torch.zeros(1, 512, latent_h //
|
| 187 |
-
cache.append(torch.zeros(1,
|
| 188 |
-
cache.append(torch.zeros(1,
|
| 189 |
-
cache.append(torch.zeros(1, 128, latent_h, latent_w, device=device, dtype=dtype)) # 1x (same as latent)
|
| 190 |
|
| 191 |
return cache
|
| 192 |
|
|
@@ -381,8 +381,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, aspect_ratio="16
|
|
| 381 |
|
| 382 |
vae_cache, latents_cache = None, None
|
| 383 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
| 384 |
-
#
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
num_blocks = 7
|
| 388 |
current_start_frame = 0
|
|
|
|
| 172 |
def get_vae_cache_for_aspect_ratio(aspect_ratio, device, dtype):
|
| 173 |
"""
|
| 174 |
Create VAE cache with appropriate dimensions for the given aspect ratio.
|
| 175 |
+
VAE cache needs to have 5 dimensions: (batch, channels, time, height, width)
|
| 176 |
"""
|
| 177 |
ar_config = ASPECT_RATIOS[aspect_ratio]
|
| 178 |
latent_h = ar_config["latent_h"]
|
| 179 |
latent_w = ar_config["latent_w"]
|
| 180 |
|
| 181 |
# Create new cache tensors with correct dimensions
|
| 182 |
+
# These need to be 5D tensors: (batch, channels, time, height, width)
|
| 183 |
cache = []
|
| 184 |
|
| 185 |
+
# The time dimension is 1 for cache initialization
|
| 186 |
+
cache.append(torch.zeros(1, 512, 1, latent_h // 8, latent_w // 8, device=device, dtype=dtype)) # 8x downsampled
|
| 187 |
+
cache.append(torch.zeros(1, 512, 1, latent_h // 4, latent_w // 4, device=device, dtype=dtype)) # 4x downsampled
|
| 188 |
+
cache.append(torch.zeros(1, 256, 1, latent_h // 2, latent_w // 2, device=device, dtype=dtype)) # 2x downsampled
|
| 189 |
+
cache.append(torch.zeros(1, 128, 1, latent_h, latent_w, device=device, dtype=dtype)) # 1x (same as latent)
|
|
|
|
| 190 |
|
| 191 |
return cache
|
| 192 |
|
|
|
|
| 381 |
|
| 382 |
vae_cache, latents_cache = None, None
|
| 383 |
if not APP_STATE["current_use_taehv"] and not args.trt:
|
| 384 |
+
# For non-TRT and non-TAEHV, we need to handle aspect ratio properly
|
| 385 |
+
# Use the original ZERO_VAE_CACHE as a template but adjust dimensions
|
| 386 |
+
if aspect_ratio == "16:9":
|
| 387 |
+
# Use default cache for 16:9
|
| 388 |
+
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
| 389 |
+
else:
|
| 390 |
+
# Create custom cache for 9:16
|
| 391 |
+
vae_cache = get_vae_cache_for_aspect_ratio(aspect_ratio, gpu, torch.float16)
|
| 392 |
|
| 393 |
num_blocks = 7
|
| 394 |
current_start_frame = 0
|