Fabrice-TIERCELIN commited on
Commit
0684df1
·
verified ·
1 Parent(s): 3049af2

Original code

Browse files
Files changed (1) hide show
  1. app.py +1 -279
app.py CHANGED
@@ -806,284 +806,6 @@ def worker_video(input_video, prompts, n_prompt, seed, batch, resolution, total_
806
  stream.output_queue.push(('end', None))
807
  return
808
 
809
- # 20250506 pftq: Modified worker to accept video input and clean frame count
810
- @spaces.GPU()
811
- @torch.no_grad()
812
- def worker_video_experimental(input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):
813
- def encode_prompt(prompt, n_prompt):
814
- llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
815
-
816
- if cfg == 1:
817
- llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
818
- else:
819
- llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
820
-
821
- llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
822
- llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
823
-
824
- llama_vec = llama_vec.to(transformer.dtype)
825
- llama_vec_n = llama_vec_n.to(transformer.dtype)
826
- clip_l_pooler = clip_l_pooler.to(transformer.dtype)
827
- clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
828
- return [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n]
829
-
830
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
831
-
832
- try:
833
- # Clean GPU
834
- if not high_vram:
835
- unload_complete_models(
836
- text_encoder, text_encoder_2, image_encoder, vae, transformer
837
- )
838
-
839
- # Text encoding
840
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
841
-
842
- if not high_vram:
843
- fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
844
- load_model_as_complete(text_encoder_2, target_device=gpu)
845
-
846
- prompt_parameters = []
847
-
848
- for prompt_part in prompts:
849
- prompt_parameters.append(encode_prompt(prompt_part, n_prompt))
850
-
851
- # 20250506 pftq: Processing input video instead of image
852
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Video processing ...'))))
853
-
854
- # 20250506 pftq: Encode video
855
- start_latent, input_image_np, video_latents, fps, height, width = video_encode(input_video, resolution, no_resize, vae, vae_batch_size=vae_batch, device=gpu)[:6]
856
- start_latent = start_latent.to(dtype=torch.float32).cpu()
857
- video_latents = video_latents.cpu()
858
-
859
- # CLIP Vision
860
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
861
-
862
- if not high_vram:
863
- load_model_as_complete(image_encoder, target_device=gpu)
864
-
865
- image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
866
- image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
867
-
868
- # Dtype
869
- image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
870
-
871
- total_latent_sections = (total_second_length * fps) / (latent_window_size * 4)
872
- total_latent_sections = int(max(round(total_latent_sections), 1))
873
-
874
- if enable_preview:
875
- def callback(d):
876
- preview = d['denoised']
877
- preview = vae_decode_fake(preview)
878
-
879
- preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
880
- preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
881
-
882
- if stream.input_queue.top() == 'end':
883
- stream.output_queue.push(('end', None))
884
- raise KeyboardInterrupt('User ends the task.')
885
-
886
- current_step = d['i'] + 1
887
- percentage = int(100.0 * current_step / steps)
888
- hint = f'Sampling {current_step}/{steps}'
889
- desc = f'Total frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / fps) :.2f} seconds (FPS-{fps}), Resolution: {height}px * {width}px, Seed: {seed}, Video {idx+1} of {batch}. The video is generating part {section_index+1} of {total_latent_sections}...'
890
- stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
891
- return
892
- else:
893
- def callback(d):
894
- return
895
-
896
- def compute_latent(history_latents, latent_window_size, num_clean_frames, start_latent):
897
- # 20250506 pftq: Use user-specified number of context frames, matching original allocation for num_clean_frames=2
898
- available_frames = history_latents.shape[2] # Number of latent frames
899
- max_pixel_frames = min(latent_window_size * 4 - 3, available_frames * 4) # Cap at available pixel frames
900
- adjusted_latent_frames = max(1, (max_pixel_frames + 3) // 4) # Convert back to latent frames
901
- # Adjust num_clean_frames to match original behavior: num_clean_frames=2 means 1 frame for clean_latents_1x
902
- effective_clean_frames = max(0, num_clean_frames - 1)
903
- effective_clean_frames = min(effective_clean_frames, available_frames - 2) if available_frames > 2 else 0 # 20250507 pftq: changed 1 to 2 for edge case for <=1 sec videos
904
- num_2x_frames = min(2, max(1, available_frames - effective_clean_frames - 1)) if available_frames > effective_clean_frames + 1 else 0 # 20250507 pftq: subtracted 1 for edge case for <=1 sec videos
905
- num_4x_frames = min(16, max(1, available_frames - effective_clean_frames - num_2x_frames)) if available_frames > effective_clean_frames + num_2x_frames else 0 # 20250507 pftq: Edge case for <=1 sec
906
-
907
- total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames
908
- total_context_frames = min(total_context_frames, available_frames) # 20250507 pftq: Edge case for <=1 sec videos
909
-
910
- indices = torch.arange(0, sum([1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames])).unsqueeze(0) # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
911
- clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split(
912
- [1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames], dim=1 # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
913
- )
914
- clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
915
-
916
- # 20250506 pftq: Split history_latents dynamically based on available frames
917
- fallback_frame_count = 2 # 20250507 pftq: Changed 0 to 2 Edge case for <=1 sec videos
918
- context_frames = clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :]
919
-
920
- if total_context_frames > 0:
921
- context_frames = history_latents[:, :, -total_context_frames:, :, :]
922
- split_sizes = [num_4x_frames, num_2x_frames, effective_clean_frames]
923
- split_sizes = [s for s in split_sizes if s > 0] # Remove zero sizes
924
- if split_sizes:
925
- splits = context_frames.split(split_sizes, dim=2)
926
- split_idx = 0
927
-
928
- if num_4x_frames > 0:
929
- clean_latents_4x = splits[split_idx]
930
- split_idx = 1
931
- if clean_latents_4x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
932
- print("Edge case for <=1 sec videos 4x")
933
- clean_latents_4x = clean_latents_4x.expand(-1, -1, 2, -1, -1)
934
-
935
- if num_2x_frames > 0 and split_idx < len(splits):
936
- clean_latents_2x = splits[split_idx]
937
- if clean_latents_2x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
938
- print("Edge case for <=1 sec videos 2x")
939
- clean_latents_2x = clean_latents_2x.expand(-1, -1, 2, -1, -1)
940
- split_idx += 1
941
- elif clean_latents_2x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
942
- clean_latents_2x = clean_latents_4x
943
-
944
- if effective_clean_frames > 0 and split_idx < len(splits):
945
- clean_latents_1x = splits[split_idx]
946
-
947
- indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
948
- clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
949
- clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
950
-
951
- clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
952
- clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
953
-
954
-
955
- # 20250507 pftq: Fix for <=1 sec videos.
956
- max_frames = min(latent_window_size * 4 - 3, history_latents.shape[2] * 4)
957
- return [max_frames, clean_latents, clean_latents_2x, clean_latents_4x, latent_indices, clean_latents, clean_latent_indices, clean_latent_2x_indices, clean_latent_4x_indices]
958
-
959
- for idx in range(batch):
960
- if batch > 1:
961
- print(f"Beginning video {idx+1} of {batch} with seed {seed} ")
962
-
963
- #job_id = generate_timestamp()
964
- job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")+f"_framepackf1-videoinput_{width}-{total_second_length}sec_seed-{seed}_steps-{steps}_distilled-{gs}_cfg-{cfg}" # 20250506 pftq: easier to read timestamp and filename
965
-
966
- # Sampling
967
- stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
968
-
969
- rnd = torch.Generator("cpu").manual_seed(seed)
970
-
971
- # 20250506 pftq: Initialize history_latents with video latents
972
- history_latents = video_latents
973
- total_generated_latent_frames = history_latents.shape[2]
974
- history_latents = torch.cat([history_latents, start_latent], dim=2)
975
- total_generated_latent_frames = 1
976
- # 20250506 pftq: Initialize history_pixels to fix UnboundLocalError
977
- history_pixels = None
978
- previous_video = None
979
-
980
- for section_index in range(total_latent_sections):
981
- if stream.input_queue.top() == 'end':
982
- stream.output_queue.push(('end', None))
983
- return
984
-
985
- print(f'section_index = {section_index}, total_latent_sections = {total_latent_sections}')
986
-
987
- if len(prompt_parameters) > 0:
988
- [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n] = prompt_parameters.pop(0)
989
-
990
- if not high_vram:
991
- unload_complete_models()
992
- move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
993
-
994
- if use_teacache:
995
- transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
996
- else:
997
- transformer.initialize_teacache(enable_teacache=False)
998
-
999
- [max_frames, clean_latents, clean_latents_2x, clean_latents_4x, latent_indices, clean_latents, clean_latent_indices, clean_latent_2x_indices, clean_latent_4x_indices] = compute_latent(history_latents, latent_window_size, num_clean_frames, start_latent)
1000
-
1001
- generated_latents = sample_hunyuan(
1002
- transformer=transformer,
1003
- sampler='unipc',
1004
- width=width,
1005
- height=height,
1006
- frames=max_frames,
1007
- real_guidance_scale=cfg,
1008
- distilled_guidance_scale=gs,
1009
- guidance_rescale=rs,
1010
- num_inference_steps=steps,
1011
- generator=rnd,
1012
- prompt_embeds=llama_vec,
1013
- prompt_embeds_mask=llama_attention_mask,
1014
- prompt_poolers=clip_l_pooler,
1015
- negative_prompt_embeds=llama_vec_n,
1016
- negative_prompt_embeds_mask=llama_attention_mask_n,
1017
- negative_prompt_poolers=clip_l_pooler_n,
1018
- device=gpu,
1019
- dtype=torch.bfloat16,
1020
- image_embeddings=image_encoder_last_hidden_state,
1021
- latent_indices=latent_indices,
1022
- clean_latents=clean_latents,
1023
- clean_latent_indices=clean_latent_indices,
1024
- clean_latents_2x=clean_latents_2x,
1025
- clean_latent_2x_indices=clean_latent_2x_indices,
1026
- clean_latents_4x=clean_latents_4x,
1027
- clean_latent_4x_indices=clean_latent_4x_indices,
1028
- callback=callback,
1029
- )
1030
-
1031
- total_generated_latent_frames += int(generated_latents.shape[2])
1032
- history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
1033
-
1034
- if not high_vram:
1035
- offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
1036
- load_model_as_complete(vae, target_device=gpu)
1037
-
1038
- if history_pixels is None:
1039
- real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
1040
- history_pixels = vae_decode(real_history_latents, vae).cpu()
1041
- else:
1042
- section_latent_frames = latent_window_size * 2
1043
- overlapped_frames = min(latent_window_size * 4 - 3, history_pixels.shape[2])
1044
-
1045
- real_history_latents = history_latents[:, :, -min(total_generated_latent_frames, section_latent_frames):, :, :]
1046
- history_pixels = soft_append_bcthw(history_pixels, vae_decode(real_history_latents, vae).cpu(), overlapped_frames)
1047
-
1048
- if not high_vram:
1049
- unload_complete_models()
1050
-
1051
- if enable_preview or section_index == total_latent_sections - 1:
1052
- output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
1053
-
1054
- # 20250506 pftq: Use input video FPS for output
1055
- save_bcthw_as_mp4(history_pixels, output_filename, fps=fps, crf=mp4_crf)
1056
- print(f"Latest video saved: {output_filename}")
1057
- # 20250508 pftq: Save prompt to mp4 metadata comments
1058
- set_mp4_comments_imageio_ffmpeg(output_filename, f"Prompt: {prompts} | Negative Prompt: {n_prompt}");
1059
- print(f"Prompt saved to mp4 metadata comments: {output_filename}")
1060
-
1061
- # 20250506 pftq: Clean up previous partial files
1062
- if previous_video is not None and os.path.exists(previous_video):
1063
- try:
1064
- os.remove(previous_video)
1065
- print(f"Previous partial video deleted: {previous_video}")
1066
- except Exception as e:
1067
- print(f"Error deleting previous partial video {previous_video}: {e}")
1068
- previous_video = output_filename
1069
-
1070
- print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
1071
-
1072
- stream.output_queue.push(('file', output_filename))
1073
-
1074
- seed = (seed + 1) % np.iinfo(np.int32).max
1075
-
1076
- except:
1077
- traceback.print_exc()
1078
-
1079
- if not high_vram:
1080
- unload_complete_models(
1081
- text_encoder, text_encoder_2, image_encoder, vae, transformer
1082
- )
1083
-
1084
- stream.output_queue.push(('end', None))
1085
- return
1086
-
1087
  def get_duration(input_image, image_position, prompt, generation_mode, n_prompt, randomize_seed, seed, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, mp4_crf):
1088
  global total_second_length_debug_value
1089
 
@@ -1218,7 +940,7 @@ def process_video(input_video, prompt, n_prompt, randomize_seed, seed, batch, re
1218
  stream = AsyncStream()
1219
 
1220
  # 20250506 pftq: Pass num_clean_frames, vae_batch, etc
1221
- async_run(worker_video_experimental, input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch)
1222
 
1223
  output_filename = None
1224
 
 
806
  stream.output_queue.push(('end', None))
807
  return
808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809
  def get_duration(input_image, image_position, prompt, generation_mode, n_prompt, randomize_seed, seed, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, mp4_crf):
810
  global total_second_length_debug_value
811
 
 
940
  stream = AsyncStream()
941
 
942
  # 20250506 pftq: Pass num_clean_frames, vae_batch, etc
943
+ async_run(worker_video, input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch)
944
 
945
  output_filename = None
946