Brian9999 commited on
Commit
85b1487
·
1 Parent(s): 8e96394

Revert VAE encoding to original sequential loop

Browse files
Files changed (1) hide show
  1. gbuffer_utils.py +9 -9
gbuffer_utils.py CHANGED
@@ -21,15 +21,15 @@ class WanVideoUnit_GBufferEncoder(PipelineUnit):
21
  if gbuffer_videos is None:
22
  return {}
23
  pipe.load_models_to_device(self.onload_model_names)
24
- # Batch all modalities into a single vae.encode() call
25
- video_tensors = [pipe.preprocess_video(gv) for gv in gbuffer_videos]
26
- # vae.encode expects a list of (C,T,H,W) tensors; preprocess_video returns (1,C,T,H,W)
27
- video_tensors = [vt.squeeze(0) for vt in video_tensors]
28
- all_latents = pipe.vae.encode(
29
- video_tensors, device=pipe.device,
30
- tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
31
- ).to(dtype=pipe.torch_dtype, device=pipe.device) # [N, 16, T, H, W]
32
- gbuffer_latents = all_latents.reshape(1, -1, *all_latents.shape[2:]) # [1, N*16, T, H, W]
33
  if y is not None:
34
  gbuffer_latents = torch.cat([y, gbuffer_latents], dim=1)
35
  return {"y": gbuffer_latents}
 
21
  if gbuffer_videos is None:
22
  return {}
23
  pipe.load_models_to_device(self.onload_model_names)
24
+ all_latents = []
25
+ for gbuffer_video in gbuffer_videos:
26
+ video_tensor = pipe.preprocess_video(gbuffer_video)
27
+ latent = pipe.vae.encode(
28
+ video_tensor, device=pipe.device,
29
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
30
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
31
+ all_latents.append(latent)
32
+ gbuffer_latents = torch.cat(all_latents, dim=1) # [1, N*16, T, H, W]
33
  if y is not None:
34
  gbuffer_latents = torch.cat([y, gbuffer_latents], dim=1)
35
  return {"y": gbuffer_latents}