Fabrice-TIERCELIN commited on
Commit
7be63fd
·
verified ·
1 Parent(s): 8adc98a

Add worker_start_end()

Browse files
Files changed (1) hide show
  1. app.py +269 -0
app.py CHANGED
@@ -585,6 +585,275 @@ def worker(input_image, end_image, image_position, prompts, n_prompt, seed, reso
585
  stream.output_queue.push(('end', None))
586
  return
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
  # 20250506 pftq: Modified worker to accept video input and clean frame count
589
  @torch.no_grad()
590
  def worker_video(input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):
 
585
  stream.output_queue.push(('end', None))
586
  return
587
 
588
+ @torch.no_grad()
589
+ def worker_start_end(input_image, end_image, image_position, prompts, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, mp4_crf, fps_number):
590
+ def encode_prompt(prompt, n_prompt):
591
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
592
+
593
+ if cfg == 1:
594
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
595
+ else:
596
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
597
+
598
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
599
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
600
+
601
+ llama_vec = llama_vec.to(transformer.dtype)
602
+ llama_vec_n = llama_vec_n.to(transformer.dtype)
603
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype)
604
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
605
+ return [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n]
606
+
607
+ total_latent_sections = (total_second_length * fps_number) / (latent_window_size * 4)
608
+ total_latent_sections = int(max(round(total_latent_sections), 1))
609
+
610
+ job_id = generate_timestamp()
611
+
612
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
613
+
614
+ try:
615
+ # Clean GPU
616
+ if not high_vram:
617
+ unload_complete_models(
618
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
619
+ )
620
+
621
+ # Text encoding
622
+
623
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
624
+
625
+ if not high_vram:
626
+ fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
627
+ load_model_as_complete(text_encoder_2, target_device=gpu)
628
+
629
+
630
+ prompt_parameters = []
631
+
632
+ for prompt_part in prompts[:total_latent_sections]:
633
+ prompt_parameters.append(encode_prompt(prompt_part, n_prompt))
634
+
635
+ # Clean GPU
636
+ if not high_vram:
637
+ unload_complete_models(
638
+ text_encoder, text_encoder_2
639
+ )
640
+
641
+ # Processing input image (start frame)
642
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Processing start frame ...'))))
643
+
644
+ H, W, C = input_image.shape
645
+ height, width = find_nearest_bucket(H, W, resolution=640)
646
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
647
+
648
+ Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}_start.png'))
649
+
650
+ input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
651
+ input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
652
+
653
+ # Processing end image (if provided)
654
+ has_end_image = end_image is not None
655
+ if has_end_image:
656
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Processing end frame ...'))))
657
+
658
+ H_end, W_end, C_end = end_image.shape
659
+ end_image_np = resize_and_center_crop(end_image, target_width=width, target_height=height)
660
+
661
+ Image.fromarray(end_image_np).save(os.path.join(outputs_folder, f'{job_id}_end.png'))
662
+
663
+ end_image_pt = torch.from_numpy(end_image_np).float() / 127.5 - 1
664
+ end_image_pt = end_image_pt.permute(2, 0, 1)[None, :, None]
665
+
666
+ # VAE encoding
667
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
668
+
669
+ if not high_vram:
670
+ load_model_as_complete(vae, target_device=gpu)
671
+
672
+ start_latent = vae_encode(input_image_pt, vae)
673
+
674
+ if has_end_image:
675
+ end_latent = vae_encode(end_image_pt, vae)
676
+
677
+ # CLIP Vision
678
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
679
+
680
+ if not high_vram:
681
+ load_model_as_complete(image_encoder, target_device=gpu)
682
+
683
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
684
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
685
+
686
+ if has_end_image:
687
+ end_image_encoder_output = hf_clip_vision_encode(end_image_np, feature_extractor, image_encoder)
688
+ end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
689
+ # Combine both image embeddings or use a weighted approach
690
+ image_encoder_last_hidden_state = (image_encoder_last_hidden_state + end_image_encoder_last_hidden_state) / 2
691
+
692
+ # Clean GPU
693
+ if not high_vram:
694
+ unload_complete_models(
695
+ image_encoder
696
+ )
697
+
698
+ # Dtype
699
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
700
+
701
+ # Sampling
702
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
703
+
704
+ rnd = torch.Generator("cpu").manual_seed(seed)
705
+ num_frames = latent_window_size * 4 - 3
706
+
707
+ history_latents = torch.zeros(size=(1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32, device=cpu)
708
+ start_latent = start_latent.to(history_latents)
709
+ if has_end_image:
710
+ end_latent = end_latent.to(history_latents)
711
+
712
+ history_pixels = None
713
+ total_generated_latent_frames = 0
714
+
715
+ if total_latent_sections > 4:
716
+ # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
717
+ # items looks better than expanding it when total_latent_sections > 4
718
+ # One can try to remove below trick and just
719
+ # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
720
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
721
+ else:
722
+ # Convert an iterator to a list
723
+ latent_paddings = list(range(total_latent_sections - 1, -1, -1))
724
+
725
+ if enable_preview:
726
+ def callback(d):
727
+ preview = d['denoised']
728
+ preview = vae_decode_fake(preview)
729
+
730
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
731
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
732
+
733
+ if stream.input_queue.top() == 'end':
734
+ stream.output_queue.push(('end', None))
735
+ raise KeyboardInterrupt('User ends the task.')
736
+
737
+ current_step = d['i'] + 1
738
+ percentage = int(100.0 * current_step / steps)
739
+ hint = f'Sampling {current_step}/{steps}'
740
+ desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / fps_number) :.2f} seconds (FPS-30), Resolution: {height}px * {width}px. The video is being extended now ...'
741
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
742
+ return
743
+ else:
744
+ def callback(d):
745
+ return
746
+
747
+ for latent_padding in latent_paddings:
748
+ is_last_section = latent_padding == 0
749
+ is_first_section = latent_padding == latent_paddings[0]
750
+ latent_padding_size = latent_padding * latent_window_size
751
+
752
+ if stream.input_queue.top() == 'end':
753
+ stream.output_queue.push(('end', None))
754
+ return
755
+
756
+ print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, is_first_section = {is_first_section}')
757
+
758
+ if len(prompt_parameters) > 0:
759
+ [llama_vec, clip_l_pooler, llama_vec_n, clip_l_pooler_n, llama_attention_mask, llama_attention_mask_n] = prompt_parameters.pop(len(prompt_parameters) - 1)
760
+
761
+ indices = torch.arange(1 + latent_padding_size + latent_window_size + 1 + 2 + 16).unsqueeze(0)
762
+ clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
763
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
764
+
765
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
766
+
767
+ # Use end image latent for the first section if provided
768
+ if has_end_image and is_first_section:
769
+ clean_latents_post = end_latent
770
+
771
+ clean_latents = torch.cat([start_latent, clean_latents_post], dim=2)
772
+
773
+ if not high_vram:
774
+ unload_complete_models()
775
+ move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
776
+
777
+ if use_teacache:
778
+ transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
779
+ else:
780
+ transformer.initialize_teacache(enable_teacache=False)
781
+
782
+ generated_latents = sample_hunyuan(
783
+ transformer=transformer,
784
+ sampler='unipc',
785
+ width=width,
786
+ height=height,
787
+ frames=num_frames,
788
+ real_guidance_scale=cfg,
789
+ distilled_guidance_scale=gs,
790
+ guidance_rescale=rs,
791
+ # shift=3.0,
792
+ num_inference_steps=steps,
793
+ generator=rnd,
794
+ prompt_embeds=llama_vec,
795
+ prompt_embeds_mask=llama_attention_mask,
796
+ prompt_poolers=clip_l_pooler,
797
+ negative_prompt_embeds=llama_vec_n,
798
+ negative_prompt_embeds_mask=llama_attention_mask_n,
799
+ negative_prompt_poolers=clip_l_pooler_n,
800
+ device=gpu,
801
+ dtype=torch.bfloat16,
802
+ image_embeddings=image_encoder_last_hidden_state,
803
+ latent_indices=latent_indices,
804
+ clean_latents=clean_latents,
805
+ clean_latent_indices=clean_latent_indices,
806
+ clean_latents_2x=clean_latents_2x,
807
+ clean_latent_2x_indices=clean_latent_2x_indices,
808
+ clean_latents_4x=clean_latents_4x,
809
+ clean_latent_4x_indices=clean_latent_4x_indices,
810
+ callback=callback,
811
+ )
812
+
813
+ if is_last_section:
814
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
815
+
816
+ total_generated_latent_frames += int(generated_latents.shape[2])
817
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
818
+
819
+ if not high_vram:
820
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
821
+ load_model_as_complete(vae, target_device=gpu)
822
+
823
+ if history_pixels is None:
824
+ history_pixels = vae_decode(history_latents[:, :, :total_generated_latent_frames, :, :], vae).cpu()
825
+ else:
826
+ section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
827
+ overlapped_frames = latent_window_size * 4 - 3
828
+
829
+ current_pixels = vae_decode(history_latents[:, :, :min(total_generated_latent_frames, section_latent_frames)], vae).cpu()
830
+ history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
831
+
832
+ if not high_vram:
833
+ unload_complete_models(vae)
834
+
835
+ if enable_preview or is_last_section:
836
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
837
+
838
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=fps_number, crf=mp4_crf)
839
+
840
+ print(f'Decoded. Pixel shape {history_pixels.shape}')
841
+
842
+ stream.output_queue.push(('file', output_filename))
843
+
844
+ if is_last_section:
845
+ break
846
+ except:
847
+ traceback.print_exc()
848
+
849
+ if not high_vram:
850
+ unload_complete_models(
851
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
852
+ )
853
+
854
+ stream.output_queue.push(('end', None))
855
+ return
856
+
857
  # 20250506 pftq: Modified worker to accept video input and clean frame count
858
  @torch.no_grad()
859
  def worker_video(input_video, prompts, n_prompt, seed, batch, resolution, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, enable_preview, use_teacache, no_resize, mp4_crf, num_clean_frames, vae_batch):