diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..41f00ec385d56b350b4b23174d68b52581bb989d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ +# Standard LFS file types for Hugging Face *.7z filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text @@ -33,3 +34,45 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +# Media files +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.avi filter=lfs diff=lfs merge=lfs -text +*.mov filter=lfs diff=lfs merge=lfs -text +*.mkv filter=lfs diff=lfs merge=lfs -text +*.webm filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.flac filter=lfs diff=lfs merge=lfs -text +# Project specific files +examples/7_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/8_video.mp4 filter=lfs diff=lfs merge=lfs -text +examples/1_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/2_video.mp4 filter=lfs diff=lfs merge=lfs -text +examples/6_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/7_video.mp4 filter=lfs diff=lfs merge=lfs -text +examples/8_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/3_video.mp4 filter=lfs diff=lfs merge=lfs -text +examples/5_video.mp4 filter=lfs diff=lfs merge=lfs -text +examples/4_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/5_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/1_video.mp4 filter=lfs diff=lfs merge=lfs -text +examples/3_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/6_video.mp4 filter=lfs diff=lfs merge=lfs -text +examples/2_result.mp4 filter=lfs diff=lfs merge=lfs -text +examples/4_video.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/0.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/4.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/6.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/7.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/1.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/2.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/3.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/5.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/8.mp4 filter=lfs diff=lfs merge=lfs -text +assets/MovieGenAudioBenchSfx/video_with_audio/9.mp4 filter=lfs diff=lfs merge=lfs -text +assets/data_pipeline.png filter=lfs diff=lfs merge=lfs -text +assets/model_arch.png filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 5dea084d42bc69ae5f90b6561770329d558def99..d268e5421b4f12324c301d14278c95f28ce40449 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,96 @@ --- -title: Hunyuanvideo Foley -emoji: 😻 +title: HunyuanVideo-Foley +emoji: 🎵 colorFrom: blue -colorTo: gray +colorTo: purple sdk: gradio -sdk_version: 5.44.1 +sdk_version: 4.44.0 app_file: app.py pinned: false license: apache-2.0 +short_description: Generate realistic audio from video and text descriptions --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# HunyuanVideo-Foley + +
+

🎵 Text-Video-to-Audio Synthesis

+

Generate realistic audio from video and text descriptions using AI

+
+ +## About + +HunyuanVideo-Foley is a multimodal diffusion model that generates high-quality audio effects (Foley audio) synchronized with video content. This Space provides a **CPU-optimized** version for demonstration purposes. + +### ⚠️ CPU Performance Notice + +This Space runs on **free CPU** which means: +- **Slower inference** (3-5 minutes per generation) +- **Limited concurrent users** +- **Reduced sample counts** (max 3 samples) + +For **faster performance**, consider: +- Using the original repository with GPU +- Running locally with CUDA support +- Upgrading to a GPU Space (if available) + +## Features + +- 🎬 **Video-to-Audio**: Generate audio effects from video content +- 📝 **Text Guidance**: Control generation with text descriptions +- 🎯 **Multiple Samples**: Generate up to 3 variations +- 🔧 **Adjustable Settings**: Control CFG scale and inference steps +- 📱 **User-Friendly**: Simple drag-and-drop interface + +## How to Use + +1. **Upload Video**: Drag and drop your video file (MP4, AVI, MOV) +2. **Add Description** (Optional): Describe the audio you want to generate +3. **Adjust Settings**: Modify CFG scale and inference steps if needed +4. **Generate**: Click "Generate Audio" and wait (3-5 minutes on CPU) +5. **Download**: Save your generated audio/video combinations + +## Tips for Best Results + +- 📏 **Video Length**: Keep videos under 30 seconds for faster processing +- 🎯 **Text Prompts**: Use simple, clear descriptions +- ⚡ **Settings**: Lower values process faster on CPU +- 🔄 **Multiple Attempts**: Try different settings if not satisfied + +## Technical Details + +- **Model**: HunyuanVideo-Foley-XXL +- **Architecture**: Multimodal diffusion transformer +- **Audio Quality**: 48kHz professional-grade output +- **Deployment**: CPU-optimized for Hugging Face Spaces + +## Original Project + +This is a **CPU deployment** of the original HunyuanVideo-Foley project: + +- 📄 **Paper**: [HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment](https://arxiv.org/abs/2508.16930) +- 💻 **GitHub**: [Tencent-Hunyuan/HunyuanVideo-Foley](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley) +- 🤗 **Models**: [tencent/HunyuanVideo-Foley](https://huggingface.co/tencent/HunyuanVideo-Foley) + +## Citation + +```bibtex +@misc{shan2025hunyuanvideofoleymultimodaldiffusionrepresentation, + title={HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation}, + author={Sizhe Shan and Qiulin Li and Yutao Cui and Miles Yang and Yuehai Wang and Qun Yang and Jin Zhou and Zhao Zhong}, + year={2025}, + eprint={2508.16930}, + archivePrefix={arXiv}, + primaryClass={eess.AS} +} +``` + +## License + +This project is licensed under the Apache 2.0 License. + +--- + +
+

🚀 Powered by Tencent Hunyuan | Optimized for CPU deployment

+
\ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..537b2d06447dd109dd12e5bf7515727d8e40c3e2 --- /dev/null +++ b/app.py @@ -0,0 +1,405 @@ +import os +import tempfile +import gradio as gr +import torch +import torchaudio +from loguru import logger +from typing import Optional, Tuple +import random +import numpy as np + +# Force CPU usage for Hugging Face Spaces +os.environ["CUDA_VISIBLE_DEVICES"] = "" + +from hunyuanvideo_foley.utils.model_utils import load_model +from hunyuanvideo_foley.utils.feature_utils import feature_process +from hunyuanvideo_foley.utils.model_utils import denoise_process +from hunyuanvideo_foley.utils.media_utils import merge_audio_video + +# Global variables for model storage +model_dict = None +cfg = None +device = None + +# Model path for Hugging Face Spaces - try to download automatically +MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/") +CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" + +def setup_device(force_cpu: bool = True) -> torch.device: + """Setup computing device - force CPU for Hugging Face Spaces""" + if force_cpu: + device = torch.device("cpu") + logger.info("Using CPU device (forced for Hugging Face Spaces)") + else: + if torch.cuda.is_available(): + device = torch.device("cuda:0") + logger.info("Using CUDA device") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + logger.info("Using MPS device") + else: + device = torch.device("cpu") + logger.info("Using CPU device") + + return device + +def download_models(): + """Download models from Hugging Face if not present""" + try: + from huggingface_hub import snapshot_download + logger.info("Downloading models from Hugging Face...") + + # Download the model files + snapshot_download( + repo_id="tencent/HunyuanVideo-Foley", + local_dir="./pretrained_models", + local_dir_use_symlinks=False + ) + + logger.info("Model download completed!") + return True + except Exception as e: + logger.error(f"Failed to download models: {str(e)}") + return False + +def auto_load_models() -> str: + """Automatically load preset models""" + global model_dict, cfg, device + + try: + # First try to download models if they don't exist + if not os.path.exists(MODEL_PATH) or not os.listdir(MODEL_PATH): + logger.info("Models not found locally, attempting to download...") + if not download_models(): + return "❌ Failed to download models from Hugging Face" + + if not os.path.exists(CONFIG_PATH): + return f"❌ Config file not found: {CONFIG_PATH}" + + # Force CPU usage for Hugging Face Spaces + device = setup_device(force_cpu=True) + + # Load model with CPU optimization + logger.info("Loading model on CPU...") + logger.info(f"Model path: {MODEL_PATH}") + logger.info(f"Config path: {CONFIG_PATH}") + + # Set torch to use fewer threads for CPU inference + torch.set_num_threads(2) + + model_dict, cfg = load_model(MODEL_PATH, CONFIG_PATH, device) + + logger.info("✅ Model loaded successfully on CPU!") + return "✅ Model loaded successfully on CPU!" + + except Exception as e: + logger.error(f"Model loading failed: {str(e)}") + return f"❌ Model loading failed: {str(e)}" + +def infer_single_video( + video_file, + text_prompt: str, + guidance_scale: float = 2.0, # Lower for CPU + num_inference_steps: int = 20, # Reduced for CPU + sample_nums: int = 1 +) -> Tuple[list, str]: + """Single video inference optimized for CPU""" + global model_dict, cfg, device + + if model_dict is None or cfg is None: + return [], "❌ Please load the model first!" + + if video_file is None: + return [], "❌ Please upload a video file!" + + # Allow empty text prompt + if text_prompt is None: + text_prompt = "" + text_prompt = text_prompt.strip() + + try: + logger.info(f"Processing video: {video_file}") + logger.info(f"Text prompt: {text_prompt}") + logger.info("Running inference on CPU (this may take a while)...") + + # Feature processing + visual_feats, text_feats, audio_len_in_s = feature_process( + video_file, + text_prompt, + model_dict, + cfg + ) + + # Denoising process with CPU-optimized settings + logger.info(f"Generating {sample_nums} audio sample(s) on CPU...") + audio, sample_rate = denoise_process( + visual_feats, + text_feats, + audio_len_in_s, + model_dict, + cfg, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + batch_size=sample_nums + ) + + # Create temporary files to save results + temp_dir = tempfile.mkdtemp() + video_outputs = [] + + # Process each generated audio sample + for i in range(sample_nums): + # Save audio file + audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav") + torchaudio.save(audio_output, audio[i], sample_rate) + + # Merge video and audio + video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4") + merge_audio_video(audio_output, video_file, video_output) + video_outputs.append(video_output) + + logger.info(f"Inference completed! Generated {sample_nums} samples.") + return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully on CPU!" + + except Exception as e: + logger.error(f"Inference failed: {str(e)}") + return [], f"❌ Inference failed: {str(e)}" + +def update_video_outputs(video_list, status_msg): + """Update video outputs based on the number of generated samples""" + # Initialize all outputs as None + outputs = [None] * 3 # Reduced to 3 for CPU + + # Set values based on generated videos + for i, video_path in enumerate(video_list[:3]): # Max 3 samples for CPU + outputs[i] = video_path + + # Return all outputs plus status message + return tuple(outputs + [status_msg]) + +def create_gradio_interface(): + """Create Gradio interface optimized for CPU deployment""" + + # Custom CSS with Hugging Face Spaces styling + css = """ + .gradio-container { + font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; + background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); + min-height: 100vh; + } + + .main-header { + text-align: center; + padding: 2rem 0; + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + border-radius: 20px; + margin-bottom: 2rem; + box-shadow: 0 8px 32px rgba(0,0,0,0.15); + } + + .main-header h1 { + color: white; + font-size: 3rem; + font-weight: 700; + margin-bottom: 0.5rem; + text-shadow: 0 2px 10px rgba(0,0,0,0.3); + } + + .main-header p { + color: rgba(255, 255, 255, 0.95); + font-size: 1.2rem; + font-weight: 300; + } + + .cpu-notice { + background: #fff3cd; + border: 1px solid #ffeaa7; + border-radius: 10px; + padding: 1rem; + margin: 1rem 0; + color: #856404; + } + """ + + with gr.Blocks(css=css, title="HunyuanVideo-Foley (CPU)") as app: + + # Main header + with gr.Column(elem_classes=["main-header"]): + gr.HTML(""" +

🎵 HunyuanVideo-Foley

+

Text-Video-to-Audio Synthesis (CPU Version)

+ """) + + # CPU Notice + gr.HTML(""" +
+ ⚠️ CPU Deployment Notice: This Space runs on CPU which means inference will be slower than GPU version. + Each generation may take 3-5 minutes. For faster inference, consider running locally with GPU. +
+ """) + + # Usage Guide + gr.Markdown(""" + ### 📋 Quick Start Guide + **1.** Upload your video file **2.** Add optional text description **3.** Click Generate Audio (be patient!) + + 💡 **Tips for CPU usage:** + - Use shorter videos (< 30 seconds recommended) + - Simple text prompts work better + - Expect longer processing times + """) + + # Main interface + with gr.Row(): + # Input section + with gr.Column(scale=1): + gr.Markdown("### 📹 Video Input") + + video_input = gr.Video( + label="Upload Video", + info="Supported formats: MP4, AVI, MOV, etc. Shorter videos recommended for CPU.", + height=300 + ) + + text_input = gr.Textbox( + label="🎯 Audio Description (English)", + placeholder="A person walks on frozen ice", + lines=3, + info="Describe the audio you want to generate (optional)" + ) + + with gr.Row(): + guidance_scale = gr.Slider( + minimum=1.0, + maximum=5.0, + value=2.0, + step=0.1, + label="🎚️ CFG Scale (lower for CPU)", + ) + + inference_steps = gr.Slider( + minimum=10, + maximum=50, + value=20, + step=5, + label="⚡ Steps (reduced for CPU)", + ) + + sample_nums = gr.Slider( + minimum=1, + maximum=3, + value=1, + step=1, + label="🎲 Sample Nums (max 3 for CPU)", + ) + + generate_btn = gr.Button( + "🎵 Generate Audio (CPU)", + variant="primary" + ) + + # Results section + with gr.Column(scale=1): + gr.Markdown("### 🎥 Generated Results") + + # Reduced number of outputs for CPU + video_output_1 = gr.Video( + label="Sample 1", + height=250, + visible=True + ) + + with gr.Row(): + video_output_2 = gr.Video( + label="Sample 2", + height=200, + visible=False + ) + video_output_3 = gr.Video( + label="Sample 3", + height=200, + visible=False + ) + + result_text = gr.Textbox( + label="Status", + interactive=False, + lines=3 + ) + + # Event handlers + def process_inference(video_file, text_prompt, guidance_scale, inference_steps, sample_nums): + # Generate videos + video_list, status_msg = infer_single_video( + video_file, text_prompt, guidance_scale, inference_steps, int(sample_nums) + ) + # Update outputs with proper visibility + return update_video_outputs(video_list, status_msg) + + # Add dynamic visibility control + def update_visibility(sample_nums): + sample_nums = int(sample_nums) + return [ + gr.update(visible=True), # Sample 1 always visible + gr.update(visible=sample_nums >= 2), # Sample 2 + gr.update(visible=sample_nums >= 3), # Sample 3 + ] + + # Update visibility when sample_nums changes + sample_nums.change( + fn=update_visibility, + inputs=[sample_nums], + outputs=[video_output_1, video_output_2, video_output_3] + ) + + generate_btn.click( + fn=process_inference, + inputs=[video_input, text_input, guidance_scale, inference_steps, sample_nums], + outputs=[ + video_output_1, + video_output_2, + video_output_3, + result_text + ] + ) + + # Footer + gr.HTML(""" +
+

🚀 Powered by HunyuanVideo-Foley | Running on CPU for Hugging Face Spaces

+

For faster inference, visit the original repository

