Fabrice-TIERCELIN commited on
Commit
f102c54
·
verified ·
1 Parent(s): 31f97bf

Only at the end

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -995,7 +995,6 @@ def worker_video(input_video, end_frame, end_stillness, prompts, n_prompt, seed,
995
  image_encoder=image_encoder, feature_extractor=feature_extractor, device=gpu
996
  )[:2]
997
  end_latent = end_latent.to(dtype=torch.float32, device=cpu)
998
- end_latent = end_latent.expand(-1, -1, 1 + end_stillness, -1, -1)
999
  else:
1000
  end_latent = end_clip_embedding = None
1001
 
@@ -1030,7 +1029,13 @@ def worker_video(input_video, end_frame, end_stillness, prompts, n_prompt, seed,
1030
  def callback(d):
1031
  return
1032
 
1033
- def compute_latent(history_latents, latent_window_size, latent_padding_size, num_clean_frames, start_latent, end_latent, end_stillness):
 
 
 
 
 
 
1034
  # 20250506 pftq: Use user-specified number of context frames, matching original allocation for num_clean_frames=2
1035
  available_frames = history_latents.shape[2] # Number of latent frames
1036
  max_pixel_frames = min(latent_window_size * 4 - 3, available_frames * 4) # Cap at available pixel frames
@@ -1045,9 +1050,9 @@ def worker_video(input_video, end_frame, end_stillness, prompts, n_prompt, seed,
1045
  total_context_frames = min(total_context_frames, available_frames) # 20250507 pftq: Edge case for <=1 sec videos
1046
 
1047
  post_frames = 100 # Single frame for end_latent, otherwise padding causes still image
1048
- indices = torch.arange(0, 1 + num_4x_frames + num_2x_frames + effective_clean_frames + adjusted_latent_frames + ((latent_padding_size + 1 + end_stillness) if end_latent is not None else 0)).unsqueeze(0) # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
1049
  clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices, blank_indices, clean_latent_indices_post = indices.split(
1050
- [1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames, latent_padding_size if end_latent is not None else 0, (1 + end_stillness) if end_latent is not None else 0], dim=1 # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
1051
  )
1052
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices, clean_latent_indices_post], dim=1)
1053
 
@@ -1083,7 +1088,7 @@ def worker_video(input_video, end_frame, end_stillness, prompts, n_prompt, seed,
1083
  clean_latents_1x = splits[split_idx]
1084
 
1085
  if end_latent is not None:
1086
- clean_latents = torch.cat([start_latent, clean_latents_1x, end_latent], dim=2)
1087
  else:
1088
  clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
1089
 
@@ -1138,7 +1143,7 @@ def worker_video(input_video, end_frame, end_stillness, prompts, n_prompt, seed,
1138
  else:
1139
  transformer.initialize_teacache(enable_teacache=False)
1140
 
1141
- [max_frames, clean_latents, clean_latents_2x, clean_latents_4x, latent_indices, clean_latents, clean_latent_indices, clean_latent_2x_indices, clean_latent_4x_indices] = compute_latent(history_latents, latent_window_size, latent_padding_size, num_clean_frames, start_latent, end_latent, end_stillness)
1142
 
1143
  generated_latents = sample_hunyuan(
1144
  transformer=transformer,
 
995
  image_encoder=image_encoder, feature_extractor=feature_extractor, device=gpu
996
  )[:2]
997
  end_latent = end_latent.to(dtype=torch.float32, device=cpu)
 
998
  else:
999
  end_latent = end_clip_embedding = None
1000
 
 
1029
  def callback(d):
1030
  return
1031
 
1032
+ def compute_latent(history_latents, latent_window_size, latent_padding_size, num_clean_frames, start_latent, end_latent, end_stillness, is_end_of_video):
1033
+ if is_end_of_video:
1034
+ local_end_stillness = end_stillness
1035
+ local_end_latent = end_latent.expand(-1, -1, 1 + local_end_stillness, -1, -1)
1036
+ else:
1037
+ local_end_stillness = 0
1038
+ local_end_latent = end_latent
1039
  # 20250506 pftq: Use user-specified number of context frames, matching original allocation for num_clean_frames=2
1040
  available_frames = history_latents.shape[2] # Number of latent frames
1041
  max_pixel_frames = min(latent_window_size * 4 - 3, available_frames * 4) # Cap at available pixel frames
 
1050
  total_context_frames = min(total_context_frames, available_frames) # 20250507 pftq: Edge case for <=1 sec videos
1051
 
1052
  post_frames = 100 # Single frame for end_latent, otherwise padding causes still image
1053
+ indices = torch.arange(0, 1 + num_4x_frames + num_2x_frames + effective_clean_frames + adjusted_latent_frames + ((latent_padding_size + 1 + local_end_stillness) if end_latent is not None else 0)).unsqueeze(0) # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
1054
  clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices, blank_indices, clean_latent_indices_post = indices.split(
1055
+ [1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames, latent_padding_size if end_latent is not None else 0, (1 + local_end_stillness) if end_latent is not None else 0], dim=1 # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
1056
  )
1057
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices, clean_latent_indices_post], dim=1)
1058
 
 
1088
  clean_latents_1x = splits[split_idx]
1089
 
1090
  if end_latent is not None:
1091
+ clean_latents = torch.cat([start_latent, clean_latents_1x, local_end_latent], dim=2)
1092
  else:
1093
  clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
1094
 
 
1143
  else:
1144
  transformer.initialize_teacache(enable_teacache=False)
1145
 
1146
+ [max_frames, clean_latents, clean_latents_2x, clean_latents_4x, latent_indices, clean_latents, clean_latent_indices, clean_latent_2x_indices, clean_latent_4x_indices] = compute_latent(history_latents, latent_window_size, latent_padding_size, num_clean_frames, start_latent, end_latent, end_stillness, is_end_of_video)
1147
 
1148
  generated_latents = sample_hunyuan(
1149
  transformer=transformer,