+
+ """) + + return app + +def set_manual_seed(global_seed): + random.seed(global_seed) + np.random.seed(global_seed) + torch.manual_seed(global_seed) + +if __name__ == "__main__": + set_manual_seed(1) + # Setup logging + logger.remove() + logger.add(lambda msg: print(msg, end=''), level="INFO") + + # Auto-load model + logger.info("Starting CPU application and loading model...") + model_load_result = auto_load_models() + logger.info(model_load_result) + + # Create and launch Gradio app + app = create_gradio_interface() + + # Log completion status + if "successfully" in model_load_result: + logger.info("Application ready, model loaded on CPU") + + app.launch( + server_name="0.0.0.0", + server_port=7860, # Standard port for Hugging Face Spaces + share=False, + debug=False, + show_error=True + ) \ No newline at end of file diff --git a/configs/hunyuanvideo-foley-xxl.yaml b/configs/hunyuanvideo-foley-xxl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b7da02713647fd3201d378d86924bfdf42e93c9 --- /dev/null +++ b/configs/hunyuanvideo-foley-xxl.yaml @@ -0,0 +1,49 @@ +model_config: + model_name: HunyuanVideo-Foley-XXL + model_type: 1d + model_precision: bf16 + model_kwargs: + depth_triple_blocks: 18 + depth_single_blocks: 36 + hidden_size: 1536 + num_heads: 12 + mlp_ratio: 4 + mlp_act_type: "gelu_tanh" + qkv_bias: True + qk_norm: True + qk_norm_type: "rms" + attn_mode: "torch" + embedder_type: "default" + interleaved_audio_visual_rope: True + enable_learnable_empty_visual_feat: True + sync_modulation: False + add_sync_feat_to_audio: True + cross_attention: True + use_attention_mask: False + condition_projection: "linear" + sync_feat_dim: 768 # syncformer 768 dim + condition_dim: 768 # clap 768 text condition dim (clip-text) + clip_dim: 768 # siglip2 visual dim + audio_vae_latent_dim: 128 + audio_frame_rate: 50 + patch_size: 1 + rope_dim_list: null + rope_theta: 10000 + text_length: 77 + clip_length: 64 + sync_length: 192 + use_mmaudio_singleblock: True + depth_triple_ssl_encoder: null + depth_single_ssl_encoder: 8 + use_repa_with_audiossl: True + +diffusion_config: + denoise_type: "flow" + flow_path_type: "linear" + flow_predict_type: "velocity" + flow_reverse: True + flow_solver: "euler" + sample_flow_shift: 1.0 + sample_use_flux_shift: False + flux_base_shift: 0.5 + flux_max_shift: 1.15 diff --git a/hunyuanvideo_foley/__init__.py b/hunyuanvideo_foley/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61662d26303a16fd02acb558569428bab8efd4b1 Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71b5507b7f193125a4b8ed840c414ab5ad36de33 Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc b/hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5be810050d3223d2d4660812e4bb557707bad15f Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc b/hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0842fa469b5edad26854f848e825529a845ec910 Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/constants.py b/hunyuanvideo_foley/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..81519407b44f33cfcbadd550f41a634b11821649 --- /dev/null +++ b/hunyuanvideo_foley/constants.py @@ -0,0 +1,57 @@ +"""Constants used throughout the HunyuanVideo-Foley project.""" + +from typing import Dict, List + +# Model configuration +DEFAULT_AUDIO_SAMPLE_RATE = 48000 +DEFAULT_VIDEO_FPS = 25 +DEFAULT_AUDIO_CHANNELS = 2 + +# Video processing +MAX_VIDEO_DURATION_SECONDS = 15.0 +MIN_VIDEO_DURATION_SECONDS = 1.0 + +# Audio processing +AUDIO_VAE_LATENT_DIM = 128 +AUDIO_FRAME_RATE = 75 # frames per second in latent space + +# Visual features +FPS_VISUAL: Dict[str, int] = { + "siglip2": 8, + "synchformer": 25 +} + +# Model paths (can be overridden by environment variables) +DEFAULT_MODEL_PATH = "./pretrained_models/" +DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml" + +# Inference parameters +DEFAULT_GUIDANCE_SCALE = 4.5 +DEFAULT_NUM_INFERENCE_STEPS = 50 +MIN_GUIDANCE_SCALE = 1.0 +MAX_GUIDANCE_SCALE = 10.0 +MIN_INFERENCE_STEPS = 10 +MAX_INFERENCE_STEPS = 100 + +# Text processing +MAX_TEXT_LENGTH = 100 +DEFAULT_NEGATIVE_PROMPT = "noisy, harsh" + +# File extensions +SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"] +SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"] + +# Quality settings +AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = { + "high": ["-b:a", "192k"], + "medium": ["-b:a", "128k"], + "low": ["-b:a", "96k"] +} + +# Error messages +ERROR_MESSAGES: Dict[str, str] = { + "model_not_loaded": "Model is not loaded. Please load the model first.", + "invalid_video_format": "Unsupported video format. Supported formats: {formats}", + "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds", + "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html" +} \ No newline at end of file diff --git a/hunyuanvideo_foley/models/__init__.py b/hunyuanvideo_foley/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7256df1d54bdbca395cf379fc5ffdf81c560e139 Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c16a43edf1fdd1e0c88011d8cd1c450d81250e Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1f8e5c911fa9516c071206f63cdb3559f853bcf Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cef069430dd3393c7c49f0b5f77ad1134a3ecaf2 Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/__init__.py b/hunyuanvideo_foley/models/dac_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/__init__.py @@ -0,0 +1,16 @@ +__version__ = "1.0.0" + +# preserved here for legacy reasons +__model_version__ = "latest" + +import audiotools + +audiotools.ml.BaseModel.INTERN += ["dac.**"] +audiotools.ml.BaseModel.EXTERN += ["einops"] + + +from . import nn +from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/hunyuanvideo_foley/models/dac_vae/__main__.py b/hunyuanvideo_foley/models/dac_vae/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fe6531c5bf82f731d8e07ec09c21d79aae4cfa --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/__main__.py @@ -0,0 +1,36 @@ +import sys + +import argbind + +from .utils import download +from .utils.decode import decode +from .utils.encode import encode + +STAGES = ["encode", "decode", "download"] + + +def run(stage: str): + """Run stages. + + Parameters + ---------- + stage : str + Stage to run + """ + if stage not in STAGES: + raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") + stage_fn = globals()[stage] + + if stage == "download": + stage_fn() + return + + stage_fn() + + +if __name__ == "__main__": + group = sys.argv.pop(1) + args = argbind.parse_args(group=group) + + with argbind.scope(args): + run(group) diff --git a/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14cf6031338150ea3a7f152c17008c66a6413af1 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f50d727a0848877df226320cc677261667551ce9 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__init__.py b/hunyuanvideo_foley/models/dac_vae/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02a75b7ad6028f5c41b6a8285b0257d4c23bdfcf --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/__init__.py @@ -0,0 +1,4 @@ +from .base import CodecMixin +from .base import DACFile +from .dac import DAC +from .discriminator import Discriminator diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1973b3c42287fe8b898d235449f0fbca61c7f671 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ceb0da01684fe8da99ed24a4ff9a03163b6387 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..997357623388982f2ff70a5e4d0cf3ca4fdaeb2d Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..791f244d54f568235063163cd7958bf90cfddeb8 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9368431bc265497e05dcf6e4de3c4739f6e4df2e Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f33b65ae4c386a21f4449e105e55e043a58b51 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7041224daee46788150ab5a04367984f3071964e Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32e5430cebc009a08f69aa50663bcf1de237438c Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/model/base.py b/hunyuanvideo_foley/models/dac_vae/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e95a84149a767f256a54b7cc3241c09551f39061 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/base.py @@ -0,0 +1,301 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + audio_signal = audio_signal.to_mono() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + if obj.input_db is not None: + recons.normalize(obj.input_db) + + resample_fn(obj.sample_rate) + + if obj.original_length is not None: + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + else: + loudness_fn() + + self.padding = original_padding + return recons diff --git a/hunyuanvideo_foley/models/dac_vae/model/dac.py b/hunyuanvideo_foley/models/dac_vae/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..3df2cbad1502774ce2519c1795b6e85571ac3fc1 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/dac.py @@ -0,0 +1,410 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from ..nn.layers import Snake1d +from ..nn.layers import WNConv1d +from ..nn.layers import WNConvTranspose1d +from ..nn.quantize import ResidualVectorQuantize +from ..nn.vae_utils import DiagonalGaussianDistribution + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + continuous: bool = False, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + self.continuous = continuous + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + if not continuous: + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + else: + self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1) + self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + @property + def dtype(self): + """Get the dtype of the model parameters.""" + # Return the dtype of the first parameter found + for param in self.parameters(): + return param.dtype + return torch.float32 # fallback + + @property + def device(self): + """Get the device of the model parameters.""" + # Return the device of the first parameter found + for param in self.parameters(): + return param.device + return torch.device('cpu') # fallback + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) # [B x D x T] + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) + else: + z = self.quant_conv(z) # [B x 2D x T] + z = DiagonalGaussianDistribution(z) + codes, latents, commitment_loss, codebook_loss = None, None, 0, 0 + + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + if not self.continuous: + audio = self.decoder(z) + else: + z = self.post_quant_conv(z) + audio = self.decoder(z) + + return audio + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + if not self.continuous: + z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + else: + posterior, _, _, _, _ = self.encode(audio_data, n_quantizers) + z = posterior.sample() + x = self.decode(z) + + kl_loss = posterior.kl() + kl_loss = kl_loss.mean() + + return { + "audio": x[..., :length], + "z": z, + "kl_loss": kl_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/hunyuanvideo_foley/models/dac_vae/model/discriminator.py b/hunyuanvideo_foley/models/dac_vae/model/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..09c79d1342ca46bef21daca64667577f05e61638 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/model/discriminator.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import ml +from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(ml.BaseModel): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = torch.zeros(1, 1, 44100) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__init__.py b/hunyuanvideo_foley/models/dac_vae/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d0ce8c3db3caad0dc1f1ca8ed34ea7fa2766785 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d403b09b59ce7996ac2af0b52421113b68446cf Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbc168386e2ae56ebb97431958c9e4eb621ea0ba Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cccc1a06bde5006b48106bcbd6d83a790efd311 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec2f4e05de4f45475a33633b7a0e4aa2a4884d1a Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1955f15c93492cb8c1d052508398b60e328815a Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01f0e8c343007975e4bec3f13e8490e6623b545e Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae130e90c7602297cc69d40f93d032c9770dc3b8 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..779686095f5751c00c76304407ea7b964da06143 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29ae286e4e1324debcebec0db3b3150d49bc4af3 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/nn/layers.py b/hunyuanvideo_foley/models/dac_vae/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/hunyuanvideo_foley/models/dac_vae/nn/loss.py b/hunyuanvideo_foley/models/dac_vae/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/hunyuanvideo_foley/models/dac_vae/nn/quantize.py b/hunyuanvideo_foley/models/dac_vae/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/quantize.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py b/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a97597f5d5ae4aa19a194c24f3c17b2238224bcf --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py @@ -0,0 +1,91 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2], + ) + + def nll(self, sample, dims=[1, 2]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/hunyuanvideo_foley/models/dac_vae/utils/__init__.py b/hunyuanvideo_foley/models/dac_vae/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dce1ed49f1b4e4fe1cb42b054298911207e0e41 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/utils/__init__.py @@ -0,0 +1,121 @@ +from pathlib import Path + +import argbind +from audiotools import ml + +from ..model import DAC +Accelerator = ml.Accelerator + +__MODEL_LATEST_TAGS__ = { + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", + ( + "16khz", + "0.0.5", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", +} + + +@argbind.bind(group="download", positional=True, without_prefix=True) +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): + """ + Function that downloads the weights file from URL if a local cache is not found. + + Parameters + ---------- + model_type : str + The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". + + Returns + ------- + Path + Directory path required to load model via audiotools. + """ + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + "16khz", + ], "model_type must be one of '44khz', '24khz', or '16khz'" + + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] + + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" + ) + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Download the model + import requests + + response = requests.get(download_link) + + if response.status_code != 200: + raise ValueError( + f"Could not download model. Received response code {response.status_code}" + ) + local_path.write_bytes(response.content) + + return local_path + + +def load_model( + model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, +): + if not load_path: + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) + return generator diff --git a/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a9a3fd6342769d3b5f2e8f1d638306ae56a8b5 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..783f2a50bc5c9f9d80e509f8a31705fee1d18570 Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/dac_vae/utils/decode.py b/hunyuanvideo_foley/models/dac_vae/utils/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..00261a561251b1bef6f11e6594bf80de10b93ff2 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/utils/decode.py @@ -0,0 +1,95 @@ +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from tqdm import tqdm + +from ..model import DACFile +from . import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="decode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def decode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + device: str = "cuda", + model_type: str = "44khz", + verbose: bool = False, +): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + + # Find all .dac files in input directory + _input = Path(input) + input_files = list(_input.glob("**/*.dac")) + + # If input is a .dac file, add it to the list + if _input.suffix == ".dac": + input_files.append(_input) + + # Create output directory + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(input_files)), desc=f"Decoding files"): + # Load file + artifact = DACFile.load(input_files[i]) + + # Reconstruct audio from codes + recons = generator.decompress(artifact, verbose=verbose) + + # Compute output path + relative_path = input_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = input_files[i] + output_name = relative_path.with_suffix(".wav").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to file + recons.write(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + decode() diff --git a/hunyuanvideo_foley/models/dac_vae/utils/encode.py b/hunyuanvideo_foley/models/dac_vae/utils/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..c86946c3c6d6a7ff1d1ea883c600d9b93c41b7d9 --- /dev/null +++ b/hunyuanvideo_foley/models/dac_vae/utils/encode.py @@ -0,0 +1,94 @@ +import math +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.core import util +from tqdm import tqdm + +from . import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="encode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def encode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + n_quantizers: int = None, + device: str = "cuda", + model_type: str = "44khz", + win_duration: float = 5.0, + verbose: bool = False, +): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + kwargs = {"n_quantizers": n_quantizers} + + # Find all audio files in input path + input = Path(input) + audio_files = util.find_audio(input) + + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(audio_files)), desc="Encoding files"): + # Load file + signal = AudioSignal(audio_files[i]) + + # Encode audio to .dac format + artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) + + # Compute output path + relative_path = audio_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = audio_files[i] + output_name = relative_path.with_suffix(".dac").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + artifact.save(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + encode() diff --git a/hunyuanvideo_foley/models/hifi_foley.py b/hunyuanvideo_foley/models/hifi_foley.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5fffcca7a61c3baf686031e9ffdbc3696f3e22 --- /dev/null +++ b/hunyuanvideo_foley/models/hifi_foley.py @@ -0,0 +1,794 @@ +from typing import List, Tuple, Optional, Union, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange +from diffusers.models import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config + +from .nn.activation_layers import SwiGLU, get_activation_layer +from .nn.attn_layers import apply_rotary_emb, attention +from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D +from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d +from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate +from .nn.norm_layers import get_norm_layer +from .nn.posemb_layers import get_nd_rotary_pos_embed + +def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor): + # [B, N1, H, C] & [B, N2, H, C] + B, N1, H, C = x1.shape + B, N2, H, C = x2.shape + assert x1.ndim == x2.ndim == 4 + + if N1 != N2: + x2 = x2.view(B, N2, -1).transpose(1, 2) + x2 = F.interpolate(x2, size=(N1), mode="nearest-exact") + x2 = x2.transpose(1, 2).view(B, N1, H, C) + x = torch.stack((x1, x2), dim=2) + x = x.reshape(B, N1 * 2, H, C) + return x + +def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int): + B, N, H, C = x.shape + assert N % 2 == 0 and N // 2 == len1 + + x = x.reshape(B, -1, 2, H, C) + x1 = x[:, :, 0] + x2 = x[:, :, 1] + if x2.shape[1] != len2: + x2 = x2.view(B, len1, H * C).transpose(1, 2) + x2 = F.interpolate(x2, size=(len2), mode="nearest-exact") + x2 = x2.transpose(1, 2).view(B, len2, H, C) + return x1, x2 + +class TwoStreamCABlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qkv_bias: bool = False, + attn_mode: str = "torch", + reverse: bool = False, + interleaved_audio_visual_rope: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.reverse = reverse + self.attn_mode = attn_mode + self.num_heads = num_heads + self.hidden_size = hidden_size + head_dim = hidden_size // num_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.interleaved_audio_visual_rope = interleaved_audio_visual_rope + + # Self attention for audio + visual + self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) + self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.audio_self_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.audio_self_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + # visual cond + self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs) + self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) + self.v_cond_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.v_cond_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + self.max_text_len = 100 + self.rope_dim_list = None + + # audio and video norm for cross attention with text + self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + # Cross attention: (video_audio) as query, text as key/value + self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs) + + self.audio_cross_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.v_cond_cross_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.text_cross_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) + + # MLPs + self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.audio_mlp = MLP( + hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs + ) + + self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.v_cond_mlp = MLP( + hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs + ) + + def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None): + target_ndim = 1 # n-d RoPE + rope_sizes = [text_len] + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + + text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + return text_freqs_cos, text_freqs_sin + + def set_attn_mode(self, new_mode): + if new_mode != "torch": + raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.") + self.attn_mode = new_mode + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def forward( + self, + audio: torch.Tensor, + cond: torch.Tensor, + v_cond: torch.Tensor, + attn_mask: torch.Tensor, + vec: torch.Tensor, + freqs_cis: tuple = None, + v_freqs_cis: tuple = None, + sync_vec: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Get modulation parameters + if sync_vec is not None: + assert sync_vec.ndim == 3 + (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, + audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, + audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, + ) = self.audio_mod(sync_vec).chunk(9, dim=-1) + else: + (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, + audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, + audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, + ) = self.audio_mod(vec).chunk(9, dim=-1) + + ( + v_cond_mod1_shift, + v_cond_mod1_scale, + v_cond_mod1_gate, + v_cond_mod2_shift, + v_cond_mod2_scale, + v_cond_mod2_gate, + v_cond_mod3_shift, + v_cond_mod3_scale, + v_cond_mod3_gate, + ) = self.v_cond_mod(vec).chunk(9, dim=-1) + + # 1. Self Attention for audio + visual + audio_modulated = self.audio_norm1(audio) + audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale) + audio_qkv = self.audio_self_attn_qkv(audio_modulated) + audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + audio_q = self.audio_self_q_norm(audio_q).to(audio_v) + audio_k = self.audio_self_k_norm(audio_k).to(audio_v) + + # Prepare visual cond for attention + v_cond_modulated = self.v_cond_norm1(v_cond) + v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale) + v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated) + v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v) + v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v) + + # Apply RoPE if needed for audio and visual + if freqs_cis is not None: + if not self.interleaved_audio_visual_rope: + audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False) + audio_q, audio_k = audio_qq, audio_kk + else: + ori_audio_len = audio_q.shape[1] + ori_v_con_len = v_cond_q.shape[1] + interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q) + interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k) + interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb( + interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False + ) + audio_qq, v_cond_qq = decouple_interleaved_two_sequences( + interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len + ) + audio_kk, v_cond_kk = decouple_interleaved_two_sequences( + interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len + ) + audio_q, audio_k = audio_qq, audio_kk + v_cond_q, v_cond_k = v_cond_qq, v_cond_kk + + # Apply RoPE to visual if needed and not interleaved + if v_freqs_cis is not None and not self.interleaved_audio_visual_rope: + v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False) + v_cond_q, v_cond_k = v_cond_qq, v_cond_kk + + # Concatenate for self-attention + q = torch.cat((v_cond_q, audio_q), dim=1) + k = torch.cat((v_cond_k, audio_k), dim=1) + v = torch.cat((v_cond_v, audio_v), dim=1) + + # Run self-attention + attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic) + v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1) + + # Apply self-attention output to audio and v_cond + audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate) + v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate) + + # 2. Cross Attention: (v_cond, audio) as query, text as key/value + # audio, v_cond modulation + audio_modulated = self.audio_norm2(audio) + audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale) + v_cond_modulated = self.v_cond_norm2(v_cond) + v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale) + + # Prepare audio query + audio_q = self.audio_cross_q(audio_modulated) + audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads) + audio_q = self.audio_cross_q_norm(audio_q) + + # Prepare v_cond query + v_cond_q = self.v_cond_cross_q(v_cond_modulated) + v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads) + v_cond_q = self.v_cond_cross_q_norm(v_cond_q) + + # Prepare text key/value + text_kv = self.text_cross_kv(cond) + text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads) + text_k = self.text_cross_k_norm(text_k).to(text_v) + + # Apply RoPE to (v_cond, audio) query and text key if needed + head_dim = self.hidden_size // self.num_heads + audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) + audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device)) + audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0] + + v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list) + v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device)) + v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0] + + text_len = text_k.shape[1] + + text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, + rope_dim_list=self.rope_dim_list) + text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device)) + text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1] + + # Concat v_cond and audio for cross-attention + v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1) + + # Run cross-attention + cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic) + v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1) + + # Apply cross-attention output + audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate) + v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate) + + # 3. Apply MLPs + audio = audio + apply_gate( + self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)), + gate=audio_mod3_gate, + ) + + # Apply visual MLP + v_cond = v_cond + apply_gate( + self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)), + gate=v_cond_mod3_gate, + ) + + return audio, cond, v_cond + +class SingleStreamBlock(nn.Module): + + def __init__(self, hidden_size: int, + num_heads: int, + mlp_ratio: float, + qk_norm_type: str = "rms", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None,): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + + self.modulation = ModulateDiT( + hidden_size=hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True) + self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs) + self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs) + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False) + self.q_norm = nn.RMSNorm(hidden_size // num_heads) + self.k_norm = nn.RMSNorm(hidden_size // num_heads) + self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads) + + def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None): + assert cond.ndim == 3, "Condition should be in shape of [B, T, D]" + modulation = self.modulation(cond) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1) + x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa + + qkv = self.linear_qkv(x_norm1) + q, k, v = self.rearrange(qkv).chunk(3, dim=-1) + q = q.squeeze(-1) + k = k.squeeze(-1) + v = v.squeeze(-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True) + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = F.scaled_dot_product_attention(q, k, v) + out = rearrange(out, 'b h n d -> b n (h d)').contiguous() + + x = x + apply_gate(self.linear1(out),gate=gate_msa) + x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp + x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp) + + return x + +class HunyuanVideoFoley(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + model_config, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + model_args = model_config.model_config.model_kwargs + self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19) + self.depth_single_blocks = model_args.get("depth_single_blocks", 38) + # Gradient checkpoint. + self.gradient_checkpoint = False + self.gradient_checkpoint_layers = None + if self.gradient_checkpoint: + assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, ( + f"Gradient checkpoint layers must be less or equal than the depth of the model. " + f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}." + ) + + self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False) + + # Condition projection. Default to linear projection. + self.condition_projection = model_args.get("condition_projection", "linear") + self.condition_dim = model_args.get("condition_dim", None) + self.use_attention_mask = model_args.get("use_attention_mask", False) + + self.patch_size = model_args.get("patch_size", 1) + self.visual_in_channels = model_args.get("clip_dim", 768) + self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128) + self.out_channels = self.audio_vae_latent_dim + self.unpatchify_channels = self.out_channels + self.reverse = model_args.get("reverse", False) + + self.num_heads = model_args.get("num_heads", 24) + self.hidden_size = model_args.get("hidden_size", 3072) + self.rope_dim_list = model_args.get("rope_dim_list", None) + self.mlp_ratio = model_args.get("mlp_ratio", 4.0) + self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh") + + self.qkv_bias = model_args.get("qkv_bias", True) + self.qk_norm = model_args.get("qk_norm", True) + self.qk_norm_type = model_args.get("qk_norm_type", "rms") + self.attn_mode = model_args.get("attn_mode", "torch") + + self.embedder_type = model_args.get("embedder_type", "default") + + # sync condition things + self.sync_modulation = model_args.get("sync_modulation", False) + self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False) + self.sync_feat_dim = model_args.get("sync_feat_dim", 768) + self.sync_in_ksz = model_args.get("sync_in_ksz", 1) + + # condition tokens length + self.clip_len = model_args.get("clip_length", 64) + self.sync_len = model_args.get("sync_length", 192) + + if self.hidden_size % self.num_heads != 0: + raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}") + + # Build audio patchify layer and visual gated linear projection + self.patch_size = 1 + self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs) + self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size) + + # condition + if self.condition_projection == "linear": + self.cond_in = ConditionProjection( + self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs + ) + else: + raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}") + + # time modulation + self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) + + # visual sync embedder if needed + if self.sync_in_ksz == 1: + sync_in_padding = 0 + elif self.sync_in_ksz == 3: + sync_in_padding = 1 + else: + raise ValueError + if self.sync_modulation or self.add_sync_feat_to_audio: + self.sync_in = nn.Sequential( + nn.Linear(self.sync_feat_dim, self.hidden_size), + nn.SiLU(), + ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding), + ) + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim))) + + self.triple_blocks = nn.ModuleList( + [ + TwoStreamCABlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + mlp_act_type=self.mlp_act_type, + qk_norm=self.qk_norm, + qk_norm_type=self.qk_norm_type, + qkv_bias=self.qkv_bias, + attn_mode=self.attn_mode, + reverse=self.reverse, + interleaved_audio_visual_rope=self.interleaved_audio_visual_rope, + **factory_kwargs, + ) + for _ in range(self.depth_triple_blocks) + ] + ) + + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qk_norm_type=self.qk_norm_type, + **factory_kwargs, + ) + for _ in range(self.depth_single_blocks) + ] + ) + + self.final_layer = FinalLayer1D( + self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs + ) + self.unpatchify_channels = self.out_channels + + self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True) + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def get_empty_string_sequence(self, bs=None) -> torch.Tensor: + if bs is None: + return self.empty_string_feat + else: + return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: + len = len if len is not None else self.clip_len + if bs is None: + return self.empty_clip_feat.expand(len, -1) # 15s + else: + return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s + + def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor: + len = len if len is not None else self.sync_len + if bs is None: + return self.empty_sync_feat.expand(len, -1) + else: + return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1) + + def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len): + assert self.patch_size == 1 + # ======================================== Build RoPE for audio tokens ====================================== + target_ndim = 1 # n-d RoPE + rope_sizes = [audio_emb_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + + # ========================== Build RoPE for clip tokens ========================= + target_ndim = 1 # n-d RoPE + rope_sizes = [visual_cond_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + freq_scaling=1.0 * audio_emb_len / visual_cond_len, + ) + return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin + + def build_rope_for_interleaved_audio_visual(self, total_len): + assert self.patch_size == 1 + # ========================== Build RoPE for audio tokens ======================== + target_ndim = 1 # n-d RoPE + rope_sizes = [total_len] + head_dim = self.hidden_size // self.num_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + rope_dim_list=rope_dim_list, + start=rope_sizes, + theta=10000, + use_real=True, + theta_rescale_factor=1.0, + ) + return freqs_cos, freqs_sin + + def set_attn_mode(self, new_mode): + for block in self.triple_blocks: + block.set_attn_mode(new_mode) + for block in self.single_blocks: + block.set_attn_mode(new_mode) + + def enable_deterministic(self): + for block in self.triple_blocks: + block.enable_deterministic() + for block in self.single_blocks: + block.enable_deterministic() + + def disable_deterministic(self): + for block in self.triple_blocks: + block.disable_deterministic() + for block in self.single_blocks: + block.disable_deterministic() + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, # Should be in range(0, 1000). + clip_feat: Optional[torch.Tensor] = None, + cond: torch.Tensor = None, + audio_mask: Optional[torch.Tensor] = None, + cond_mask: torch.Tensor = None, + sync_feat: Optional[torch.Tensor] = None, + drop_visual: Optional[List[bool]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + out = {} + audio = x + bs, _, ol = x.shape + tl = ol // self.patch_size + + # Prepare learnable empty conditions for visual condition + if drop_visual is not None: + clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) + sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype) + + # ========================= Prepare time & visual modulation ========================= + vec = self.time_in(t) + sync_vec = None + if self.sync_modulation: + assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 + sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb + sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels + sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c + sync_vec = ( + F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) + ) # bs, tl, c + sync_vec = sync_vec + vec.unsqueeze(1) + elif self.add_sync_feat_to_audio: + assert sync_feat is not None and sync_feat.shape[1] % 8 == 0 + sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb + sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels + sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c + add_sync_feat_to_audio = ( + F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2) + ) # bs, tl, c + + # ========================= Get text, audio and video clip embedding ========================= + cond = self.cond_in(cond) + cond_seq_len = cond.shape[1] + + audio = self.audio_embedder(x) + audio_seq_len = audio.shape[1] + v_cond = self.visual_proj(clip_feat) + v_cond_seq_len = v_cond.shape[1] + + # ========================= Compute attention mask ========================= + attn_mask = None + if self.use_attention_mask: + assert cond_mask is not None + batch_size = audio.shape[0] + seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len + + # get default audio_mask and v_cond_mask + audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device) + v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device) + + # batch_size x seq_len + concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1) + # batch_size x 1 x seq_len x seq_len + attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + # batch_size x 1 x seq_len x seq_len + attn_mask_2 = attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads + attn_mask = (attn_mask_1 & attn_mask_2).bool() + # avoids self-attention weight being NaN for text padding tokens + attn_mask[:, :, :, 0] = True + + + # ========================= Build rope for audio and clip tokens ========================= + if self.interleaved_audio_visual_rope: + freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2) + v_freqs_cos = v_freqs_sin = None + else: + freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual( + audio_seq_len, v_cond_seq_len + ) + + # ========================= Pass through DiT blocks ========================= + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None + + if self.add_sync_feat_to_audio: + add_sync_layer = 0 + assert ( + add_sync_layer < self.depth_triple_blocks + ), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})." + # Triple-stream blocks + for layer_num, block in enumerate(self.triple_blocks): + if self.add_sync_feat_to_audio and layer_num == add_sync_layer: + audio = audio + add_sync_feat_to_audio + triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec] + if ( + self.training + and self.gradient_checkpoint + and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers) + ): + audio, cond, v_cond = torch.utils.checkpoint.checkpoint( + ckpt_wrapper(block), *triple_block_args, use_reentrant=False + ) + else: + audio, cond, v_cond = block(*triple_block_args) + + x = audio + if sync_vec is not None: + vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1) + vec = torch.cat((vec, sync_vec), dim=1) + + freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len) + if self.add_sync_feat_to_audio: + vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1) + if len(self.single_blocks) > 0: + for layer_num, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + (freqs_cos, freqs_sin), + ] + if ( + self.training + and self.gradient_checkpoint + and ( + self.gradient_checkpoint_layers == -1 + or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers + ) + ): + x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False) + else: + x = block(*single_block_args) + + audio = x + + # ========================= Final layer ========================= + if sync_vec is not None: + vec = sync_vec + audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels) + audio = self.unpatchify1d(audio, tl) + + if return_dict: + out["x"] = audio + return out + return audio + + def unpatchify1d(self, x, l): + # x: (N, L, patch_size * C) + # audio: (N, C, T), T == L * patch_size + c = self.unpatchify_channels + p = self.patch_size + assert l == x.shape[1] + + x = x.reshape(shape=(x.shape[0], l, p, c)) + x = torch.einsum("ntpc->nctp", x) + audio = x.reshape(shape=(x.shape[0], c, l * p)) + return audio + + def params_count(self): + counts = { + "triple": sum( + [ + sum(p.numel() for p in block.audio_cross_q.parameters()) + + sum(p.numel() for p in block.v_cond_cross_q.parameters()) + + sum(p.numel() for p in block.text_cross_kv.parameters()) + + sum(p.numel() for p in block.audio_self_attn_qkv.parameters()) + + sum(p.numel() for p in block.v_cond_attn_qkv.parameters()) + + sum(p.numel() for p in block.audio_mlp.parameters()) + + sum(p.numel() for p in block.audio_self_proj.parameters()) + + sum(p.numel() for p in block.v_cond_self_proj.parameters()) + + sum(p.numel() for p in block.v_cond_mlp.parameters()) + for block in self.triple_blocks + ] + ), + "single": sum( + [ + sum(p.numel() for p in block.linear1.parameters()) + + sum(p.numel() for p in block.linear2.parameters()) + for block in self.single_blocks + ] + ), + "total": sum(p.numel() for p in self.parameters()), + } + + counts["attn+mlp"] = counts["triple"] + counts["single"] + return counts diff --git a/hunyuanvideo_foley/models/nn/__init__.py b/hunyuanvideo_foley/models/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8417010095d79ce34302eb9fff5e9e212581ebd Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800467b3f6df786612b9fd6939ecb22c6826b82b Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eed57fb2c6cfbdf03a5d61db7fc7aa10088de804 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a893c3a1f805493fb00bdf772babc4c6e4e5158c Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b306642dd45a16ea9844856d4f9afa09dfaa1c Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09c0305aab4fc4ab4a8996a47fe8adfad2fce8af Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bdb875572318d0e47775f39fe919fd54b0c5345 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c0b62dc8897c0f45a0109da50879618f15cae1a Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbc6e29d5f454f5db631ed2151c7c27cd07da296 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e3c84b3ff0b3ef24802bbd48df8f460a3f48735 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..749d2ee4d023a9c8b9738c5dfd42930974b42398 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d744d5a5fb5e12428be5598e366ab9b80e83e446 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e51f79ea16c268ee6ecf4af4dd88234ff840bf9a Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6a3e2e5b81e5168797f465940e0f972038d25c7 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17aa3dfe33c273d7c84ed4cc34806b26b24c3e9d Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fd9c2ca472d6f967e3c5bbc6e1f936c283f1548 Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/nn/activation_layers.py b/hunyuanvideo_foley/models/nn/activation_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..55414cbd054546263e1217f363e9fe02e846a122 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/activation_layers.py @@ -0,0 +1,44 @@ +import torch.nn as nn +import torch.nn.functional as F + +def get_activation_layer(act_type): + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + # Approximate `tanh` requires torch >= 1.13 + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") + +class SwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + out_dim: int, + ): + """ + Initialize the SwiGLU FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + + Attributes: + w1: Linear transformation for the first layer. + w2: Linear transformation for the second layer. + w3: Linear transformation for the third layer. + + """ + super().__init__() + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/hunyuanvideo_foley/models/nn/attn_layers.py b/hunyuanvideo_foley/models/nn/attn_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c954eebb33038369f3c7d71cda4fdd9c3a8d27dd --- /dev/null +++ b/hunyuanvideo_foley/models/nn/attn_layers.py @@ -0,0 +1,546 @@ +import importlib.metadata +import math +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +try: + from flash_attn import ( + flash_attn_qkvpacked_func, + flash_attn_kvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + ) + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +except ImportError: + flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None + index_first_axis = None +from packaging import version +from transformers.utils.import_utils import _is_package_available + +from .norm_layers import get_norm_layer + +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. + When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +class BasicAttentionLayer(nn.Module): + def __init__(self, attn_mode="flash", deterministic=False): + super().__init__() + self.attn_mode = attn_mode + self.deterministic = deterministic + + def set_attn_mode(self, new_mode): + self.attn_mode = new_mode + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + +MEMORY_LAYOUT = { + "self_flash": ( + lambda x: x, + lambda x: x, + ), + "cross_flash": ( + lambda x: x, + lambda x: x, + ), + "flash_torch_sp": ( + lambda x: x, + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6 +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822 +def is_flash_attn_greater_or_equal(library_version: str): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) + + +def get_kv_seqlens_with_mask(attn_mask, k, v): + indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask) + b, s1, a, d = k.shape + k = index_first_axis(k.reshape(b * s1, a, d), indices_k) + v = index_first_axis(v.reshape(b * s1, a, d), indices_k) + kv = torch.stack([k, v], dim=1) + return cu_seqlens_k, max_seqlen_k, kv + + +def get_q_seqlens(q): + bs, s, a, d = q.shape + cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device) + q = q.reshape(bs * s, a, d) + return cu_seqlens_q, s, q + +def flash_attn_no_pad( + qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None +): + # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27 + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + # x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch + # x_unpad, indices, cu_seqlens, max_s + unpad_results = unpad_input( + x, key_padding_mask + ) + + if len(unpad_results) == 4: + x_unpad, indices, cu_seqlens, max_s = unpad_results + elif len(unpad_results) == 5: + x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results + else: + raise ValueError + + x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, + cu_seqlens, + max_s, + dropout_p, + softmax_scale=softmax_scale, + causal=causal, + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen + ), + "b s (h d) -> b s h d", + h=nheads, + ) + return output + + +def attention( + q, + k, + v, + mode, + drop_rate=0, + attn_mask=None, + cond_mask=None, + causal=False, + deterministic=False, + cu_seqlens=None, + max_seqlen=None, + cu_seqlens_k=None, + max_seqlen_k=None, + img_seq_len=None, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + deterministic (bool): Whether to use deterministic attention. (default: False) + cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + max_seqlen (int): The maximum sequence length in the batch of q. + cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_k (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + if mode in ["torch", "vanilla", "self_flash", "cross_flash"]: + if isinstance(q, tuple): + q = torch.cat(q, dim=1) + if isinstance(k, tuple): + k = torch.cat(k, dim=1) + if isinstance(v, tuple): + v = torch.cat(v, dim=1) + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) + + if "flash" in mode: + assert ( + flash_attn_qkvpacked_func is not None + ), "Flash attention is not available. Please install flash_attn first." + flash_kwargs = dict(dropout_p=drop_rate, causal=causal) + if deterministic: + if not is_flash_attn_greater_or_equal("2.4.1"): + raise ValueError( + "Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn" + ) + flash_kwargs["deterministic"] = deterministic + + if mode == "self_flash": + qkv = torch.stack([q, k, v], dim=2) + if attn_mask is not None: + raise ValueError("Self attention does not support attention mask") + x = flash_attn_qkvpacked_func(qkv, **flash_kwargs) + + elif mode == "cross_flash": + kv = torch.stack([k, v], dim=2) + if attn_mask is None: + x = flash_attn_kvpacked_func(q, kv, **flash_kwargs) + else: + b, s, a, h = q.shape + cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q) + cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v) + + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + **flash_kwargs, + ) + x = attn_output.reshape(b, s, a, h) + elif mode == 'torch': + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + # TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + if mode in ["torch", "vanilla", "self_flash", "cross_flash"]: + x = post_attn_layout(x).contiguous() + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class SelfAttentionLayer(BasicAttentionLayer): + def __init__( + self, + dim, + num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=0, + proj_drop=0, + dtype=None, + device=None, + norm_type="layer", + attn_mode="self_flash", + deterministic=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(attn_mode, deterministic) + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, "dim must be divisible by num_heads" + self.head_dim = self.dim // num_heads + self.attn_drop = attn_drop + + # This assertion is aligned with flash attention + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs) + + norm_layer = get_norm_layer(norm_type) + self.q_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis=None, attn_mask=None): + """ + Args: + x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim) + freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image + attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention + """ + b, s, d = x.shape + + # Apply QKV projection + qkv = self.Wqkv(x) + qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d] + q, k, v = qkv.unbind(dim=2) # [b, s, a, d] + + # Apply QK-Norm if needed + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply RoPE if needed + if freqs_cis is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis) + assert ( + qq.shape == q.shape and kk.shape == k.shape + ), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}" + q, k = qq, kk + + # Apply self attention + context = attention( + q, + k, + v, + drop_rate=self.attn_drop if self.training else 0, + attn_mask=attn_mask, + mode=self.attn_mode, + deterministic=self.deterministic, + ) + out = self.out_proj(context) + out = self.proj_drop(out) + + return out + + +class CrossAttentionLayer(BasicAttentionLayer): + def __init__( + self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=True, + attn_drop=0, + proj_drop=0, + dtype=None, + device=None, + norm_type="layer", + attn_mode="cross_flash", + deterministic=False, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(attn_mode, deterministic) + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + self.attn_drop = attn_drop + + # This assertion is aligned with flash attention + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + norm_layer = get_norm_layer(norm_type) + self.q_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, attn_mask=None): + """ + Args: + x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim) + y (torch.Tensor): (batch, seq_len1, hidden_dim1) + attn_mask (torch.Tensor): (batch, seq_len1), mask for attention + """ + b, s, d = x.shape + _, s1, d1 = y.shape + + q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim) + kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(dim=2) + + # Apply QK-Norm if needed + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply cross attention + context = attention( + q, + k, + v, + attn_mask=attn_mask, + drop_rate=self.attn_drop if self.training else 0, + mode=self.attn_mode, + deterministic=self.deterministic, + ) + out = self.out_proj(context) + out = self.proj_drop(out) + + return out diff --git a/hunyuanvideo_foley/models/nn/embed_layers.py b/hunyuanvideo_foley/models/nn/embed_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fd15167836fd5a5aec7b0c21af296082d7d1b2f4 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/embed_layers.py @@ -0,0 +1,136 @@ +import math +import torch +import torch.nn as nn + +from ...utils.helper import to_2tuple, to_1tuple + +class PatchEmbed1D(nn.Module): + """1D Audio to Patch Embedding + + A convolution based approach to patchifying a 1D audio w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + """ + + def __init__( + self, + patch_size=1, + in_chans=768, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + patch_size = to_1tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv1d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs + ) + nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) + if bias: + nn.init.zeros_(self.proj.bias) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + assert ( + x.shape[2] % self.patch_size[0] == 0 + ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x." + + x = self.proj(x) + if self.flatten: + x = x.transpose(1, 2) # BCN -> BNC + x = self.norm(x) + return x + + +class ConditionProjection(nn.Module): + """ + Projects condition embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) + self.act_1 = act_layer() + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None + ): + factory_kwargs = {'dtype': dtype, 'device': device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/hunyuanvideo_foley/models/nn/mlp_layers.py b/hunyuanvideo_foley/models/nn/mlp_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..0434559025e280b74fbd7c181247aa1c2a41a409 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/mlp_layers.py @@ -0,0 +1,149 @@ +# Modified from timm library: +# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 + +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modulate_layers import modulate +from ...utils.helper import to_2tuple + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +# only used when use_vanilla is True +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class LinearWarpforSingle(nn.Module): + def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs) + + def forward(self, x, y): + z = torch.cat([x, y], dim=2) + return self.fc(z) + +class FinalLayer1D(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + # Just use LayerNorm for the final layer + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + # Here we don't distinguish between the modulate types. Just use the simple one. + self.adaLN_modulation = nn.Sequential( + act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x + + +class ChannelLastConv1d(nn.Conv1d): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = super().forward(x) + x = x.permute(0, 2, 1) + return x + + +class ConvMLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + device=None, + dtype=None, + ): + """ + Convolutional MLP module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1: Linear transformation for the first layer. + w2: Linear transformation for the second layer. + w3: Linear transformation for the third layer. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) + self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) + self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/hunyuanvideo_foley/models/nn/modulate_layers.py b/hunyuanvideo_foley/models/nn/modulate_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..5235233e996fc17c0476de04428a13bfcf3ba8fe --- /dev/null +++ b/hunyuanvideo_foley/models/nn/modulate_layers.py @@ -0,0 +1,49 @@ +from typing import Callable +import torch +import torch.nn as nn + +class ModulateDiT(nn.Module): + def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) + # Zero-initialize the modulation + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(x)) + + +def modulate(x, shift=None, scale=None): + if x.ndim == 3: + shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None + scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale) + elif scale is None: + return x + shift + else: + return x * (1 + scale) + shift + + +def apply_gate(x, gate=None, tanh=False): + if gate is None: + return x + if gate.ndim == 2 and x.ndim == 3: + gate = gate.unsqueeze(1) + if tanh: + return x * gate.tanh() + else: + return x * gate + + +def ckpt_wrapper(module): + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward diff --git a/hunyuanvideo_foley/models/nn/norm_layers.py b/hunyuanvideo_foley/models/nn/norm_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad30b0ea4faeaa18e22ed25fcc44b97aee2d243 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/norm_layers.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, + device=None, dtype=None): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") diff --git a/hunyuanvideo_foley/models/nn/posemb_layers.py b/hunyuanvideo_foley/models/nn/posemb_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd188bbeb710d8155e758590446a51f5f0dd038 --- /dev/null +++ b/hunyuanvideo_foley/models/nn/posemb_layers.py @@ -0,0 +1,159 @@ +import torch +from typing import Union, Tuple + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def get_nd_rotary_pos_embed( + rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor, + freq_scaling=freq_scaling, + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + freq_scaling: float = 1.0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 1)) + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + freqs *= freq_scaling + freqs = torch.outer(pos, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis diff --git a/hunyuanvideo_foley/models/synchformer/__init__.py b/hunyuanvideo_foley/models/synchformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6331d467e30b9d24378025bf540e7430ff4fd7ad --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/__init__.py @@ -0,0 +1 @@ +from .synchformer import Synchformer diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6bd2373949f7c75e11b201c7d55290646c009a4 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67c8b840aecff1da6b19f8596177ba473b6a041f Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65d041544a06271d266d7f4e970d8ef4e156177b Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92bf67679e34b8680152e2af4365c7d7ade7ebeb Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5ac7c20466981dfaec48b85a759a3c7bb7502f Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..157972264fdc52bbbe76238af47d18c672b4581b Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f7d92890865df12a0201e9493dd93214ebd30d1 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e44fd9e5a05ac0c8cba93a4700da7510b24a5a28 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796bea889521a94ceaca7536b94cdcb170534995 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c48d32ad2f7984d0318be84f7c102bf8941d062 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1abeff774ab246a1a6f6a5525fef4319b291c3c8 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed1af3cab9fddcfda9cc52df57ec2a7be5169ce7 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1e39b538331705e76298b67b2ed12db7e213108 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..384881a3362ceb67899ff3617b75b90579be0332 Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6ef2042a3ddd8efa942e4ffd9d16c6dcc7ccd5c Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f09aeec0fc09976c2e7027d6c695656c3e80fbf Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/models/synchformer/ast_model.py b/hunyuanvideo_foley/models/synchformer/ast_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4394306ccd08de3a1e6bb556df8f42d2e4cacb --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/ast_model.py @@ -0,0 +1,289 @@ +import logging + +import torch +from transformers.modeling_outputs import BaseModelOutputWithPooling + +from .modeling_ast import ASTForAudioClassification, ASTConfig +from .motionformer import AveragePooling, BaseEncoderLayer, TemporalTransformerEncoderLayer +from .utils import check_if_file_exists_else_download + + +class AST(torch.nn.Module): + def __init__( + self, + extract_features: bool = False, + ckpt_path: str = None, + feat_type: str = None, + max_spec_t: int = None, + factorize_freq_time: bool = None, + agg_freq_module: str = None, + agg_time_module: str = None, + add_global_repr: bool = True, + agg_segments_module: str = None, + max_segments: int = None, + ) -> None: + """ + extract_features: if True, then the model will return the features instead of head's output + ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub. + feat_type: if extract_features is True, this parameter specifies the type of features to return + max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec + factorize_freq_time: if True, then the model will use a factorized freq/time aggregation + agg_freq_module: if specified, then the model will use this module for freq aggregation + agg_time_module: if specified, then the model will use this module for time aggregation + add_global_repr: if True, adds a global representation to the features (aggregation on segments) + agg_segments_module: if specified, then the model will use this module for segments aggregation + max_segments: if specified, the initialization of PE in the global agg module will use this value. + This should correspond to the max number of segments per video (if None, 16 is used) + """ + super().__init__() + self.extract_features = extract_features + self.ckpt_path = ckpt_path + self.max_spec_t = max_spec_t + self.max_segments = max_segments + + # depending on whether the feat extractor was pre-trained contrastively or not, we need to + # load the state dict differently. + + # if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model + if ckpt_path == "MIT/ast-finetuned-audioset-10-10-0.4593": + revision = "c1c0c66" # fixing the revision for compatibility (V4.27.4) + self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision) + full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision) + logging.info(f"Loaded AST from {ckpt_path}") + else: + self.config = ASTConfig() + self.config.num_labels = 527 # 2 by default, audioset has 527 labels + full_model = ASTForAudioClassification(self.config) + logging.info("Initialized AST from scratch with the AST AudioSet config") + + was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith(".pt") + + # feature extractor + self.ast = full_model.audio_spectrogram_transformer + + if self.extract_features: + # assign `feat_type` (use default if not specified) + self.feat_type = "last_hidden_state" if feat_type is None else feat_type + # define adapters if needed + self.factorize_freq_time = factorize_freq_time + # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer) + transf_enc_layer_kwargs = dict( + d_model=self.config.hidden_size, + nhead=self.config.num_attention_heads, + dim_feedforward=self.config.intermediate_size, + activation=torch.nn.GELU(), + batch_first=True, + dropout=self.config.attention_probs_dropout_prob, + layer_norm_eps=1e-6, + norm_first=True, + ) + if factorize_freq_time: + self.feat_type = "last_hidden_state" # this feat_type supports factorization + # frequency aggreration + if agg_freq_module == "TransformerEncoderLayer": + self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_freq_module == "AveragePooling": + self.freq_attn_agg = AveragePooling( + avg_pattern="BS D f t -> BS D t", then_permute_pattern="BS D t -> BS t D" + ) + # time aggreration + if agg_time_module == "TransformerEncoderLayer": + self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_time_module == "AveragePooling": + self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D") + elif "Identity" in agg_time_module: + self.temp_attn_agg = torch.nn.Identity() + # define a global aggregation layer (aggregarate over segments) + self.add_global_repr = add_global_repr + if add_global_repr: + if agg_segments_module == "TransformerEncoderLayer": + # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D) + # we need to add pos emb (PE) because previously we added the same PE for each segment + pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1 + self.global_attn_agg = TemporalTransformerEncoderLayer( + add_pos_emb=True, + pos_emb_drop=self.config.hidden_dropout_prob, + pos_max_len=pos_max_len, + **transf_enc_layer_kwargs, + ) + elif agg_segments_module == "AveragePooling": + self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D") + else: + self.classifier = full_model.classifier + + # AST.device fails with AttributeError. This is a workaround + self.device = full_model.device + + # pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74) + self.patch_position_emb() + + if was_pt_on_avclip: + # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors) + # and keep only the state_dict of the feat extractor + check_if_file_exists_else_download(self.ckpt_path) + ckpt = torch.load(ckpt_path, map_location="cpu") + ckpt_weights = dict() + for k, v in ckpt["state_dict"].items(): + if k.startswith(("module.a_encoder.", "a_encoder.")): + k = k.replace("module.", "").replace("a_encoder.", "") + ckpt_weights[k] = v + _load_status = self.load_state_dict(ckpt_weights, strict=False) + if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: + logging.warning( + f"Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n" + f"Missing keys ({len(_load_status.missing_keys)}): " + f"{_load_status.missing_keys}, \n" + f"Unexpected keys ({len(_load_status.unexpected_keys)}): " + f"{_load_status.unexpected_keys} \n" + f"temp_attn_agg are expected to be missing if ckpt was pt contrastively." + ) + else: + logging.info(f"Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.") + + # print the number of parameters + logging.info(f"AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}") + + def forward( + self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs + ) -> torch.Tensor: + """ + x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins, + ast_kwargs: additional arguments for the AST model + cont_mask: (B, S, T, F) where 0s are the values to be masked out + if `for_loop=True`, we use a for loop to extract features for each segment separately. + if `for_loop=False`, we extract features for all segments at once. + Using the for loop is slower but more memory efficient, while using all segments at once + is faster but more memory inefficient. + Using for loop allows to control the memory footprint by varying the number of videos in a + batch (batch size) rather than the number of segments in a video. + """ + B, S, T, F = x.shape + + if for_loop: + assert cont_mask is None, "cont_mask is not supported with for_loop=True" + orig_shape_s = (B, 1, T, F) + # NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F). + # (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1. + x = torch.cat( + [self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1 + ) + else: + orig_shape = (B, S, T, F) + x = x.view(B * S, T, F) + if cont_mask is not None: + cont_mask = cont_mask.reshape(B * S, T, F) + # AST expects a tensor of shape (B*S, T, F). + x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs) + # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D)) + x = x.view(B, S, *x.shape[1:]) + # x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity` + + global_x = None + if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError + assert len(x.shape) == 3, f"Local representation should be (B, S, D) {x.shape}" + global_x = self.global_attn_agg(x) # (B, D) + + return x, global_x # x is (B, S, ...), global_x is (B, D) or None + + def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs): + """x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out""" + # 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, ] + # x_mask is (B, T) where 0s are the values to be masked out + x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs) + + if self.extract_features: + x = self.get_features_by_type(x) + if self.factorize_freq_time: + x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D) + if cont_mask is not None: + # duplicating the mask for the latent dimension (D) to be compatible with the next func + x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size) + x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D) + # again removing the latent + x_mask = x_mask[:, 0, :, :] + else: + x_mask = None + x = self.freq_attn_agg(x, x_mask) # (BS, t, D) + x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity + else: + x = x["pooler_output"] + x = self.classifier(x) + return x + + def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor: + if self.feat_type == "pooler_output": + return x["pooler_output"] # (B, D) + elif self.feat_type == "CLS": + return x["last_hidden_state"][:, 0, :] # (B, D) + elif self.feat_type == "last_hidden_state": + return x["last_hidden_state"] # (B, 2+T, D) + elif self.feat_type == "last_hidden_state_no_AUX": + return x["last_hidden_state"][:, 2:, :] # (B, T, D) removing CLS and distill tokens + else: + raise ValueError(f"Unknown feature type: {self.feat_type}") + + def restore_freq_temp_dims(self, feats, orig_shape: tuple): + """ + feats are of shape (B*S, T, D) + where T = 2 + f * t (if feat_type == 'last_hidden_state') + where T = f * t (if feat_type == 'last_hidden_state_no_AUX') + Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching. + From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats: + `feats.transpose(1, 2).view(B*S, D, f, t)` + + (Similar function is defined in for RGB features in `motionformer.py`) + """ + B, S, T, F = orig_shape + D = self.config.hidden_size + + # num patches in each dimension + f, t = self.ast.embeddings.get_shape(self.config) + + if self.feat_type == "last_hidden_state": + feats = feats[:, 2:, :] # removing CLS and distill tokens + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, f, t) # (B*S, D, f, t) + + return feats + + def patch_position_emb(self): + if self.max_spec_t is not None: + self.config.max_length = self.max_spec_t + f, t = self.ast.embeddings.get_shape(self.config) + shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens + self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device) + + def to(self, device): + """AST.device fails with AttributeError. This is a workaround.""" + self.device = torch.device(device) + return super().to(device) + + +class FrequencyTransformerEncoderLayer(BaseEncoderLayer): + """This layer is used to aggregate the features along the frequency axis. + It follows the same logic as spatio-temporal aggregation in visual feature extractor. + Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py`""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + """x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out""" + BS, D, f, t = x.shape + + # time as a batch dimension + x = x.permute(0, 3, 2, 1) # (B*S, t, f, D) + x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory + # similar to mask + if x_mask is not None: + x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f) + x_mask = x_mask.reshape(BS * t, f) + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D) + + # reshape back to (B*S, t, D) + x = x.view(BS, t, D) + + return x # (B*S, t, D) diff --git a/hunyuanvideo_foley/models/synchformer/compute_desync_score.py b/hunyuanvideo_foley/models/synchformer/compute_desync_score.py new file mode 100644 index 0000000000000000000000000000000000000000..936c4994bd6d68444980959f78aa710e4ba5a205 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/compute_desync_score.py @@ -0,0 +1,214 @@ +import argparse +import subprocess +from pathlib import Path + +import torch +import torchaudio +import torchvision +from omegaconf import OmegaConf + +import data_transforms +from .synchformer import Synchformer +from .data_transforms import make_class_grid, quantize_offset +from .utils import check_if_file_exists_else_download, which_ffmpeg + + +def prepare_inputs(batch, device): + aud = batch["audio"].to(device) + vid = batch["video"].to(device) + + return aud, vid + + +def get_test_transforms(): + ts = [ + data_transforms.EqualifyFromRight(), + data_transforms.RGBSpatialCrop(input_size=224, is_random=False), + data_transforms.TemporalCropAndOffset( + crop_len_sec=5, + max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml + max_wiggle_sec=0.0, + do_offset=True, + offset_type="grid", + prob_oos="null", + grid_size=21, + segment_size_vframes=16, + n_segments=14, + step_size_seg=0.5, + vfps=25, + ), + data_transforms.GenerateMultipleSegments( + segment_size_vframes=16, + n_segments=14, + is_start_random=False, + step_size_seg=0.5, + ), + data_transforms.RGBToHalfToZeroOne(), + data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization + data_transforms.AudioMelSpectrogram( + sample_rate=16000, + win_length=400, # 25 ms * 16 kHz + hop_length=160, # 10 ms * 16 kHz + n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate))) + n_mels=128, # as in AST + ), + data_transforms.AudioLog(), + data_transforms.PadOrTruncate(max_spec_t=66), + data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet + data_transforms.PermuteStreams( + einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same + ), + ] + transforms = torchvision.transforms.Compose(ts) + + return transforms + + +def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None): + orig_path = path + # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta) + rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW") + assert meta["video_fps"], f"No video fps for {orig_path}" + # (Ta) <- (Ca, Ta) + audio = audio.mean(dim=0) + # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader. + meta = { + "video": {"fps": [meta["video_fps"]]}, + "audio": {"framerate": [meta["audio_fps"]]}, + } + return rgb, audio, meta + + +def reencode_video(path, vfps=25, afps=16000, in_size=256): + assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated." + new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4" + new_path.parent.mkdir(exist_ok=True) + new_path = str(new_path) + cmd = f"{which_ffmpeg()}" + # no info/error printing + cmd += " -hide_banner -loglevel panic" + cmd += f" -y -i {path}" + # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate + cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2" + cmd += f" -ar {afps}" + cmd += f" {new_path}" + subprocess.call(cmd.split()) + cmd = f"{which_ffmpeg()}" + cmd += " -hide_banner -loglevel panic" + cmd += f" -y -i {new_path}" + cmd += f" -acodec pcm_s16le -ac 1" + cmd += f' {new_path.replace(".mp4", ".wav")}' + subprocess.call(cmd.split()) + return new_path + + +def decode_single_video_prediction(off_logits, grid, item): + label = item["targets"]["offset_label"].item() + print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})") + print() + print("Prediction Results:") + off_probs = torch.softmax(off_logits, dim=-1) + k = min(off_probs.shape[-1], 5) + topk_logits, topk_preds = torch.topk(off_logits, k) + # remove batch dimension + assert len(topk_logits) == 1, "batch is larger than 1" + topk_logits = topk_logits[0] + topk_preds = topk_preds[0] + off_logits = off_logits[0] + off_probs = off_probs[0] + for target_hat in topk_preds: + print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})') + return off_probs + + +def main(args): + vfps = 25 + afps = 16000 + in_size = 256 + # making the offset class grid similar to the one used in transforms, + # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml + max_off_sec = 2 + num_cls = 21 + + # checking if the provided video has the correct frame rates + print(f"Using video: {args.vid_path}") + v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec") + _, H, W, _ = v.shape + if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size: + print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ") + print(f'afps: {info["audio_fps"]} -> {afps};', end=" ") + print(f"{(H, W)} -> min(H, W)={in_size}") + args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size) + else: + print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}') + + device = torch.device(args.device) + + # load visual and audio streams + # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1] + rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True) + + # making an item (dict) to apply transformations + # NOTE: here is how it works: + # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3` + # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio + # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will + # start by `offset_sec` earlier than the rgb track. + # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`) + item = dict( + video=rgb, + audio=audio, + meta=meta, + path=args.vid_path, + split="test", + targets={ + "v_start_i_sec": args.v_start_i_sec, + "offset_sec": args.offset_sec, + }, + ) + + grid = make_class_grid(-max_off_sec, max_off_sec, num_cls) + if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)): + print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}') + + # applying the test-time transform + item = get_test_transforms()(item) + + # prepare inputs for inference + batch = torch.utils.data.default_collate([item]) + aud, vid = prepare_inputs(batch, device) + + # TODO: + # sanity check: we will take the input to the `model` and recontruct make a video from it. + # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified) + # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec, + # vfps, afps) + + # forward pass + with torch.set_grad_enabled(False): + with torch.autocast("cuda", enabled=True): + _, logits = synchformer(vid, aud) + + # simply prints the results of the prediction + decode_single_video_prediction(logits, grid, item) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx") + parser.add_argument("--vid_path", required=True, help="A path to .mp4 video") + parser.add_argument("--offset_sec", type=float, default=0.0) + parser.add_argument("--v_start_i_sec", type=float, default=0.0) + parser.add_argument("--device", default="cuda:0") + args = parser.parse_args() + + synchformer = Synchformer().cuda().eval() + synchformer.load_state_dict( + torch.load( + os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), + weights_only=True, + map_location="cpu", + ) + ) + + main(args) diff --git a/hunyuanvideo_foley/models/synchformer/data_transforms.py b/hunyuanvideo_foley/models/synchformer/data_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d59331eb3c27454a4c52bdbf8a8b85946c63c0a3 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/data_transforms.py @@ -0,0 +1,1130 @@ +import logging +import math +import random +from typing import Tuple +import torch +import torchvision +import torchaudio +import numpy as np +import einops + + +def sec2frames(sec, fps): + return int(sec * fps) + + +def frames2sec(frames, fps): + return frames / fps + + +class EqualifyFromRight(torch.nn.Module): + + def __init__(self, clip_max_len_sec=10): + """ + Takes the dataset item and makes sure more streams are of an equal size in terms of fps. + It, however, assumes that the signal is synched and trims the ending parts ('from the right'). + """ + super().__init__() + self.clip_max_len_sec = clip_max_len_sec + + def forward(self, item): + """ + `item`: {'video': (Tv, C, H, W), 'audio': (Ta,), + 'meta': { + 'audio': {'framerate': [float], 'duration': [float]} + 'video': {'fps': [float], 'duration': [float]}} + """ + a_fps = item["meta"]["audio"]["framerate"][0] + v_fps = item["meta"]["video"]["fps"][0] + + Ta = item["audio"].shape[0] + Tv, C, H, W = item["video"].shape + + a_len_secs = Ta / a_fps + v_len_secs = Tv / v_fps + min_len = min(self.clip_max_len_sec, a_len_secs, v_len_secs) + + a_frames_per_v_frame = a_fps // v_fps + v_len_frames = int(v_fps * min_len) + a_len_frames = int(a_frames_per_v_frame * v_len_frames) + # print(a_len_frames, v_len_frames) + + assert a_len_frames <= Ta and v_len_frames <= Tv + + item["audio"] = item["audio"][:a_len_frames] + item["video"] = item["video"][:v_len_frames, :, :, :] + + return item + + +class RGBSpatialCrop(torch.nn.Module): + + def __init__(self, input_size, is_random): + super().__init__() + assert input_size is not None, f"smaller_input_size is `{input_size}`" + if isinstance(input_size, int): + input_size = (input_size, input_size) + self.input_size = input_size + self.is_random = is_random + + @staticmethod + def get_random_crop_sides(vid, output_size): + """Slice parameters for random crop""" + h, w = vid.shape[-2:] + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + @staticmethod + def get_center_crop_sides(vid, output_size): + """Slice parameters for center crop""" + h, w = vid.shape[-2:] + th, tw = output_size + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return i, j, th, tw + + def forward(self, item): + # (Tv, C, H, W) + vid = item["video"] + if self.is_random: + i, j, h, w = self.get_random_crop_sides(vid, self.input_size) + else: + i, j, h, w = self.get_center_crop_sides(vid, self.input_size) + item["video"] = vid[..., i : (i + h), j : (j + w)] + return item + + +class Resize(torchvision.transforms.Resize): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, item): + item["video"] = super().forward(item["video"]) + return item + + +class RGBSpatialCropSometimesUpscale(torch.nn.Module): + """This (randomly) crops the input video and with prob `sometimes_p` this crop is smaller but upscaled + to `target_input_size`""" + + def __init__(self, sometimes_p, target_input_size, is_random, smaller_input_size=None): + super().__init__() + self.sometimes_p = sometimes_p + self.do_sometimes_upscale = sometimes_p is not None and sometimes_p > 0 + + self.crop_only = RGBSpatialCrop(target_input_size, is_random) + + if self.do_sometimes_upscale: + self.crop_further_and_upscale = torchvision.transforms.Compose( + [ + RGBSpatialCrop(smaller_input_size, is_random), + Resize(target_input_size, antialias=None), + ] + ) + + def forward(self, item): + assert len(item["video"].shape) == 4, ( + f"{item['video'].shape}: if it is applied after GenerateMultipleClips," + "augs should be applied to each clip separately, not to the whole video array. " + "Otherwise, ignore this warning (comment it)." + ) + if self.do_sometimes_upscale and self.sometimes_p > torch.rand(1): + return self.crop_further_and_upscale(item) + else: + return self.crop_only(item) + + +class RandomApplyColorDistortion(torch.nn.Module): + + def __init__(self, p_gray_scale=0.0, p_color_jitter=0.0, s=1.0) -> None: + super().__init__() + self.p_gray_scale = p_gray_scale + self.p_color_jitter = p_color_jitter + self.s = s + assert 0 <= self.p_color_jitter <= 1 and 0 <= self.p_gray_scale <= 1, (p_color_jitter, p_gray_scale) + # SimCLR params + color_jitter = torchvision.transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + rand_color_jitter = torchvision.transforms.RandomApply([color_jitter], p_color_jitter) + rand_gray = torchvision.transforms.RandomGrayscale(p_gray_scale) + self.transforms = torchvision.transforms.Compose([rand_color_jitter, rand_gray]) + + def apply_to_single_clip(self, clip): + return self.transforms(clip) + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["video"].shape) == 5 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["video"] = fn(item["video"]) + return item + + +class ApplyColorJitterFrameWise(torch.nn.Module): + + def __init__(self, s=1.0) -> None: + super().__init__() + self.s = s + # SimCLR params + self.transform = torchvision.transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) + + def apply_to_single_clip(self, clip): + for i, frame in enumerate(clip): + clip[i] = self.transform(frame) + return clip + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["video"].shape) == 5 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["video"] = fn(item["video"]) + return item + + +class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip): + + def __init__(self, p=0.5): + super().__init__(p) + + def apply_to_single_clip(self, clip): + return super().forward(clip) + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["video"].shape) == 5 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["video"] = fn(item["video"]) + return item + + +def make_class_grid( + leftmost_val, + rightmost_val, + grid_size, + add_extreme_offset: bool = False, + seg_size_vframes: int = None, + nseg: int = None, + step_size_seg: float = None, + vfps: float = None, +): + assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" + grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() + if add_extreme_offset: + assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" + seg_size_sec = seg_size_vframes / vfps + trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) + extreme_value = trim_size_in_seg * seg_size_sec + grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid + return grid + + +def quantize_offset(grid: torch.Tensor, off_sec: float) -> Tuple[float, int]: + """Takes in the offset in seconds and snaps it onto the closest grid element. + Returns the grid value and its index.""" + closest_grid_el = (grid - off_sec).abs().argmin() + return grid[closest_grid_el], closest_grid_el + + +def apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps, max_a_jitter_sec): + max_a_start_i = a_len_frames - a_crop_len_frames + max_a_jitter_i = sec2frames(max_a_jitter_sec, a_fps) + max_a_jitter_i_left = min(a_start_i, max_a_jitter_i) + max_a_jitter_i_right = min(max_a_start_i - a_start_i, max_a_jitter_i) + # jitter is U[left, right] + a_jitter_i = random.randint(-max_a_jitter_i_left, max_a_jitter_i_right) + # apply jitter + a_start_i = a_start_i + a_jitter_i + # making sure that any value from `a_start_i + U[left, right]` will be inside of [0, len-crop] region + assert 0 <= a_start_i <= max_a_start_i, f"{a_jitter_i} {max_a_jitter_i_left} {max_a_jitter_i_right} {max_a_start_i}" + return a_start_i, a_jitter_i + + +class TemporalCropAndOffset(torch.nn.Module): + + def __init__( + self, + crop_len_sec: float, + max_off_sec: float, + offset_type="grid", + do_offset: bool = True, + grid_size: int = None, + max_wiggle_sec: float = None, + add_doubt_cls: bool = False, + segment_size_vframes: int = None, + n_segments: int = None, + step_size_seg: float = None, + vfps: float = None, + prob_oos: float = None, + ): + super().__init__() + self.crop_len_sec = crop_len_sec + self.do_offset = do_offset + self.grid_size = grid_size + self.offset_type = offset_type + self.max_off_sec = max_off_sec + self.max_a_jitter_sec = max_wiggle_sec + if do_offset: + if offset_type == "grid": + self.class_grid = make_class_grid( + -max_off_sec, + max_off_sec, + grid_size, + add_doubt_cls, + segment_size_vframes, + n_segments, + step_size_seg, + vfps, + ) + logging.info(f"Offsets class grid: {self.class_grid}") + if self.max_a_jitter_sec is not None: + assert (max_wiggle_sec - 1e-6) <= ( + (self.class_grid[1] - self.class_grid[0]) / 2 + ), f"{self.class_grid}" + elif offset_type == "uniform": + self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec) + logging.info(f"Offset uniform distribution: {self.off_dist}") + elif offset_type == "uniform_binary": + self.itu_t_range = (-0.125, 0.045) + self.prob_oos = prob_oos + self.ins_dist = torch.distributions.uniform.Uniform(self.itu_t_range[0], self.itu_t_range[1]) + self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec) + else: + raise NotImplementedError(f"Unknown offset type: {offset_type}") + + def forward(self, item): + vid = item["video"] + aud = item["audio"] + v_len_frames, C, H, W = vid.shape + a_len_frames = aud.shape[0] + + v_fps = int(item["meta"]["video"]["fps"][0]) + a_fps = int(item["meta"]["audio"]["framerate"][0]) + + v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps) + a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps) + + if self.do_offset: + # trying to get the offset parameters (for instance during valid and test we have fixed offsets) + offset_sec = item["targets"].get("offset_sec", None) + v_start_i_sec = item["targets"].get("v_start_i_sec", None) + if "offset_target" in item["targets"]: + is_oos = item["targets"]["offset_target"].get("oos", None) + # train-time + if offset_sec is None and v_start_i_sec is None: + # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec + if self.offset_type == "grid": + offset_sec = random.choice(self.class_grid.tolist()) + elif self.offset_type == "uniform": + offset_sec = self.off_dist.sample().item() + elif self.offset_type == "uniform_binary": + # in-sync: Uniform(-0.125, 0.045) + # out-of-sync: Uniform(-5.5, 5.5) and resampled until not in Uniform(-0.125, 0.045) + # first, we sample if the offset is out-of-sync with prob_oss + is_oos = (torch.rand(1) < self.prob_oos).item() + if is_oos: + # second, we sample the offset itself (if in in-sync range, trying again) + offset_sec = self.off_dist.sample().item() + while self.itu_t_range[0] <= offset_sec <= self.itu_t_range[1]: + offset_sec = self.off_dist.sample().item() + else: + offset_sec = self.ins_dist.sample().item() + offset_sec = round(offset_sec, 2) + v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps) + assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}' + # `v_start_sec` IS NOT rounded to the fps grid + v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec - offset_sec)) + assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}' + v_start_i = sec2frames(v_start_sec, v_fps) + # `v_start_i_sec` IS rounded to the fps grid + v_start_i_sec = frames2sec(v_start_i, v_fps) + else: + offset_sec = round(offset_sec, 2) + v_start_i = sec2frames(v_start_i_sec, v_fps) + v_end_i = v_start_i + v_crop_len_frames + # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise + # (v_start_sec) we have ±0.1 jittering + a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps) + else: + offset_sec = 0.0 + is_random_crop = item["split"] == "train" + v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop) + v_start_i_sec = frames2sec(v_start_i, v_fps) + a_start_i = sec2frames(v_start_i_sec, a_fps) + + # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48 + # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5) + if a_start_i < 0: + how_much_out = a_start_i + logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}') + if abs(how_much_out) <= a_fps / v_fps: + logging.info("fixing it") + a_start_i += abs(how_much_out) + else: + raise Exception(f'{how_much_out} {item["path"]}') + + if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0: + a_start_i, a_jitter_i = apply_a_jitter( + a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec + ) + item["meta"]["a_jitter_i"] = a_jitter_i + + a_end_i = a_start_i + a_crop_len_frames + + assert v_start_i < v_end_i and a_start_i < a_end_i + assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}' + assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}' + + vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i] + + item["video"] = vid + item["audio"] = aud + + assert item["video"].shape[0] == v_fps * self.crop_len_sec, f'{item["video"].shape} {item["path"]}' + assert item["audio"].shape[0] == a_fps * self.crop_len_sec, f'{item["audio"].shape} {item["path"]}' + + # caching parameters + if self.do_offset: + if self.offset_type == "grid": + offset_label, offset_target = quantize_offset(self.class_grid, offset_sec) + elif self.offset_type == "uniform": + offset_label, offset_target = offset_sec, offset_sec + elif self.offset_type == "uniform_binary": + offset_label, offset_target = offset_sec, {"oos": is_oos, "offset": offset_sec} + item["targets"]["offset_sec"] = offset_sec + item["targets"]["v_start_i_sec"] = v_start_i_sec + item["targets"]["offset_label"] = offset_label + # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?' + item["targets"]["offset_target"] = offset_target + + return item + + def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True): + if len_frames == crop_len_frames: + return 0, len_frames + if is_random: + left_i = random.randint(0, len_frames - crop_len_frames) + else: + left_i = int(round((len_frames - crop_len_frames) / 2.0)) + return left_i, left_i + crop_len_frames + + +class GenerateMultipleSegments(torch.nn.Module): + """ + Given an item with video and audio, generates a batch of `n_segments` segments + of length `segment_size_vframes` (if None, the max number of segments will be made). + If `is_start_random` is True, the starting position of the 1st segment will be random but respecting + n_segments. + `audio_jitter_sec` is the amount of audio offset in seconds. + """ + + def __init__( + self, + segment_size_vframes: int, + n_segments: int = None, + is_start_random: bool = False, + audio_jitter_sec: float = 0.0, + step_size_seg: float = 1, + ): + super().__init__() + self.segment_size_vframes = segment_size_vframes + self.n_segments = n_segments + self.is_start_random = is_start_random + self.audio_jitter_sec = audio_jitter_sec + self.step_size_seg = step_size_seg + logging.info(f"Segment step size: {self.step_size_seg}") + + def forward(self, item): + v_len_frames, C, H, W = item["video"].shape + a_len_frames = item["audio"].shape[0] + + v_fps = int(item["meta"]["video"]["fps"][0]) + a_fps = int(item["meta"]["audio"]["framerate"][0]) + + ## Determining the number of segments + # segment size + segment_size_vframes = self.segment_size_vframes + segment_size_aframes = sec2frames(frames2sec(self.segment_size_vframes, v_fps), a_fps) + # step size (stride) + stride_vframes = int(self.step_size_seg * segment_size_vframes) + stride_aframes = int(self.step_size_seg * segment_size_aframes) + # calculating the number of segments. (W - F + 2P) / S + 1 + n_segments_max_v = math.floor((v_len_frames - segment_size_vframes) / stride_vframes) + 1 + n_segments_max_a = math.floor((a_len_frames - segment_size_aframes) / stride_aframes) + 1 + # making sure audio and video can accommodate the same number of segments + n_segments_max = min(n_segments_max_v, n_segments_max_a) + n_segments = n_segments_max if self.n_segments is None else self.n_segments + + assert n_segments <= n_segments_max, ( + f"cant make {n_segments} segs of len {self.segment_size_vframes} in a vid " + f'of len {v_len_frames} for {item["path"]}' + ) + + # (n_segments, 2) each + v_ranges, a_ranges = self.get_sequential_seg_ranges( + v_len_frames, a_len_frames, v_fps, a_fps, n_segments, segment_size_aframes + ) + + # segmenting original streams (n_segments, segment_size_frames, C, H, W) + item["video"] = torch.stack([item["video"][s:e] for s, e in v_ranges], dim=0) + item["audio"] = torch.stack([item["audio"][s:e] for s, e in a_ranges], dim=0) + return item + + def get_sequential_seg_ranges(self, v_len_frames, a_len_frames, v_fps, a_fps, n_seg, seg_size_aframes): + # if is_start_random is True, the starting position of the 1st segment will + # be random but respecting n_segments like so: "-CCCCCCCC---" (maybe with fixed overlap), + # else the segments are taken from the middle of the video respecting n_segments: "--CCCCCCCC--" + + seg_size_vframes = self.segment_size_vframes # for brevity + + # calculating the step size in frames + step_size_vframes = int(self.step_size_seg * seg_size_vframes) + step_size_aframes = int(self.step_size_seg * seg_size_aframes) + + # calculating the length of the sequence of segments (and in frames) + seg_seq_len = n_seg * self.step_size_seg + (1 - self.step_size_seg) + vframes_seg_seq_len = int(seg_seq_len * seg_size_vframes) + aframes_seg_seq_len = int(seg_seq_len * seg_size_aframes) + + # doing temporal crop + max_v_start_i = v_len_frames - vframes_seg_seq_len + if self.is_start_random: + v_start_i = random.randint(0, max_v_start_i) + else: + v_start_i = max_v_start_i // 2 + a_start_i = sec2frames(frames2sec(v_start_i, v_fps), a_fps) # vid frames -> seconds -> aud frames + + # make segments starts + v_start_seg_i = torch.tensor([v_start_i + i * step_size_vframes for i in range(n_seg)]).int() + a_start_seg_i = torch.tensor([a_start_i + i * step_size_aframes for i in range(n_seg)]).int() + + # apply jitter to audio + if self.audio_jitter_sec > 0: + jitter_aframes = sec2frames(self.audio_jitter_sec, a_fps) + # making sure after applying jitter, the audio is still within the audio boundaries + jitter_aframes = min(jitter_aframes, a_start_i, a_len_frames - a_start_i - aframes_seg_seq_len) + a_start_seg_i += random.randint(-jitter_aframes, jitter_aframes) # applying jitter to segments + + # make segment ends + v_ends_seg_i = v_start_seg_i + seg_size_vframes + a_ends_seg_i = a_start_seg_i + seg_size_aframes # using the adjusted a_start_seg_i (with jitter) + + # make ranges + v_ranges = torch.stack([v_start_seg_i, v_ends_seg_i], dim=1) + a_ranges = torch.stack([a_start_seg_i, a_ends_seg_i], dim=1) + assert (a_ranges >= 0).all() and (a_ranges <= a_len_frames).all(), f"{a_ranges} out of {a_len_frames}" + assert (v_ranges <= v_len_frames).all(), f"{v_ranges} out of {v_len_frames}" + return v_ranges, a_ranges + + +class TemporalCropAndOffsetForSyncabilityTraining(torch.nn.Module): + + def __init__( + self, + max_off_sec: float, + do_offset: bool = True, + grid_size: int = None, + max_wiggle_sec: float = None, + segment_size_vframes: int = None, + n_segments: int = None, + step_size_seg: float = None, + vfps: float = None, + ): + super().__init__() + seg_size_sec = segment_size_vframes / vfps + trim_size_in_seg = n_segments - (1 - step_size_seg) * (n_segments - 1) + self.crop_len_sec = round(trim_size_in_seg * seg_size_sec, 2) + logging.info(f"Crop len: {self.crop_len_sec}") + self.do_offset = do_offset + self.grid_size = grid_size + self.max_off_sec = max_off_sec + self.max_a_jitter_sec = max_wiggle_sec + self.segment_size_vframes = segment_size_vframes + self.n_segments = n_segments + self.step_size_seg = step_size_seg + self.prob_syncable = 0.5 + if do_offset: + self.class_grid = make_class_grid(-max_off_sec, max_off_sec, grid_size) + logging.info(f"Offset class grid: {self.class_grid}") + if self.max_a_jitter_sec is not None: + assert (max_wiggle_sec - 1e-6) <= ((self.class_grid[1] - self.class_grid[0]) / 2), f"{self.class_grid}" + + def forward(self, item): + vid = item["video"] + aud = item["audio"] + v_len_frames, C, H, W = vid.shape + a_len_frames = aud.shape[0] + + v_fps = int(item["meta"]["video"]["fps"][0]) + a_fps = int(item["meta"]["audio"]["framerate"][0]) + + v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps) + a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps) + + if self.do_offset: + # trying to get the offset parameters (for instance during valid and test we have fixed offsets) + offset_sec = item["targets"].get("offset_sec", None) + v_start_i_sec = item["targets"].get("v_start_i_sec", None) + # train-time + if offset_sec is None and v_start_i_sec is None: + + # for the syncability training, we want to have a syncable or non-syncable offset with 50% prob + offset_is_syncable = random.random() < self.prob_syncable # 1=syncable, 0=non-syncable + if offset_is_syncable: + offset_sec = random.choice(self.class_grid.tolist()) + else: + offset_sec = random.choice([-self.crop_len_sec, self.crop_len_sec]) # either - or + offset + # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec + + offset_sec = round(offset_sec, 2) + v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps) + assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}' + # `v_start_sec` IS NOT rounded to the fps grid + v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec - offset_sec)) + assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}' + v_start_i = sec2frames(v_start_sec, v_fps) + v_end_i = v_start_i + v_crop_len_frames + # `v_start_i_sec` IS rounded to the fps grid + v_start_i_sec = frames2sec(v_start_i, v_fps) + # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise + # (v_start_sec) we have ±0.1 jittering + a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps) + if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0: + a_start_i, a_jitter_i = apply_a_jitter( + a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec + ) + item["meta"]["a_jitter_i"] = a_jitter_i + a_end_i = a_start_i + a_crop_len_frames + else: + offset_sec = round(offset_sec, 2) + v_start_i = sec2frames(v_start_i_sec, v_fps) + a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps) + v_end_i = v_start_i + v_crop_len_frames + a_end_i = a_start_i + a_crop_len_frames + else: + offset_sec = 0.0 + is_random_crop = item["split"] == "train" + v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop) + v_start_i_sec = frames2sec(v_start_i, v_fps) + a_start_i = sec2frames(v_start_i_sec, a_fps) + if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0: + a_start_i, a_jitter_i = apply_a_jitter( + a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec + ) + item["meta"]["a_jitter_i"] = a_jitter_i + a_end_i = a_start_i + a_crop_len_frames + + # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48 + # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5) + if a_start_i < 0: + how_much_out = a_start_i + logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}') + if abs(how_much_out) <= a_fps / v_fps: + logging.info("fixing it") + a_start_i += abs(how_much_out) + a_end_i += abs(how_much_out) + else: + raise Exception(f'{how_much_out} {item["path"]}') + + assert v_start_i < v_end_i and a_start_i < a_end_i + assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}' + assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}' + + vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i] + + item["video"] = vid + item["audio"] = aud + + assert item["video"].shape[0] == int(v_fps * self.crop_len_sec), f'{item["video"].shape} {item["path"]}' + assert item["audio"].shape[0] == int(a_fps * self.crop_len_sec), f'{item["audio"].shape} {item["path"]}' + + # caching parameters + if self.do_offset: + # NOTE: this is useless for the extreme offsetting + offset_label, offset_target = quantize_offset(self.class_grid, offset_sec) + item["targets"]["offset_sec"] = offset_sec + item["targets"]["offset_label"] = offset_label + # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?' + item["targets"]["offset_target"] = offset_target + item["targets"]["v_start_i_sec"] = v_start_i_sec + item["targets"]["sync_target"] = int(offset_is_syncable) + + return item + + def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True): + if len_frames == crop_len_frames: + return 0, len_frames + if is_random: + left_i = random.randint(0, len_frames - crop_len_frames) + else: + left_i = int(round((len_frames - crop_len_frames) / 2.0)) + return left_i, left_i + crop_len_frames + + +class RGBToFloatToZeroOne(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, item): + item["video"] = item["video"].to(torch.float32).div(255.0) + return item + + +class RGBToHalfToZeroOne(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, item): + item["video"] = item["video"].half().div(255.0) + return item + + +class RGBNormalize(torchvision.transforms.Normalize): + """The same as the torchvision`s but with different interface for the dict. + This should work for any shape (..., C, H, W)""" + + def __init__(self, mean, std, inplace=False): + super().__init__(mean, std, inplace) + logging.info(f"RGBNormalize: mean={mean}, std={std}") + + def forward(self, item): + item["video"] = super().forward(item["video"]) + item["meta"]["video"]["norm_stats"] = {"mean": torch.as_tensor(self.mean), "std": torch.as_tensor(self.std)} + return item + + +class AudioRandomVolume(torch.nn.Module): + + def __init__(self, p: float, **kwargs): + super().__init__() + transform = torchaudio.transforms.Vol(**kwargs) + self.transform = torchvision.transforms.RandomApply([transform], p) + + def apply_to_single_clip(self, clip): + return self.transform(clip) + + def apply_to_each_clip(self, clips): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip) + return clips + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"]) + return item + + +class AudioRandomLowpassFilter(torch.nn.Module): + + def __init__(self, p: float, cutoff_freq: float, Q: float = 0.707): + super().__init__() + self.p = p + self.cutoff_freq = cutoff_freq + self.Q = Q + + def apply_to_single_clip(self, clip, sr): + if self.p > torch.rand(1): + return torchaudio.functional.lowpass_biquad(clip, sr, self.cutoff_freq, self.Q) + else: + return clip + + def apply_to_each_clip(self, clips, sr): + for i, clip in enumerate(clips): + clips[i] = self.apply_to_single_clip(clip, sr) + return clips + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + sr = int(item["meta"]["audio"]["framerate"][0]) + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"], sr) + return item + + +class AudioRandomPitchShift(torch.nn.Module): + + def __init__(self, p: float, shift: int) -> None: + super().__init__() + self.p = p + self.shift = shift + + def apply_to_single_clip(self, wave, sr): + if self.p > torch.rand(1): + effects = [["pitch", f"{self.shift}"], ["rate", f"{sr}"]] + wave = wave.unsqueeze(0) + wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, sr, effects) + wave = wave.squeeze(0) + return wave + + def apply_to_each_clip(self, waves, sr): + for i, wave in enumerate(waves): + waves[i] = self.apply_to_single_clip(wave, sr) + return waves + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + sr = int(item["meta"]["audio"]["framerate"][0]) + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"], sr) + return item + + +class AudioRandomReverb(torch.nn.Module): + + def __init__(self, p: float) -> None: + super().__init__() + self.p = p + self.effects = [["reverb", "-w"]] + + def apply_to_single_clip(self, wave, fps): + if self.p > torch.rand(1): + wave = wave.unsqueeze(0) + wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, fps, self.effects) + wave = wave.mean(dim=0) + return wave + + def apply_to_each_clip(self, waves, fps): + for i, wave in enumerate(waves): + waves[i] = self.apply_to_single_clip(wave, fps) + return waves + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + sr = int(item["meta"]["audio"]["framerate"][0]) + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"], sr) + return item + + +class AudioRandomGaussNoise(torch.nn.Module): + + def __init__(self, p: float, amplitude=0.01) -> None: + super().__init__() + self.p = p + self.amplitude = amplitude + + def apply_to_single_clip(self, wave): + if self.p > torch.rand(1): + noise = torch.randn_like(wave, dtype=wave.dtype) + wave = wave + self.amplitude * noise + return wave + + def apply_to_each_clip(self, waves): + for i, wave in enumerate(waves): + waves[i] = self.apply_to_single_clip(wave) + return waves + + def forward(self, item): + has_batch_dim = len(item["audio"].shape) == 2 + if has_batch_dim: + fn = self.apply_to_each_clip + else: + fn = self.apply_to_single_clip + item["audio"] = fn(item["audio"]) + return item + + +class AudioMelSpectrogram(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.spec = torchaudio.transforms.MelSpectrogram(**kwargs) + + def forward(self, item): + item["audio"] = self.spec(item["audio"]) # safe for batched input + return item + + +class AudioLog(torch.nn.Module): + + def __init__(self, eps=1e-6) -> None: + super().__init__() + self.eps = eps + + def forward(self, item): + item["audio"] = torch.log(item["audio"] + self.eps) + return item + + +class PadOrTruncate(torch.nn.Module): + + def __init__(self, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): + super().__init__() + self.max_spec_t = max_spec_t + self.pad_mode = pad_mode + self.pad_value = pad_value + + def forward(self, item): + item["audio"] = self.pad_or_truncate(item["audio"]) + return item + + def pad_or_truncate(self, audio): + difference = self.max_spec_t - audio.shape[-1] # safe for batched input + # pad or truncate, depending on difference + if difference > 0: + # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input + pad_dims = (0, difference) + audio = torch.nn.functional.pad(audio, pad_dims, self.pad_mode, self.pad_value) + elif difference < 0: + logging.warning(f"Truncating spec ({audio.shape}) to max_spec_t ({self.max_spec_t}).") + audio = audio[..., : self.max_spec_t] # safe for batched input + return audio + + +class AudioNormalizeAST(torch.nn.Module): + """Normalization is done with two specified mean and std (half)""" + + def __init__(self, mean: float, std: float) -> None: + super().__init__() + self.mean = mean + self.std = std + + def forward(self, item): + item["audio"] = (item["audio"] - self.mean) / (2 * self.std) + item["meta"]["audio"]["norm_stats"] = {"mean": self.mean, "std": self.std} + return item + + +class PermuteStreams(torch.nn.Module): + + def __init__(self, einops_order_audio: str, einops_order_rgb: str) -> None: + '''For example: + einops_order_audio: "S F T -> S T F" + einops_order_rgb: "S T C H W -> S C T H W"''' + super().__init__() + self.einops_order_audio = einops_order_audio + self.einops_order_rgb = einops_order_rgb + + def forward(self, item): + if self.einops_order_audio is not None: + item["audio"] = einops.rearrange(item["audio"], self.einops_order_audio).contiguous() + if self.einops_order_rgb is not None: + item["video"] = einops.rearrange(item["video"], self.einops_order_rgb).contiguous() + return item + + +class ResampleAudio(torch.nn.Module): + + def __init__(self, new_fps: int): + super().__init__() + self.new_fps = new_fps + + def forward(self, item): + orig_fps = int(item["meta"]["audio"]["framerate"][0]) + item["meta"]["audio"]["orig_shape"] = item["audio"].shape + if orig_fps != self.new_fps: + item["audio"] = torchaudio.functional.resample(item["audio"], orig_fps, self.new_fps) + item["meta"]["audio"]["framerate"][0] = self.new_fps + return item + + +class ResampleRGB(torch.nn.Module): + + def __init__(self, new_fps: int) -> None: + super().__init__() + self.new_fps = new_fps + + def forward(self, item): + orig_fps = float(item["meta"]["video"]["fps"][0]) + item["meta"]["video"]["orig_shape"] = item["video"].shape + if orig_fps != self.new_fps: + duration_sec = item["video"].shape[0] / orig_fps + indices = torch.arange(0, orig_fps * duration_sec - 1e-9, orig_fps / self.new_fps) + # basically, rounding + indices = indices.to(dtype=torch.long) + item["video"] = item["video"][indices] + item["meta"]["video"]["fps"][0] = self.new_fps + return item + + +class ResizeAndLetterboxPad(torch.nn.Module): + """Adapted from WACV24 Amazon`s challenge""" + + def __init__(self, new_h, new_w): + super().__init__() + self.new_h = new_h + self.new_w = new_w + self.aspect_ratio = new_w / new_h + + def forward(self, item): + item["video"] = self.resize_and_pad(item["video"]) + return item + + def resize_and_pad(self, rgb: torch.Tensor): + _, _, height, width = rgb.shape + current_aspect_ratio = width / height + if current_aspect_ratio > self.aspect_ratio: + scaled_height = round(self.new_w / current_aspect_ratio) + rgb = torchvision.transforms.functional.resize(rgb, (scaled_height, self.new_w), antialias=None) + top = (self.new_h - scaled_height) // 2 + bottom = self.new_h - (scaled_height + top) + rgb = torch.nn.ConstantPad2d((0, 0, top, bottom), 0)(rgb) + elif current_aspect_ratio < self.aspect_ratio: + scaled_width = round(self.new_h * current_aspect_ratio) + rgb = torchvision.transforms.functional.resize(rgb, (self.new_h, scaled_width), antialias=None) + left = (self.new_w - scaled_width) // 2 + right = self.new_w - (scaled_width + left) + rgb = torch.nn.ConstantPad2d((left, right, 0, 0), 0)(rgb) + return rgb + + +class ResampleResizeLetterboxPad(torch.nn.Module): + + def __init__(self, afps, vfps, new_h, new_w) -> None: + super().__init__() + self.transforms = torchvision.transforms.Compose( + [ResampleAudio(new_fps=afps), ResampleRGB(new_fps=vfps), ResizeAndLetterboxPad(new_h=new_h, new_w=new_w)] + ) + + def forward(self, x: dict) -> dict: + return self.transforms(x) + + +class DoNothing(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + def forward(self, x: dict) -> dict: + return x + + +if __name__ == "__main__": + grid = make_class_grid(-1, 1, 21) + grid = make_class_grid(-2, 2, 41) + print("grid:", grid) + print("value quantization:", quantize_offset(grid, 0.06)) + v_fps = 25.0 + duration = 10.0 + + input = { + "video": torch.randint(0, 256, (int(duration * v_fps), 3, 720 // 2, 1280 // 2), dtype=torch.uint8), + "audio": torch.arange(221184 - 1).float(), + "targets": {}, + "meta": { + "video": {"duration": [duration], "fps": [v_fps]}, + "audio": {"duration": [duration], "framerate": [22050.0]}, + "subtitles": {"duration": []}, + "cc": {"duration": []}, + }, + "path": "/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4", + "split": "train", + } + + print(input["audio"].shape, input["video"].shape) + + fn = EqualifyFromRight(clip_max_len_sec=10) + input = fn(input) + print(input["audio"].shape, input["video"].shape) + + fn = RGBSpatialCrop((224, 224), is_random=True) + # fn = RGBSpatialCrop((112, 112), is_random=True) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = Resize((224, 224)) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = GenerateMultipleSegments( + segment_size_vframes=16, n_segments=14, is_start_random=False, audio_jitter_sec=0.05, step_size_seg=0.5 + ) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = RandomApplyColorDistortion(p_gray_scale=0.5, p_color_jitter=0.5, s=1.0) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = RGBToFloatToZeroOne() + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + print(input["meta"]) + + fn = RGBNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + print(input["video"].mean(dim=(0, 2, 3))) + print(input["meta"]) + + fn = AudioRandomReverb(p=1.0) + input = fn(input) + + fn = AudioRandomVolume(p=1.0, gain=2.0, gain_type="amplitude") + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioRandomPitchShift(p=1.0, shift=1000) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioRandomLowpassFilter(p=1.0, cutoff_freq=100) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioRandomGaussNoise(p=1.0, amplitude=0.01) + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + fn = AudioLog() + input = fn(input) + print(input["audio"].shape, input["video"].shape, input["meta"]["audio"]) + + # audio only + input = { + "audio": torch.arange(221184).float(), + "meta": { + "video": {"duration": [10.0], "fps": [10.0]}, + "audio": {"duration": [11.0], "framerate": [22050.0]}, + "subtitles": {"duration": []}, + "cc": {"duration": []}, + }, + "path": "/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4", + } + + print(input["audio"].shape) + + fn = AudioLog() + input = fn(input) + print(input["audio"].shape, input["meta"]["audio"]) + print(input["meta"]) + print(input["audio"].min(), input["audio"].max()) diff --git a/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml b/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9d20b76302a8af7928391643bd4b2d184e970aa --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml @@ -0,0 +1,84 @@ +TRAIN: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 32 + EVAL_PERIOD: 5 + CHECKPOINT_PERIOD: 5 + AUTO_RESUME: True + CHECKPOINT_EPOCH_RESET: True + CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth +DATA: + NUM_FRAMES: 16 + SAMPLING_RATE: 4 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 224 + INPUT_CHANNEL_NUM: [3] + MEAN: [0.5, 0.5, 0.5] + STD: [0.5, 0.5, 0.5] + PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 + PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames + INV_UNIFORM_SAMPLE: True + RANDOM_FLIP: False + REVERSE_INPUT_CHANNEL: True + USE_RAND_AUGMENT: True + RE_PROB: 0.0 + USE_REPEATED_AUG: False + USE_RANDOM_RESIZE_CROPS: False + COLORJITTER: False + GRAYSCALE: False + GAUSSIAN: False +SOLVER: + BASE_LR: 1e-4 + LR_POLICY: steps_with_relative_lrs + LRS: [1, 0.1, 0.01] + STEPS: [0, 20, 30] + MAX_EPOCH: 35 + MOMENTUM: 0.9 + WEIGHT_DECAY: 5e-2 + WARMUP_EPOCHS: 0.0 + OPTIMIZING_METHOD: adamw + USE_MIXED_PRECISION: True + SMOOTHING: 0.2 +SLOWFAST: + ALPHA: 8 +VIT: + PATCH_SIZE: 16 + PATCH_SIZE_TEMP: 2 + CHANNELS: 3 + EMBED_DIM: 768 + DEPTH: 12 + NUM_HEADS: 12 + MLP_RATIO: 4 + QKV_BIAS: True + VIDEO_INPUT: True + TEMPORAL_RESOLUTION: 8 + USE_MLP: True + DROP: 0.0 + POS_DROPOUT: 0.0 + DROP_PATH: 0.2 + IM_PRETRAINED: True + HEAD_DROPOUT: 0.0 + HEAD_ACT: tanh + PRETRAINED_WEIGHTS: vit_1k + ATTN_LAYER: divided +MODEL: + NUM_CLASSES: 174 + ARCH: slow + MODEL_NAME: VisionTransformer + LOSS_FUNC: cross_entropy +TEST: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 64 + NUM_ENSEMBLE_VIEWS: 1 + NUM_SPATIAL_CROPS: 3 +DATA_LOADER: + NUM_WORKERS: 4 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 4 +RNG_SEED: 0 +OUTPUT_DIR: . +TENSORBOARD: + ENABLE: True diff --git a/hunyuanvideo_foley/models/synchformer/modeling_ast.py b/hunyuanvideo_foley/models/synchformer/modeling_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..f456753ecfff180dd36a3d2ff3e50a47ab735d52 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/modeling_ast.py @@ -0,0 +1,673 @@ +# coding=utf-8 +# Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified by v-iashin to support token masking + +"""PyTorch Audio Spectrogram Transformer (AST) model.""" + +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ASTConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593" +_EXPECTED_OUTPUT_SHAPE = [1, 1214, 768] + +# Audio classification docstring +_SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593" +_SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'" +_SEQ_CLASS_EXPECTED_LOSS = 0.17 + + +AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "MIT/ast-finetuned-audioset-10-10-0.4593", + # See all Audio Spectrogram Transformer models at https://huggingface.co/models?filter=ast +] + + +class ASTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + """ + + def __init__(self, config: ASTConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = ASTPatchEmbeddings(config) + + frequency_out_dimension, time_out_dimension = self.get_shape(config) + num_patches = frequency_out_dimension * time_out_dimension + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def get_shape(self, config): + # see Karpathy's cs231n blog on how to calculate the output dimensions + # https://cs231n.github.io/convolutional-networks/#conv + frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1 + time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1 + + return frequency_out_dimension, time_out_dimension + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + batch_size = input_values.shape[0] + embeddings = self.patch_embeddings(input_values) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) + print(self.position_embeddings.shape) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +class ASTPatchEmbeddings(nn.Module): + """ + This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size, + seq_length, hidden_size)` to be consumed by a Transformer. + """ + + def __init__(self, config): + super().__init__() + + patch_size = config.patch_size + frequency_stride = config.frequency_stride + time_stride = config.time_stride + + self.projection = nn.Conv2d( + 1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride) + ) + + def forward(self, input_values: torch.Tensor) -> torch.Tensor: + input_values = input_values.unsqueeze(1) + input_values = input_values.transpose(2, 3) + embeddings = self.projection(input_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST +class ASTSelfAttention(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # apply masking if provided, tok_mask is (BS, N): 1s - keep; attention_scores is (BS, H, N, N) + if tok_mask is not None: + BS, N = tok_mask.shape + attention_scores = attention_scores.masked_fill(tok_mask.view(BS, 1, 1, N) == 0, float("-inf")) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST +class ASTSelfOutput(nn.Module): + """ + The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST +class ASTAttention(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.attention = ASTSelfAttention(config) + self.output = ASTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, tok_mask, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST +class ASTIntermediate(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST +class ASTOutput(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST +class ASTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ASTAttention(config) + self.intermediate = ASTIntermediate(config) + self.output = ASTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention + tok_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in AST, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST +class ASTEncoder(nn.Module): + def __init__(self, config: ASTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + tok_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + tok_mask, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ASTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ASTConfig + base_model_prefix = "audio_spectrogram_transformer" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST + def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None: + if isinstance(module, ASTEncoder): + module.gradient_checkpointing = value + + +AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ASTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`ASTFeatureExtractor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare AST Model transformer outputting raw hidden-states without any specific head on top.", + AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING, +) +class ASTModel(ASTPreTrainedModel): + def __init__(self, config: ASTConfig): + super().__init__(config) + self.config = config + + self.embeddings = ASTEmbeddings(config) + self.encoder = ASTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ASTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + cont_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_values is None: + raise ValueError("You have to specify input_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(input_values) + + # transforms the mask that has spectrogram dims to the token masking which is obtained after patching. + # Due to the ovelap in patching, getting the token mask from spectrogram mask is not straightforward, + # because one 16x16 content patch is encoded in two tokens if stride is <16. So, to get the mask for + # tokens I will apply the patching func (self.embeddings) to the tensor with infinities at the masked + # content position. For infs, the patching fn will return nans, which I'll use to get the token mask. + if cont_mask is not None: + indicator = torch.ones_like(input_values).to(input_values.dtype) + # replace content mask (0s) with infs + indicator[~cont_mask] = torch.inf + # apply patching; now nans are where the content mask was + with torch.no_grad(): + indicator = self.embeddings(indicator) # BS, N, D + # replace nans with 0s; these are the tokens that correspond to the masked content + tok_mask = ~torch.isnan(indicator) + # since all values in the D-dimension (latent) will also be nans, we can just use the first el + tok_mask = tok_mask[:, :, 0] # (BS, 2+num_patches) -- 2 is from CLS and DISTIL tokens + else: + tok_mask = None + + encoder_outputs = self.encoder( + embedding_output, + tok_mask=tok_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2 + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return ( + BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ), + tok_mask, + ) + + +class ASTMLPHead(nn.Module): + def __init__(self, config: ASTConfig): + super().__init__() + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + def forward(self, hidden_state): + hidden_state = self.layernorm(hidden_state) + hidden_state = self.dense(hidden_state) + return hidden_state + + +@add_start_docstrings( + """ + Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled + output) e.g. for datasets like AudioSet, Speech Commands v2. + """, + AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING, +) +class ASTForAudioClassification(ASTPreTrainedModel): + def __init__(self, config: ASTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.audio_spectrogram_transformer = ASTModel(config) + + # Classifier head + self.classifier = ASTMLPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_SEQ_CLASS_CHECKPOINT, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor] = None, + cont_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the audio classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.audio_spectrogram_transformer( + input_values, + cont_mask=cont_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/hunyuanvideo_foley/models/synchformer/motionformer.py b/hunyuanvideo_foley/models/synchformer/motionformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9980a6f6d667699a275b10c6f613a30493566713 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/motionformer.py @@ -0,0 +1,397 @@ +import logging +from pathlib import Path + +import einops +import torch +from omegaconf import OmegaConf +from timm.layers import trunc_normal_ +from torch import nn + +from .utils import check_if_file_exists_else_download +from .video_model_builder import VisionTransformer + + +FILE2URL = { + # cfg + "motionformer_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml", + "joint_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml", + "divided_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml", + # ckpt + "ssv2_motionformer_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth", + "ssv2_joint_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth", + "ssv2_divided_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth", +} + + +class MotionFormer(VisionTransformer): + """This class serves three puposes: + 1. Renames the class to MotionFormer. + 2. Downloads the cfg from the original repo and patches it if needed. + 3. Takes care of feature extraction by redefining .forward() + - if `extract_features=True` and `factorize_space_time=False`, + the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D) + and spatial and temporal transformer encoder layers are used. + - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True` + the output is of shape (B, D) and spatial and temporal transformer encoder layers + are used as well as the global representation is extracted from segments (extra pos emb + is added). + """ + + def __init__( + self, + extract_features: bool = False, + ckpt_path: str = None, + factorize_space_time: bool = None, + agg_space_module: str = None, + agg_time_module: str = None, + add_global_repr: bool = True, + agg_segments_module: str = None, + max_segments: int = None, + ): + self.extract_features = extract_features + self.ckpt_path = ckpt_path + self.factorize_space_time = factorize_space_time + + if self.ckpt_path is not None: + check_if_file_exists_else_download(self.ckpt_path, FILE2URL) + ckpt = torch.load(self.ckpt_path, map_location="cpu") + mformer_ckpt2cfg = { + "ssv2_motionformer_224_16x4.pyth": "motionformer_224_16x4.yaml", + "ssv2_joint_224_16x4.pyth": "joint_224_16x4.yaml", + "ssv2_divided_224_16x4.pyth": "divided_224_16x4.yaml", + } + # init from motionformer ckpt or from our Stage I ckpt + # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to + # load the state dict differently + was_pt_on_avclip = self.ckpt_path.endswith(".pt") # checks if it is a stage I ckpt (FIXME: a bit generic) + if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())): + cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name] + elif was_pt_on_avclip: + # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it) + s1_cfg = ckpt.get("args", None) # Stage I cfg + if s1_cfg is not None: + s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path + # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch + if s1_vfeat_extractor_ckpt_path is not None: + cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name] + else: + cfg_fname = "divided_224_16x4.yaml" + else: + cfg_fname = "divided_224_16x4.yaml" + else: + raise ValueError(f"ckpt_path {self.ckpt_path} is not supported.") + else: + was_pt_on_avclip = False + cfg_fname = "divided_224_16x4.yaml" + # logging.info(f'No ckpt_path provided, using {cfg_fname} config.') + + if cfg_fname in ["motionformer_224_16x4.yaml", "divided_224_16x4.yaml"]: + pos_emb_type = "separate" + elif cfg_fname == "joint_224_16x4.yaml": + pos_emb_type = "joint" + + self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname + + check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL) + mformer_cfg = OmegaConf.load(self.mformer_cfg_path) + logging.info(f"Loading MotionFormer config from {self.mformer_cfg_path.absolute()}") + + # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`) + mformer_cfg.VIT.ATTN_DROPOUT = 0.0 + mformer_cfg.VIT.POS_EMBED = pos_emb_type + mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True + mformer_cfg.VIT.APPROX_ATTN_TYPE = "none" # guessing + mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg'] + + # finally init VisionTransformer with the cfg + super().__init__(mformer_cfg) + + # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt + if (self.ckpt_path is not None) and (not was_pt_on_avclip): + _ckpt_load_status = self.load_state_dict(ckpt["model_state"], strict=False) + if len(_ckpt_load_status.missing_keys) > 0 or len(_ckpt_load_status.unexpected_keys) > 0: + logging.warning( + f"Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed." + f"Missing keys: {_ckpt_load_status.missing_keys}, " + f"Unexpected keys: {_ckpt_load_status.unexpected_keys}" + ) + else: + logging.info(f"Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.") + + if self.extract_features: + assert isinstance(self.norm, nn.LayerNorm), "early x[:, 1:, :] may not be safe for per-tr weights" + # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger + self.pre_logits = nn.Identity() + # we don't need the classification head (saving memory) + self.head = nn.Identity() + self.head_drop = nn.Identity() + # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer) + transf_enc_layer_kwargs = dict( + d_model=self.embed_dim, + nhead=self.num_heads, + activation=nn.GELU(), + batch_first=True, + dim_feedforward=self.mlp_ratio * self.embed_dim, + dropout=self.drop_rate, + layer_norm_eps=1e-6, + norm_first=True, + ) + # define adapters if needed + if self.factorize_space_time: + if agg_space_module == "TransformerEncoderLayer": + self.spatial_attn_agg = SpatialTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_space_module == "AveragePooling": + self.spatial_attn_agg = AveragePooling( + avg_pattern="BS D t h w -> BS D t", then_permute_pattern="BS D t -> BS t D" + ) + if agg_time_module == "TransformerEncoderLayer": + self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_time_module == "AveragePooling": + self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D") + elif "Identity" in agg_time_module: + self.temp_attn_agg = nn.Identity() + # define a global aggregation layer (aggregarate over segments) + self.add_global_repr = add_global_repr + if add_global_repr: + if agg_segments_module == "TransformerEncoderLayer": + # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D) + # we need to add pos emb (PE) because previously we added the same PE for each segment + pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1 + self.global_attn_agg = TemporalTransformerEncoderLayer( + add_pos_emb=True, + pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT, + pos_max_len=pos_max_len, + **transf_enc_layer_kwargs, + ) + elif agg_segments_module == "AveragePooling": + self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D") + + if was_pt_on_avclip: + # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors) + # and keep only the state_dict of the feat extractor + ckpt_weights = dict() + for k, v in ckpt["state_dict"].items(): + if k.startswith(("module.v_encoder.", "v_encoder.")): + k = k.replace("module.", "").replace("v_encoder.", "") + ckpt_weights[k] = v + _load_status = self.load_state_dict(ckpt_weights, strict=False) + if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: + logging.warning( + f"Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n" + f"Missing keys ({len(_load_status.missing_keys)}): " + f"{_load_status.missing_keys}, \n" + f"Unexpected keys ({len(_load_status.unexpected_keys)}): " + f"{_load_status.unexpected_keys} \n" + f"temp_attn_agg are expected to be missing if ckpt was pt contrastively." + ) + else: + logging.info(f"Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.") + + # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1 + # but it used to calculate the number of patches, so we need to set keep it + self.patch_embed.requires_grad_(False) + + def forward(self, x): + """ + x is of shape (B, S, C, T, H, W) where S is the number of segments. + """ + # Batch, Segments, Channels, T=frames, Height, Width + B, S, C, T, H, W = x.shape + # Motionformer expects a tensor of shape (1, B, C, T, H, W). + # The first dimension (1) is a dummy dimension to make the input tensor and won't be used: + # see `video_model_builder.video_input`. + # x = x.unsqueeze(0) # (1, B, S, C, T, H, W) + + orig_shape = (B, S, C, T, H, W) + x = x.view(B * S, C, T, H, W) # flatten batch and segments + x = self.forward_segments(x, orig_shape=orig_shape) + # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D)) + x = x.view(B, S, *x.shape[1:]) + # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity` + + return x # x is (B, S, ...) + + def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor: + """x is of shape (1, BS, C, T, H, W) where S is the number of segments.""" + x, x_mask = self.forward_features(x) + + assert self.extract_features + + # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + x = x[:, 1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC) + x = self.norm(x) + x = self.pre_logits(x) + if self.factorize_space_time: + x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D) + + x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D) + x = self.temp_attn_agg(x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity` + + return x + + def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + """ + feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions. + From `self.patch_embed_3d`, it follows that we could reshape feats with: + `feats.transpose(1, 2).view(B*S, D, t, h, w)` + """ + B, S, C, T, H, W = orig_shape + D = self.embed_dim + + # num patches in each dimension + t = T // self.patch_embed_3d.z_block_size + h = self.patch_embed_3d.height + w = self.patch_embed_3d.width + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w) + + return feats + + +class BaseEncoderLayer(nn.TransformerEncoderLayer): + """ + This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token + to the sequence and outputs the CLS token's representation. + This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream + and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream. + We also, optionally, add a positional embedding to the input sequence which + allows to reuse it for global aggregation (of segments) for both streams. + """ + + def __init__( + self, + add_pos_emb: bool = False, + pos_emb_drop: float = None, + pos_max_len: int = None, + *args_transformer_enc, + **kwargs_transformer_enc, + ): + super().__init__(*args_transformer_enc, **kwargs_transformer_enc) + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim)) + trunc_normal_(self.cls_token, std=0.02) + + # add positional embedding + self.add_pos_emb = add_pos_emb + if add_pos_emb: + self.pos_max_len = 1 + pos_max_len # +1 (for CLS) + self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim)) + self.pos_drop = nn.Dropout(pos_emb_drop) + trunc_normal_(self.pos_emb, std=0.02) + + self.apply(self._init_weights) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): + """x is of shape (B, N, D); if provided x_mask is of shape (B, N)""" + batch_dim = x.shape[0] + + # add CLS token + cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension + x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D) + if x_mask is not None: + cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, device=x_mask.device) # 1=keep; 0=mask + x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len) + B, N = x_mask_w_cls.shape + # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks + x_mask_w_cls = ( + x_mask_w_cls.reshape(B, 1, 1, N) + .expand(-1, self.self_attn.num_heads, N, -1) + .reshape(B * self.self_attn.num_heads, N, N) + ) + assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, "x_mask_w_cls.dtype != bool" + x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask) + else: + x_mask_w_cls = None + + # add positional embedding + if self.add_pos_emb: + seq_len = x.shape[1] # (don't even think about moving it before the CLS token concatenation) + assert seq_len <= self.pos_max_len, f"Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})" + x = x + self.pos_emb[:, :seq_len, :] + x = self.pos_drop(x) + + # apply encoder layer (calls nn.TransformerEncoderLayer.forward); + x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D) + + # CLS token is expected to hold spatial information for each frame + x = x[:, 0, :] # (batch_dim, D) + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"cls_token", "pos_emb"} + + +class SpatialTransformerEncoderLayer(BaseEncoderLayer): + """Aggregates spatial dimensions by applying attention individually to each frame.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + """x is of shape (B*S, D, t, h, w) where S is the number of segments. + if specified x_mask (B*S, t, h, w), 0=masked, 1=kept + Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame.""" + BS, D, t, h, w = x.shape + + # time as a batch dimension and flatten spatial dimensions as sequence + x = einops.rearrange(x, "BS D t h w -> (BS t) (h w) D") + # similar to mask + if x_mask is not None: + x_mask = einops.rearrange(x_mask, "BS t h w -> (BS t) (h w)") + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D) + + # reshape back to (B*S, t, D) + x = einops.rearrange(x, "(BS t) D -> BS t D", BS=BS, t=t) + + # (B*S, t, D) + return x + + +class TemporalTransformerEncoderLayer(BaseEncoderLayer): + """Aggregates temporal dimension with attention. Also used with pos emb as global aggregation + in both streams.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + """x is of shape (B*S, t, D) where S is the number of segments. + Returns a tensor of shape (B*S, D) pooling temporal information.""" + BS, t, D = x.shape + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x) # (B*S, D) + + return x # (B*S, D) + + +class AveragePooling(nn.Module): + + def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None: + """patterns are e.g. "bs t d -> bs d" """ + super().__init__() + # TODO: need to register them as buffers (but fails because these are strings) + self.reduce_fn = "mean" + self.avg_pattern = avg_pattern + self.then_permute_pattern = then_permute_pattern + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + x = einops.reduce(x, self.avg_pattern, self.reduce_fn) + if self.then_permute_pattern is not None: + x = einops.rearrange(x, self.then_permute_pattern) + return x diff --git a/hunyuanvideo_foley/models/synchformer/synchformer.py b/hunyuanvideo_foley/models/synchformer/synchformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1238bbc68dc451a56121ba7ab1fc00aa290420c3 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/synchformer.py @@ -0,0 +1,355 @@ +import logging +import math +from typing import Any, Mapping + +import einops +import numpy as np +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F + +from .motionformer import MotionFormer +from .ast_model import AST +from .utils import Config + + +class Synchformer(nn.Module): + + def __init__(self): + super().__init__() + + self.vfeat_extractor = MotionFormer( + extract_features=True, + factorize_space_time=True, + agg_space_module="TransformerEncoderLayer", + agg_time_module="torch.nn.Identity", + add_global_repr=False, + ) + self.afeat_extractor = AST( + extract_features=True, + max_spec_t=66, + factorize_freq_time=True, + agg_freq_module="TransformerEncoderLayer", + agg_time_module="torch.nn.Identity", + add_global_repr=False, + ) + + # # bridging the s3d latent dim (1024) into what is specified in the config + # # to match e.g. the transformer dim + self.vproj = nn.Linear(in_features=768, out_features=768) + self.aproj = nn.Linear(in_features=768, out_features=768) + self.transformer = GlobalTransformer( + tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768 + ) + + def forward(self, vis): + B, S, Tv, C, H, W = vis.shape + vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) + # feat extractors return a tuple of segment-level and global features (ignored for sync) + # (B, S, tv, D), e.g. (B, 7, 8, 768) + vis = self.vfeat_extractor(vis) + return vis + + def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor): + vis = self.vproj(vis) + aud = self.aproj(aud) + + B, S, tv, D = vis.shape + B, S, ta, D = aud.shape + vis = vis.view(B, S * tv, D) # (B, S*tv, D) + aud = aud.view(B, S * ta, D) # (B, S*ta, D) + # print(vis.shape, aud.shape) + + # self.transformer will concatenate the vis and aud in one sequence with aux tokens, + # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens + logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer + + return logits + + def extract_vfeats(self, vis): + B, S, Tv, C, H, W = vis.shape + vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) + # feat extractors return a tuple of segment-level and global features (ignored for sync) + # (B, S, tv, D), e.g. (B, 7, 8, 768) + vis = self.vfeat_extractor(vis) + return vis + + def extract_afeats(self, aud): + B, S, _, Fa, Ta = aud.shape + aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F) + # (B, S, ta, D), e.g. (B, 7, 6, 768) + aud, _ = self.afeat_extractor(aud) + return aud + + def compute_loss(self, logits, targets, loss_fn: str = None): + loss = None + if targets is not None: + if loss_fn is None or loss_fn == "cross_entropy": + # logits: (B, cls) and targets: (B,) + loss = F.cross_entropy(logits, targets) + else: + raise NotImplementedError(f"Loss {loss_fn} not implemented") + return loss + + def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): + # discard all entries except vfeat_extractor + # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} + + return super().load_state_dict(sd, strict) + + +class RandInitPositionalEncoding(nn.Module): + """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors.""" + + def __init__(self, block_shape: list, n_embd: int): + super().__init__() + self.block_shape = block_shape + self.n_embd = n_embd + self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd)) + + def forward(self, token_embeddings): + return token_embeddings + self.pos_emb + + +class GlobalTransformer(torch.nn.Module): + """Same as in SparseSync but without the selector transformers and the head""" + + def __init__( + self, + tok_pdrop=0.0, + embd_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + n_layer=3, + n_head=8, + n_embd=768, + pos_emb_block_shape=[ + 198, + ], + n_off_head_out=21, + ) -> None: + super().__init__() + self.config = Config( + embd_pdrop=embd_pdrop, + resid_pdrop=resid_pdrop, + attn_pdrop=attn_pdrop, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + ) + # input norm + self.vis_in_lnorm = torch.nn.LayerNorm(n_embd) + self.aud_in_lnorm = torch.nn.LayerNorm(n_embd) + # aux tokens + self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) + self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd)) + # whole token dropout + self.tok_pdrop = tok_pdrop + self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop) + self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop) + # maybe add pos emb + self.pos_emb_cfg = RandInitPositionalEncoding( + block_shape=pos_emb_block_shape, + n_embd=n_embd, + ) + # the stem + self.drop = torch.nn.Dropout(embd_pdrop) + self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)]) + # pre-output norm + self.ln_f = torch.nn.LayerNorm(n_embd) + # maybe add a head + self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out) + + def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True): + B, Sv, D = v.shape + B, Sa, D = a.shape + # broadcasting special tokens to the batch size + off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B) + mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B) + # norm + v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a) + # maybe whole token dropout + if self.tok_pdrop > 0: + v, a = self.tok_drop_vis(v), self.tok_drop_aud(a) + # (B, 1+Sv+1+Sa, D) + x = torch.cat((off_tok, v, mod_tok, a), dim=1) + # maybe add pos emb + if hasattr(self, "pos_emb_cfg"): + x = self.pos_emb_cfg(x) + # dropout -> stem -> norm + x = self.drop(x) + x = self.blocks(x) + x = self.ln_f(x) + # maybe add heads + if attempt_to_apply_heads and hasattr(self, "off_head"): + x = self.off_head(x[:, 0, :]) + return x + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # # causal mask to ensure that attention is only applied to the left in the input sequence + # mask = torch.tril(torch.ones(config.block_size, + # config.block_size)) + # if hasattr(config, "n_unmasked"): + # mask[:config.n_unmasked, :config.n_unmasked] = 1 + # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + + def forward(self, x): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + + return y + + +class Block(nn.Module): + """an unassuming Transformer block""" + + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = SelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), # nice + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + + +def make_class_grid( + leftmost_val, + rightmost_val, + grid_size, + add_extreme_offset: bool = False, + seg_size_vframes: int = None, + nseg: int = None, + step_size_seg: float = None, + vfps: float = None, +): + assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()" + grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float() + if add_extreme_offset: + assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}" + seg_size_sec = seg_size_vframes / vfps + trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1) + extreme_value = trim_size_in_seg * seg_size_sec + grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid + return grid + + +# from synchformer +def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0): + difference = max_spec_t - audio.shape[-1] # safe for batched input + # pad or truncate, depending on difference + if difference > 0: + # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input + pad_dims = (0, difference) + audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value) + elif difference < 0: + print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).") + audio = audio[..., :max_spec_t] # safe for batched input + return audio + + +def encode_audio_with_sync( + synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram +) -> torch.Tensor: + b, t = x.shape + + # partition the video + segment_size = 10240 + step_size = 10240 // 2 + num_segments = (t - segment_size) // step_size + 1 + segments = [] + for i in range(num_segments): + segments.append(x[:, i * step_size : i * step_size + segment_size]) + x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) + + x = mel(x) + x = torch.log(x + 1e-6) + x = pad_or_truncate(x, 66) + + mean = -4.2677393 + std = 4.5689974 + x = (x - mean) / (2 * std) + # x: B * S * 128 * 66 + x = synchformer.extract_afeats(x.unsqueeze(2)) + return x + + +def read_audio(filename, expected_length=int(16000 * 4)): + waveform, sr = torchaudio.load(filename) + waveform = waveform.mean(dim=0) + + if sr != 16000: + resampler = torchaudio.transforms.Resample(sr, 16000) + waveform = resampler[sr](waveform) + + waveform = waveform[:expected_length] + if waveform.shape[0] != expected_length: + raise ValueError(f"Audio {filename} is too short") + + waveform = waveform.squeeze() + + return waveform + + +if __name__ == "__main__": + synchformer = Synchformer().cuda().eval() + + # mmaudio provided synchformer ckpt + synchformer.load_state_dict( + torch.load( + os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"), + weights_only=True, + map_location="cpu", + ) + ) + + sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram( + sample_rate=16000, + win_length=400, + hop_length=160, + n_fft=1024, + n_mels=128, + ) diff --git a/hunyuanvideo_foley/models/synchformer/utils.py b/hunyuanvideo_foley/models/synchformer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05595cc15b925f52ccd07fea8f131ec810f56bd7 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/utils.py @@ -0,0 +1,87 @@ +from hashlib import md5 +from pathlib import Path +import subprocess + +import requests +from tqdm import tqdm + +PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a" +FNAME2LINK = { + # S3: Synchability: AudioSet (run 2) + "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt", + "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml", + # S2: Synchformer: AudioSet (run 2) + "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt", + "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml", + # S2: Synchformer: AudioSet (run 1) + "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt", + "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml", + # S2: Synchformer: LRS3 (run 2) + "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt", + "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml", + # S2: Synchformer: VGS (run 2) + "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt", + "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml", + # SparseSync: ft VGGSound-Full + "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt", + "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml", + # SparseSync: ft VGGSound-Sparse + "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt", + "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml", + # SparseSync: only pt on LRS3 + "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt", + "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml", + # SparseSync: feature extractors + "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s + "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s + "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s + "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s + "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s + "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s + "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s +} + + +def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): + """Checks if file exists, if not downloads it from the link to the path""" + path = Path(path) + if not path.exists(): + path.parent.mkdir(exist_ok=True, parents=True) + link = fname2link.get(path.name, None) + if link is None: + raise ValueError( + f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists." + ) + with requests.get(fname2link[path.name], stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def which_ffmpeg() -> str: + """Determines the path to ffmpeg library + Returns: + str -- path to the library + """ + result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "") + return ffmpeg_path + + +def get_md5sum(path): + hash_md5 = md5() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(4096 * 8), b""): + hash_md5.update(chunk) + md5sum = hash_md5.hexdigest() + return md5sum + + +class Config: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) diff --git a/hunyuanvideo_foley/models/synchformer/video_model_builder.py b/hunyuanvideo_foley/models/synchformer/video_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..190df1a5f066c2c06ab41178fc1174c7956bc599 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/video_model_builder.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition + +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +from timm.layers import trunc_normal_ + +from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock + + +class VisionTransformer(nn.Module): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__(self, cfg): + super().__init__() + self.img_size = cfg.DATA.TRAIN_CROP_SIZE + self.patch_size = cfg.VIT.PATCH_SIZE + self.in_chans = cfg.VIT.CHANNELS + if cfg.TRAIN.DATASET == "Epickitchens": + self.num_classes = [97, 300] + else: + self.num_classes = cfg.MODEL.NUM_CLASSES + self.embed_dim = cfg.VIT.EMBED_DIM + self.depth = cfg.VIT.DEPTH + self.num_heads = cfg.VIT.NUM_HEADS + self.mlp_ratio = cfg.VIT.MLP_RATIO + self.qkv_bias = cfg.VIT.QKV_BIAS + self.drop_rate = cfg.VIT.DROP + self.drop_path_rate = cfg.VIT.DROP_PATH + self.head_dropout = cfg.VIT.HEAD_DROPOUT + self.video_input = cfg.VIT.VIDEO_INPUT + self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION + self.use_mlp = cfg.VIT.USE_MLP + self.num_features = self.embed_dim + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT + self.head_act = cfg.VIT.HEAD_ACT + self.cfg = cfg + + # Patch Embedding + self.patch_embed = PatchEmbed( + img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim + ) + + # 3D Patch Embedding + self.patch_embed_3d = PatchEmbed3D( + img_size=self.img_size, + temporal_resolution=self.temporal_resolution, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP, + ) + self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data) + + # Number of patches + if self.video_input: + num_patches = self.patch_embed.num_patches * self.temporal_resolution + else: + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + # CLS token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + trunc_normal_(self.cls_token, std=0.02) + + # Positional embedding + self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) + self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) + trunc_normal_(self.pos_embed, std=0.02) + + if self.cfg.VIT.POS_EMBED == "joint": + self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) + trunc_normal_(self.st_embed, std=0.02) + elif self.cfg.VIT.POS_EMBED == "separate": + self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) + + # Layer Blocks + dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] + if self.cfg.VIT.ATTN_LAYER == "divided": + self.blocks = nn.ModuleList( + [ + DividedSpaceTimeBlock( + attn_type=cfg.VIT.ATTN_LAYER, + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(self.depth) + ] + ) + + self.norm = norm_layer(self.embed_dim) + + # MLP head + if self.use_mlp: + hidden_dim = self.embed_dim + if self.head_act == "tanh": + # logging.info("Using TanH activation in MLP") + act = nn.Tanh() + elif self.head_act == "gelu": + # logging.info("Using GELU activation in MLP") + act = nn.GELU() + else: + # logging.info("Using ReLU activation in MLP") + act = nn.ReLU() + self.pre_logits = nn.Sequential( + OrderedDict( + [ + ("fc", nn.Linear(self.embed_dim, hidden_dim)), + ("act", act), + ] + ) + ) + else: + self.pre_logits = nn.Identity() + + # Classifier Head + self.head_drop = nn.Dropout(p=self.head_dropout) + if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1: + for a, i in enumerate(range(len(self.num_classes))): + setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) + else: + self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + if self.cfg.VIT.POS_EMBED == "joint": + return {"pos_embed", "cls_token", "st_embed"} + else: + return {"pos_embed", "cls_token", "temp_embed"} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + # if self.video_input: + # x = x[0] + B = x.shape[0] + + # Tokenize input + # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: + # for simplicity of mapping between content dimensions (input x) and token dims (after patching) + # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): + + # apply patching on input + x = self.patch_embed_3d(x) + tok_mask = None + + # else: + # tok_mask = None + # # 2D tokenization + # if self.video_input: + # x = x.permute(0, 2, 1, 3, 4) + # (B, T, C, H, W) = x.shape + # x = x.reshape(B * T, C, H, W) + + # x = self.patch_embed(x) + + # if self.video_input: + # (B2, T2, D2) = x.shape + # x = x.reshape(B, T * T2, D2) + + # Append CLS token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + # if tok_mask is not None: + # # prepend 1(=keep) to the mask to account for the CLS token as well + # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) + + # Interpolate positinoal embeddings + # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: + # pos_embed = self.pos_embed + # N = pos_embed.shape[1] - 1 + # npatch = int((x.size(1) - 1) / self.temporal_resolution) + # class_emb = pos_embed[:, 0] + # pos_embed = pos_embed[:, 1:] + # dim = x.shape[-1] + # pos_embed = torch.nn.functional.interpolate( + # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + # scale_factor=math.sqrt(npatch / N), + # mode='bicubic', + # ) + # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + # else: + new_pos_embed = self.pos_embed + npatch = self.patch_embed.num_patches + + # Add positional embeddings to input + if self.video_input: + if self.cfg.VIT.POS_EMBED == "separate": + cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) + tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) + tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) + total_pos_embed = tile_pos_embed + tile_temporal_embed + total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) + x = x + total_pos_embed + elif self.cfg.VIT.POS_EMBED == "joint": + x = x + self.st_embed + else: + # image input + x = x + new_pos_embed + + # Apply positional dropout + x = self.pos_drop(x) + + # Encoding using transformer layers + for i, blk in enumerate(self.blocks): + x = blk( + x, + seq_len=npatch, + num_frames=self.temporal_resolution, + approx=self.cfg.VIT.APPROX_ATTN_TYPE, + num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, + tok_mask=tok_mask, + ) + + ### v-iashin: I moved it to the forward pass + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + ### + return x, tok_mask + + # def forward(self, x): + # x = self.forward_features(x) + # ### v-iashin: here. This should leave the same forward output as before + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + # ### + # x = self.head_drop(x) + # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: + # output = [] + # for head in range(len(self.num_classes)): + # x_out = getattr(self, "head%d" % head)(x) + # if not self.training: + # x_out = torch.nn.functional.softmax(x_out, dim=-1) + # output.append(x_out) + # return output + # else: + # x = self.head(x) + # if not self.training: + # x = torch.nn.functional.softmax(x, dim=-1) + # return x diff --git a/hunyuanvideo_foley/models/synchformer/vit_helper.py b/hunyuanvideo_foley/models/synchformer/vit_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..29739530ce8692f3124b3ae748f11a4a06aa5fc8 --- /dev/null +++ b/hunyuanvideo_foley/models/synchformer/vit_helper.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition +"""Video models.""" + +import math + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from timm.layers import to_2tuple +from torch import einsum +from torch.nn import functional as F + +default_cfgs = { + "vit_1k": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth", + "vit_1k_large": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth", +} + + +def qkv_attn(q, k, v, tok_mask: torch.Tensor = None): + sim = einsum("b i d, b j d -> b i j", q, k) + # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N) + if tok_mask is not None: + BSH, N = tok_mask.shape + sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, float("-inf")) # 1 - broadcasts across N + attn = sim.softmax(dim=-1) + out = einsum("b i j, b j d -> b i d", attn, v) + return out + + +class DividedAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + # init to zeros + self.qkv.weight.data.fill_(0) + self.qkv.bias.data.fill_(0) + self.proj.weight.data.fill_(1) + self.proj.bias.data.fill_(0) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): + # num of heads variable + h = self.num_heads + + # project x to q, k, v vaalues + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + if tok_mask is not None: + # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d + assert len(tok_mask.shape) == 2 + tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1]) + + # Scale q + q *= self.scale + + # Take out cls_q, cls_k, cls_v + (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) + # the same for masking + if tok_mask is not None: + cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:] + else: + cls_mask, mask_ = None, None + + # let CLS token attend to key / values of all patches across time and space + cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask) + + # rearrange across time or space + q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_)) + + # expand CLS token keys and values across time or space and concat + r = q_.shape[0] // cls_k.shape[0] + cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v)) + + k_ = torch.cat((cls_k, k_), dim=1) + v_ = torch.cat((cls_v, v_), dim=1) + + # the same for masking (if provided) + if tok_mask is not None: + # since mask does not have the latent dim (d), we need to remove it from einops dims + mask_ = rearrange(mask_, f"{einops_from} -> {einops_to}".replace(" d", ""), **einops_dims) + cls_mask = repeat(cls_mask, "b () -> (b r) ()", r=r) # expand cls_mask across time or space + mask_ = torch.cat((cls_mask, mask_), dim=1) + + # attention + out = qkv_attn(q_, k_, v_, tok_mask=mask_) + + # merge back time or space + out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims) + + # concat back the cls token + out = torch.cat((cls_out, out), dim=1) + + # merge back the heads + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + + ## to out + x = self.proj(out) + x = self.proj_drop(x) + return x + + +class DividedSpaceTimeBlock(nn.Module): + + def __init__( + self, + dim=768, + num_heads=12, + attn_type="divided", + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + + self.einops_from_space = "b (f n) d" + self.einops_to_space = "(b f) n d" + self.einops_from_time = "b (f n) d" + self.einops_to_time = "(b n) f d" + + self.norm1 = norm_layer(dim) + + self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + + self.timeattn = DividedAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm3 = norm_layer(dim) + + def forward(self, x, seq_len=196, num_frames=8, approx="none", num_landmarks=128, tok_mask: torch.Tensor = None): + time_output = self.timeattn( + self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask + ) + time_residual = x + time_output + + space_output = self.attn( + self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask + ) + space_residual = time_residual + self.drop_path(space_output) + + x = space_residual + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = img_size if type(img_size) is tuple else to_2tuple(img_size) + patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + temporal_resolution=4, + in_chans=3, + patch_size=16, + z_block_size=2, + embed_dim=768, + flatten=True, + ): + super().__init__() + self.height = img_size // patch_size + self.width = img_size // patch_size + ### v-iashin: these two are incorrect + # self.frames = (temporal_resolution // z_block_size) + # self.num_patches = self.height * self.width * self.frames + self.z_block_size = z_block_size + ### + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=(z_block_size, patch_size, patch_size), + stride=(z_block_size, patch_size, patch_size), + ) + self.flatten = flatten + + def forward(self, x): + B, C, T, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + return x + + +class HeadMLP(nn.Module): + + def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): + super(HeadMLP, self).__init__() + self.n_input = n_input + self.n_classes = n_classes + self.n_hidden = n_hidden + if n_hidden is None: + # use linear classifier + self.block_forward = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True)) + else: + # use simple MLP classifier + self.block_forward = nn.Sequential( + nn.Dropout(p=p), + nn.Linear(n_input, n_hidden, bias=True), + nn.BatchNorm1d(n_hidden), + nn.ReLU(inplace=True), + nn.Dropout(p=p), + nn.Linear(n_hidden, n_classes, bias=True), + ) + print(f"Dropout-NLP: {p}") + + def forward(self, x): + return self.block_forward(x) + + +def _conv_filter(state_dict, patch_size=16): + """convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if "patch_embed.proj.weight" in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +def adapt_input_conv(in_chans, conv_weight, agg="sum"): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + if agg == "sum": + print("Summing conv1 weights") + conv_weight = conv_weight.sum(dim=1, keepdim=True) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError("Weight format not supported by conversion.") + else: + if agg == "sum": + print("Summing conv1 weights") + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= 3 / float(in_chans) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + conv_weight = conv_weight.repeat(1, in_chans, 1, 1) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): + # Load state dict + assert f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]" + state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS]) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + input_convs = "patch_embed.proj" + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + ".weight" + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name], agg="avg") + print(f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)") + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + print(f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer.") + + classifier_name = "head" + label_offset = cfg.get("label_offset", 0) + pretrain_classes = 1000 + if num_classes != pretrain_classes: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + ".weight"] + del state_dict[classifier_name + ".bias"] + strict = False + elif label_offset > 0: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + ".weight"] + state_dict[classifier_name + ".weight"] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + ".bias"] + state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:] + + loaded_state = state_dict + self_state = model.state_dict() + all_names = set(self_state.keys()) + saved_names = set([]) + for name, param in loaded_state.items(): + param = param + if "module." in name: + name = name.replace("module.", "") + if name in self_state.keys() and param.shape == self_state[name].shape: + saved_names.add(name) + self_state[name].copy_(param) + else: + print(f"didnt load: {name} of shape: {param.shape}") + print("Missing Keys:") + print(all_names - saved_names) diff --git a/hunyuanvideo_foley/utils/__init__.py b/hunyuanvideo_foley/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ca77588c6a0f723bb634860c77fcb4cdcfd1ead Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eb0c8455ff9e2fbd8c0b379a70280b0e5242439 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c5f88e71027e453b930a83e2ae0aadf23d3a057 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87840261b58868ed5c287cf825b36421a42ff394 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e85af7b40960c6204595bf1fd83da5f2643f415e Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f50455f2bb1971c21da20d1979c966b8060d2d5 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/helper.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bd596162c662e3b5e539562110e67f32ab1cfb Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7985f406d537b40f2a34c1318d08997eae4b060 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92561d8239b90bcb3fdbd5e3bc7df74352d94bc6 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5d902564adfd96af1ac11ea48f17cbb184422b4 Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b04d192c9e90ea387d0404756e7a8bf131e24ef Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd5172c98502b6572776c6f4dbd4f6c90ab071dd Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/config_utils.py b/hunyuanvideo_foley/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1bdf9108dfebf68c0a27cfdd675c8cc5e4c1d88 --- /dev/null +++ b/hunyuanvideo_foley/utils/config_utils.py @@ -0,0 +1,109 @@ +"""Configuration utilities for the HunyuanVideo-Foley project.""" + +import yaml +from pathlib import Path +from typing import Any, Dict, List, Union + +class AttributeDict: + + def __init__(self, data: Union[Dict, List, Any]): + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, (dict, list)): + value = AttributeDict(value) + setattr(self, self._sanitize_key(key), value) + elif isinstance(data, list): + self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item + for item in data] + else: + self._value = data + + def _sanitize_key(self, key: str) -> str: + import re + sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key)) + if sanitized[0].isdigit(): + sanitized = f'_{sanitized}' + return sanitized + + def __getitem__(self, key): + if hasattr(self, '_list'): + return self._list[key] + return getattr(self, self._sanitize_key(key)) + + def __setitem__(self, key, value): + if hasattr(self, '_list'): + self._list[key] = value + else: + setattr(self, self._sanitize_key(key), value) + + def __iter__(self): + if hasattr(self, '_list'): + return iter(self._list) + return iter(self.__dict__.keys()) + + def __len__(self): + if hasattr(self, '_list'): + return len(self._list) + return len(self.__dict__) + + def get(self, key, default=None): + try: + return self[key] + except (KeyError, AttributeError, IndexError): + return default + + def keys(self): + if hasattr(self, '_list'): + return range(len(self._list)) + elif hasattr(self, '_value'): + return [] + else: + return [key for key in self.__dict__.keys() if not key.startswith('_')] + + def values(self): + if hasattr(self, '_list'): + return self._list + elif hasattr(self, '_value'): + return [self._value] + else: + return [value for key, value in self.__dict__.items() if not key.startswith('_')] + + def items(self): + if hasattr(self, '_list'): + return enumerate(self._list) + elif hasattr(self, '_value'): + return [] + else: + return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')] + + def __repr__(self): + if hasattr(self, '_list'): + return f"AttributeDict({self._list})" + elif hasattr(self, '_value'): + return f"AttributeDict({self._value})" + return f"AttributeDict({dict(self.__dict__)})" + + def to_dict(self) -> Union[Dict, List, Any]: + if hasattr(self, '_list'): + return [item.to_dict() if isinstance(item, AttributeDict) else item + for item in self._list] + elif hasattr(self, '_value'): + return self._value + else: + result = {} + for key, value in self.__dict__.items(): + if isinstance(value, AttributeDict): + result[key] = value.to_dict() + else: + result[key] = value + return result + +def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict: + try: + with open(file_path, 'r', encoding=encoding) as file: + data = yaml.safe_load(file) + return AttributeDict(data) + except FileNotFoundError: + raise FileNotFoundError(f"YAML file not found: {file_path}") + except yaml.YAMLError as e: + raise yaml.YAMLError(f"YAML format error: {e}") diff --git a/hunyuanvideo_foley/utils/feature_utils.py b/hunyuanvideo_foley/utils/feature_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..39f474479fdac18000e89d4db0b5211e35dffe23 --- /dev/null +++ b/hunyuanvideo_foley/utils/feature_utils.py @@ -0,0 +1,156 @@ +"""Feature extraction utilities for video and text processing.""" + +import os +import numpy as np +import torch +import av +from PIL import Image +from einops import rearrange +from typing import Any, Dict, List, Union, Tuple +from loguru import logger + +from .config_utils import AttributeDict +from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS + + +class FeatureExtractionError(Exception): + """Exception raised for feature extraction errors.""" + pass + +def get_frames_av( + video_path: str, + fps: float, + max_length: float = None, +) -> Tuple[np.ndarray, float]: + end_sec = max_length if max_length is not None else 15 + next_frame_time_for_each_fps = 0.0 + time_delta_for_each_fps = 1 / fps + + all_frames = [] + output_frames = [] + + with av.open(video_path) as container: + stream = container.streams.video[0] + ori_fps = stream.guessed_rate + stream.thread_type = "AUTO" + for packet in container.demux(stream): + for frame in packet.decode(): + frame_time = frame.time + if frame_time < 0: + continue + if frame_time > end_sec: + break + + frame_np = None + + this_time = frame_time + while this_time >= next_frame_time_for_each_fps: + if frame_np is None: + frame_np = frame.to_ndarray(format="rgb24") + + output_frames.append(frame_np) + next_frame_time_for_each_fps += time_delta_for_each_fps + + output_frames = np.stack(output_frames) + + vid_len_in_s = len(output_frames) / fps + if max_length is not None and len(output_frames) > int(max_length * fps): + output_frames = output_frames[: int(max_length * fps)] + vid_len_in_s = max_length + + return output_frames, vid_len_in_s + +@torch.inference_mode() +def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1): + b, t, c, h, w = x.shape + if batch_size < 0: + batch_size = b * t + x = rearrange(x, "b t c h w -> (b t) c h w") + outputs = [] + for i in range(0, b * t, batch_size): + outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size])) + res = torch.cat(outputs, dim=0) + res = rearrange(res, "(b t) d -> b t d", b=b) + return res + +@torch.inference_mode() +def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1): + """ + The input video of x is best to be in fps of 24 of greater than 24. + Input: + x: tensor in shape of [B, T, C, H, W] + batch_size: the batch_size for synchformer inference + """ + b, t, c, h, w = x.shape + assert c == 3 and h == 224 and w == 224 + + segment_size = 16 + step_size = 8 + num_segments = (t - segment_size) // step_size + 1 + segments = [] + for i in range(num_segments): + segments.append(x[:, i * step_size : i * step_size + segment_size]) + x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224) + + outputs = [] + if batch_size < 0: + batch_size = b * num_segments + x = rearrange(x, "b s t c h w -> (b s) 1 t c h w") + for i in range(0, b * num_segments, batch_size): + with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half): + outputs.append(model_dict.syncformer_model(x[i : i + batch_size])) + x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768] + x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b) + return x + + +@torch.inference_mode() +def encode_video_features(video_path, model_dict): + visual_features = {} + # siglip2 visual features + frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"]) + images = [Image.fromarray(frame).convert('RGB') for frame in frames] + images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W] + clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0) + visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device) + + # synchformer visual features + frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"]) + images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W] + sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224] + # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video + visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict) + + vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"] + visual_features = AttributeDict(visual_features) + + return visual_features, vid_len_in_s + +@torch.inference_mode() +def encode_text_feat(text: List[str], model_dict): + # x: (B, L) + inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device) + outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True) + return outputs.last_hidden_state, outputs.attentions + + +def feature_process(video_path, prompt, model_dict, cfg): + visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict) + neg_prompt = "noisy, harsh" + prompts = [neg_prompt, prompt] + text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict) + + text_feat = text_feat_res[1:] + uncond_text_feat = text_feat_res[:1] + + if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]: + text_seq_length = cfg.model_config.model_kwargs.text_length + text_feat = text_feat[:, :text_seq_length] + uncond_text_feat = uncond_text_feat[:, :text_seq_length] + + text_feats = AttributeDict({ + 'text_feat': text_feat, + 'uncond_text_feat': uncond_text_feat, + }) + + return visual_feats, text_feats, audio_len_in_s diff --git a/hunyuanvideo_foley/utils/helper.py b/hunyuanvideo_foley/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..04840dfcdf847e3f61ec327ebc691b0dbd139da4 --- /dev/null +++ b/hunyuanvideo_foley/utils/helper.py @@ -0,0 +1,134 @@ +import collections.abc +from itertools import repeat +import importlib +import yaml +import time + +def default(value, default_val): + return default_val if value is None else value + + +def default_dtype(value, default_val): + if value is not None: + assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}." + return value + return default_val + + +def repeat_interleave(lst, num_repeats): + return [item for item in lst for _ in range(num_repeats)] + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + x = tuple(x) + if len(x) == 1: + x = tuple(repeat(x[0], n)) + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) + + +def as_tuple(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + if x is None or isinstance(x, (int, float, str)): + return (x,) + else: + raise ValueError(f"Unknown type {type(x)}") + + +def as_list_of_2tuple(x): + x = as_tuple(x) + if len(x) == 1: + x = (x[0], x[0]) + assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." + lst = [] + for i in range(0, len(x), 2): + lst.append((x[i], x[i + 1])) + return lst + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n - (n % k) + k + + +def merge_dicts(dict1, dict2): + for key, value in dict2.items(): + if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict): + merge_dicts(dict1[key], value) + else: + dict1[key] = value + return dict1 + + +def merge_yaml_files(file_list): + merged_config = {} + + for file in file_list: + with open(file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + if config: + # Remove the first level + for key, value in config.items(): + if isinstance(value, dict): + merged_config = merge_dicts(merged_config, value) + else: + merged_config[key] = value + + return merged_config + + +def merge_dict(file_list): + merged_config = {} + + for file in file_list: + with open(file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + if config: + merged_config = merge_dicts(merged_config, config) + + return merged_config + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def readable_time(seconds): + """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """ + seconds = int(seconds) + days, seconds = divmod(seconds, 86400) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + if days > 0: + return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds" + if hours > 0: + return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds" + if minutes > 0: + return f"{minutes} Minutes, {seconds} Seconds" + return f"{seconds} Seconds" + + +def get_obj_from_cfg(cfg, reload=False): + if isinstance(cfg, str): + return get_obj_from_str(cfg, reload) + elif isinstance(cfg, (list, tuple,)): + return tuple([get_obj_from_str(c, reload) for c in cfg]) + else: + raise NotImplementedError(f"Not implemented for {type(cfg)}.") diff --git a/hunyuanvideo_foley/utils/media_utils.py b/hunyuanvideo_foley/utils/media_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d80e6cbfe5164caeb7340fa2fbbcb8a11ddd9c13 --- /dev/null +++ b/hunyuanvideo_foley/utils/media_utils.py @@ -0,0 +1,101 @@ +"""Media utilities for audio/video processing.""" + +import os +import subprocess +from pathlib import Path +from typing import Optional + +from loguru import logger + + +class MediaProcessingError(Exception): + """Exception raised for media processing errors.""" + pass + + +def merge_audio_video( + audio_path: str, + video_path: str, + output_path: str, + overwrite: bool = True, + quality: str = "high" +) -> str: + """ + Merge audio and video files using ffmpeg. + + Args: + audio_path: Path to input audio file + video_path: Path to input video file + output_path: Path for output video file + overwrite: Whether to overwrite existing output file + quality: Quality setting ('high', 'medium', 'low') + + Returns: + Path to the output file + + Raises: + MediaProcessingError: If input files don't exist or ffmpeg fails + FileNotFoundError: If ffmpeg is not installed + """ + # Validate input files + if not os.path.exists(audio_path): + raise MediaProcessingError(f"Audio file not found: {audio_path}") + if not os.path.exists(video_path): + raise MediaProcessingError(f"Video file not found: {video_path}") + + # Create output directory if needed + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Quality settings + quality_settings = { + "high": ["-b:a", "192k"], + "medium": ["-b:a", "128k"], + "low": ["-b:a", "96k"] + } + + # Build ffmpeg command + ffmpeg_command = [ + "ffmpeg", + "-i", video_path, + "-i", audio_path, + "-c:v", "copy", + "-c:a", "aac", + "-ac", "2", + "-af", "pan=stereo|c0=c0|c1=c0", + "-map", "0:v:0", + "-map", "1:a:0", + *quality_settings.get(quality, quality_settings["high"]), + ] + + if overwrite: + ffmpeg_command.append("-y") + + ffmpeg_command.append(output_path) + + try: + logger.info(f"Merging audio '{audio_path}' with video '{video_path}'") + process = subprocess.Popen( + ffmpeg_command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + stdout, stderr = process.communicate() + + if process.returncode != 0: + error_msg = f"FFmpeg failed with return code {process.returncode}: {stderr}" + logger.error(error_msg) + raise MediaProcessingError(error_msg) + else: + logger.info(f"Successfully merged video saved to: {output_path}") + + except FileNotFoundError: + raise FileNotFoundError( + "ffmpeg not found. Please install ffmpeg: " + "https://ffmpeg.org/download.html" + ) + except Exception as e: + raise MediaProcessingError(f"Unexpected error during media processing: {e}") + + return output_path diff --git a/hunyuanvideo_foley/utils/model_utils.py b/hunyuanvideo_foley/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..354ea223c7ef8ae4f8759f6763028d27e6ad13bc --- /dev/null +++ b/hunyuanvideo_foley/utils/model_utils.py @@ -0,0 +1,241 @@ +import torch +import os +from loguru import logger +from torchvision import transforms +from torchvision.transforms import v2 +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection +from ..models.dac_vae.model.dac import DAC +from ..models.synchformer import Synchformer +from ..models.hifi_foley import HunyuanVideoFoley +from .config_utils import load_yaml, AttributeDict +from .schedulers import FlowMatchDiscreteScheduler +from tqdm import tqdm + +def load_state_dict(model, model_path): + logger.info(f"Loading model state dict from: {model_path}") + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if missing_keys: + logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):") + for key in missing_keys: + logger.warning(f" - {key}") + else: + logger.info("No missing keys found") + + if unexpected_keys: + logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):") + for key in unexpected_keys: + logger.warning(f" - {key}") + else: + logger.info("No unexpected keys found") + + logger.info("Model state dict loaded successfully") + return model + +def load_model(model_path, config_path, device): + logger.info("Starting model loading process...") + logger.info(f"Configuration file: {config_path}") + logger.info(f"Model weights dir: {model_path}") + logger.info(f"Target device: {device}") + + cfg = load_yaml(config_path) + logger.info("Configuration loaded successfully") + + # HunyuanVideoFoley + logger.info("Loading HunyuanVideoFoley main model...") + foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16) + foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth")) + foley_model.eval() + logger.info("HunyuanVideoFoley model loaded and set to evaluation mode") + + # DAC-VAE + dac_path = os.path.join(model_path, "vae_128d_48k.pth") + logger.info(f"Loading DAC VAE model from: {dac_path}") + dac_model = DAC.load(dac_path) + dac_model = dac_model.to(device) + dac_model.requires_grad_(False) + dac_model.eval() + logger.info("DAC VAE model loaded successfully") + + # Siglip2 visual-encoder + logger.info("Loading SigLIP2 visual encoder...") + siglip2_preprocess = transforms.Compose([ + transforms.Resize((512, 512)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval() + logger.info("SigLIP2 model and preprocessing pipeline loaded successfully") + + # clap text-encoder + logger.info("Loading CLAP text encoder...") + clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general") + clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device) + logger.info("CLAP tokenizer and model loaded successfully") + + # syncformer + syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth") + logger.info(f"Loading Synchformer model from: {syncformer_path}") + syncformer_preprocess = v2.Compose( + [ + v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(224), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + + syncformer_model = Synchformer() + syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu")) + syncformer_model = syncformer_model.to(device).eval() + logger.info("Synchformer model and preprocessing pipeline loaded successfully") + + + logger.info("Creating model dictionary with attribute access...") + model_dict = AttributeDict({ + 'foley_model': foley_model, + 'dac_model': dac_model, + 'siglip2_preprocess': siglip2_preprocess, + 'siglip2_model': siglip2_model, + 'clap_tokenizer': clap_tokenizer, + 'clap_model': clap_model, + 'syncformer_preprocess': syncformer_preprocess, + 'syncformer_model': syncformer_model, + 'device': device, + }) + + logger.info("All models loaded successfully!") + logger.info("Available model components:") + for key in model_dict.keys(): + logger.info(f" - {key}") + logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)") + + return model_dict, cfg + +def retrieve_timesteps( + scheduler, + num_inference_steps, + device, + **kwargs, +): + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device): + shape = (batch_size, num_channels_latents, int(length)) + latents = randn_tensor(shape, device=device, dtype=dtype) + + # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler + if hasattr(scheduler, "init_noise_sigma"): + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler.init_noise_sigma + + return latents + + +@torch.no_grad() +def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1): + + target_dtype = model_dict.foley_model.dtype + autocast_enabled = target_dtype != torch.float32 + device = model_dict.device + + scheduler = FlowMatchDiscreteScheduler( + shift=cfg.diffusion_config.sample_flow_shift, + reverse=cfg.diffusion_config.flow_reverse, + solver=cfg.diffusion_config.flow_solver, + use_flux_shift=cfg.diffusion_config.sample_use_flux_shift, + flux_base_shift=cfg.diffusion_config.flux_base_shift, + flux_max_shift=cfg.diffusion_config.flux_max_shift, + ) + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + device, + ) + + latents = prepare_latents( + scheduler, + batch_size=batch_size, + num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim, + length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate, + dtype=target_dtype, + device=device, + ) + + # Denoise loop + for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"): + # noise latents + latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents + latent_input = scheduler.scale_model_input(latent_input, t) + + t_expand = t.repeat(latent_input.shape[0]) + + # siglip2 features + siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence( + bs=batch_size, len=siglip2_feat.shape[1] + ).to(device) + + if guidance_scale is not None and guidance_scale > 1.0: + siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0) + else: + siglip2_feat_input = siglip2_feat + + # syncformer features + syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence( + bs=batch_size, len=syncformer_feat.shape[1] + ).to(device) + if guidance_scale is not None and guidance_scale > 1.0: + syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0) + else: + syncformer_feat_input = syncformer_feat + + # text features + text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size + if guidance_scale is not None and guidance_scale > 1.0: + text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0) + else: + text_feat_input = text_feat_repeated + + with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype): + # Predict the noise residual + noise_pred = model_dict.foley_model( + x=latent_input, + t=t_expand, + cond=text_feat_input, + clip_feat=siglip2_feat_input, + sync_feat=syncformer_feat_input, + return_dict=True, + )["x"] + + noise_pred = noise_pred.to(dtype=torch.float32) + + if guidance_scale is not None and guidance_scale > 1.0: + # Perform classifier-free guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # Post-process the latents to audio + + with torch.no_grad(): + audio = model_dict.dac_model.decode(latents) + audio = audio.float().cpu() + + audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)] + + return audio, model_dict.dac_model.sample_rate + + diff --git a/hunyuanvideo_foley/utils/schedulers/__init__.py b/hunyuanvideo_foley/utils/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fea4433560c92fc9d8569993447d0bdb456dc9e --- /dev/null +++ b/hunyuanvideo_foley/utils/schedulers/__init__.py @@ -0,0 +1,2 @@ +from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler +from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler \ No newline at end of file diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf9cbc935de5371a4cb668e5303dd33e845f8368 Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..381e78dd35cabf6f12fbf6da052bcf4562808b45 Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-312.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96c4539046ffa747401e185d9fda122792d003a6 Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-312.pyc differ diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a2af07df956649977e8a493af347471de260eab Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc differ diff --git a/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py b/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py new file mode 100644 index 0000000000000000000000000000000000000000..2f58814479643b42eb18158db2fcc2a29544424e --- /dev/null +++ b/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py @@ -0,0 +1,376 @@ +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + reverse (`bool`, defaults to `True`): + Whether to reverse the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + reverse: bool = True, + solver: str = "euler", + use_flux_shift: bool = False, + flux_base_shift: float = 0.5, + flux_max_shift: float = 1.15, + n_tokens: Optional[int] = None, + ): + sigmas = torch.linspace(1, 0, num_train_timesteps + 1) + + if not reverse: + sigmas = sigmas.flip(0) + + self.sigmas = sigmas + # the value fed to model + self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) + self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32) + + self._step_index = None + self._begin_index = None + + self.supported_solver = [ + "euler", + "heun-2", "midpoint-2", + "kutta-4", + ] + if solver not in self.supported_solver: + raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}") + + # empty dt and derivative (for heun) + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + @property + def state_in_first_order(self): + return self.derivative_1 is None + + @property + def state_in_second_order(self): + return self.derivative_2 is None + + @property + def state_in_third_order(self): + return self.derivative_3 is None + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, + n_tokens: int = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + """ + self.num_inference_steps = num_inference_steps + + sigmas = torch.linspace(1, 0, num_inference_steps + 1) + + # Apply timestep shift + if self.config.use_flux_shift: + assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift" + mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens) + sigmas = self.flux_time_shift(mu, 1.0, sigmas) + elif self.config.shift != 1.: + sigmas = self.sd3_time_shift(sigmas) + + if not self.config.reverse: + sigmas = 1 - sigmas + + self.sigmas = sigmas + self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) + self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) + + # empty dt and derivative (for kutta) + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + + # Reset step index + self._step_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + return sample + + @staticmethod + def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + @staticmethod + def flux_time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def sd3_time_shift(self, t: torch.Tensor): + return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + pred_uncond: torch.FloatTensor = None, + generator: Optional[torch.Generator] = None, + n_tokens: Optional[int] = None, + return_dict: bool = True, + ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + model_output = model_output.to(torch.float32) + pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None + + # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + last_inner_step = True + if self.config.solver == "euler": + derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample) + elif self.config.solver in ["heun-2", "midpoint-2"]: + derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample) + elif self.config.solver == "kutta-4": + derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample) + else: + raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}") + + prev_sample = sample + derivative * dt + + # Cast sample back to model compatible dtype + # prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + if last_inner_step: + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) + + def first_order_method(self, model_output, sigma, sigma_next, sample): + derivative = model_output.float() + dt = sigma_next - sigma + return derivative, dt, sample, True + + def second_order_method(self, model_output, sigma, sigma_next, sample): + if self.state_in_first_order: + # store for 2nd order step + self.derivative_1 = model_output + self.dt = sigma_next - sigma + self.sample = sample + + derivative = model_output + if self.config.solver == 'heun-2': + dt = self.dt + elif self.config.solver == 'midpoint-2': + dt = self.dt / 2 + else: + raise NotImplementedError(f"Solver {self.config.solver} not supported.") + last_inner_step = False + + else: + if self.config.solver == 'heun-2': + derivative = 0.5 * (self.derivative_1 + model_output) + elif self.config.solver == 'midpoint-2': + derivative = model_output + else: + raise NotImplementedError(f"Solver {self.config.solver} not supported.") + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + last_inner_step = True + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.derivative_1 = None + self.dt = None + self.sample = None + + return derivative, dt, sample, last_inner_step + + def fourth_order_method(self, model_output, sigma, sigma_next, sample): + if self.state_in_first_order: + self.derivative_1 = model_output + self.dt = sigma_next - sigma + self.sample = sample + derivative = model_output + dt = self.dt / 2 + last_inner_step = False + + elif self.state_in_second_order: + self.derivative_2 = model_output + derivative = model_output + dt = self.dt / 2 + last_inner_step = False + + elif self.state_in_third_order: + self.derivative_3 = model_output + derivative = model_output + dt = self.dt + last_inner_step = False + + else: + derivative = 1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 + 1/6 * model_output + + # 3. take prev timestep & sample + dt = self.dt + sample = self.sample + last_inner_step = True + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.derivative_1 = None + self.derivative_2 = None + self.derivative_3 = None + self.dt = None + self.sample = None + + return derivative, dt, sample, last_inner_step + + def __len__(self): + return self.config.num_train_timesteps diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f95a2ef41a38066e9991861ef6c8e9a299fea561 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,52 @@ +# Core ML dependencies - CPU optimized +torch>=2.0.0,<2.3.0 +torchvision>=0.15.0,<0.18.0 +torchaudio>=2.0.0,<2.3.0 +numpy>=1.21.0,<1.27.0 +scipy>=1.7.0 + +# Deep Learning frameworks +diffusers>=0.24.0 +timm>=0.9.0 +accelerate>=0.20.0 + +# Transformers and NLP - use stable version +transformers>=4.35.0,<4.50.0 +sentencepiece>=0.1.99 + +# Audio processing - use HTTPS version for HF Spaces +audiotools @ git+https://github.com/descriptinc/audiotools.git + +# Video/Image processing +pillow>=8.3.0 +av>=10.0.0 +einops>=0.6.0 + +# Configuration and utilities +pyyaml>=6.0 +omegaconf>=2.3.0 +easydict>=1.9.0 +loguru>=0.6.0 +tqdm>=4.64.0 +setuptools>=65.0.0 + +# Data handling +pandas>=1.3.0 +pyarrow>=10.0.0 + +# Web interface - compatible version for HF Spaces +gradio>=4.0.0,<5.0.0 + +# Network +urllib3>=1.26.0,<3.0.0 + +# Hugging Face integration +huggingface_hub>=0.16.0 +datasets>=2.14.0 + +# Additional dependencies for stability +packaging>=21.0 +typing-extensions>=4.0.0 + +# Optional: reduce memory usage +psutil>=5.8.0 \ No newline at end of file