diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..41f00ec385d56b350b4b23174d68b52581bb989d 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,3 +1,4 @@
+# Standard LFS file types for Hugging Face
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +34,45 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+# Media files
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.jpeg filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
+*.avi filter=lfs diff=lfs merge=lfs -text
+*.mov filter=lfs diff=lfs merge=lfs -text
+*.mkv filter=lfs diff=lfs merge=lfs -text
+*.webm filter=lfs diff=lfs merge=lfs -text
+*.wav filter=lfs diff=lfs merge=lfs -text
+*.mp3 filter=lfs diff=lfs merge=lfs -text
+*.flac filter=lfs diff=lfs merge=lfs -text
+# Project specific files
+examples/7_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/8_video.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/1_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/2_video.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/6_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/7_video.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/8_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/3_video.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/5_video.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/4_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/5_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/1_video.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/3_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/6_video.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/2_result.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/4_video.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/0.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/4.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/6.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/7.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/2.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/3.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/5.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/8.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/MovieGenAudioBenchSfx/video_with_audio/9.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/data_pipeline.png filter=lfs diff=lfs merge=lfs -text
+assets/model_arch.png filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 5dea084d42bc69ae5f90b6561770329d558def99..d268e5421b4f12324c301d14278c95f28ce40449 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,96 @@
---
-title: Hunyuanvideo Foley
-emoji: 😻
+title: HunyuanVideo-Foley
+emoji: 🎵
colorFrom: blue
-colorTo: gray
+colorTo: purple
sdk: gradio
-sdk_version: 5.44.1
+sdk_version: 4.44.0
app_file: app.py
pinned: false
license: apache-2.0
+short_description: Generate realistic audio from video and text descriptions
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# HunyuanVideo-Foley
+
+
+
🎵 Text-Video-to-Audio Synthesis
+
Generate realistic audio from video and text descriptions using AI
+
+
+## About
+
+HunyuanVideo-Foley is a multimodal diffusion model that generates high-quality audio effects (Foley audio) synchronized with video content. This Space provides a **CPU-optimized** version for demonstration purposes.
+
+### ⚠️ CPU Performance Notice
+
+This Space runs on **free CPU** which means:
+- **Slower inference** (3-5 minutes per generation)
+- **Limited concurrent users**
+- **Reduced sample counts** (max 3 samples)
+
+For **faster performance**, consider:
+- Using the original repository with GPU
+- Running locally with CUDA support
+- Upgrading to a GPU Space (if available)
+
+## Features
+
+- 🎬 **Video-to-Audio**: Generate audio effects from video content
+- 📝 **Text Guidance**: Control generation with text descriptions
+- 🎯 **Multiple Samples**: Generate up to 3 variations
+- 🔧 **Adjustable Settings**: Control CFG scale and inference steps
+- 📱 **User-Friendly**: Simple drag-and-drop interface
+
+## How to Use
+
+1. **Upload Video**: Drag and drop your video file (MP4, AVI, MOV)
+2. **Add Description** (Optional): Describe the audio you want to generate
+3. **Adjust Settings**: Modify CFG scale and inference steps if needed
+4. **Generate**: Click "Generate Audio" and wait (3-5 minutes on CPU)
+5. **Download**: Save your generated audio/video combinations
+
+## Tips for Best Results
+
+- 📏 **Video Length**: Keep videos under 30 seconds for faster processing
+- 🎯 **Text Prompts**: Use simple, clear descriptions
+- ⚡ **Settings**: Lower values process faster on CPU
+- 🔄 **Multiple Attempts**: Try different settings if not satisfied
+
+## Technical Details
+
+- **Model**: HunyuanVideo-Foley-XXL
+- **Architecture**: Multimodal diffusion transformer
+- **Audio Quality**: 48kHz professional-grade output
+- **Deployment**: CPU-optimized for Hugging Face Spaces
+
+## Original Project
+
+This is a **CPU deployment** of the original HunyuanVideo-Foley project:
+
+- 📄 **Paper**: [HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment](https://arxiv.org/abs/2508.16930)
+- 💻 **GitHub**: [Tencent-Hunyuan/HunyuanVideo-Foley](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley)
+- 🤗 **Models**: [tencent/HunyuanVideo-Foley](https://huggingface.co/tencent/HunyuanVideo-Foley)
+
+## Citation
+
+```bibtex
+@misc{shan2025hunyuanvideofoleymultimodaldiffusionrepresentation,
+ title={HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation},
+ author={Sizhe Shan and Qiulin Li and Yutao Cui and Miles Yang and Yuehai Wang and Qun Yang and Jin Zhou and Zhao Zhong},
+ year={2025},
+ eprint={2508.16930},
+ archivePrefix={arXiv},
+ primaryClass={eess.AS}
+}
+```
+
+## License
+
+This project is licensed under the Apache 2.0 License.
+
+---
+
+
+
🚀 Powered by Tencent Hunyuan | Optimized for CPU deployment
+
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..537b2d06447dd109dd12e5bf7515727d8e40c3e2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,405 @@
+import os
+import tempfile
+import gradio as gr
+import torch
+import torchaudio
+from loguru import logger
+from typing import Optional, Tuple
+import random
+import numpy as np
+
+# Force CPU usage for Hugging Face Spaces
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+from hunyuanvideo_foley.utils.model_utils import load_model
+from hunyuanvideo_foley.utils.feature_utils import feature_process
+from hunyuanvideo_foley.utils.model_utils import denoise_process
+from hunyuanvideo_foley.utils.media_utils import merge_audio_video
+
+# Global variables for model storage
+model_dict = None
+cfg = None
+device = None
+
+# Model path for Hugging Face Spaces - try to download automatically
+MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/")
+CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
+
+def setup_device(force_cpu: bool = True) -> torch.device:
+ """Setup computing device - force CPU for Hugging Face Spaces"""
+ if force_cpu:
+ device = torch.device("cpu")
+ logger.info("Using CPU device (forced for Hugging Face Spaces)")
+ else:
+ if torch.cuda.is_available():
+ device = torch.device("cuda:0")
+ logger.info("Using CUDA device")
+ elif torch.backends.mps.is_available():
+ device = torch.device("mps")
+ logger.info("Using MPS device")
+ else:
+ device = torch.device("cpu")
+ logger.info("Using CPU device")
+
+ return device
+
+def download_models():
+ """Download models from Hugging Face if not present"""
+ try:
+ from huggingface_hub import snapshot_download
+ logger.info("Downloading models from Hugging Face...")
+
+ # Download the model files
+ snapshot_download(
+ repo_id="tencent/HunyuanVideo-Foley",
+ local_dir="./pretrained_models",
+ local_dir_use_symlinks=False
+ )
+
+ logger.info("Model download completed!")
+ return True
+ except Exception as e:
+ logger.error(f"Failed to download models: {str(e)}")
+ return False
+
+def auto_load_models() -> str:
+ """Automatically load preset models"""
+ global model_dict, cfg, device
+
+ try:
+ # First try to download models if they don't exist
+ if not os.path.exists(MODEL_PATH) or not os.listdir(MODEL_PATH):
+ logger.info("Models not found locally, attempting to download...")
+ if not download_models():
+ return "❌ Failed to download models from Hugging Face"
+
+ if not os.path.exists(CONFIG_PATH):
+ return f"❌ Config file not found: {CONFIG_PATH}"
+
+ # Force CPU usage for Hugging Face Spaces
+ device = setup_device(force_cpu=True)
+
+ # Load model with CPU optimization
+ logger.info("Loading model on CPU...")
+ logger.info(f"Model path: {MODEL_PATH}")
+ logger.info(f"Config path: {CONFIG_PATH}")
+
+ # Set torch to use fewer threads for CPU inference
+ torch.set_num_threads(2)
+
+ model_dict, cfg = load_model(MODEL_PATH, CONFIG_PATH, device)
+
+ logger.info("✅ Model loaded successfully on CPU!")
+ return "✅ Model loaded successfully on CPU!"
+
+ except Exception as e:
+ logger.error(f"Model loading failed: {str(e)}")
+ return f"❌ Model loading failed: {str(e)}"
+
+def infer_single_video(
+ video_file,
+ text_prompt: str,
+ guidance_scale: float = 2.0, # Lower for CPU
+ num_inference_steps: int = 20, # Reduced for CPU
+ sample_nums: int = 1
+) -> Tuple[list, str]:
+ """Single video inference optimized for CPU"""
+ global model_dict, cfg, device
+
+ if model_dict is None or cfg is None:
+ return [], "❌ Please load the model first!"
+
+ if video_file is None:
+ return [], "❌ Please upload a video file!"
+
+ # Allow empty text prompt
+ if text_prompt is None:
+ text_prompt = ""
+ text_prompt = text_prompt.strip()
+
+ try:
+ logger.info(f"Processing video: {video_file}")
+ logger.info(f"Text prompt: {text_prompt}")
+ logger.info("Running inference on CPU (this may take a while)...")
+
+ # Feature processing
+ visual_feats, text_feats, audio_len_in_s = feature_process(
+ video_file,
+ text_prompt,
+ model_dict,
+ cfg
+ )
+
+ # Denoising process with CPU-optimized settings
+ logger.info(f"Generating {sample_nums} audio sample(s) on CPU...")
+ audio, sample_rate = denoise_process(
+ visual_feats,
+ text_feats,
+ audio_len_in_s,
+ model_dict,
+ cfg,
+ guidance_scale=guidance_scale,
+ num_inference_steps=num_inference_steps,
+ batch_size=sample_nums
+ )
+
+ # Create temporary files to save results
+ temp_dir = tempfile.mkdtemp()
+ video_outputs = []
+
+ # Process each generated audio sample
+ for i in range(sample_nums):
+ # Save audio file
+ audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav")
+ torchaudio.save(audio_output, audio[i], sample_rate)
+
+ # Merge video and audio
+ video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4")
+ merge_audio_video(audio_output, video_file, video_output)
+ video_outputs.append(video_output)
+
+ logger.info(f"Inference completed! Generated {sample_nums} samples.")
+ return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully on CPU!"
+
+ except Exception as e:
+ logger.error(f"Inference failed: {str(e)}")
+ return [], f"❌ Inference failed: {str(e)}"
+
+def update_video_outputs(video_list, status_msg):
+ """Update video outputs based on the number of generated samples"""
+ # Initialize all outputs as None
+ outputs = [None] * 3 # Reduced to 3 for CPU
+
+ # Set values based on generated videos
+ for i, video_path in enumerate(video_list[:3]): # Max 3 samples for CPU
+ outputs[i] = video_path
+
+ # Return all outputs plus status message
+ return tuple(outputs + [status_msg])
+
+def create_gradio_interface():
+ """Create Gradio interface optimized for CPU deployment"""
+
+ # Custom CSS with Hugging Face Spaces styling
+ css = """
+ .gradio-container {
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
+ min-height: 100vh;
+ }
+
+ .main-header {
+ text-align: center;
+ padding: 2rem 0;
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
+ border-radius: 20px;
+ margin-bottom: 2rem;
+ box-shadow: 0 8px 32px rgba(0,0,0,0.15);
+ }
+
+ .main-header h1 {
+ color: white;
+ font-size: 3rem;
+ font-weight: 700;
+ margin-bottom: 0.5rem;
+ text-shadow: 0 2px 10px rgba(0,0,0,0.3);
+ }
+
+ .main-header p {
+ color: rgba(255, 255, 255, 0.95);
+ font-size: 1.2rem;
+ font-weight: 300;
+ }
+
+ .cpu-notice {
+ background: #fff3cd;
+ border: 1px solid #ffeaa7;
+ border-radius: 10px;
+ padding: 1rem;
+ margin: 1rem 0;
+ color: #856404;
+ }
+ """
+
+ with gr.Blocks(css=css, title="HunyuanVideo-Foley (CPU)") as app:
+
+ # Main header
+ with gr.Column(elem_classes=["main-header"]):
+ gr.HTML("""
+ 🎵 HunyuanVideo-Foley
+ Text-Video-to-Audio Synthesis (CPU Version)
+ """)
+
+ # CPU Notice
+ gr.HTML("""
+
+ ⚠️ CPU Deployment Notice: This Space runs on CPU which means inference will be slower than GPU version.
+ Each generation may take 3-5 minutes. For faster inference, consider running locally with GPU.
+
+ """)
+
+ # Usage Guide
+ gr.Markdown("""
+ ### 📋 Quick Start Guide
+ **1.** Upload your video file **2.** Add optional text description **3.** Click Generate Audio (be patient!)
+
+ 💡 **Tips for CPU usage:**
+ - Use shorter videos (< 30 seconds recommended)
+ - Simple text prompts work better
+ - Expect longer processing times
+ """)
+
+ # Main interface
+ with gr.Row():
+ # Input section
+ with gr.Column(scale=1):
+ gr.Markdown("### 📹 Video Input")
+
+ video_input = gr.Video(
+ label="Upload Video",
+ info="Supported formats: MP4, AVI, MOV, etc. Shorter videos recommended for CPU.",
+ height=300
+ )
+
+ text_input = gr.Textbox(
+ label="🎯 Audio Description (English)",
+ placeholder="A person walks on frozen ice",
+ lines=3,
+ info="Describe the audio you want to generate (optional)"
+ )
+
+ with gr.Row():
+ guidance_scale = gr.Slider(
+ minimum=1.0,
+ maximum=5.0,
+ value=2.0,
+ step=0.1,
+ label="🎚️ CFG Scale (lower for CPU)",
+ )
+
+ inference_steps = gr.Slider(
+ minimum=10,
+ maximum=50,
+ value=20,
+ step=5,
+ label="⚡ Steps (reduced for CPU)",
+ )
+
+ sample_nums = gr.Slider(
+ minimum=1,
+ maximum=3,
+ value=1,
+ step=1,
+ label="🎲 Sample Nums (max 3 for CPU)",
+ )
+
+ generate_btn = gr.Button(
+ "🎵 Generate Audio (CPU)",
+ variant="primary"
+ )
+
+ # Results section
+ with gr.Column(scale=1):
+ gr.Markdown("### 🎥 Generated Results")
+
+ # Reduced number of outputs for CPU
+ video_output_1 = gr.Video(
+ label="Sample 1",
+ height=250,
+ visible=True
+ )
+
+ with gr.Row():
+ video_output_2 = gr.Video(
+ label="Sample 2",
+ height=200,
+ visible=False
+ )
+ video_output_3 = gr.Video(
+ label="Sample 3",
+ height=200,
+ visible=False
+ )
+
+ result_text = gr.Textbox(
+ label="Status",
+ interactive=False,
+ lines=3
+ )
+
+ # Event handlers
+ def process_inference(video_file, text_prompt, guidance_scale, inference_steps, sample_nums):
+ # Generate videos
+ video_list, status_msg = infer_single_video(
+ video_file, text_prompt, guidance_scale, inference_steps, int(sample_nums)
+ )
+ # Update outputs with proper visibility
+ return update_video_outputs(video_list, status_msg)
+
+ # Add dynamic visibility control
+ def update_visibility(sample_nums):
+ sample_nums = int(sample_nums)
+ return [
+ gr.update(visible=True), # Sample 1 always visible
+ gr.update(visible=sample_nums >= 2), # Sample 2
+ gr.update(visible=sample_nums >= 3), # Sample 3
+ ]
+
+ # Update visibility when sample_nums changes
+ sample_nums.change(
+ fn=update_visibility,
+ inputs=[sample_nums],
+ outputs=[video_output_1, video_output_2, video_output_3]
+ )
+
+ generate_btn.click(
+ fn=process_inference,
+ inputs=[video_input, text_input, guidance_scale, inference_steps, sample_nums],
+ outputs=[
+ video_output_1,
+ video_output_2,
+ video_output_3,
+ result_text
+ ]
+ )
+
+ # Footer
+ gr.HTML("""
+
+
🚀 Powered by HunyuanVideo-Foley | Running on CPU for Hugging Face Spaces
+
For faster inference, visit the original repository
+
+ """)
+
+ return app
+
+def set_manual_seed(global_seed):
+ random.seed(global_seed)
+ np.random.seed(global_seed)
+ torch.manual_seed(global_seed)
+
+if __name__ == "__main__":
+ set_manual_seed(1)
+ # Setup logging
+ logger.remove()
+ logger.add(lambda msg: print(msg, end=''), level="INFO")
+
+ # Auto-load model
+ logger.info("Starting CPU application and loading model...")
+ model_load_result = auto_load_models()
+ logger.info(model_load_result)
+
+ # Create and launch Gradio app
+ app = create_gradio_interface()
+
+ # Log completion status
+ if "successfully" in model_load_result:
+ logger.info("Application ready, model loaded on CPU")
+
+ app.launch(
+ server_name="0.0.0.0",
+ server_port=7860, # Standard port for Hugging Face Spaces
+ share=False,
+ debug=False,
+ show_error=True
+ )
\ No newline at end of file
diff --git a/configs/hunyuanvideo-foley-xxl.yaml b/configs/hunyuanvideo-foley-xxl.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9b7da02713647fd3201d378d86924bfdf42e93c9
--- /dev/null
+++ b/configs/hunyuanvideo-foley-xxl.yaml
@@ -0,0 +1,49 @@
+model_config:
+ model_name: HunyuanVideo-Foley-XXL
+ model_type: 1d
+ model_precision: bf16
+ model_kwargs:
+ depth_triple_blocks: 18
+ depth_single_blocks: 36
+ hidden_size: 1536
+ num_heads: 12
+ mlp_ratio: 4
+ mlp_act_type: "gelu_tanh"
+ qkv_bias: True
+ qk_norm: True
+ qk_norm_type: "rms"
+ attn_mode: "torch"
+ embedder_type: "default"
+ interleaved_audio_visual_rope: True
+ enable_learnable_empty_visual_feat: True
+ sync_modulation: False
+ add_sync_feat_to_audio: True
+ cross_attention: True
+ use_attention_mask: False
+ condition_projection: "linear"
+ sync_feat_dim: 768 # syncformer 768 dim
+ condition_dim: 768 # clap 768 text condition dim (clip-text)
+ clip_dim: 768 # siglip2 visual dim
+ audio_vae_latent_dim: 128
+ audio_frame_rate: 50
+ patch_size: 1
+ rope_dim_list: null
+ rope_theta: 10000
+ text_length: 77
+ clip_length: 64
+ sync_length: 192
+ use_mmaudio_singleblock: True
+ depth_triple_ssl_encoder: null
+ depth_single_ssl_encoder: 8
+ use_repa_with_audiossl: True
+
+diffusion_config:
+ denoise_type: "flow"
+ flow_path_type: "linear"
+ flow_predict_type: "velocity"
+ flow_reverse: True
+ flow_solver: "euler"
+ sample_flow_shift: 1.0
+ sample_use_flux_shift: False
+ flux_base_shift: 0.5
+ flux_max_shift: 1.15
diff --git a/hunyuanvideo_foley/__init__.py b/hunyuanvideo_foley/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61662d26303a16fd02acb558569428bab8efd4b1
Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71b5507b7f193125a4b8ed840c414ab5ad36de33
Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc b/hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5be810050d3223d2d4660812e4bb557707bad15f
Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc b/hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0842fa469b5edad26854f848e825529a845ec910
Binary files /dev/null and b/hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/constants.py b/hunyuanvideo_foley/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..81519407b44f33cfcbadd550f41a634b11821649
--- /dev/null
+++ b/hunyuanvideo_foley/constants.py
@@ -0,0 +1,57 @@
+"""Constants used throughout the HunyuanVideo-Foley project."""
+
+from typing import Dict, List
+
+# Model configuration
+DEFAULT_AUDIO_SAMPLE_RATE = 48000
+DEFAULT_VIDEO_FPS = 25
+DEFAULT_AUDIO_CHANNELS = 2
+
+# Video processing
+MAX_VIDEO_DURATION_SECONDS = 15.0
+MIN_VIDEO_DURATION_SECONDS = 1.0
+
+# Audio processing
+AUDIO_VAE_LATENT_DIM = 128
+AUDIO_FRAME_RATE = 75 # frames per second in latent space
+
+# Visual features
+FPS_VISUAL: Dict[str, int] = {
+ "siglip2": 8,
+ "synchformer": 25
+}
+
+# Model paths (can be overridden by environment variables)
+DEFAULT_MODEL_PATH = "./pretrained_models/"
+DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
+
+# Inference parameters
+DEFAULT_GUIDANCE_SCALE = 4.5
+DEFAULT_NUM_INFERENCE_STEPS = 50
+MIN_GUIDANCE_SCALE = 1.0
+MAX_GUIDANCE_SCALE = 10.0
+MIN_INFERENCE_STEPS = 10
+MAX_INFERENCE_STEPS = 100
+
+# Text processing
+MAX_TEXT_LENGTH = 100
+DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
+
+# File extensions
+SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
+SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
+
+# Quality settings
+AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
+ "high": ["-b:a", "192k"],
+ "medium": ["-b:a", "128k"],
+ "low": ["-b:a", "96k"]
+}
+
+# Error messages
+ERROR_MESSAGES: Dict[str, str] = {
+ "model_not_loaded": "Model is not loaded. Please load the model first.",
+ "invalid_video_format": "Unsupported video format. Supported formats: {formats}",
+ "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
+ "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
+}
\ No newline at end of file
diff --git a/hunyuanvideo_foley/models/__init__.py b/hunyuanvideo_foley/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7256df1d54bdbca395cf379fc5ffdf81c560e139
Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06c16a43edf1fdd1e0c88011d8cd1c450d81250e
Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d1f8e5c911fa9516c071206f63cdb3559f853bcf
Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cef069430dd3393c7c49f0b5f77ad1134a3ecaf2
Binary files /dev/null and b/hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/__init__.py b/hunyuanvideo_foley/models/dac_vae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/__init__.py
@@ -0,0 +1,16 @@
+__version__ = "1.0.0"
+
+# preserved here for legacy reasons
+__model_version__ = "latest"
+
+import audiotools
+
+audiotools.ml.BaseModel.INTERN += ["dac.**"]
+audiotools.ml.BaseModel.EXTERN += ["einops"]
+
+
+from . import nn
+from . import model
+from . import utils
+from .model import DAC
+from .model import DACFile
diff --git a/hunyuanvideo_foley/models/dac_vae/__main__.py b/hunyuanvideo_foley/models/dac_vae/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1fe6531c5bf82f731d8e07ec09c21d79aae4cfa
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/__main__.py
@@ -0,0 +1,36 @@
+import sys
+
+import argbind
+
+from .utils import download
+from .utils.decode import decode
+from .utils.encode import encode
+
+STAGES = ["encode", "decode", "download"]
+
+
+def run(stage: str):
+ """Run stages.
+
+ Parameters
+ ----------
+ stage : str
+ Stage to run
+ """
+ if stage not in STAGES:
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
+ stage_fn = globals()[stage]
+
+ if stage == "download":
+ stage_fn()
+ return
+
+ stage_fn()
+
+
+if __name__ == "__main__":
+ group = sys.argv.pop(1)
+ args = argbind.parse_args(group=group)
+
+ with argbind.scope(args):
+ run(group)
diff --git a/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14cf6031338150ea3a7f152c17008c66a6413af1
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f50d727a0848877df226320cc677261667551ce9
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__init__.py b/hunyuanvideo_foley/models/dac_vae/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a75b7ad6028f5c41b6a8285b0257d4c23bdfcf
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/model/__init__.py
@@ -0,0 +1,4 @@
+from .base import CodecMixin
+from .base import DACFile
+from .dac import DAC
+from .discriminator import Discriminator
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1973b3c42287fe8b898d235449f0fbca61c7f671
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83ceb0da01684fe8da99ed24a4ff9a03163b6387
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..997357623388982f2ff70a5e4d0cf3ca4fdaeb2d
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..791f244d54f568235063163cd7958bf90cfddeb8
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9368431bc265497e05dcf6e4de3c4739f6e4df2e
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6f33b65ae4c386a21f4449e105e55e043a58b51
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7041224daee46788150ab5a04367984f3071964e
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32e5430cebc009a08f69aa50663bcf1de237438c
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/model/base.py b/hunyuanvideo_foley/models/dac_vae/model/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e95a84149a767f256a54b7cc3241c09551f39061
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/model/base.py
@@ -0,0 +1,301 @@
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+import torch
+import tqdm
+from audiotools import AudioSignal
+from torch import nn
+
+SUPPORTED_VERSIONS = ["1.0.0"]
+
+
+@dataclass
+class DACFile:
+ codes: torch.Tensor
+
+ # Metadata
+ chunk_length: int
+ original_length: int
+ input_db: float
+ channels: int
+ sample_rate: int
+ padding: bool
+ dac_version: str
+
+ def save(self, path):
+ artifacts = {
+ "codes": self.codes.numpy().astype(np.uint16),
+ "metadata": {
+ "input_db": self.input_db.numpy().astype(np.float32),
+ "original_length": self.original_length,
+ "sample_rate": self.sample_rate,
+ "chunk_length": self.chunk_length,
+ "channels": self.channels,
+ "padding": self.padding,
+ "dac_version": SUPPORTED_VERSIONS[-1],
+ },
+ }
+ path = Path(path).with_suffix(".dac")
+ with open(path, "wb") as f:
+ np.save(f, artifacts)
+ return path
+
+ @classmethod
+ def load(cls, path):
+ artifacts = np.load(path, allow_pickle=True)[()]
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
+ raise RuntimeError(
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
+ )
+ return cls(codes=codes, **artifacts["metadata"])
+
+
+class CodecMixin:
+ @property
+ def padding(self):
+ if not hasattr(self, "_padding"):
+ self._padding = True
+ return self._padding
+
+ @padding.setter
+ def padding(self, value):
+ assert isinstance(value, bool)
+
+ layers = [
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
+ ]
+
+ for layer in layers:
+ if value:
+ if hasattr(layer, "original_padding"):
+ layer.padding = layer.original_padding
+ else:
+ layer.original_padding = layer.padding
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
+
+ self._padding = value
+
+ def get_delay(self):
+ # Any number works here, delay is invariant to input length
+ l_out = self.get_output_length(0)
+ L = l_out
+
+ layers = []
+ for layer in self.modules():
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+ layers.append(layer)
+
+ for layer in reversed(layers):
+ d = layer.dilation[0]
+ k = layer.kernel_size[0]
+ s = layer.stride[0]
+
+ if isinstance(layer, nn.ConvTranspose1d):
+ L = ((L - d * (k - 1) - 1) / s) + 1
+ elif isinstance(layer, nn.Conv1d):
+ L = (L - 1) * s + d * (k - 1) + 1
+
+ L = math.ceil(L)
+
+ l_in = L
+
+ return (l_in - l_out) // 2
+
+ def get_output_length(self, input_length):
+ L = input_length
+ # Calculate output length
+ for layer in self.modules():
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
+ d = layer.dilation[0]
+ k = layer.kernel_size[0]
+ s = layer.stride[0]
+
+ if isinstance(layer, nn.Conv1d):
+ L = ((L - d * (k - 1) - 1) / s) + 1
+ elif isinstance(layer, nn.ConvTranspose1d):
+ L = (L - 1) * s + d * (k - 1) + 1
+
+ L = math.floor(L)
+ return L
+
+ @torch.no_grad()
+ def compress(
+ self,
+ audio_path_or_signal: Union[str, Path, AudioSignal],
+ win_duration: float = 1.0,
+ verbose: bool = False,
+ normalize_db: float = -16,
+ n_quantizers: int = None,
+ ) -> DACFile:
+ """Processes an audio signal from a file or AudioSignal object into
+ discrete codes. This function processes the signal in short windows,
+ using constant GPU memory.
+
+ Parameters
+ ----------
+ audio_path_or_signal : Union[str, Path, AudioSignal]
+ audio signal to reconstruct
+ win_duration : float, optional
+ window duration in seconds, by default 5.0
+ verbose : bool, optional
+ by default False
+ normalize_db : float, optional
+ normalize db, by default -16
+
+ Returns
+ -------
+ DACFile
+ Object containing compressed codes and metadata
+ required for decompression
+ """
+ audio_signal = audio_path_or_signal
+ if isinstance(audio_signal, (str, Path)):
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
+
+ self.eval()
+ original_padding = self.padding
+ original_device = audio_signal.device
+
+ audio_signal = audio_signal.clone()
+ audio_signal = audio_signal.to_mono()
+ original_sr = audio_signal.sample_rate
+
+ resample_fn = audio_signal.resample
+ loudness_fn = audio_signal.loudness
+
+ # If audio is > 10 minutes long, use the ffmpeg versions
+ if audio_signal.signal_duration >= 10 * 60 * 60:
+ resample_fn = audio_signal.ffmpeg_resample
+ loudness_fn = audio_signal.ffmpeg_loudness
+
+ original_length = audio_signal.signal_length
+ resample_fn(self.sample_rate)
+ input_db = loudness_fn()
+
+ if normalize_db is not None:
+ audio_signal.normalize(normalize_db)
+ audio_signal.ensure_max_of_audio()
+
+ nb, nac, nt = audio_signal.audio_data.shape
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
+ win_duration = (
+ audio_signal.signal_duration if win_duration is None else win_duration
+ )
+
+ if audio_signal.signal_duration <= win_duration:
+ # Unchunked compression (used if signal length < win duration)
+ self.padding = True
+ n_samples = nt
+ hop = nt
+ else:
+ # Chunked inference
+ self.padding = False
+ # Zero-pad signal on either side by the delay
+ audio_signal.zero_pad(self.delay, self.delay)
+ n_samples = int(win_duration * self.sample_rate)
+ # Round n_samples to nearest hop length multiple
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
+ hop = self.get_output_length(n_samples)
+
+ codes = []
+ range_fn = range if not verbose else tqdm.trange
+
+ for i in range_fn(0, nt, hop):
+ x = audio_signal[..., i : i + n_samples]
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
+
+ audio_data = x.audio_data.to(self.device)
+ audio_data = self.preprocess(audio_data, self.sample_rate)
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
+ codes.append(c.to(original_device))
+ chunk_length = c.shape[-1]
+
+ codes = torch.cat(codes, dim=-1)
+
+ dac_file = DACFile(
+ codes=codes,
+ chunk_length=chunk_length,
+ original_length=original_length,
+ input_db=input_db,
+ channels=nac,
+ sample_rate=original_sr,
+ padding=self.padding,
+ dac_version=SUPPORTED_VERSIONS[-1],
+ )
+
+ if n_quantizers is not None:
+ codes = codes[:, :n_quantizers, :]
+
+ self.padding = original_padding
+ return dac_file
+
+ @torch.no_grad()
+ def decompress(
+ self,
+ obj: Union[str, Path, DACFile],
+ verbose: bool = False,
+ ) -> AudioSignal:
+ """Reconstruct audio from a given .dac file
+
+ Parameters
+ ----------
+ obj : Union[str, Path, DACFile]
+ .dac file location or corresponding DACFile object.
+ verbose : bool, optional
+ Prints progress if True, by default False
+
+ Returns
+ -------
+ AudioSignal
+ Object with the reconstructed audio
+ """
+ self.eval()
+ if isinstance(obj, (str, Path)):
+ obj = DACFile.load(obj)
+
+ original_padding = self.padding
+ self.padding = obj.padding
+
+ range_fn = range if not verbose else tqdm.trange
+ codes = obj.codes
+ original_device = codes.device
+ chunk_length = obj.chunk_length
+ recons = []
+
+ for i in range_fn(0, codes.shape[-1], chunk_length):
+ c = codes[..., i : i + chunk_length].to(self.device)
+ z = self.quantizer.from_codes(c)[0]
+ r = self.decode(z)
+ recons.append(r.to(original_device))
+
+ recons = torch.cat(recons, dim=-1)
+ recons = AudioSignal(recons, self.sample_rate)
+
+ resample_fn = recons.resample
+ loudness_fn = recons.loudness
+
+ # If audio is > 10 minutes long, use the ffmpeg versions
+ if recons.signal_duration >= 10 * 60 * 60:
+ resample_fn = recons.ffmpeg_resample
+ loudness_fn = recons.ffmpeg_loudness
+
+ if obj.input_db is not None:
+ recons.normalize(obj.input_db)
+
+ resample_fn(obj.sample_rate)
+
+ if obj.original_length is not None:
+ recons = recons[..., : obj.original_length]
+ loudness_fn()
+ recons.audio_data = recons.audio_data.reshape(
+ -1, obj.channels, obj.original_length
+ )
+ else:
+ loudness_fn()
+
+ self.padding = original_padding
+ return recons
diff --git a/hunyuanvideo_foley/models/dac_vae/model/dac.py b/hunyuanvideo_foley/models/dac_vae/model/dac.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df2cbad1502774ce2519c1795b6e85571ac3fc1
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/model/dac.py
@@ -0,0 +1,410 @@
+import math
+from typing import List
+from typing import Union
+
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.ml import BaseModel
+from torch import nn
+
+from .base import CodecMixin
+from ..nn.layers import Snake1d
+from ..nn.layers import WNConv1d
+from ..nn.layers import WNConvTranspose1d
+from ..nn.quantize import ResidualVectorQuantize
+from ..nn.vae_utils import DiagonalGaussianDistribution
+
+
+def init_weights(m):
+ if isinstance(m, nn.Conv1d):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim: int = 16, dilation: int = 1):
+ super().__init__()
+ pad = ((7 - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Snake1d(dim),
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
+ Snake1d(dim),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+
+ def forward(self, x):
+ y = self.block(x)
+ pad = (x.shape[-1] - y.shape[-1]) // 2
+ if pad > 0:
+ x = x[..., pad:-pad]
+ return x + y
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, dim: int = 16, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ ResidualUnit(dim // 2, dilation=1),
+ ResidualUnit(dim // 2, dilation=3),
+ ResidualUnit(dim // 2, dilation=9),
+ Snake1d(dim // 2),
+ WNConv1d(
+ dim // 2,
+ dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ d_model: int = 64,
+ strides: list = [2, 4, 8, 8],
+ d_latent: int = 64,
+ ):
+ super().__init__()
+ # Create first convolution
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+
+ # Create EncoderBlocks that double channels as they downsample by `stride`
+ for stride in strides:
+ d_model *= 2
+ self.block += [EncoderBlock(d_model, stride=stride)]
+
+ # Create last convolution
+ self.block += [
+ Snake1d(d_model),
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
+ ]
+
+ # Wrap black into nn.Sequential
+ self.block = nn.Sequential(*self.block)
+ self.enc_dim = d_model
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
+ super().__init__()
+ self.block = nn.Sequential(
+ Snake1d(input_dim),
+ WNConvTranspose1d(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ output_padding=stride % 2,
+ ),
+ ResidualUnit(output_dim, dilation=1),
+ ResidualUnit(output_dim, dilation=3),
+ ResidualUnit(output_dim, dilation=9),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ input_channel,
+ channels,
+ rates,
+ d_out: int = 1,
+ ):
+ super().__init__()
+
+ # Add first conv layer
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
+
+ # Add upsampling + MRF blocks
+ for i, stride in enumerate(rates):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
+
+ # Add final conv layer
+ layers += [
+ Snake1d(output_dim),
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class DAC(BaseModel, CodecMixin):
+ def __init__(
+ self,
+ encoder_dim: int = 64,
+ encoder_rates: List[int] = [2, 4, 8, 8],
+ latent_dim: int = None,
+ decoder_dim: int = 1536,
+ decoder_rates: List[int] = [8, 8, 4, 2],
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: bool = False,
+ sample_rate: int = 44100,
+ continuous: bool = False,
+ ):
+ super().__init__()
+
+ self.encoder_dim = encoder_dim
+ self.encoder_rates = encoder_rates
+ self.decoder_dim = decoder_dim
+ self.decoder_rates = decoder_rates
+ self.sample_rate = sample_rate
+ self.continuous = continuous
+
+ if latent_dim is None:
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
+
+ self.latent_dim = latent_dim
+
+ self.hop_length = np.prod(encoder_rates)
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
+
+ if not continuous:
+ self.n_codebooks = n_codebooks
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.quantizer = ResidualVectorQuantize(
+ input_dim=latent_dim,
+ n_codebooks=n_codebooks,
+ codebook_size=codebook_size,
+ codebook_dim=codebook_dim,
+ quantizer_dropout=quantizer_dropout,
+ )
+ else:
+ self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
+ self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
+
+ self.decoder = Decoder(
+ latent_dim,
+ decoder_dim,
+ decoder_rates,
+ )
+ self.sample_rate = sample_rate
+ self.apply(init_weights)
+
+ self.delay = self.get_delay()
+
+ @property
+ def dtype(self):
+ """Get the dtype of the model parameters."""
+ # Return the dtype of the first parameter found
+ for param in self.parameters():
+ return param.dtype
+ return torch.float32 # fallback
+
+ @property
+ def device(self):
+ """Get the device of the model parameters."""
+ # Return the device of the first parameter found
+ for param in self.parameters():
+ return param.device
+ return torch.device('cpu') # fallback
+
+ def preprocess(self, audio_data, sample_rate):
+ if sample_rate is None:
+ sample_rate = self.sample_rate
+ assert sample_rate == self.sample_rate
+
+ length = audio_data.shape[-1]
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
+
+ return audio_data
+
+ def encode(
+ self,
+ audio_data: torch.Tensor,
+ n_quantizers: int = None,
+ ):
+ """Encode given audio data and return quantized latent codes
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x 1 x T]
+ Audio data to encode
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ """
+ z = self.encoder(audio_data) # [B x D x T]
+ if not self.continuous:
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
+ else:
+ z = self.quant_conv(z) # [B x 2D x T]
+ z = DiagonalGaussianDistribution(z)
+ codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
+
+ return z, codes, latents, commitment_loss, codebook_loss
+
+ def decode(self, z: torch.Tensor):
+ """Decode given latent codes and return audio data
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ Quantized continuous representation of input
+ length : int, optional
+ Number of samples in output audio, by default None
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "audio" : Tensor[B x 1 x length]
+ Decoded audio data.
+ """
+ if not self.continuous:
+ audio = self.decoder(z)
+ else:
+ z = self.post_quant_conv(z)
+ audio = self.decoder(z)
+
+ return audio
+
+ def forward(
+ self,
+ audio_data: torch.Tensor,
+ sample_rate: int = None,
+ n_quantizers: int = None,
+ ):
+ """Model forward pass
+
+ Parameters
+ ----------
+ audio_data : Tensor[B x 1 x T]
+ Audio data to encode
+ sample_rate : int, optional
+ Sample rate of audio data in Hz, by default None
+ If None, defaults to `self.sample_rate`
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None.
+ If None, all quantizers are used.
+
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ "length" : int
+ Number of samples in input audio
+ "audio" : Tensor[B x 1 x length]
+ Decoded audio data.
+ """
+ length = audio_data.shape[-1]
+ audio_data = self.preprocess(audio_data, sample_rate)
+ if not self.continuous:
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
+
+ x = self.decode(z)
+ return {
+ "audio": x[..., :length],
+ "z": z,
+ "codes": codes,
+ "latents": latents,
+ "vq/commitment_loss": commitment_loss,
+ "vq/codebook_loss": codebook_loss,
+ }
+ else:
+ posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
+ z = posterior.sample()
+ x = self.decode(z)
+
+ kl_loss = posterior.kl()
+ kl_loss = kl_loss.mean()
+
+ return {
+ "audio": x[..., :length],
+ "z": z,
+ "kl_loss": kl_loss,
+ }
+
+
+if __name__ == "__main__":
+ import numpy as np
+ from functools import partial
+
+ model = DAC().to("cpu")
+
+ for n, m in model.named_modules():
+ o = m.extra_repr()
+ p = sum([np.prod(p.size()) for p in m.parameters()])
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
+ print(model)
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
+
+ length = 88200 * 2
+ x = torch.randn(1, 1, length).to(model.device)
+ x.requires_grad_(True)
+ x.retain_grad()
+
+ # Make a forward pass
+ out = model(x)["audio"]
+ print("Input shape:", x.shape)
+ print("Output shape:", out.shape)
+
+ # Create gradient variable
+ grad = torch.zeros_like(out)
+ grad[:, :, grad.shape[-1] // 2] = 1
+
+ # Make a backward pass
+ out.backward(grad)
+
+ # Check non-zero values
+ gradmap = x.grad.squeeze(0)
+ gradmap = (gradmap != 0).sum(0) # sum across features
+ rf = (gradmap != 0).sum()
+
+ print(f"Receptive field: {rf.item()}")
+
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
+ model.decompress(model.compress(x, verbose=True), verbose=True)
diff --git a/hunyuanvideo_foley/models/dac_vae/model/discriminator.py b/hunyuanvideo_foley/models/dac_vae/model/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c79d1342ca46bef21daca64667577f05e61638
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/model/discriminator.py
@@ -0,0 +1,228 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import ml
+from audiotools import STFTParams
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+def WNConv2d(*args, **kwargs):
+ act = kwargs.pop("act", True)
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
+ if not act:
+ return conv
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
+
+
+class MPD(nn.Module):
+ def __init__(self, period):
+ super().__init__()
+ self.period = period
+ self.convs = nn.ModuleList(
+ [
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
+ ]
+ )
+ self.conv_post = WNConv2d(
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
+ )
+
+ def pad_to_period(self, x):
+ t = x.shape[-1]
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
+ return x
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.pad_to_period(x)
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
+
+ for layer in self.convs:
+ x = layer(x)
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class MSD(nn.Module):
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
+ super().__init__()
+ self.convs = nn.ModuleList(
+ [
+ WNConv1d(1, 16, 15, 1, padding=7),
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
+ WNConv1d(1024, 1024, 5, 1, padding=2),
+ ]
+ )
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
+ self.sample_rate = sample_rate
+ self.rate = rate
+
+ def forward(self, x):
+ x = AudioSignal(x, self.sample_rate)
+ x.resample(self.sample_rate // self.rate)
+ x = x.audio_data
+
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
+
+
+class MRD(nn.Module):
+ def __init__(
+ self,
+ window_length: int,
+ hop_factor: float = 0.25,
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ ):
+ """Complex multi-band spectrogram discriminator.
+ Parameters
+ ----------
+ window_length : int
+ Window length of STFT.
+ hop_factor : float, optional
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run discriminator over.
+ """
+ super().__init__()
+
+ self.window_length = window_length
+ self.hop_factor = hop_factor
+ self.sample_rate = sample_rate
+ self.stft_params = STFTParams(
+ window_length=window_length,
+ hop_length=int(window_length * hop_factor),
+ match_stride=True,
+ )
+
+ n_fft = window_length // 2 + 1
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+ self.bands = bands
+
+ ch = 32
+ convs = lambda: nn.ModuleList(
+ [
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
+ ]
+ )
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
+
+ def spectrogram(self, x):
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
+ x = torch.view_as_real(x.stft())
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
+ # Split into bands
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+ return x_bands
+
+ def forward(self, x):
+ x_bands = self.spectrogram(x)
+ fmap = []
+
+ x = []
+ for band, stack in zip(x_bands, self.band_convs):
+ for layer in stack:
+ band = layer(band)
+ fmap.append(band)
+ x.append(band)
+
+ x = torch.cat(x, dim=-1)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class Discriminator(ml.BaseModel):
+ def __init__(
+ self,
+ rates: list = [],
+ periods: list = [2, 3, 5, 7, 11],
+ fft_sizes: list = [2048, 1024, 512],
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ ):
+ """Discriminator that combines multiple discriminators.
+
+ Parameters
+ ----------
+ rates : list, optional
+ sampling rates (in Hz) to run MSD at, by default []
+ If empty, MSD is not used.
+ periods : list, optional
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
+ fft_sizes : list, optional
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run MRD at, by default `BANDS`
+ """
+ super().__init__()
+ discs = []
+ discs += [MPD(p) for p in periods]
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
+ self.discriminators = nn.ModuleList(discs)
+
+ def preprocess(self, y):
+ # Remove DC offset
+ y = y - y.mean(dim=-1, keepdims=True)
+ # Peak normalize the volume of input audio
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+ return y
+
+ def forward(self, x):
+ x = self.preprocess(x)
+ fmaps = [d(x) for d in self.discriminators]
+ return fmaps
+
+
+if __name__ == "__main__":
+ disc = Discriminator()
+ x = torch.zeros(1, 1, 44100)
+ results = disc(x)
+ for i, result in enumerate(results):
+ print(f"disc{i}")
+ for i, r in enumerate(result):
+ print(r.shape, r.mean(), r.min(), r.max())
+ print()
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__init__.py b/hunyuanvideo_foley/models/dac_vae/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/nn/__init__.py
@@ -0,0 +1,3 @@
+from . import layers
+from . import loss
+from . import quantize
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d0ce8c3db3caad0dc1f1ca8ed34ea7fa2766785
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d403b09b59ce7996ac2af0b52421113b68446cf
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbc168386e2ae56ebb97431958c9e4eb621ea0ba
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cccc1a06bde5006b48106bcbd6d83a790efd311
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec2f4e05de4f45475a33633b7a0e4aa2a4884d1a
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1955f15c93492cb8c1d052508398b60e328815a
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01f0e8c343007975e4bec3f13e8490e6623b545e
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae130e90c7602297cc69d40f93d032c9770dc3b8
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..779686095f5751c00c76304407ea7b964da06143
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29ae286e4e1324debcebec0db3b3150d49bc4af3
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/layers.py b/hunyuanvideo_foley/models/dac_vae/nn/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/nn/layers.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+# Scripting this brings model speed up 1.4x
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/loss.py b/hunyuanvideo_foley/models/dac_vae/nn/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/nn/loss.py
@@ -0,0 +1,368 @@
+import typing
+from typing import List
+
+import torch
+import torch.nn.functional as F
+from audiotools import AudioSignal
+from audiotools import STFTParams
+from torch import nn
+
+
+class L1Loss(nn.L1Loss):
+ """L1 Loss between AudioSignals. Defaults
+ to comparing ``audio_data``, but any
+ attribute of an AudioSignal can be used.
+
+ Parameters
+ ----------
+ attribute : str, optional
+ Attribute of signal to compare, defaults to ``audio_data``.
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
+ self.attribute = attribute
+ self.weight = weight
+ super().__init__(**kwargs)
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate AudioSignal
+ y : AudioSignal
+ Reference AudioSignal
+
+ Returns
+ -------
+ torch.Tensor
+ L1 loss between AudioSignal attributes.
+ """
+ if isinstance(x, AudioSignal):
+ x = getattr(x, self.attribute)
+ y = getattr(y, self.attribute)
+ return super().forward(x, y)
+
+
+class SISDRLoss(nn.Module):
+ """
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
+ of estimated and reference audio signals or aligned features.
+
+ Parameters
+ ----------
+ scaling : int, optional
+ Whether to use scale-invariant (True) or
+ signal-to-noise ratio (False), by default True
+ reduction : str, optional
+ How to reduce across the batch (either 'mean',
+ 'sum', or none).], by default ' mean'
+ zero_mean : int, optional
+ Zero mean the references and estimates before
+ computing the loss, by default True
+ clip_min : int, optional
+ The minimum possible loss value. Helps network
+ to not focus on making already good examples better, by default None
+ weight : float, optional
+ Weight of this loss, defaults to 1.0.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
+ """
+
+ def __init__(
+ self,
+ scaling: int = True,
+ reduction: str = "mean",
+ zero_mean: int = True,
+ clip_min: int = None,
+ weight: float = 1.0,
+ ):
+ self.scaling = scaling
+ self.reduction = reduction
+ self.zero_mean = zero_mean
+ self.clip_min = clip_min
+ self.weight = weight
+ super().__init__()
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ eps = 1e-8
+ # nb, nc, nt
+ if isinstance(x, AudioSignal):
+ references = x.audio_data
+ estimates = y.audio_data
+ else:
+ references = x
+ estimates = y
+
+ nb = references.shape[0]
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
+
+ # samples now on axis 1
+ if self.zero_mean:
+ mean_reference = references.mean(dim=1, keepdim=True)
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
+ else:
+ mean_reference = 0
+ mean_estimate = 0
+
+ _references = references - mean_reference
+ _estimates = estimates - mean_estimate
+
+ references_projection = (_references**2).sum(dim=-2) + eps
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
+
+ scale = (
+ (references_on_estimates / references_projection).unsqueeze(1)
+ if self.scaling
+ else 1
+ )
+
+ e_true = scale * _references
+ e_res = _estimates - e_true
+
+ signal = (e_true**2).sum(dim=1)
+ noise = (e_res**2).sum(dim=1)
+ sdr = -10 * torch.log10(signal / noise + eps)
+
+ if self.clip_min is not None:
+ sdr = torch.clamp(sdr, min=self.clip_min)
+
+ if self.reduction == "mean":
+ sdr = sdr.mean()
+ elif self.reduction == "sum":
+ sdr = sdr.sum()
+ return sdr
+
+
+class MultiScaleSTFTLoss(nn.Module):
+ """Computes the multi-scale STFT loss from [1].
+
+ Parameters
+ ----------
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ References
+ ----------
+
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
+ "DDSP: Differentiable Digital Signal Processing."
+ International Conference on Learning Representations. 2019.
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.loss_fn = loss_fn
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.clamp_eps = clamp_eps
+ self.weight = weight
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes multi-scale STFT between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Multi-scale STFT loss.
+ """
+ loss = 0.0
+ for s in self.stft_params:
+ x.stft(s.window_length, s.hop_length, s.window_type)
+ y.stft(s.window_length, s.hop_length, s.window_type)
+ loss += self.log_weight * self.loss_fn(
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
+ return loss
+
+
+class MelSpectrogramLoss(nn.Module):
+ """Compute distance between mel spectrograms. Can be used
+ in a multi-scale way.
+
+ Parameters
+ ----------
+ n_mels : List[int]
+ Number of mels per STFT, by default [150, 80],
+ window_lengths : List[int], optional
+ Length of each window of each STFT, by default [2048, 512]
+ loss_fn : typing.Callable, optional
+ How to compare each loss, by default nn.L1Loss()
+ clamp_eps : float, optional
+ Clamp on the log magnitude, below, by default 1e-5
+ mag_weight : float, optional
+ Weight of raw magnitude portion of loss, by default 1.0
+ log_weight : float, optional
+ Weight of log magnitude portion of loss, by default 1.0
+ pow : float, optional
+ Power to raise magnitude to before taking log, by default 2.0
+ weight : float, optional
+ Weight of this loss, by default 1.0
+ match_stride : bool, optional
+ Whether to match the stride of convolutional layers, by default False
+
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
+ """
+
+ def __init__(
+ self,
+ n_mels: List[int] = [150, 80],
+ window_lengths: List[int] = [2048, 512],
+ loss_fn: typing.Callable = nn.L1Loss(),
+ clamp_eps: float = 1e-5,
+ mag_weight: float = 1.0,
+ log_weight: float = 1.0,
+ pow: float = 2.0,
+ weight: float = 1.0,
+ match_stride: bool = False,
+ mel_fmin: List[float] = [0.0, 0.0],
+ mel_fmax: List[float] = [None, None],
+ window_type: str = None,
+ ):
+ super().__init__()
+ self.stft_params = [
+ STFTParams(
+ window_length=w,
+ hop_length=w // 4,
+ match_stride=match_stride,
+ window_type=window_type,
+ )
+ for w in window_lengths
+ ]
+ self.n_mels = n_mels
+ self.loss_fn = loss_fn
+ self.clamp_eps = clamp_eps
+ self.log_weight = log_weight
+ self.mag_weight = mag_weight
+ self.weight = weight
+ self.mel_fmin = mel_fmin
+ self.mel_fmax = mel_fmax
+ self.pow = pow
+
+ def forward(self, x: AudioSignal, y: AudioSignal):
+ """Computes mel loss between an estimate and a reference
+ signal.
+
+ Parameters
+ ----------
+ x : AudioSignal
+ Estimate signal
+ y : AudioSignal
+ Reference signal
+
+ Returns
+ -------
+ torch.Tensor
+ Mel loss.
+ """
+ loss = 0.0
+ for n_mels, fmin, fmax, s in zip(
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
+ ):
+ kwargs = {
+ "window_length": s.window_length,
+ "hop_length": s.hop_length,
+ "window_type": s.window_type,
+ }
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
+
+ loss += self.log_weight * self.loss_fn(
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
+ )
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
+ return loss
+
+
+class GANLoss(nn.Module):
+ """
+ Computes a discriminator loss, given a discriminator on
+ generated waveforms/spectrograms compared to ground truth
+ waveforms/spectrograms. Computes the loss for both the
+ discriminator and the generator in separate functions.
+ """
+
+ def __init__(self, discriminator):
+ super().__init__()
+ self.discriminator = discriminator
+
+ def forward(self, fake, real):
+ d_fake = self.discriminator(fake.audio_data)
+ d_real = self.discriminator(real.audio_data)
+ return d_fake, d_real
+
+ def discriminator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+ return loss_d
+
+ def generator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake, real)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+ return loss_g, loss_feature
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/quantize.py b/hunyuanvideo_foley/models/dac_vae/nn/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/nn/quantize.py
@@ -0,0 +1,262 @@
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+from .layers import WNConv1d
+
+
+class VectorQuantize(nn.Module):
+ """
+ Implementation of VQ similar to Karpathy's repo:
+ https://github.com/karpathy/deep-vector-quantization
+ Additionally uses following tricks from Improved VQGAN
+ (https://arxiv.org/pdf/2110.04627.pdf):
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+ for improved codebook usage
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+ improves training stability
+ """
+
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+ z_e = self.in_proj(z) # z_e : (B x D x T)
+ z_q, indices = self.decode_latents(z_e)
+
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = self.out_proj(z_q)
+
+ return z_q, commitment_loss, codebook_loss, indices, z_e
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+
+ # L2 normalize encodings and codebook (ViT-VQGAN)
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+ return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+ """
+ Introduced in SoundStream: An end2end neural audio codec
+ https://arxiv.org/abs/2107.03312
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: float = 0.0,
+ ):
+ super().__init__()
+ if isinstance(codebook_dim, int):
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+ self.n_codebooks = n_codebooks
+ self.codebook_dim = codebook_dim
+ self.codebook_size = codebook_size
+
+ self.quantizers = nn.ModuleList(
+ [
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+ for i in range(n_codebooks)
+ ]
+ )
+ self.quantizer_dropout = quantizer_dropout
+
+ def forward(self, z, n_quantizers: int = None):
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
+ the corresponding codebook vectors
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ n_quantizers : int, optional
+ No. of quantizers to use
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
+ when in training mode, and a random number of quantizers is used.
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ """
+ z_q = 0
+ residual = z
+ commitment_loss = 0
+ codebook_loss = 0
+
+ codebook_indices = []
+ latents = []
+
+ if n_quantizers is None:
+ n_quantizers = self.n_codebooks
+ if self.training:
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(z.device)
+
+ for i, quantizer in enumerate(self.quantizers):
+ if self.training is False and i >= n_quantizers:
+ break
+
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+ residual
+ )
+
+ # Create mask to apply quantizer dropout
+ mask = (
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+ )
+ z_q = z_q + z_q_i * mask[:, None, None]
+ residual = residual - z_q_i
+
+ # Sum losses
+ commitment_loss += (commitment_loss_i * mask).mean()
+ codebook_loss += (codebook_loss_i * mask).mean()
+
+ codebook_indices.append(indices_i)
+ latents.append(z_e_i)
+
+ codes = torch.stack(codebook_indices, dim=1)
+ latents = torch.cat(latents, dim=1)
+
+ return z_q, codes, latents, commitment_loss, codebook_loss
+
+ def from_codes(self, codes: torch.Tensor):
+ """Given the quantized codes, reconstruct the continuous representation
+ Parameters
+ ----------
+ codes : Tensor[B x N x T]
+ Quantized discrete representation of input
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ """
+ z_q = 0.0
+ z_p = []
+ n_codebooks = codes.shape[1]
+ for i in range(n_codebooks):
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+ z_p.append(z_p_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+ return z_q, torch.cat(z_p, dim=1), codes
+
+ def from_latents(self, latents: torch.Tensor):
+ """Given the unquantized latents, reconstruct the
+ continuous representation after quantization.
+
+ Parameters
+ ----------
+ latents : Tensor[B x N x T]
+ Continuous representation of input after projection
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized representation of full-projected space
+ Tensor[B x D x T]
+ Quantized representation of latent space
+ """
+ z_q = 0
+ z_p = []
+ codes = []
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+ 0
+ ]
+ for i in range(n_codebooks):
+ j, k = dims[i], dims[i + 1]
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+ z_p.append(z_p_i)
+ codes.append(codes_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+if __name__ == "__main__":
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
+ x = torch.randn(16, 512, 80)
+ y = rvq(x)
+ print(y["latents"].shape)
diff --git a/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py b/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a97597f5d5ae4aa19a194c24f3c17b2238224bcf
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py
@@ -0,0 +1,91 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.mean(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2],
+ )
+ else:
+ return 0.5 * torch.mean(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2],
+ )
+
+ def nll(self, sample, dims=[1, 2]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
+
+ return 0.5 * (
+ -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/hunyuanvideo_foley/models/dac_vae/utils/__init__.py b/hunyuanvideo_foley/models/dac_vae/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dce1ed49f1b4e4fe1cb42b054298911207e0e41
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/utils/__init__.py
@@ -0,0 +1,121 @@
+from pathlib import Path
+
+import argbind
+from audiotools import ml
+
+from ..model import DAC
+Accelerator = ml.Accelerator
+
+__MODEL_LATEST_TAGS__ = {
+ ("44khz", "8kbps"): "0.0.1",
+ ("24khz", "8kbps"): "0.0.4",
+ ("16khz", "8kbps"): "0.0.5",
+ ("44khz", "16kbps"): "1.0.0",
+}
+
+__MODEL_URLS__ = {
+ (
+ "44khz",
+ "0.0.1",
+ "8kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
+ (
+ "24khz",
+ "0.0.4",
+ "8kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
+ (
+ "16khz",
+ "0.0.5",
+ "8kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
+ (
+ "44khz",
+ "1.0.0",
+ "16kbps",
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
+}
+
+
+@argbind.bind(group="download", positional=True, without_prefix=True)
+def download(
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
+):
+ """
+ Function that downloads the weights file from URL if a local cache is not found.
+
+ Parameters
+ ----------
+ model_type : str
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
+ model_bitrate: str
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+ Only 44khz model supports 16kbps.
+ tag : str
+ The tag of the model to download. Defaults to "latest".
+
+ Returns
+ -------
+ Path
+ Directory path required to load model via audiotools.
+ """
+ model_type = model_type.lower()
+ tag = tag.lower()
+
+ assert model_type in [
+ "44khz",
+ "24khz",
+ "16khz",
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
+
+ assert model_bitrate in [
+ "8kbps",
+ "16kbps",
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
+
+ if tag == "latest":
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
+
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
+
+ if download_link is None:
+ raise ValueError(
+ f"Could not find model with tag {tag} and model type {model_type}"
+ )
+
+ local_path = (
+ Path.home()
+ / ".cache"
+ / "descript"
+ / "dac"
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
+ )
+ if not local_path.exists():
+ local_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Download the model
+ import requests
+
+ response = requests.get(download_link)
+
+ if response.status_code != 200:
+ raise ValueError(
+ f"Could not download model. Received response code {response.status_code}"
+ )
+ local_path.write_bytes(response.content)
+
+ return local_path
+
+
+def load_model(
+ model_type: str = "44khz",
+ model_bitrate: str = "8kbps",
+ tag: str = "latest",
+ load_path: str = None,
+):
+ if not load_path:
+ load_path = download(
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
+ )
+ generator = DAC.load(load_path)
+ return generator
diff --git a/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1a9a3fd6342769d3b5f2e8f1d638306ae56a8b5
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..783f2a50bc5c9f9d80e509f8a31705fee1d18570
Binary files /dev/null and b/hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/dac_vae/utils/decode.py b/hunyuanvideo_foley/models/dac_vae/utils/decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..00261a561251b1bef6f11e6594bf80de10b93ff2
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/utils/decode.py
@@ -0,0 +1,95 @@
+import warnings
+from pathlib import Path
+
+import argbind
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from tqdm import tqdm
+
+from ..model import DACFile
+from . import load_model
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+@argbind.bind(group="decode", positional=True, without_prefix=True)
+@torch.inference_mode()
+@torch.no_grad()
+def decode(
+ input: str,
+ output: str = "",
+ weights_path: str = "",
+ model_tag: str = "latest",
+ model_bitrate: str = "8kbps",
+ device: str = "cuda",
+ model_type: str = "44khz",
+ verbose: bool = False,
+):
+ """Decode audio from codes.
+
+ Parameters
+ ----------
+ input : str
+ Path to input directory or file
+ output : str, optional
+ Path to output directory, by default "".
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
+ weights_path : str, optional
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
+ model_tag and model_type.
+ model_tag : str, optional
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
+ model_bitrate: str
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+ device : str, optional
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
+ model_type : str, optional
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
+ """
+ generator = load_model(
+ model_type=model_type,
+ model_bitrate=model_bitrate,
+ tag=model_tag,
+ load_path=weights_path,
+ )
+ generator.to(device)
+ generator.eval()
+
+ # Find all .dac files in input directory
+ _input = Path(input)
+ input_files = list(_input.glob("**/*.dac"))
+
+ # If input is a .dac file, add it to the list
+ if _input.suffix == ".dac":
+ input_files.append(_input)
+
+ # Create output directory
+ output = Path(output)
+ output.mkdir(parents=True, exist_ok=True)
+
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
+ # Load file
+ artifact = DACFile.load(input_files[i])
+
+ # Reconstruct audio from codes
+ recons = generator.decompress(artifact, verbose=verbose)
+
+ # Compute output path
+ relative_path = input_files[i].relative_to(input)
+ output_dir = output / relative_path.parent
+ if not relative_path.name:
+ output_dir = output
+ relative_path = input_files[i]
+ output_name = relative_path.with_suffix(".wav").name
+ output_path = output_dir / output_name
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Write to file
+ recons.write(output_path)
+
+
+if __name__ == "__main__":
+ args = argbind.parse_args()
+ with argbind.scope(args):
+ decode()
diff --git a/hunyuanvideo_foley/models/dac_vae/utils/encode.py b/hunyuanvideo_foley/models/dac_vae/utils/encode.py
new file mode 100644
index 0000000000000000000000000000000000000000..c86946c3c6d6a7ff1d1ea883c600d9b93c41b7d9
--- /dev/null
+++ b/hunyuanvideo_foley/models/dac_vae/utils/encode.py
@@ -0,0 +1,94 @@
+import math
+import warnings
+from pathlib import Path
+
+import argbind
+import numpy as np
+import torch
+from audiotools import AudioSignal
+from audiotools.core import util
+from tqdm import tqdm
+
+from . import load_model
+
+warnings.filterwarnings("ignore", category=UserWarning)
+
+
+@argbind.bind(group="encode", positional=True, without_prefix=True)
+@torch.inference_mode()
+@torch.no_grad()
+def encode(
+ input: str,
+ output: str = "",
+ weights_path: str = "",
+ model_tag: str = "latest",
+ model_bitrate: str = "8kbps",
+ n_quantizers: int = None,
+ device: str = "cuda",
+ model_type: str = "44khz",
+ win_duration: float = 5.0,
+ verbose: bool = False,
+):
+ """Encode audio files in input path to .dac format.
+
+ Parameters
+ ----------
+ input : str
+ Path to input audio file or directory
+ output : str, optional
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
+ weights_path : str, optional
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
+ model_tag and model_type.
+ model_tag : str, optional
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
+ model_bitrate: str
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
+ n_quantizers : int, optional
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
+ device : str, optional
+ Device to use, by default "cuda"
+ model_type : str, optional
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
+ """
+ generator = load_model(
+ model_type=model_type,
+ model_bitrate=model_bitrate,
+ tag=model_tag,
+ load_path=weights_path,
+ )
+ generator.to(device)
+ generator.eval()
+ kwargs = {"n_quantizers": n_quantizers}
+
+ # Find all audio files in input path
+ input = Path(input)
+ audio_files = util.find_audio(input)
+
+ output = Path(output)
+ output.mkdir(parents=True, exist_ok=True)
+
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
+ # Load file
+ signal = AudioSignal(audio_files[i])
+
+ # Encode audio to .dac format
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
+
+ # Compute output path
+ relative_path = audio_files[i].relative_to(input)
+ output_dir = output / relative_path.parent
+ if not relative_path.name:
+ output_dir = output
+ relative_path = audio_files[i]
+ output_name = relative_path.with_suffix(".dac").name
+ output_path = output_dir / output_name
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ artifact.save(output_path)
+
+
+if __name__ == "__main__":
+ args = argbind.parse_args()
+ with argbind.scope(args):
+ encode()
diff --git a/hunyuanvideo_foley/models/hifi_foley.py b/hunyuanvideo_foley/models/hifi_foley.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b5fffcca7a61c3baf686031e9ffdbc3696f3e22
--- /dev/null
+++ b/hunyuanvideo_foley/models/hifi_foley.py
@@ -0,0 +1,794 @@
+from typing import List, Tuple, Optional, Union, Dict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from diffusers.models import ModelMixin
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+
+from .nn.activation_layers import SwiGLU, get_activation_layer
+from .nn.attn_layers import apply_rotary_emb, attention
+from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D
+from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d
+from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate
+from .nn.norm_layers import get_norm_layer
+from .nn.posemb_layers import get_nd_rotary_pos_embed
+
+def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
+ # [B, N1, H, C] & [B, N2, H, C]
+ B, N1, H, C = x1.shape
+ B, N2, H, C = x2.shape
+ assert x1.ndim == x2.ndim == 4
+
+ if N1 != N2:
+ x2 = x2.view(B, N2, -1).transpose(1, 2)
+ x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
+ x2 = x2.transpose(1, 2).view(B, N1, H, C)
+ x = torch.stack((x1, x2), dim=2)
+ x = x.reshape(B, N1 * 2, H, C)
+ return x
+
+def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
+ B, N, H, C = x.shape
+ assert N % 2 == 0 and N // 2 == len1
+
+ x = x.reshape(B, -1, 2, H, C)
+ x1 = x[:, :, 0]
+ x2 = x[:, :, 1]
+ if x2.shape[1] != len2:
+ x2 = x2.view(B, len1, H * C).transpose(1, 2)
+ x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
+ x2 = x2.transpose(1, 2).view(B, len2, H, C)
+ return x1, x2
+
+class TwoStreamCABlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float,
+ mlp_act_type: str = "gelu_tanh",
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ qkv_bias: bool = False,
+ attn_mode: str = "torch",
+ reverse: bool = False,
+ interleaved_audio_visual_rope: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.deterministic = False
+ self.reverse = reverse
+ self.attn_mode = attn_mode
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ head_dim = hidden_size // num_heads
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
+
+ # Self attention for audio + visual
+ self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
+ self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.audio_self_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.audio_self_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ # visual cond
+ self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
+ self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
+ self.v_cond_attn_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.v_cond_attn_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ self.max_text_len = 100
+ self.rope_dim_list = None
+
+ # audio and video norm for cross attention with text
+ self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ # Cross attention: (video_audio) as query, text as key/value
+ self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+ self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+ self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
+
+ self.audio_cross_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.v_cond_cross_q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.text_cross_k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+ self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
+
+ # MLPs
+ self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.audio_mlp = MLP(
+ hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
+ )
+
+ self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.v_cond_mlp = MLP(
+ hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
+ )
+
+ def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
+ target_ndim = 1 # n-d RoPE
+ rope_sizes = [text_len]
+
+ if rope_dim_list is None:
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
+
+ text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
+ rope_dim_list=rope_dim_list,
+ start=rope_sizes,
+ theta=10000,
+ use_real=True,
+ theta_rescale_factor=1.0,
+ )
+ return text_freqs_cos, text_freqs_sin
+
+ def set_attn_mode(self, new_mode):
+ if new_mode != "torch":
+ raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.")
+ self.attn_mode = new_mode
+
+ def enable_deterministic(self):
+ self.deterministic = True
+
+ def disable_deterministic(self):
+ self.deterministic = False
+
+ def forward(
+ self,
+ audio: torch.Tensor,
+ cond: torch.Tensor,
+ v_cond: torch.Tensor,
+ attn_mask: torch.Tensor,
+ vec: torch.Tensor,
+ freqs_cis: tuple = None,
+ v_freqs_cis: tuple = None,
+ sync_vec: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # Get modulation parameters
+ if sync_vec is not None:
+ assert sync_vec.ndim == 3
+ (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
+ audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
+ audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
+ ) = self.audio_mod(sync_vec).chunk(9, dim=-1)
+ else:
+ (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate,
+ audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
+ audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
+ ) = self.audio_mod(vec).chunk(9, dim=-1)
+
+ (
+ v_cond_mod1_shift,
+ v_cond_mod1_scale,
+ v_cond_mod1_gate,
+ v_cond_mod2_shift,
+ v_cond_mod2_scale,
+ v_cond_mod2_gate,
+ v_cond_mod3_shift,
+ v_cond_mod3_scale,
+ v_cond_mod3_gate,
+ ) = self.v_cond_mod(vec).chunk(9, dim=-1)
+
+ # 1. Self Attention for audio + visual
+ audio_modulated = self.audio_norm1(audio)
+ audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale)
+ audio_qkv = self.audio_self_attn_qkv(audio_modulated)
+ audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
+ audio_q = self.audio_self_q_norm(audio_q).to(audio_v)
+ audio_k = self.audio_self_k_norm(audio_k).to(audio_v)
+
+ # Prepare visual cond for attention
+ v_cond_modulated = self.v_cond_norm1(v_cond)
+ v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale)
+ v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated)
+ v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
+ v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v)
+ v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v)
+
+ # Apply RoPE if needed for audio and visual
+ if freqs_cis is not None:
+ if not self.interleaved_audio_visual_rope:
+ audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
+ audio_q, audio_k = audio_qq, audio_kk
+ else:
+ ori_audio_len = audio_q.shape[1]
+ ori_v_con_len = v_cond_q.shape[1]
+ interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
+ interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
+ interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
+ interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
+ )
+ audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
+ interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
+ )
+ audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
+ interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
+ )
+ audio_q, audio_k = audio_qq, audio_kk
+ v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
+
+ # Apply RoPE to visual if needed and not interleaved
+ if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
+ v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
+ v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
+
+ # Concatenate for self-attention
+ q = torch.cat((v_cond_q, audio_q), dim=1)
+ k = torch.cat((v_cond_k, audio_k), dim=1)
+ v = torch.cat((v_cond_v, audio_v), dim=1)
+
+ # Run self-attention
+ attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic)
+ v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
+
+ # Apply self-attention output to audio and v_cond
+ audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
+ v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
+
+ # 2. Cross Attention: (v_cond, audio) as query, text as key/value
+ # audio, v_cond modulation
+ audio_modulated = self.audio_norm2(audio)
+ audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale)
+ v_cond_modulated = self.v_cond_norm2(v_cond)
+ v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale)
+
+ # Prepare audio query
+ audio_q = self.audio_cross_q(audio_modulated)
+ audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads)
+ audio_q = self.audio_cross_q_norm(audio_q)
+
+ # Prepare v_cond query
+ v_cond_q = self.v_cond_cross_q(v_cond_modulated)
+ v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads)
+ v_cond_q = self.v_cond_cross_q_norm(v_cond_q)
+
+ # Prepare text key/value
+ text_kv = self.text_cross_kv(cond)
+ text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
+ text_k = self.text_cross_k_norm(text_k).to(text_v)
+
+ # Apply RoPE to (v_cond, audio) query and text key if needed
+ head_dim = self.hidden_size // self.num_heads
+ audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
+ audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device))
+ audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0]
+
+ v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
+ v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device))
+ v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0]
+
+ text_len = text_k.shape[1]
+
+ text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim,
+ rope_dim_list=self.rope_dim_list)
+ text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
+ text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
+
+ # Concat v_cond and audio for cross-attention
+ v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
+
+ # Run cross-attention
+ cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic)
+ v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
+
+ # Apply cross-attention output
+ audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
+ v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
+
+ # 3. Apply MLPs
+ audio = audio + apply_gate(
+ self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)),
+ gate=audio_mod3_gate,
+ )
+
+ # Apply visual MLP
+ v_cond = v_cond + apply_gate(
+ self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)),
+ gate=v_cond_mod3_gate,
+ )
+
+ return audio, cond, v_cond
+
+class SingleStreamBlock(nn.Module):
+
+ def __init__(self, hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float,
+ qk_norm_type: str = "rms",
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+
+ self.modulation = ModulateDiT(
+ hidden_size=hidden_size,
+ factor=6,
+ act_layer=get_activation_layer("silu"),
+ **factory_kwargs,
+ )
+ self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
+ self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs)
+ self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs)
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
+ self.q_norm = nn.RMSNorm(hidden_size // num_heads)
+ self.k_norm = nn.RMSNorm(hidden_size // num_heads)
+ self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
+
+ def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
+ assert cond.ndim == 3, "Condition should be in shape of [B, T, D]"
+ modulation = self.modulation(cond)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
+ x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
+
+ qkv = self.linear_qkv(x_norm1)
+ q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
+ q = q.squeeze(-1)
+ k = k.squeeze(-1)
+ v = v.squeeze(-1)
+
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
+
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = F.scaled_dot_product_attention(q, k, v)
+ out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
+
+ x = x + apply_gate(self.linear1(out),gate=gate_msa)
+ x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
+ x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
+
+ return x
+
+class HunyuanVideoFoley(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ model_config,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ model_args = model_config.model_config.model_kwargs
+ self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19)
+ self.depth_single_blocks = model_args.get("depth_single_blocks", 38)
+ # Gradient checkpoint.
+ self.gradient_checkpoint = False
+ self.gradient_checkpoint_layers = None
+ if self.gradient_checkpoint:
+ assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, (
+ f"Gradient checkpoint layers must be less or equal than the depth of the model. "
+ f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}."
+ )
+
+ self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False)
+
+ # Condition projection. Default to linear projection.
+ self.condition_projection = model_args.get("condition_projection", "linear")
+ self.condition_dim = model_args.get("condition_dim", None)
+ self.use_attention_mask = model_args.get("use_attention_mask", False)
+
+ self.patch_size = model_args.get("patch_size", 1)
+ self.visual_in_channels = model_args.get("clip_dim", 768)
+ self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
+ self.out_channels = self.audio_vae_latent_dim
+ self.unpatchify_channels = self.out_channels
+ self.reverse = model_args.get("reverse", False)
+
+ self.num_heads = model_args.get("num_heads", 24)
+ self.hidden_size = model_args.get("hidden_size", 3072)
+ self.rope_dim_list = model_args.get("rope_dim_list", None)
+ self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
+ self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh")
+
+ self.qkv_bias = model_args.get("qkv_bias", True)
+ self.qk_norm = model_args.get("qk_norm", True)
+ self.qk_norm_type = model_args.get("qk_norm_type", "rms")
+ self.attn_mode = model_args.get("attn_mode", "torch")
+
+ self.embedder_type = model_args.get("embedder_type", "default")
+
+ # sync condition things
+ self.sync_modulation = model_args.get("sync_modulation", False)
+ self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False)
+ self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
+ self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
+
+ # condition tokens length
+ self.clip_len = model_args.get("clip_length", 64)
+ self.sync_len = model_args.get("sync_length", 192)
+
+ if self.hidden_size % self.num_heads != 0:
+ raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}")
+
+ # Build audio patchify layer and visual gated linear projection
+ self.patch_size = 1
+ self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs)
+ self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size)
+
+ # condition
+ if self.condition_projection == "linear":
+ self.cond_in = ConditionProjection(
+ self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs
+ )
+ else:
+ raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}")
+
+ # time modulation
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
+
+ # visual sync embedder if needed
+ if self.sync_in_ksz == 1:
+ sync_in_padding = 0
+ elif self.sync_in_ksz == 3:
+ sync_in_padding = 1
+ else:
+ raise ValueError
+ if self.sync_modulation or self.add_sync_feat_to_audio:
+ self.sync_in = nn.Sequential(
+ nn.Linear(self.sync_feat_dim, self.hidden_size),
+ nn.SiLU(),
+ ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding),
+ )
+ self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim)))
+
+ self.triple_blocks = nn.ModuleList(
+ [
+ TwoStreamCABlock(
+ hidden_size=self.hidden_size,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ mlp_act_type=self.mlp_act_type,
+ qk_norm=self.qk_norm,
+ qk_norm_type=self.qk_norm_type,
+ qkv_bias=self.qkv_bias,
+ attn_mode=self.attn_mode,
+ reverse=self.reverse,
+ interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
+ **factory_kwargs,
+ )
+ for _ in range(self.depth_triple_blocks)
+ ]
+ )
+
+
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(
+ hidden_size=self.hidden_size,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ qk_norm_type=self.qk_norm_type,
+ **factory_kwargs,
+ )
+ for _ in range(self.depth_single_blocks)
+ ]
+ )
+
+ self.final_layer = FinalLayer1D(
+ self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs
+ )
+ self.unpatchify_channels = self.out_channels
+
+ self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True)
+ self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True)
+ nn.init.constant_(self.empty_clip_feat, 0)
+ nn.init.constant_(self.empty_sync_feat, 0)
+
+ def get_empty_string_sequence(self, bs=None) -> torch.Tensor:
+ if bs is None:
+ return self.empty_string_feat
+ else:
+ return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
+
+ def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
+ len = len if len is not None else self.clip_len
+ if bs is None:
+ return self.empty_clip_feat.expand(len, -1) # 15s
+ else:
+ return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1) # 15s
+
+ def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
+ len = len if len is not None else self.sync_len
+ if bs is None:
+ return self.empty_sync_feat.expand(len, -1)
+ else:
+ return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
+
+ def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
+ assert self.patch_size == 1
+ # ======================================== Build RoPE for audio tokens ======================================
+ target_ndim = 1 # n-d RoPE
+ rope_sizes = [audio_emb_len]
+ head_dim = self.hidden_size // self.num_heads
+ rope_dim_list = self.rope_dim_list
+ if rope_dim_list is None:
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
+ rope_dim_list=rope_dim_list,
+ start=rope_sizes,
+ theta=10000,
+ use_real=True,
+ theta_rescale_factor=1.0,
+ )
+
+ # ========================== Build RoPE for clip tokens =========================
+ target_ndim = 1 # n-d RoPE
+ rope_sizes = [visual_cond_len]
+ head_dim = self.hidden_size // self.num_heads
+ rope_dim_list = self.rope_dim_list
+ if rope_dim_list is None:
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
+ v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
+ rope_dim_list=rope_dim_list,
+ start=rope_sizes,
+ theta=10000,
+ use_real=True,
+ theta_rescale_factor=1.0,
+ freq_scaling=1.0 * audio_emb_len / visual_cond_len,
+ )
+ return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
+
+ def build_rope_for_interleaved_audio_visual(self, total_len):
+ assert self.patch_size == 1
+ # ========================== Build RoPE for audio tokens ========================
+ target_ndim = 1 # n-d RoPE
+ rope_sizes = [total_len]
+ head_dim = self.hidden_size // self.num_heads
+ rope_dim_list = self.rope_dim_list
+ if rope_dim_list is None:
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
+ rope_dim_list=rope_dim_list,
+ start=rope_sizes,
+ theta=10000,
+ use_real=True,
+ theta_rescale_factor=1.0,
+ )
+ return freqs_cos, freqs_sin
+
+ def set_attn_mode(self, new_mode):
+ for block in self.triple_blocks:
+ block.set_attn_mode(new_mode)
+ for block in self.single_blocks:
+ block.set_attn_mode(new_mode)
+
+ def enable_deterministic(self):
+ for block in self.triple_blocks:
+ block.enable_deterministic()
+ for block in self.single_blocks:
+ block.enable_deterministic()
+
+ def disable_deterministic(self):
+ for block in self.triple_blocks:
+ block.disable_deterministic()
+ for block in self.single_blocks:
+ block.disable_deterministic()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.Tensor, # Should be in range(0, 1000).
+ clip_feat: Optional[torch.Tensor] = None,
+ cond: torch.Tensor = None,
+ audio_mask: Optional[torch.Tensor] = None,
+ cond_mask: torch.Tensor = None,
+ sync_feat: Optional[torch.Tensor] = None,
+ drop_visual: Optional[List[bool]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ out = {}
+ audio = x
+ bs, _, ol = x.shape
+ tl = ol // self.patch_size
+
+ # Prepare learnable empty conditions for visual condition
+ if drop_visual is not None:
+ clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
+ sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
+
+ # ========================= Prepare time & visual modulation =========================
+ vec = self.time_in(t)
+ sync_vec = None
+ if self.sync_modulation:
+ assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
+ sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb
+ sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
+ sync_vec = self.sync_in(sync_feat) # bs, num_segments * 8, c
+ sync_vec = (
+ F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
+ ) # bs, tl, c
+ sync_vec = sync_vec + vec.unsqueeze(1)
+ elif self.add_sync_feat_to_audio:
+ assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
+ sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb
+ sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim) # bs, num_segments * 8, channels
+ sync_feat = self.sync_in(sync_feat) # bs, num_segments * 8, c
+ add_sync_feat_to_audio = (
+ F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
+ ) # bs, tl, c
+
+ # ========================= Get text, audio and video clip embedding =========================
+ cond = self.cond_in(cond)
+ cond_seq_len = cond.shape[1]
+
+ audio = self.audio_embedder(x)
+ audio_seq_len = audio.shape[1]
+ v_cond = self.visual_proj(clip_feat)
+ v_cond_seq_len = v_cond.shape[1]
+
+ # ========================= Compute attention mask =========================
+ attn_mask = None
+ if self.use_attention_mask:
+ assert cond_mask is not None
+ batch_size = audio.shape[0]
+ seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len
+
+ # get default audio_mask and v_cond_mask
+ audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device)
+ v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device)
+
+ # batch_size x seq_len
+ concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1)
+ # batch_size x 1 x seq_len x seq_len
+ attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ # batch_size x 1 x seq_len x seq_len
+ attn_mask_2 = attn_mask_1.transpose(2, 3)
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
+ attn_mask = (attn_mask_1 & attn_mask_2).bool()
+ # avoids self-attention weight being NaN for text padding tokens
+ attn_mask[:, :, :, 0] = True
+
+
+ # ========================= Build rope for audio and clip tokens =========================
+ if self.interleaved_audio_visual_rope:
+ freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
+ v_freqs_cos = v_freqs_sin = None
+ else:
+ freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual(
+ audio_seq_len, v_cond_seq_len
+ )
+
+ # ========================= Pass through DiT blocks =========================
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
+ v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
+
+ if self.add_sync_feat_to_audio:
+ add_sync_layer = 0
+ assert (
+ add_sync_layer < self.depth_triple_blocks
+ ), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})."
+ # Triple-stream blocks
+ for layer_num, block in enumerate(self.triple_blocks):
+ if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
+ audio = audio + add_sync_feat_to_audio
+ triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
+ if (
+ self.training
+ and self.gradient_checkpoint
+ and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers)
+ ):
+ audio, cond, v_cond = torch.utils.checkpoint.checkpoint(
+ ckpt_wrapper(block), *triple_block_args, use_reentrant=False
+ )
+ else:
+ audio, cond, v_cond = block(*triple_block_args)
+
+ x = audio
+ if sync_vec is not None:
+ vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
+ vec = torch.cat((vec, sync_vec), dim=1)
+
+ freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
+ if self.add_sync_feat_to_audio:
+ vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
+ if len(self.single_blocks) > 0:
+ for layer_num, block in enumerate(self.single_blocks):
+ single_block_args = [
+ x,
+ vec,
+ (freqs_cos, freqs_sin),
+ ]
+ if (
+ self.training
+ and self.gradient_checkpoint
+ and (
+ self.gradient_checkpoint_layers == -1
+ or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers
+ )
+ ):
+ x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
+ else:
+ x = block(*single_block_args)
+
+ audio = x
+
+ # ========================= Final layer =========================
+ if sync_vec is not None:
+ vec = sync_vec
+ audio = self.final_layer(audio, vec) # (N, T, patch_size * out_channels)
+ audio = self.unpatchify1d(audio, tl)
+
+ if return_dict:
+ out["x"] = audio
+ return out
+ return audio
+
+ def unpatchify1d(self, x, l):
+ # x: (N, L, patch_size * C)
+ # audio: (N, C, T), T == L * patch_size
+ c = self.unpatchify_channels
+ p = self.patch_size
+ assert l == x.shape[1]
+
+ x = x.reshape(shape=(x.shape[0], l, p, c))
+ x = torch.einsum("ntpc->nctp", x)
+ audio = x.reshape(shape=(x.shape[0], c, l * p))
+ return audio
+
+ def params_count(self):
+ counts = {
+ "triple": sum(
+ [
+ sum(p.numel() for p in block.audio_cross_q.parameters())
+ + sum(p.numel() for p in block.v_cond_cross_q.parameters())
+ + sum(p.numel() for p in block.text_cross_kv.parameters())
+ + sum(p.numel() for p in block.audio_self_attn_qkv.parameters())
+ + sum(p.numel() for p in block.v_cond_attn_qkv.parameters())
+ + sum(p.numel() for p in block.audio_mlp.parameters())
+ + sum(p.numel() for p in block.audio_self_proj.parameters())
+ + sum(p.numel() for p in block.v_cond_self_proj.parameters())
+ + sum(p.numel() for p in block.v_cond_mlp.parameters())
+ for block in self.triple_blocks
+ ]
+ ),
+ "single": sum(
+ [
+ sum(p.numel() for p in block.linear1.parameters())
+ + sum(p.numel() for p in block.linear2.parameters())
+ for block in self.single_blocks
+ ]
+ ),
+ "total": sum(p.numel() for p in self.parameters()),
+ }
+
+ counts["attn+mlp"] = counts["triple"] + counts["single"]
+ return counts
diff --git a/hunyuanvideo_foley/models/nn/__init__.py b/hunyuanvideo_foley/models/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f8417010095d79ce34302eb9fff5e9e212581ebd
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..800467b3f6df786612b9fd6939ecb22c6826b82b
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eed57fb2c6cfbdf03a5d61db7fc7aa10088de804
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a893c3a1f805493fb00bdf772babc4c6e4e5158c
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/activation_layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b1b306642dd45a16ea9844856d4f9afa09dfaa1c
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09c0305aab4fc4ab4a8996a47fe8adfad2fce8af
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/attn_layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bdb875572318d0e47775f39fe919fd54b0c5345
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c0b62dc8897c0f45a0109da50879618f15cae1a
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/embed_layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbc6e29d5f454f5db631ed2151c7c27cd07da296
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e3c84b3ff0b3ef24802bbd48df8f460a3f48735
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/mlp_layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..749d2ee4d023a9c8b9738c5dfd42930974b42398
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d744d5a5fb5e12428be5598e366ab9b80e83e446
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/modulate_layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e51f79ea16c268ee6ecf4af4dd88234ff840bf9a
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6a3e2e5b81e5168797f465940e0f972038d25c7
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/norm_layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-312.pyc b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17aa3dfe33c273d7c84ed4cc34806b26b24c3e9d
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1fd9c2ca472d6f967e3c5bbc6e1f936c283f1548
Binary files /dev/null and b/hunyuanvideo_foley/models/nn/__pycache__/posemb_layers.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/nn/activation_layers.py b/hunyuanvideo_foley/models/nn/activation_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..55414cbd054546263e1217f363e9fe02e846a122
--- /dev/null
+++ b/hunyuanvideo_foley/models/nn/activation_layers.py
@@ -0,0 +1,44 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+def get_activation_layer(act_type):
+ if act_type == "gelu":
+ return lambda: nn.GELU()
+ elif act_type == "gelu_tanh":
+ # Approximate `tanh` requires torch >= 1.13
+ return lambda: nn.GELU(approximate="tanh")
+ elif act_type == "relu":
+ return nn.ReLU
+ elif act_type == "silu":
+ return nn.SiLU
+ else:
+ raise ValueError(f"Unknown activation type: {act_type}")
+
+class SwiGLU(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ out_dim: int,
+ ):
+ """
+ Initialize the SwiGLU FeedForward module.
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+
+ Attributes:
+ w1: Linear transformation for the first layer.
+ w2: Linear transformation for the second layer.
+ w3: Linear transformation for the third layer.
+
+ """
+ super().__init__()
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
diff --git a/hunyuanvideo_foley/models/nn/attn_layers.py b/hunyuanvideo_foley/models/nn/attn_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c954eebb33038369f3c7d71cda4fdd9c3a8d27dd
--- /dev/null
+++ b/hunyuanvideo_foley/models/nn/attn_layers.py
@@ -0,0 +1,546 @@
+import importlib.metadata
+import math
+from typing import Tuple, Union
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+try:
+ from flash_attn import (
+ flash_attn_qkvpacked_func,
+ flash_attn_kvpacked_func,
+ flash_attn_varlen_kvpacked_func,
+ flash_attn_varlen_qkvpacked_func,
+ )
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
+except ImportError:
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
+ index_first_axis = None
+from packaging import version
+from transformers.utils.import_utils import _is_package_available
+
+from .norm_layers import get_norm_layer
+
+def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
+ """
+ Reshape frequency tensor for broadcasting it with another tensor.
+
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+ Notes:
+ When using FlashMHAModified, head_first should be False.
+ When using Attention, head_first should be True.
+
+ Args:
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ torch.Tensor: Reshaped frequency tensor.
+
+ Raises:
+ AssertionError: If the frequency tensor doesn't match the expected shape.
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
+ """
+ ndim = x.ndim
+ assert 0 <= 1 < ndim
+
+ if isinstance(freqs_cis, tuple):
+ # freqs_cis: (cos, sin) in real space
+ if head_first:
+ assert freqs_cis[0].shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ else:
+ assert freqs_cis[0].shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
+ else:
+ # freqs_cis: values in complex space
+ if head_first:
+ assert freqs_cis.shape == (
+ x.shape[-2],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ else:
+ assert freqs_cis.shape == (
+ x.shape[1],
+ x.shape[-1],
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+ return freqs_cis.view(*shape)
+
+
+def rotate_half(x):
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
+ head_first: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+ returned as real tensors.
+
+ Args:
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
+ head_first (bool): head dimension first (except batch dim) or not.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+ """
+ xk_out = None
+ if isinstance(freqs_cis, tuple):
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
+ # real * cos - imag * sin
+ # imag * cos + real * sin
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
+ else:
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
+
+ return xq_out, xk_out
+
+
+class BasicAttentionLayer(nn.Module):
+ def __init__(self, attn_mode="flash", deterministic=False):
+ super().__init__()
+ self.attn_mode = attn_mode
+ self.deterministic = deterministic
+
+ def set_attn_mode(self, new_mode):
+ self.attn_mode = new_mode
+
+ def enable_deterministic(self):
+ self.deterministic = True
+
+ def disable_deterministic(self):
+ self.deterministic = False
+
+
+MEMORY_LAYOUT = {
+ "self_flash": (
+ lambda x: x,
+ lambda x: x,
+ ),
+ "cross_flash": (
+ lambda x: x,
+ lambda x: x,
+ ),
+ "flash_torch_sp": (
+ lambda x: x,
+ lambda x: x,
+ ),
+ "torch": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+ "vanilla": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+}
+
+
+# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
+def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ indices (`torch.Tensor):
+ The indices of non-masked tokens from the flattened input sequence.
+ cu_seqlens (`torch.Tensor`):
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ max_seqlen_in_batch (`int`):
+ Maximum sequence length in batch.
+ """
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+# Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
+def is_flash_attn_greater_or_equal(library_version: str):
+ if not _is_package_available("flash_attn"):
+ return False
+
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
+
+
+def get_kv_seqlens_with_mask(attn_mask, k, v):
+ indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
+ b, s1, a, d = k.shape
+ k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
+ v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
+ kv = torch.stack([k, v], dim=1)
+ return cu_seqlens_k, max_seqlen_k, kv
+
+
+def get_q_seqlens(q):
+ bs, s, a, d = q.shape
+ cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
+ q = q.reshape(bs * s, a, d)
+ return cu_seqlens_q, s, q
+
+def flash_attn_no_pad(
+ qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None
+):
+ # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
+ batch_size = qkv.shape[0]
+ seqlen = qkv.shape[1]
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
+ # x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch
+ # x_unpad, indices, cu_seqlens, max_s
+ unpad_results = unpad_input(
+ x, key_padding_mask
+ )
+
+ if len(unpad_results) == 4:
+ x_unpad, indices, cu_seqlens, max_s = unpad_results
+ elif len(unpad_results) == 5:
+ x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results
+ else:
+ raise ValueError
+
+ x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
+ output_unpad = flash_attn_varlen_qkvpacked_func(
+ x_unpad,
+ cu_seqlens,
+ max_s,
+ dropout_p,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ output = rearrange(
+ pad_input(
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
+ ),
+ "b s (h d) -> b s h d",
+ h=nheads,
+ )
+ return output
+
+
+def attention(
+ q,
+ k,
+ v,
+ mode,
+ drop_rate=0,
+ attn_mask=None,
+ cond_mask=None,
+ causal=False,
+ deterministic=False,
+ cu_seqlens=None,
+ max_seqlen=None,
+ cu_seqlens_k=None,
+ max_seqlen_k=None,
+ img_seq_len=None,
+):
+ """
+ Perform QKV self attention.
+
+ Args:
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
+ drop_rate (float): Dropout rate in attention map. (default: 0)
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
+ (default: None)
+ causal (bool): Whether to use causal attention. (default: False)
+ deterministic (bool): Whether to use deterministic attention. (default: False)
+ cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into q.
+ max_seqlen (int): The maximum sequence length in the batch of q.
+ cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into kv.
+ max_seqlen_k (int): The maximum sequence length in the batch of k and v.
+
+ Returns:
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
+ """
+ if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
+ if isinstance(q, tuple):
+ q = torch.cat(q, dim=1)
+ if isinstance(k, tuple):
+ k = torch.cat(k, dim=1)
+ if isinstance(v, tuple):
+ v = torch.cat(v, dim=1)
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
+ q = pre_attn_layout(q)
+ k = pre_attn_layout(k)
+ v = pre_attn_layout(v)
+
+ if "flash" in mode:
+ assert (
+ flash_attn_qkvpacked_func is not None
+ ), "Flash attention is not available. Please install flash_attn first."
+ flash_kwargs = dict(dropout_p=drop_rate, causal=causal)
+ if deterministic:
+ if not is_flash_attn_greater_or_equal("2.4.1"):
+ raise ValueError(
+ "Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn"
+ )
+ flash_kwargs["deterministic"] = deterministic
+
+ if mode == "self_flash":
+ qkv = torch.stack([q, k, v], dim=2)
+ if attn_mask is not None:
+ raise ValueError("Self attention does not support attention mask")
+ x = flash_attn_qkvpacked_func(qkv, **flash_kwargs)
+
+ elif mode == "cross_flash":
+ kv = torch.stack([k, v], dim=2)
+ if attn_mask is None:
+ x = flash_attn_kvpacked_func(q, kv, **flash_kwargs)
+ else:
+ b, s, a, h = q.shape
+ cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q)
+ cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v)
+
+ attn_output = flash_attn_varlen_kvpacked_func(
+ q,
+ kv,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ **flash_kwargs,
+ )
+ x = attn_output.reshape(b, s, a, h)
+ elif mode == 'torch':
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(q.dtype)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
+
+ elif mode == "vanilla":
+ scale_factor = 1 / math.sqrt(q.size(-1))
+
+ b, a, s, _ = q.shape
+ s1 = k.size(2)
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
+ if causal:
+ # Only applied to self attention
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+ attn_bias.to(q.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ # TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
+ attn += attn_bias
+ attn = attn.softmax(dim=-1)
+ attn = torch.dropout(attn, p=drop_rate, train=True)
+ x = attn @ v
+ else:
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
+
+ if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
+ x = post_attn_layout(x).contiguous()
+ b, s, a, d = x.shape
+ out = x.reshape(b, s, -1)
+ return out
+
+
+class SelfAttentionLayer(BasicAttentionLayer):
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=True,
+ attn_drop=0,
+ proj_drop=0,
+ dtype=None,
+ device=None,
+ norm_type="layer",
+ attn_mode="self_flash",
+ deterministic=False,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__(attn_mode, deterministic)
+ self.dim = dim
+ self.num_heads = num_heads
+ assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
+ self.head_dim = self.dim // num_heads
+ self.attn_drop = attn_drop
+
+ # This assertion is aligned with flash attention
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+
+ self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
+
+ norm_layer = get_norm_layer(norm_type)
+ self.q_norm = (
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.k_norm = (
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, freqs_cis=None, attn_mask=None):
+ """
+ Args:
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
+ freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
+ attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
+ """
+ b, s, d = x.shape
+
+ # Apply QKV projection
+ qkv = self.Wqkv(x)
+ qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d]
+ q, k, v = qkv.unbind(dim=2) # [b, s, a, d]
+
+ # Apply QK-Norm if needed
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ # Apply RoPE if needed
+ if freqs_cis is not None:
+ qq, kk = apply_rotary_emb(q, k, freqs_cis)
+ assert (
+ qq.shape == q.shape and kk.shape == k.shape
+ ), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
+ q, k = qq, kk
+
+ # Apply self attention
+ context = attention(
+ q,
+ k,
+ v,
+ drop_rate=self.attn_drop if self.training else 0,
+ attn_mask=attn_mask,
+ mode=self.attn_mode,
+ deterministic=self.deterministic,
+ )
+ out = self.out_proj(context)
+ out = self.proj_drop(out)
+
+ return out
+
+
+class CrossAttentionLayer(BasicAttentionLayer):
+ def __init__(
+ self,
+ qdim,
+ kdim,
+ num_heads,
+ qkv_bias=True,
+ qk_norm=True,
+ attn_drop=0,
+ proj_drop=0,
+ dtype=None,
+ device=None,
+ norm_type="layer",
+ attn_mode="cross_flash",
+ deterministic=False,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__(attn_mode, deterministic)
+ self.qdim = qdim
+ self.kdim = kdim
+ self.num_heads = num_heads
+ assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
+ self.head_dim = self.qdim // num_heads
+ self.attn_drop = attn_drop
+
+ # This assertion is aligned with flash attention
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
+
+ norm_layer = get_norm_layer(norm_type)
+ self.q_norm = (
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.k_norm = (
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, y, attn_mask=None):
+ """
+ Args:
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
+ y (torch.Tensor): (batch, seq_len1, hidden_dim1)
+ attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
+ """
+ b, s, d = x.shape
+ _, s1, d1 = y.shape
+
+ q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
+ kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
+ k, v = kv.unbind(dim=2)
+
+ # Apply QK-Norm if needed
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ # Apply cross attention
+ context = attention(
+ q,
+ k,
+ v,
+ attn_mask=attn_mask,
+ drop_rate=self.attn_drop if self.training else 0,
+ mode=self.attn_mode,
+ deterministic=self.deterministic,
+ )
+ out = self.out_proj(context)
+ out = self.proj_drop(out)
+
+ return out
diff --git a/hunyuanvideo_foley/models/nn/embed_layers.py b/hunyuanvideo_foley/models/nn/embed_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd15167836fd5a5aec7b0c21af296082d7d1b2f4
--- /dev/null
+++ b/hunyuanvideo_foley/models/nn/embed_layers.py
@@ -0,0 +1,136 @@
+import math
+import torch
+import torch.nn as nn
+
+from ...utils.helper import to_2tuple, to_1tuple
+
+class PatchEmbed1D(nn.Module):
+ """1D Audio to Patch Embedding
+
+ A convolution based approach to patchifying a 1D audio w/ embedding projection.
+
+ Based on the impl in https://github.com/google-research/vision_transformer
+
+ Hacked together by / Copyright 2020 Ross Wightman
+ """
+
+ def __init__(
+ self,
+ patch_size=1,
+ in_chans=768,
+ embed_dim=768,
+ norm_layer=None,
+ flatten=True,
+ bias=True,
+ dtype=None,
+ device=None,
+ ):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ patch_size = to_1tuple(patch_size)
+ self.patch_size = patch_size
+ self.flatten = flatten
+
+ self.proj = nn.Conv1d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
+ )
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
+ if bias:
+ nn.init.zeros_(self.proj.bias)
+
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ assert (
+ x.shape[2] % self.patch_size[0] == 0
+ ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
+
+ x = self.proj(x)
+ if self.flatten:
+ x = x.transpose(1, 2) # BCN -> BNC
+ x = self.norm(x)
+ return x
+
+
+class ConditionProjection(nn.Module):
+ """
+ Projects condition embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
+ factory_kwargs = {'dtype': dtype, 'device': device}
+ super().__init__()
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
+ self.act_1 = act_layer()
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+def timestep_embedding(t, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ Args:
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ dim (int): the dimension of the output.
+ max_period (int): controls the minimum frequency of the embeddings.
+
+ Returns:
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
+
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ return embedding
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self,
+ hidden_size,
+ act_layer,
+ frequency_embedding_size=256,
+ max_period=10000,
+ out_size=None,
+ dtype=None,
+ device=None
+ ):
+ factory_kwargs = {'dtype': dtype, 'device': device}
+ super().__init__()
+ self.frequency_embedding_size = frequency_embedding_size
+ self.max_period = max_period
+ if out_size is None:
+ out_size = hidden_size
+
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
+ act_layer(),
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
+ )
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
+
+ def forward(self, t):
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
diff --git a/hunyuanvideo_foley/models/nn/mlp_layers.py b/hunyuanvideo_foley/models/nn/mlp_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0434559025e280b74fbd7c181247aa1c2a41a409
--- /dev/null
+++ b/hunyuanvideo_foley/models/nn/mlp_layers.py
@@ -0,0 +1,149 @@
+# Modified from timm library:
+# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .modulate_layers import modulate
+from ...utils.helper import to_2tuple
+
+class MLP(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_channels,
+ hidden_channels=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ device=None,
+ dtype=None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ out_features = out_features or in_channels
+ hidden_channels = hidden_channels or in_channels
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
+ self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.norm(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+# copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
+# only used when use_vanilla is True
+class MLPEmbedder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class LinearWarpforSingle(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
+
+ def forward(self, x, y):
+ z = torch.cat([x, y], dim=2)
+ return self.fc(z)
+
+class FinalLayer1D(nn.Module):
+ def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ # Just use LayerNorm for the final layer
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ # Here we don't distinguish between the modulate types. Just use the simple one.
+ self.adaLN_modulation = nn.Sequential(
+ act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
+ )
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
+ x = self.linear(x)
+ return x
+
+
+class ChannelLastConv1d(nn.Conv1d):
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.permute(0, 2, 1)
+ x = super().forward(x)
+ x = x.permute(0, 2, 1)
+ return x
+
+
+class ConvMLP(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int = 256,
+ kernel_size: int = 3,
+ padding: int = 1,
+ device=None,
+ dtype=None,
+ ):
+ """
+ Convolutional MLP module.
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+
+ Attributes:
+ w1: Linear transformation for the first layer.
+ w2: Linear transformation for the second layer.
+ w3: Linear transformation for the third layer.
+
+ """
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
+ self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
+ self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
diff --git a/hunyuanvideo_foley/models/nn/modulate_layers.py b/hunyuanvideo_foley/models/nn/modulate_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..5235233e996fc17c0476de04428a13bfcf3ba8fe
--- /dev/null
+++ b/hunyuanvideo_foley/models/nn/modulate_layers.py
@@ -0,0 +1,49 @@
+from typing import Callable
+import torch
+import torch.nn as nn
+
+class ModulateDiT(nn.Module):
+ def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.act = act_layer()
+ self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
+ # Zero-initialize the modulation
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.linear(self.act(x))
+
+
+def modulate(x, shift=None, scale=None):
+ if x.ndim == 3:
+ shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
+ scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
+ if scale is None and shift is None:
+ return x
+ elif shift is None:
+ return x * (1 + scale)
+ elif scale is None:
+ return x + shift
+ else:
+ return x * (1 + scale) + shift
+
+
+def apply_gate(x, gate=None, tanh=False):
+ if gate is None:
+ return x
+ if gate.ndim == 2 and x.ndim == 3:
+ gate = gate.unsqueeze(1)
+ if tanh:
+ return x * gate.tanh()
+ else:
+ return x * gate
+
+
+def ckpt_wrapper(module):
+ def ckpt_forward(*inputs):
+ outputs = module(*inputs)
+ return outputs
+
+ return ckpt_forward
diff --git a/hunyuanvideo_foley/models/nn/norm_layers.py b/hunyuanvideo_foley/models/nn/norm_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ad30b0ea4faeaa18e22ed25fcc44b97aee2d243
--- /dev/null
+++ b/hunyuanvideo_foley/models/nn/norm_layers.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
+ device=None, dtype=None):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.eps = eps
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ if hasattr(self, "weight"):
+ output = output * self.weight
+ return output
+
+
+def get_norm_layer(norm_layer):
+ """
+ Get the normalization layer.
+
+ Args:
+ norm_layer (str): The type of normalization layer.
+
+ Returns:
+ norm_layer (nn.Module): The normalization layer.
+ """
+ if norm_layer == "layer":
+ return nn.LayerNorm
+ elif norm_layer == "rms":
+ return RMSNorm
+ else:
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
diff --git a/hunyuanvideo_foley/models/nn/posemb_layers.py b/hunyuanvideo_foley/models/nn/posemb_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbd188bbeb710d8155e758590446a51f5f0dd038
--- /dev/null
+++ b/hunyuanvideo_foley/models/nn/posemb_layers.py
@@ -0,0 +1,159 @@
+import torch
+from typing import Union, Tuple
+
+
+def _to_tuple(x, dim=2):
+ if isinstance(x, int):
+ return (x,) * dim
+ elif len(x) == dim:
+ return x
+ else:
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
+
+
+def get_meshgrid_nd(start, *args, dim=2):
+ """
+ Get n-D meshgrid with start, stop and num.
+
+ Args:
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
+ n-tuples.
+ *args: See above.
+ dim (int): Dimension of the meshgrid. Defaults to 2.
+
+ Returns:
+ grid (np.ndarray): [dim, ...]
+ """
+ if len(args) == 0:
+ # start is grid_size
+ num = _to_tuple(start, dim=dim)
+ start = (0,) * dim
+ stop = num
+ elif len(args) == 1:
+ # start is start, args[0] is stop, step is 1
+ start = _to_tuple(start, dim=dim)
+ stop = _to_tuple(args[0], dim=dim)
+ num = [stop[i] - start[i] for i in range(dim)]
+ elif len(args) == 2:
+ # start is start, args[0] is stop, args[1] is num
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
+ else:
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
+
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
+ axis_grid = []
+ for i in range(dim):
+ a, b, n = start[i], stop[i], num[i]
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
+ axis_grid.append(g)
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
+
+ return grid
+
+
+#################################################################################
+# Rotary Positional Embedding Functions #
+#################################################################################
+# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
+
+
+def get_nd_rotary_pos_embed(
+ rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0
+):
+ """
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
+
+ Args:
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
+ sum(rope_dim_list) should equal to head_dim of attention layer.
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
+ *args: See above.
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
+ part and an imaginary part separately.
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
+ freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
+
+ Returns:
+ pos_embed (torch.Tensor): [HW, D/2]
+ """
+
+ grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
+
+ # use 1/ndim of dimensions to encode grid_axis
+ embs = []
+ for i in range(len(rope_dim_list)):
+ emb = get_1d_rotary_pos_embed(
+ rope_dim_list[i],
+ grid[i].reshape(-1),
+ theta,
+ use_real=use_real,
+ theta_rescale_factor=theta_rescale_factor,
+ freq_scaling=freq_scaling,
+ ) # 2 x [WHD, rope_dim_list[i]]
+ embs.append(emb)
+
+ if use_real:
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
+ return cos, sin
+ else:
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
+ return emb
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[torch.FloatTensor, int],
+ theta: float = 10000.0,
+ use_real: bool = False,
+ theta_rescale_factor: float = 1.0,
+ freq_scaling: float = 1.0,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
+
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
+ The returned tensor contains complex values in complex64 data type.
+
+ Args:
+ dim (int): Dimension of the frequency tensor.
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (bool, optional): If True, return real part and imaginary part separately.
+ Otherwise, return complex numbers.
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
+ freq_scaling (float, optional): Frequence rescale factor, which is proposed in mmaudio. Defaults to 1.0.
+
+ Returns:
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
+ """
+ if isinstance(pos, int):
+ pos = torch.arange(pos).float()
+
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
+ # has some connection to NTK literature
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ if theta_rescale_factor != 1.0:
+ theta *= theta_rescale_factor ** (dim / (dim - 1))
+
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
+ freqs *= freq_scaling
+ freqs = torch.outer(pos, freqs) # [S, D/2]
+ if use_real:
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
diff --git a/hunyuanvideo_foley/models/synchformer/__init__.py b/hunyuanvideo_foley/models/synchformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6331d467e30b9d24378025bf540e7430ff4fd7ad
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/__init__.py
@@ -0,0 +1 @@
+from .synchformer import Synchformer
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6bd2373949f7c75e11b201c7d55290646c009a4
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67c8b840aecff1da6b19f8596177ba473b6a041f
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65d041544a06271d266d7f4e970d8ef4e156177b
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92bf67679e34b8680152e2af4365c7d7ade7ebeb
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/ast_model.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff5ac7c20466981dfaec48b85a759a3c7bb7502f
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..157972264fdc52bbbe76238af47d18c672b4581b
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/modeling_ast.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f7d92890865df12a0201e9493dd93214ebd30d1
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e44fd9e5a05ac0c8cba93a4700da7510b24a5a28
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/motionformer.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..796bea889521a94ceaca7536b94cdcb170534995
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c48d32ad2f7984d0318be84f7c102bf8941d062
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/synchformer.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1abeff774ab246a1a6f6a5525fef4319b291c3c8
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed1af3cab9fddcfda9cc52df57ec2a7be5169ce7
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/utils.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1e39b538331705e76298b67b2ed12db7e213108
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..384881a3362ceb67899ff3617b75b90579be0332
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/video_model_builder.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-312.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b6ef2042a3ddd8efa942e4ffd9d16c6dcc7ccd5c
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f09aeec0fc09976c2e7027d6c695656c3e80fbf
Binary files /dev/null and b/hunyuanvideo_foley/models/synchformer/__pycache__/vit_helper.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/models/synchformer/ast_model.py b/hunyuanvideo_foley/models/synchformer/ast_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f4394306ccd08de3a1e6bb556df8f42d2e4cacb
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/ast_model.py
@@ -0,0 +1,289 @@
+import logging
+
+import torch
+from transformers.modeling_outputs import BaseModelOutputWithPooling
+
+from .modeling_ast import ASTForAudioClassification, ASTConfig
+from .motionformer import AveragePooling, BaseEncoderLayer, TemporalTransformerEncoderLayer
+from .utils import check_if_file_exists_else_download
+
+
+class AST(torch.nn.Module):
+ def __init__(
+ self,
+ extract_features: bool = False,
+ ckpt_path: str = None,
+ feat_type: str = None,
+ max_spec_t: int = None,
+ factorize_freq_time: bool = None,
+ agg_freq_module: str = None,
+ agg_time_module: str = None,
+ add_global_repr: bool = True,
+ agg_segments_module: str = None,
+ max_segments: int = None,
+ ) -> None:
+ """
+ extract_features: if True, then the model will return the features instead of head's output
+ ckpt_path: is not a path to a ckpt file, but a name of a model from the HuggingFace model hub.
+ feat_type: if extract_features is True, this parameter specifies the type of features to return
+ max_spec_t: if specified, then the model (pos emb) will be patched to support this length of spec
+ factorize_freq_time: if True, then the model will use a factorized freq/time aggregation
+ agg_freq_module: if specified, then the model will use this module for freq aggregation
+ agg_time_module: if specified, then the model will use this module for time aggregation
+ add_global_repr: if True, adds a global representation to the features (aggregation on segments)
+ agg_segments_module: if specified, then the model will use this module for segments aggregation
+ max_segments: if specified, the initialization of PE in the global agg module will use this value.
+ This should correspond to the max number of segments per video (if None, 16 is used)
+ """
+ super().__init__()
+ self.extract_features = extract_features
+ self.ckpt_path = ckpt_path
+ self.max_spec_t = max_spec_t
+ self.max_segments = max_segments
+
+ # depending on whether the feat extractor was pre-trained contrastively or not, we need to
+ # load the state dict differently.
+
+ # if ckpt is specified, then load the model from the HuggingFace model hub, otherwise init a new model
+ if ckpt_path == "MIT/ast-finetuned-audioset-10-10-0.4593":
+ revision = "c1c0c66" # fixing the revision for compatibility (V4.27.4)
+ self.config = ASTConfig.from_pretrained(ckpt_path, revision=revision)
+ full_model = ASTForAudioClassification.from_pretrained(ckpt_path, revision=revision)
+ logging.info(f"Loaded AST from {ckpt_path}")
+ else:
+ self.config = ASTConfig()
+ self.config.num_labels = 527 # 2 by default, audioset has 527 labels
+ full_model = ASTForAudioClassification(self.config)
+ logging.info("Initialized AST from scratch with the AST AudioSet config")
+
+ was_pt_on_avclip = ckpt_path is not None and ckpt_path.endswith(".pt")
+
+ # feature extractor
+ self.ast = full_model.audio_spectrogram_transformer
+
+ if self.extract_features:
+ # assign `feat_type` (use default if not specified)
+ self.feat_type = "last_hidden_state" if feat_type is None else feat_type
+ # define adapters if needed
+ self.factorize_freq_time = factorize_freq_time
+ # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
+ transf_enc_layer_kwargs = dict(
+ d_model=self.config.hidden_size,
+ nhead=self.config.num_attention_heads,
+ dim_feedforward=self.config.intermediate_size,
+ activation=torch.nn.GELU(),
+ batch_first=True,
+ dropout=self.config.attention_probs_dropout_prob,
+ layer_norm_eps=1e-6,
+ norm_first=True,
+ )
+ if factorize_freq_time:
+ self.feat_type = "last_hidden_state" # this feat_type supports factorization
+ # frequency aggreration
+ if agg_freq_module == "TransformerEncoderLayer":
+ self.freq_attn_agg = FrequencyTransformerEncoderLayer(**transf_enc_layer_kwargs)
+ elif agg_freq_module == "AveragePooling":
+ self.freq_attn_agg = AveragePooling(
+ avg_pattern="BS D f t -> BS D t", then_permute_pattern="BS D t -> BS t D"
+ )
+ # time aggreration
+ if agg_time_module == "TransformerEncoderLayer":
+ self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
+ elif agg_time_module == "AveragePooling":
+ self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D")
+ elif "Identity" in agg_time_module:
+ self.temp_attn_agg = torch.nn.Identity()
+ # define a global aggregation layer (aggregarate over segments)
+ self.add_global_repr = add_global_repr
+ if add_global_repr:
+ if agg_segments_module == "TransformerEncoderLayer":
+ # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
+ # we need to add pos emb (PE) because previously we added the same PE for each segment
+ pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
+ self.global_attn_agg = TemporalTransformerEncoderLayer(
+ add_pos_emb=True,
+ pos_emb_drop=self.config.hidden_dropout_prob,
+ pos_max_len=pos_max_len,
+ **transf_enc_layer_kwargs,
+ )
+ elif agg_segments_module == "AveragePooling":
+ self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D")
+ else:
+ self.classifier = full_model.classifier
+
+ # AST.device fails with AttributeError. This is a workaround
+ self.device = full_model.device
+
+ # pre-trained on 12*101+2=1214 tokens, but we have less (e.g. 12*6+2=74)
+ self.patch_position_emb()
+
+ if was_pt_on_avclip:
+ # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
+ # and keep only the state_dict of the feat extractor
+ check_if_file_exists_else_download(self.ckpt_path)
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+ ckpt_weights = dict()
+ for k, v in ckpt["state_dict"].items():
+ if k.startswith(("module.a_encoder.", "a_encoder.")):
+ k = k.replace("module.", "").replace("a_encoder.", "")
+ ckpt_weights[k] = v
+ _load_status = self.load_state_dict(ckpt_weights, strict=False)
+ if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
+ logging.warning(
+ f"Loading exact afeat_extractor ckpt from {self.ckpt_path} failed. \n"
+ f"Missing keys ({len(_load_status.missing_keys)}): "
+ f"{_load_status.missing_keys}, \n"
+ f"Unexpected keys ({len(_load_status.unexpected_keys)}): "
+ f"{_load_status.unexpected_keys} \n"
+ f"temp_attn_agg are expected to be missing if ckpt was pt contrastively."
+ )
+ else:
+ logging.info(f"Loading afeat_extractor ckpt from {self.ckpt_path} succeeded.")
+
+ # print the number of parameters
+ logging.info(f"AST: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
+
+ def forward(
+ self, x: torch.Tensor, for_loop: bool = False, cont_mask: torch.Tensor = None, **ast_kwargs
+ ) -> torch.Tensor:
+ """
+ x: (B, S, T, F) where S is number of segments, F is number of (mel) frequency bins,
+ ast_kwargs: additional arguments for the AST model
+ cont_mask: (B, S, T, F) where 0s are the values to be masked out
+ if `for_loop=True`, we use a for loop to extract features for each segment separately.
+ if `for_loop=False`, we extract features for all segments at once.
+ Using the for loop is slower but more memory efficient, while using all segments at once
+ is faster but more memory inefficient.
+ Using for loop allows to control the memory footprint by varying the number of videos in a
+ batch (batch size) rather than the number of segments in a video.
+ """
+ B, S, T, F = x.shape
+
+ if for_loop:
+ assert cont_mask is None, "cont_mask is not supported with for_loop=True"
+ orig_shape_s = (B, 1, T, F)
+ # NOTE: since x is (B, S, T, F), and forward_segments expects (BS, T, F).
+ # (B, S, T, F)[:, s] is (B, T, F) or (BS, T, F) if S=1.
+ x = torch.cat(
+ [self.forward_segments(x[:, s], orig_shape_s, **ast_kwargs).unsqueeze(1) for s in range(S)], dim=1
+ )
+ else:
+ orig_shape = (B, S, T, F)
+ x = x.view(B * S, T, F)
+ if cont_mask is not None:
+ cont_mask = cont_mask.reshape(B * S, T, F)
+ # AST expects a tensor of shape (B*S, T, F).
+ x = self.forward_segments(x, orig_shape=orig_shape, cont_mask=cont_mask, **ast_kwargs)
+ # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
+ x = x.view(B, S, *x.shape[1:])
+ # x now is of shape (B, S, D) or (B, S, t, D) if `self.temp_attn_agg` is `Identity`
+
+ global_x = None
+ if self.extract_features and self.add_global_repr: # lazy execution, throws AttributeError
+ assert len(x.shape) == 3, f"Local representation should be (B, S, D) {x.shape}"
+ global_x = self.global_attn_agg(x) # (B, D)
+
+ return x, global_x # x is (B, S, ...), global_x is (B, D) or None
+
+ def forward_segments(self, x, orig_shape: tuple, cont_mask: torch.Tensor = None, **ast_kwargs):
+ """x is (BS, T, F), where S is the number of segments; cont_mask is (BS, T, F): 0s to be masked out"""
+ # 'pooler_output': (B, D); or 'last_hidden_state: (B, T, D) where T is [CLS, DISTILL, ]
+ # x_mask is (B, T) where 0s are the values to be masked out
+ x, x_mask = self.ast(x, cont_mask=cont_mask, **ast_kwargs)
+
+ if self.extract_features:
+ x = self.get_features_by_type(x)
+ if self.factorize_freq_time:
+ x = self.restore_freq_temp_dims(x, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
+ if cont_mask is not None:
+ # duplicating the mask for the latent dimension (D) to be compatible with the next func
+ x_mask = x_mask.unsqueeze(-1).expand(-1, -1, self.config.hidden_size)
+ x_mask = self.restore_freq_temp_dims(x_mask, orig_shape) # (BS, D, f, t) <- (B*S, T, D)
+ # again removing the latent
+ x_mask = x_mask[:, 0, :, :]
+ else:
+ x_mask = None
+ x = self.freq_attn_agg(x, x_mask) # (BS, t, D)
+ x = self.temp_attn_agg(x) # (BS, D) or (BS, t, D) if self.temp_attn_agg is Identity
+ else:
+ x = x["pooler_output"]
+ x = self.classifier(x)
+ return x
+
+ def get_features_by_type(self, x: BaseModelOutputWithPooling) -> torch.Tensor:
+ if self.feat_type == "pooler_output":
+ return x["pooler_output"] # (B, D)
+ elif self.feat_type == "CLS":
+ return x["last_hidden_state"][:, 0, :] # (B, D)
+ elif self.feat_type == "last_hidden_state":
+ return x["last_hidden_state"] # (B, 2+T, D)
+ elif self.feat_type == "last_hidden_state_no_AUX":
+ return x["last_hidden_state"][:, 2:, :] # (B, T, D) removing CLS and distill tokens
+ else:
+ raise ValueError(f"Unknown feature type: {self.feat_type}")
+
+ def restore_freq_temp_dims(self, feats, orig_shape: tuple):
+ """
+ feats are of shape (B*S, T, D)
+ where T = 2 + f * t (if feat_type == 'last_hidden_state')
+ where T = f * t (if feat_type == 'last_hidden_state_no_AUX')
+ Our goal is to make them of shape (B*S, f, t, D) where f and t are dimensions after patching.
+ From `self.ast.embeddings.patch_embeddings`, it follows that we could reshape feats:
+ `feats.transpose(1, 2).view(B*S, D, f, t)`
+
+ (Similar function is defined in for RGB features in `motionformer.py`)
+ """
+ B, S, T, F = orig_shape
+ D = self.config.hidden_size
+
+ # num patches in each dimension
+ f, t = self.ast.embeddings.get_shape(self.config)
+
+ if self.feat_type == "last_hidden_state":
+ feats = feats[:, 2:, :] # removing CLS and distill tokens
+
+ feats = feats.permute(0, 2, 1) # (B*S, D, T)
+ feats = feats.view(B * S, D, f, t) # (B*S, D, f, t)
+
+ return feats
+
+ def patch_position_emb(self):
+ if self.max_spec_t is not None:
+ self.config.max_length = self.max_spec_t
+ f, t = self.ast.embeddings.get_shape(self.config)
+ shortened = self.ast.embeddings.position_embeddings[:, : f * t + 2].clone() # +2 for CLS and distill tokens
+ self.ast.embeddings.position_embeddings = torch.nn.Parameter(shortened).to(self.device)
+
+ def to(self, device):
+ """AST.device fails with AttributeError. This is a workaround."""
+ self.device = torch.device(device)
+ return super().to(device)
+
+
+class FrequencyTransformerEncoderLayer(BaseEncoderLayer):
+ """This layer is used to aggregate the features along the frequency axis.
+ It follows the same logic as spatio-temporal aggregation in visual feature extractor.
+ Thus, it is recommended to check the definition of `BaseEncoderLayer` in `motionformer.py`"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
+ """x: (B*S, D, f, t); if specified x_mask (B*S, f, t), 0s are the values to be masked out"""
+ BS, D, f, t = x.shape
+
+ # time as a batch dimension
+ x = x.permute(0, 3, 2, 1) # (B*S, t, f, D)
+ x = x.reshape(BS * t, f, D) # .view() fails with non-contiguous memory
+ # similar to mask
+ if x_mask is not None:
+ x_mask = x_mask.permute(0, 2, 1) # (B*S, t, f)
+ x_mask = x_mask.reshape(BS * t, f)
+
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
+ x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
+
+ # reshape back to (B*S, t, D)
+ x = x.view(BS, t, D)
+
+ return x # (B*S, t, D)
diff --git a/hunyuanvideo_foley/models/synchformer/compute_desync_score.py b/hunyuanvideo_foley/models/synchformer/compute_desync_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..936c4994bd6d68444980959f78aa710e4ba5a205
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/compute_desync_score.py
@@ -0,0 +1,214 @@
+import argparse
+import subprocess
+from pathlib import Path
+
+import torch
+import torchaudio
+import torchvision
+from omegaconf import OmegaConf
+
+import data_transforms
+from .synchformer import Synchformer
+from .data_transforms import make_class_grid, quantize_offset
+from .utils import check_if_file_exists_else_download, which_ffmpeg
+
+
+def prepare_inputs(batch, device):
+ aud = batch["audio"].to(device)
+ vid = batch["video"].to(device)
+
+ return aud, vid
+
+
+def get_test_transforms():
+ ts = [
+ data_transforms.EqualifyFromRight(),
+ data_transforms.RGBSpatialCrop(input_size=224, is_random=False),
+ data_transforms.TemporalCropAndOffset(
+ crop_len_sec=5,
+ max_off_sec=2, # https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
+ max_wiggle_sec=0.0,
+ do_offset=True,
+ offset_type="grid",
+ prob_oos="null",
+ grid_size=21,
+ segment_size_vframes=16,
+ n_segments=14,
+ step_size_seg=0.5,
+ vfps=25,
+ ),
+ data_transforms.GenerateMultipleSegments(
+ segment_size_vframes=16,
+ n_segments=14,
+ is_start_random=False,
+ step_size_seg=0.5,
+ ),
+ data_transforms.RGBToHalfToZeroOne(),
+ data_transforms.RGBNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # motionformer normalization
+ data_transforms.AudioMelSpectrogram(
+ sample_rate=16000,
+ win_length=400, # 25 ms * 16 kHz
+ hop_length=160, # 10 ms * 16 kHz
+ n_fft=1024, # 2^(ceil(log2(window_size * sampling_rate)))
+ n_mels=128, # as in AST
+ ),
+ data_transforms.AudioLog(),
+ data_transforms.PadOrTruncate(max_spec_t=66),
+ data_transforms.AudioNormalizeAST(mean=-4.2677393, std=4.5689974), # AST, pre-trained on AudioSet
+ data_transforms.PermuteStreams(
+ einops_order_audio="S F T -> S 1 F T", einops_order_rgb="S T C H W -> S T C H W" # same
+ ),
+ ]
+ transforms = torchvision.transforms.Compose(ts)
+
+ return transforms
+
+
+def get_video_and_audio(path, get_meta=False, start_sec=0, end_sec=None):
+ orig_path = path
+ # (Tv, 3, H, W) [0, 255, uint8]; (Ca, Ta)
+ rgb, audio, meta = torchvision.io.read_video(str(path), start_sec, end_sec, "sec", output_format="TCHW")
+ assert meta["video_fps"], f"No video fps for {orig_path}"
+ # (Ta) <- (Ca, Ta)
+ audio = audio.mean(dim=0)
+ # FIXME: this is legacy format of `meta` as it used to be loaded by VideoReader.
+ meta = {
+ "video": {"fps": [meta["video_fps"]]},
+ "audio": {"framerate": [meta["audio_fps"]]},
+ }
+ return rgb, audio, meta
+
+
+def reencode_video(path, vfps=25, afps=16000, in_size=256):
+ assert which_ffmpeg() != "", "Is ffmpeg installed? Check if the conda environment is activated."
+ new_path = Path.cwd() / "vis" / f"{Path(path).stem}_{vfps}fps_{in_size}side_{afps}hz.mp4"
+ new_path.parent.mkdir(exist_ok=True)
+ new_path = str(new_path)
+ cmd = f"{which_ffmpeg()}"
+ # no info/error printing
+ cmd += " -hide_banner -loglevel panic"
+ cmd += f" -y -i {path}"
+ # 1) change fps, 2) resize: min(H,W)=MIN_SIDE (vertical vids are supported), 3) change audio framerate
+ cmd += f" -vf fps={vfps},scale=iw*{in_size}/'min(iw,ih)':ih*{in_size}/'min(iw,ih)',crop='trunc(iw/2)'*2:'trunc(ih/2)'*2"
+ cmd += f" -ar {afps}"
+ cmd += f" {new_path}"
+ subprocess.call(cmd.split())
+ cmd = f"{which_ffmpeg()}"
+ cmd += " -hide_banner -loglevel panic"
+ cmd += f" -y -i {new_path}"
+ cmd += f" -acodec pcm_s16le -ac 1"
+ cmd += f' {new_path.replace(".mp4", ".wav")}'
+ subprocess.call(cmd.split())
+ return new_path
+
+
+def decode_single_video_prediction(off_logits, grid, item):
+ label = item["targets"]["offset_label"].item()
+ print("Ground Truth offset (sec):", f"{label:.2f} ({quantize_offset(grid, label)[-1].item()})")
+ print()
+ print("Prediction Results:")
+ off_probs = torch.softmax(off_logits, dim=-1)
+ k = min(off_probs.shape[-1], 5)
+ topk_logits, topk_preds = torch.topk(off_logits, k)
+ # remove batch dimension
+ assert len(topk_logits) == 1, "batch is larger than 1"
+ topk_logits = topk_logits[0]
+ topk_preds = topk_preds[0]
+ off_logits = off_logits[0]
+ off_probs = off_probs[0]
+ for target_hat in topk_preds:
+ print(f'p={off_probs[target_hat]:.4f} ({off_logits[target_hat]:.4f}), "{grid[target_hat]:.2f}" ({target_hat})')
+ return off_probs
+
+
+def main(args):
+ vfps = 25
+ afps = 16000
+ in_size = 256
+ # making the offset class grid similar to the one used in transforms,
+ # refer to the used one: https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml
+ max_off_sec = 2
+ num_cls = 21
+
+ # checking if the provided video has the correct frame rates
+ print(f"Using video: {args.vid_path}")
+ v, _, info = torchvision.io.read_video(args.vid_path, pts_unit="sec")
+ _, H, W, _ = v.shape
+ if info["video_fps"] != vfps or info["audio_fps"] != afps or min(H, W) != in_size:
+ print(f'Reencoding. vfps: {info["video_fps"]} -> {vfps};', end=" ")
+ print(f'afps: {info["audio_fps"]} -> {afps};', end=" ")
+ print(f"{(H, W)} -> min(H, W)={in_size}")
+ args.vid_path = reencode_video(args.vid_path, vfps, afps, in_size)
+ else:
+ print(f'Skipping reencoding. vfps: {info["video_fps"]}; afps: {info["audio_fps"]}; min(H, W)={in_size}')
+
+ device = torch.device(args.device)
+
+ # load visual and audio streams
+ # rgb: (Tv, 3, H, W) in [0, 225], audio: (Ta,) in [-1, 1]
+ rgb, audio, meta = get_video_and_audio(args.vid_path, get_meta=True)
+
+ # making an item (dict) to apply transformations
+ # NOTE: here is how it works:
+ # For instance, if the model is trained on 5sec clips, the provided video is 9sec, and `v_start_i_sec=1.3`
+ # the transform will crop out a 5sec-clip from 1.3 to 6.3 seconds and shift the start of the audio
+ # track by `args.offset_sec` seconds. It means that if `offset_sec` > 0, the audio will
+ # start by `offset_sec` earlier than the rgb track.
+ # It is a good idea to use something in [-`max_off_sec`, `max_off_sec`] (-2, +2) seconds (see `grid`)
+ item = dict(
+ video=rgb,
+ audio=audio,
+ meta=meta,
+ path=args.vid_path,
+ split="test",
+ targets={
+ "v_start_i_sec": args.v_start_i_sec,
+ "offset_sec": args.offset_sec,
+ },
+ )
+
+ grid = make_class_grid(-max_off_sec, max_off_sec, num_cls)
+ if not (min(grid) <= item["targets"]["offset_sec"] <= max(grid)):
+ print(f'WARNING: offset_sec={item["targets"]["offset_sec"]} is outside the trained grid: {grid}')
+
+ # applying the test-time transform
+ item = get_test_transforms()(item)
+
+ # prepare inputs for inference
+ batch = torch.utils.data.default_collate([item])
+ aud, vid = prepare_inputs(batch, device)
+
+ # TODO:
+ # sanity check: we will take the input to the `model` and recontruct make a video from it.
+ # Use this check to make sure the input makes sense (audio should be ok but shifted as you specified)
+ # reconstruct_video_from_input(aud, vid, batch['meta'], args.vid_path, args.v_start_i_sec, args.offset_sec,
+ # vfps, afps)
+
+ # forward pass
+ with torch.set_grad_enabled(False):
+ with torch.autocast("cuda", enabled=True):
+ _, logits = synchformer(vid, aud)
+
+ # simply prints the results of the prediction
+ decode_single_video_prediction(logits, grid, item)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--exp_name", required=True, help="In a format: xx-xx-xxTxx-xx-xx")
+ parser.add_argument("--vid_path", required=True, help="A path to .mp4 video")
+ parser.add_argument("--offset_sec", type=float, default=0.0)
+ parser.add_argument("--v_start_i_sec", type=float, default=0.0)
+ parser.add_argument("--device", default="cuda:0")
+ args = parser.parse_args()
+
+ synchformer = Synchformer().cuda().eval()
+ synchformer.load_state_dict(
+ torch.load(
+ os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
+ weights_only=True,
+ map_location="cpu",
+ )
+ )
+
+ main(args)
diff --git a/hunyuanvideo_foley/models/synchformer/data_transforms.py b/hunyuanvideo_foley/models/synchformer/data_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d59331eb3c27454a4c52bdbf8a8b85946c63c0a3
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/data_transforms.py
@@ -0,0 +1,1130 @@
+import logging
+import math
+import random
+from typing import Tuple
+import torch
+import torchvision
+import torchaudio
+import numpy as np
+import einops
+
+
+def sec2frames(sec, fps):
+ return int(sec * fps)
+
+
+def frames2sec(frames, fps):
+ return frames / fps
+
+
+class EqualifyFromRight(torch.nn.Module):
+
+ def __init__(self, clip_max_len_sec=10):
+ """
+ Takes the dataset item and makes sure more streams are of an equal size in terms of fps.
+ It, however, assumes that the signal is synched and trims the ending parts ('from the right').
+ """
+ super().__init__()
+ self.clip_max_len_sec = clip_max_len_sec
+
+ def forward(self, item):
+ """
+ `item`: {'video': (Tv, C, H, W), 'audio': (Ta,),
+ 'meta': {
+ 'audio': {'framerate': [float], 'duration': [float]}
+ 'video': {'fps': [float], 'duration': [float]}}
+ """
+ a_fps = item["meta"]["audio"]["framerate"][0]
+ v_fps = item["meta"]["video"]["fps"][0]
+
+ Ta = item["audio"].shape[0]
+ Tv, C, H, W = item["video"].shape
+
+ a_len_secs = Ta / a_fps
+ v_len_secs = Tv / v_fps
+ min_len = min(self.clip_max_len_sec, a_len_secs, v_len_secs)
+
+ a_frames_per_v_frame = a_fps // v_fps
+ v_len_frames = int(v_fps * min_len)
+ a_len_frames = int(a_frames_per_v_frame * v_len_frames)
+ # print(a_len_frames, v_len_frames)
+
+ assert a_len_frames <= Ta and v_len_frames <= Tv
+
+ item["audio"] = item["audio"][:a_len_frames]
+ item["video"] = item["video"][:v_len_frames, :, :, :]
+
+ return item
+
+
+class RGBSpatialCrop(torch.nn.Module):
+
+ def __init__(self, input_size, is_random):
+ super().__init__()
+ assert input_size is not None, f"smaller_input_size is `{input_size}`"
+ if isinstance(input_size, int):
+ input_size = (input_size, input_size)
+ self.input_size = input_size
+ self.is_random = is_random
+
+ @staticmethod
+ def get_random_crop_sides(vid, output_size):
+ """Slice parameters for random crop"""
+ h, w = vid.shape[-2:]
+ th, tw = output_size
+ if w == tw and h == th:
+ return 0, 0, h, w
+ i = random.randint(0, h - th)
+ j = random.randint(0, w - tw)
+ return i, j, th, tw
+
+ @staticmethod
+ def get_center_crop_sides(vid, output_size):
+ """Slice parameters for center crop"""
+ h, w = vid.shape[-2:]
+ th, tw = output_size
+
+ i = int(round((h - th) / 2.0))
+ j = int(round((w - tw) / 2.0))
+ return i, j, th, tw
+
+ def forward(self, item):
+ # (Tv, C, H, W)
+ vid = item["video"]
+ if self.is_random:
+ i, j, h, w = self.get_random_crop_sides(vid, self.input_size)
+ else:
+ i, j, h, w = self.get_center_crop_sides(vid, self.input_size)
+ item["video"] = vid[..., i : (i + h), j : (j + w)]
+ return item
+
+
+class Resize(torchvision.transforms.Resize):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, item):
+ item["video"] = super().forward(item["video"])
+ return item
+
+
+class RGBSpatialCropSometimesUpscale(torch.nn.Module):
+ """This (randomly) crops the input video and with prob `sometimes_p` this crop is smaller but upscaled
+ to `target_input_size`"""
+
+ def __init__(self, sometimes_p, target_input_size, is_random, smaller_input_size=None):
+ super().__init__()
+ self.sometimes_p = sometimes_p
+ self.do_sometimes_upscale = sometimes_p is not None and sometimes_p > 0
+
+ self.crop_only = RGBSpatialCrop(target_input_size, is_random)
+
+ if self.do_sometimes_upscale:
+ self.crop_further_and_upscale = torchvision.transforms.Compose(
+ [
+ RGBSpatialCrop(smaller_input_size, is_random),
+ Resize(target_input_size, antialias=None),
+ ]
+ )
+
+ def forward(self, item):
+ assert len(item["video"].shape) == 4, (
+ f"{item['video'].shape}: if it is applied after GenerateMultipleClips,"
+ "augs should be applied to each clip separately, not to the whole video array. "
+ "Otherwise, ignore this warning (comment it)."
+ )
+ if self.do_sometimes_upscale and self.sometimes_p > torch.rand(1):
+ return self.crop_further_and_upscale(item)
+ else:
+ return self.crop_only(item)
+
+
+class RandomApplyColorDistortion(torch.nn.Module):
+
+ def __init__(self, p_gray_scale=0.0, p_color_jitter=0.0, s=1.0) -> None:
+ super().__init__()
+ self.p_gray_scale = p_gray_scale
+ self.p_color_jitter = p_color_jitter
+ self.s = s
+ assert 0 <= self.p_color_jitter <= 1 and 0 <= self.p_gray_scale <= 1, (p_color_jitter, p_gray_scale)
+ # SimCLR params
+ color_jitter = torchvision.transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
+ rand_color_jitter = torchvision.transforms.RandomApply([color_jitter], p_color_jitter)
+ rand_gray = torchvision.transforms.RandomGrayscale(p_gray_scale)
+ self.transforms = torchvision.transforms.Compose([rand_color_jitter, rand_gray])
+
+ def apply_to_single_clip(self, clip):
+ return self.transforms(clip)
+
+ def apply_to_each_clip(self, clips):
+ for i, clip in enumerate(clips):
+ clips[i] = self.apply_to_single_clip(clip)
+ return clips
+
+ def forward(self, item):
+ has_batch_dim = len(item["video"].shape) == 5
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["video"] = fn(item["video"])
+ return item
+
+
+class ApplyColorJitterFrameWise(torch.nn.Module):
+
+ def __init__(self, s=1.0) -> None:
+ super().__init__()
+ self.s = s
+ # SimCLR params
+ self.transform = torchvision.transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
+
+ def apply_to_single_clip(self, clip):
+ for i, frame in enumerate(clip):
+ clip[i] = self.transform(frame)
+ return clip
+
+ def apply_to_each_clip(self, clips):
+ for i, clip in enumerate(clips):
+ clips[i] = self.apply_to_single_clip(clip)
+ return clips
+
+ def forward(self, item):
+ has_batch_dim = len(item["video"].shape) == 5
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["video"] = fn(item["video"])
+ return item
+
+
+class RandomHorizontalFlip(torchvision.transforms.RandomHorizontalFlip):
+
+ def __init__(self, p=0.5):
+ super().__init__(p)
+
+ def apply_to_single_clip(self, clip):
+ return super().forward(clip)
+
+ def apply_to_each_clip(self, clips):
+ for i, clip in enumerate(clips):
+ clips[i] = self.apply_to_single_clip(clip)
+ return clips
+
+ def forward(self, item):
+ has_batch_dim = len(item["video"].shape) == 5
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["video"] = fn(item["video"])
+ return item
+
+
+def make_class_grid(
+ leftmost_val,
+ rightmost_val,
+ grid_size,
+ add_extreme_offset: bool = False,
+ seg_size_vframes: int = None,
+ nseg: int = None,
+ step_size_seg: float = None,
+ vfps: float = None,
+):
+ assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()"
+ grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float()
+ if add_extreme_offset:
+ assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}"
+ seg_size_sec = seg_size_vframes / vfps
+ trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1)
+ extreme_value = trim_size_in_seg * seg_size_sec
+ grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid
+ return grid
+
+
+def quantize_offset(grid: torch.Tensor, off_sec: float) -> Tuple[float, int]:
+ """Takes in the offset in seconds and snaps it onto the closest grid element.
+ Returns the grid value and its index."""
+ closest_grid_el = (grid - off_sec).abs().argmin()
+ return grid[closest_grid_el], closest_grid_el
+
+
+def apply_a_jitter(a_start_i, a_len_frames, a_crop_len_frames, a_fps, max_a_jitter_sec):
+ max_a_start_i = a_len_frames - a_crop_len_frames
+ max_a_jitter_i = sec2frames(max_a_jitter_sec, a_fps)
+ max_a_jitter_i_left = min(a_start_i, max_a_jitter_i)
+ max_a_jitter_i_right = min(max_a_start_i - a_start_i, max_a_jitter_i)
+ # jitter is U[left, right]
+ a_jitter_i = random.randint(-max_a_jitter_i_left, max_a_jitter_i_right)
+ # apply jitter
+ a_start_i = a_start_i + a_jitter_i
+ # making sure that any value from `a_start_i + U[left, right]` will be inside of [0, len-crop] region
+ assert 0 <= a_start_i <= max_a_start_i, f"{a_jitter_i} {max_a_jitter_i_left} {max_a_jitter_i_right} {max_a_start_i}"
+ return a_start_i, a_jitter_i
+
+
+class TemporalCropAndOffset(torch.nn.Module):
+
+ def __init__(
+ self,
+ crop_len_sec: float,
+ max_off_sec: float,
+ offset_type="grid",
+ do_offset: bool = True,
+ grid_size: int = None,
+ max_wiggle_sec: float = None,
+ add_doubt_cls: bool = False,
+ segment_size_vframes: int = None,
+ n_segments: int = None,
+ step_size_seg: float = None,
+ vfps: float = None,
+ prob_oos: float = None,
+ ):
+ super().__init__()
+ self.crop_len_sec = crop_len_sec
+ self.do_offset = do_offset
+ self.grid_size = grid_size
+ self.offset_type = offset_type
+ self.max_off_sec = max_off_sec
+ self.max_a_jitter_sec = max_wiggle_sec
+ if do_offset:
+ if offset_type == "grid":
+ self.class_grid = make_class_grid(
+ -max_off_sec,
+ max_off_sec,
+ grid_size,
+ add_doubt_cls,
+ segment_size_vframes,
+ n_segments,
+ step_size_seg,
+ vfps,
+ )
+ logging.info(f"Offsets class grid: {self.class_grid}")
+ if self.max_a_jitter_sec is not None:
+ assert (max_wiggle_sec - 1e-6) <= (
+ (self.class_grid[1] - self.class_grid[0]) / 2
+ ), f"{self.class_grid}"
+ elif offset_type == "uniform":
+ self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec)
+ logging.info(f"Offset uniform distribution: {self.off_dist}")
+ elif offset_type == "uniform_binary":
+ self.itu_t_range = (-0.125, 0.045)
+ self.prob_oos = prob_oos
+ self.ins_dist = torch.distributions.uniform.Uniform(self.itu_t_range[0], self.itu_t_range[1])
+ self.off_dist = torch.distributions.uniform.Uniform(-max_off_sec, max_off_sec)
+ else:
+ raise NotImplementedError(f"Unknown offset type: {offset_type}")
+
+ def forward(self, item):
+ vid = item["video"]
+ aud = item["audio"]
+ v_len_frames, C, H, W = vid.shape
+ a_len_frames = aud.shape[0]
+
+ v_fps = int(item["meta"]["video"]["fps"][0])
+ a_fps = int(item["meta"]["audio"]["framerate"][0])
+
+ v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps)
+ a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps)
+
+ if self.do_offset:
+ # trying to get the offset parameters (for instance during valid and test we have fixed offsets)
+ offset_sec = item["targets"].get("offset_sec", None)
+ v_start_i_sec = item["targets"].get("v_start_i_sec", None)
+ if "offset_target" in item["targets"]:
+ is_oos = item["targets"]["offset_target"].get("oos", None)
+ # train-time
+ if offset_sec is None and v_start_i_sec is None:
+ # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec
+ if self.offset_type == "grid":
+ offset_sec = random.choice(self.class_grid.tolist())
+ elif self.offset_type == "uniform":
+ offset_sec = self.off_dist.sample().item()
+ elif self.offset_type == "uniform_binary":
+ # in-sync: Uniform(-0.125, 0.045)
+ # out-of-sync: Uniform(-5.5, 5.5) and resampled until not in Uniform(-0.125, 0.045)
+ # first, we sample if the offset is out-of-sync with prob_oss
+ is_oos = (torch.rand(1) < self.prob_oos).item()
+ if is_oos:
+ # second, we sample the offset itself (if in in-sync range, trying again)
+ offset_sec = self.off_dist.sample().item()
+ while self.itu_t_range[0] <= offset_sec <= self.itu_t_range[1]:
+ offset_sec = self.off_dist.sample().item()
+ else:
+ offset_sec = self.ins_dist.sample().item()
+ offset_sec = round(offset_sec, 2)
+ v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps)
+ assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}'
+ # `v_start_sec` IS NOT rounded to the fps grid
+ v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec - offset_sec))
+ assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}'
+ v_start_i = sec2frames(v_start_sec, v_fps)
+ # `v_start_i_sec` IS rounded to the fps grid
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
+ else:
+ offset_sec = round(offset_sec, 2)
+ v_start_i = sec2frames(v_start_i_sec, v_fps)
+ v_end_i = v_start_i + v_crop_len_frames
+ # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise
+ # (v_start_sec) we have ±0.1 jittering
+ a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
+ else:
+ offset_sec = 0.0
+ is_random_crop = item["split"] == "train"
+ v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop)
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
+ a_start_i = sec2frames(v_start_i_sec, a_fps)
+
+ # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48
+ # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5)
+ if a_start_i < 0:
+ how_much_out = a_start_i
+ logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}')
+ if abs(how_much_out) <= a_fps / v_fps:
+ logging.info("fixing it")
+ a_start_i += abs(how_much_out)
+ else:
+ raise Exception(f'{how_much_out} {item["path"]}')
+
+ if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
+ a_start_i, a_jitter_i = apply_a_jitter(
+ a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec
+ )
+ item["meta"]["a_jitter_i"] = a_jitter_i
+
+ a_end_i = a_start_i + a_crop_len_frames
+
+ assert v_start_i < v_end_i and a_start_i < a_end_i
+ assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}'
+ assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}'
+
+ vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i]
+
+ item["video"] = vid
+ item["audio"] = aud
+
+ assert item["video"].shape[0] == v_fps * self.crop_len_sec, f'{item["video"].shape} {item["path"]}'
+ assert item["audio"].shape[0] == a_fps * self.crop_len_sec, f'{item["audio"].shape} {item["path"]}'
+
+ # caching parameters
+ if self.do_offset:
+ if self.offset_type == "grid":
+ offset_label, offset_target = quantize_offset(self.class_grid, offset_sec)
+ elif self.offset_type == "uniform":
+ offset_label, offset_target = offset_sec, offset_sec
+ elif self.offset_type == "uniform_binary":
+ offset_label, offset_target = offset_sec, {"oos": is_oos, "offset": offset_sec}
+ item["targets"]["offset_sec"] = offset_sec
+ item["targets"]["v_start_i_sec"] = v_start_i_sec
+ item["targets"]["offset_label"] = offset_label
+ # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?'
+ item["targets"]["offset_target"] = offset_target
+
+ return item
+
+ def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True):
+ if len_frames == crop_len_frames:
+ return 0, len_frames
+ if is_random:
+ left_i = random.randint(0, len_frames - crop_len_frames)
+ else:
+ left_i = int(round((len_frames - crop_len_frames) / 2.0))
+ return left_i, left_i + crop_len_frames
+
+
+class GenerateMultipleSegments(torch.nn.Module):
+ """
+ Given an item with video and audio, generates a batch of `n_segments` segments
+ of length `segment_size_vframes` (if None, the max number of segments will be made).
+ If `is_start_random` is True, the starting position of the 1st segment will be random but respecting
+ n_segments.
+ `audio_jitter_sec` is the amount of audio offset in seconds.
+ """
+
+ def __init__(
+ self,
+ segment_size_vframes: int,
+ n_segments: int = None,
+ is_start_random: bool = False,
+ audio_jitter_sec: float = 0.0,
+ step_size_seg: float = 1,
+ ):
+ super().__init__()
+ self.segment_size_vframes = segment_size_vframes
+ self.n_segments = n_segments
+ self.is_start_random = is_start_random
+ self.audio_jitter_sec = audio_jitter_sec
+ self.step_size_seg = step_size_seg
+ logging.info(f"Segment step size: {self.step_size_seg}")
+
+ def forward(self, item):
+ v_len_frames, C, H, W = item["video"].shape
+ a_len_frames = item["audio"].shape[0]
+
+ v_fps = int(item["meta"]["video"]["fps"][0])
+ a_fps = int(item["meta"]["audio"]["framerate"][0])
+
+ ## Determining the number of segments
+ # segment size
+ segment_size_vframes = self.segment_size_vframes
+ segment_size_aframes = sec2frames(frames2sec(self.segment_size_vframes, v_fps), a_fps)
+ # step size (stride)
+ stride_vframes = int(self.step_size_seg * segment_size_vframes)
+ stride_aframes = int(self.step_size_seg * segment_size_aframes)
+ # calculating the number of segments. (W - F + 2P) / S + 1
+ n_segments_max_v = math.floor((v_len_frames - segment_size_vframes) / stride_vframes) + 1
+ n_segments_max_a = math.floor((a_len_frames - segment_size_aframes) / stride_aframes) + 1
+ # making sure audio and video can accommodate the same number of segments
+ n_segments_max = min(n_segments_max_v, n_segments_max_a)
+ n_segments = n_segments_max if self.n_segments is None else self.n_segments
+
+ assert n_segments <= n_segments_max, (
+ f"cant make {n_segments} segs of len {self.segment_size_vframes} in a vid "
+ f'of len {v_len_frames} for {item["path"]}'
+ )
+
+ # (n_segments, 2) each
+ v_ranges, a_ranges = self.get_sequential_seg_ranges(
+ v_len_frames, a_len_frames, v_fps, a_fps, n_segments, segment_size_aframes
+ )
+
+ # segmenting original streams (n_segments, segment_size_frames, C, H, W)
+ item["video"] = torch.stack([item["video"][s:e] for s, e in v_ranges], dim=0)
+ item["audio"] = torch.stack([item["audio"][s:e] for s, e in a_ranges], dim=0)
+ return item
+
+ def get_sequential_seg_ranges(self, v_len_frames, a_len_frames, v_fps, a_fps, n_seg, seg_size_aframes):
+ # if is_start_random is True, the starting position of the 1st segment will
+ # be random but respecting n_segments like so: "-CCCCCCCC---" (maybe with fixed overlap),
+ # else the segments are taken from the middle of the video respecting n_segments: "--CCCCCCCC--"
+
+ seg_size_vframes = self.segment_size_vframes # for brevity
+
+ # calculating the step size in frames
+ step_size_vframes = int(self.step_size_seg * seg_size_vframes)
+ step_size_aframes = int(self.step_size_seg * seg_size_aframes)
+
+ # calculating the length of the sequence of segments (and in frames)
+ seg_seq_len = n_seg * self.step_size_seg + (1 - self.step_size_seg)
+ vframes_seg_seq_len = int(seg_seq_len * seg_size_vframes)
+ aframes_seg_seq_len = int(seg_seq_len * seg_size_aframes)
+
+ # doing temporal crop
+ max_v_start_i = v_len_frames - vframes_seg_seq_len
+ if self.is_start_random:
+ v_start_i = random.randint(0, max_v_start_i)
+ else:
+ v_start_i = max_v_start_i // 2
+ a_start_i = sec2frames(frames2sec(v_start_i, v_fps), a_fps) # vid frames -> seconds -> aud frames
+
+ # make segments starts
+ v_start_seg_i = torch.tensor([v_start_i + i * step_size_vframes for i in range(n_seg)]).int()
+ a_start_seg_i = torch.tensor([a_start_i + i * step_size_aframes for i in range(n_seg)]).int()
+
+ # apply jitter to audio
+ if self.audio_jitter_sec > 0:
+ jitter_aframes = sec2frames(self.audio_jitter_sec, a_fps)
+ # making sure after applying jitter, the audio is still within the audio boundaries
+ jitter_aframes = min(jitter_aframes, a_start_i, a_len_frames - a_start_i - aframes_seg_seq_len)
+ a_start_seg_i += random.randint(-jitter_aframes, jitter_aframes) # applying jitter to segments
+
+ # make segment ends
+ v_ends_seg_i = v_start_seg_i + seg_size_vframes
+ a_ends_seg_i = a_start_seg_i + seg_size_aframes # using the adjusted a_start_seg_i (with jitter)
+
+ # make ranges
+ v_ranges = torch.stack([v_start_seg_i, v_ends_seg_i], dim=1)
+ a_ranges = torch.stack([a_start_seg_i, a_ends_seg_i], dim=1)
+ assert (a_ranges >= 0).all() and (a_ranges <= a_len_frames).all(), f"{a_ranges} out of {a_len_frames}"
+ assert (v_ranges <= v_len_frames).all(), f"{v_ranges} out of {v_len_frames}"
+ return v_ranges, a_ranges
+
+
+class TemporalCropAndOffsetForSyncabilityTraining(torch.nn.Module):
+
+ def __init__(
+ self,
+ max_off_sec: float,
+ do_offset: bool = True,
+ grid_size: int = None,
+ max_wiggle_sec: float = None,
+ segment_size_vframes: int = None,
+ n_segments: int = None,
+ step_size_seg: float = None,
+ vfps: float = None,
+ ):
+ super().__init__()
+ seg_size_sec = segment_size_vframes / vfps
+ trim_size_in_seg = n_segments - (1 - step_size_seg) * (n_segments - 1)
+ self.crop_len_sec = round(trim_size_in_seg * seg_size_sec, 2)
+ logging.info(f"Crop len: {self.crop_len_sec}")
+ self.do_offset = do_offset
+ self.grid_size = grid_size
+ self.max_off_sec = max_off_sec
+ self.max_a_jitter_sec = max_wiggle_sec
+ self.segment_size_vframes = segment_size_vframes
+ self.n_segments = n_segments
+ self.step_size_seg = step_size_seg
+ self.prob_syncable = 0.5
+ if do_offset:
+ self.class_grid = make_class_grid(-max_off_sec, max_off_sec, grid_size)
+ logging.info(f"Offset class grid: {self.class_grid}")
+ if self.max_a_jitter_sec is not None:
+ assert (max_wiggle_sec - 1e-6) <= ((self.class_grid[1] - self.class_grid[0]) / 2), f"{self.class_grid}"
+
+ def forward(self, item):
+ vid = item["video"]
+ aud = item["audio"]
+ v_len_frames, C, H, W = vid.shape
+ a_len_frames = aud.shape[0]
+
+ v_fps = int(item["meta"]["video"]["fps"][0])
+ a_fps = int(item["meta"]["audio"]["framerate"][0])
+
+ v_crop_len_frames = sec2frames(self.crop_len_sec, v_fps)
+ a_crop_len_frames = sec2frames(self.crop_len_sec, a_fps)
+
+ if self.do_offset:
+ # trying to get the offset parameters (for instance during valid and test we have fixed offsets)
+ offset_sec = item["targets"].get("offset_sec", None)
+ v_start_i_sec = item["targets"].get("v_start_i_sec", None)
+ # train-time
+ if offset_sec is None and v_start_i_sec is None:
+
+ # for the syncability training, we want to have a syncable or non-syncable offset with 50% prob
+ offset_is_syncable = random.random() < self.prob_syncable # 1=syncable, 0=non-syncable
+ if offset_is_syncable:
+ offset_sec = random.choice(self.class_grid.tolist())
+ else:
+ offset_sec = random.choice([-self.crop_len_sec, self.crop_len_sec]) # either - or + offset
+ # aud starts `offset_sec` earlier than it should; aud has what will be shown after offset_sec
+
+ offset_sec = round(offset_sec, 2)
+ v_start_max_sec = frames2sec(v_len_frames - v_crop_len_frames, v_fps)
+ assert v_start_max_sec > 0, f'{v_len_frames} {v_crop_len_frames} {v_fps} @ {item["path"]}'
+ # `v_start_sec` IS NOT rounded to the fps grid
+ v_start_sec = random.uniform(max(0, -offset_sec), min(v_start_max_sec, v_start_max_sec - offset_sec))
+ assert 0 <= v_start_sec <= v_start_max_sec, f'{v_start_sec} {v_start_max_sec} {item["path"]}'
+ v_start_i = sec2frames(v_start_sec, v_fps)
+ v_end_i = v_start_i + v_crop_len_frames
+ # `v_start_i_sec` IS rounded to the fps grid
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
+ # `a_start_i` depends on the rounded value `v_start_i_sec`, otherwise
+ # (v_start_sec) we have ±0.1 jittering
+ a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
+ if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
+ a_start_i, a_jitter_i = apply_a_jitter(
+ a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec
+ )
+ item["meta"]["a_jitter_i"] = a_jitter_i
+ a_end_i = a_start_i + a_crop_len_frames
+ else:
+ offset_sec = round(offset_sec, 2)
+ v_start_i = sec2frames(v_start_i_sec, v_fps)
+ a_start_i = sec2frames(v_start_i_sec + offset_sec, a_fps)
+ v_end_i = v_start_i + v_crop_len_frames
+ a_end_i = a_start_i + a_crop_len_frames
+ else:
+ offset_sec = 0.0
+ is_random_crop = item["split"] == "train"
+ v_start_i, v_end_i = self.get_crop_idx(v_len_frames, v_crop_len_frames, is_random=is_random_crop)
+ v_start_i_sec = frames2sec(v_start_i, v_fps)
+ a_start_i = sec2frames(v_start_i_sec, a_fps)
+ if self.max_a_jitter_sec is not None and self.max_a_jitter_sec > 0:
+ a_start_i, a_jitter_i = apply_a_jitter(
+ a_start_i, a_len_frames, a_crop_len_frames, a_fps, self.max_a_jitter_sec
+ )
+ item["meta"]["a_jitter_i"] = a_jitter_i
+ a_end_i = a_start_i + a_crop_len_frames
+
+ # sometimes due to the rounding error e.g. v_start_sec = 1.505 but sec2frames(1.505, 25) = 1.48
+ # given offset is -1.5, the a_start_i will be a small negative value. (likely a_fps * 1/v_fps * 0.5)
+ if a_start_i < 0:
+ how_much_out = a_start_i
+ logging.info(f'a_start_i is negative ({how_much_out}) at {item["path"]}')
+ if abs(how_much_out) <= a_fps / v_fps:
+ logging.info("fixing it")
+ a_start_i += abs(how_much_out)
+ a_end_i += abs(how_much_out)
+ else:
+ raise Exception(f'{how_much_out} {item["path"]}')
+
+ assert v_start_i < v_end_i and a_start_i < a_end_i
+ assert aud.shape[0] >= a_end_i, f'{aud.shape} {a_end_i} {item["path"]}'
+ assert vid.shape[0] >= v_end_i, f'{vid.shape} {v_end_i} {item["path"]}'
+
+ vid, aud = vid[v_start_i:v_end_i, :, :, :], aud[a_start_i:a_end_i]
+
+ item["video"] = vid
+ item["audio"] = aud
+
+ assert item["video"].shape[0] == int(v_fps * self.crop_len_sec), f'{item["video"].shape} {item["path"]}'
+ assert item["audio"].shape[0] == int(a_fps * self.crop_len_sec), f'{item["audio"].shape} {item["path"]}'
+
+ # caching parameters
+ if self.do_offset:
+ # NOTE: this is useless for the extreme offsetting
+ offset_label, offset_target = quantize_offset(self.class_grid, offset_sec)
+ item["targets"]["offset_sec"] = offset_sec
+ item["targets"]["offset_label"] = offset_label
+ # assert 'offset_target' not in item['targets'], f'{item["targets"]}. What passed it there?'
+ item["targets"]["offset_target"] = offset_target
+ item["targets"]["v_start_i_sec"] = v_start_i_sec
+ item["targets"]["sync_target"] = int(offset_is_syncable)
+
+ return item
+
+ def get_crop_idx(self, len_frames: int, crop_len_frames: int, is_random=True):
+ if len_frames == crop_len_frames:
+ return 0, len_frames
+ if is_random:
+ left_i = random.randint(0, len_frames - crop_len_frames)
+ else:
+ left_i = int(round((len_frames - crop_len_frames) / 2.0))
+ return left_i, left_i + crop_len_frames
+
+
+class RGBToFloatToZeroOne(torch.nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, item):
+ item["video"] = item["video"].to(torch.float32).div(255.0)
+ return item
+
+
+class RGBToHalfToZeroOne(torch.nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, item):
+ item["video"] = item["video"].half().div(255.0)
+ return item
+
+
+class RGBNormalize(torchvision.transforms.Normalize):
+ """The same as the torchvision`s but with different interface for the dict.
+ This should work for any shape (..., C, H, W)"""
+
+ def __init__(self, mean, std, inplace=False):
+ super().__init__(mean, std, inplace)
+ logging.info(f"RGBNormalize: mean={mean}, std={std}")
+
+ def forward(self, item):
+ item["video"] = super().forward(item["video"])
+ item["meta"]["video"]["norm_stats"] = {"mean": torch.as_tensor(self.mean), "std": torch.as_tensor(self.std)}
+ return item
+
+
+class AudioRandomVolume(torch.nn.Module):
+
+ def __init__(self, p: float, **kwargs):
+ super().__init__()
+ transform = torchaudio.transforms.Vol(**kwargs)
+ self.transform = torchvision.transforms.RandomApply([transform], p)
+
+ def apply_to_single_clip(self, clip):
+ return self.transform(clip)
+
+ def apply_to_each_clip(self, clips):
+ for i, clip in enumerate(clips):
+ clips[i] = self.apply_to_single_clip(clip)
+ return clips
+
+ def forward(self, item):
+ has_batch_dim = len(item["audio"].shape) == 2
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["audio"] = fn(item["audio"])
+ return item
+
+
+class AudioRandomLowpassFilter(torch.nn.Module):
+
+ def __init__(self, p: float, cutoff_freq: float, Q: float = 0.707):
+ super().__init__()
+ self.p = p
+ self.cutoff_freq = cutoff_freq
+ self.Q = Q
+
+ def apply_to_single_clip(self, clip, sr):
+ if self.p > torch.rand(1):
+ return torchaudio.functional.lowpass_biquad(clip, sr, self.cutoff_freq, self.Q)
+ else:
+ return clip
+
+ def apply_to_each_clip(self, clips, sr):
+ for i, clip in enumerate(clips):
+ clips[i] = self.apply_to_single_clip(clip, sr)
+ return clips
+
+ def forward(self, item):
+ has_batch_dim = len(item["audio"].shape) == 2
+ sr = int(item["meta"]["audio"]["framerate"][0])
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["audio"] = fn(item["audio"], sr)
+ return item
+
+
+class AudioRandomPitchShift(torch.nn.Module):
+
+ def __init__(self, p: float, shift: int) -> None:
+ super().__init__()
+ self.p = p
+ self.shift = shift
+
+ def apply_to_single_clip(self, wave, sr):
+ if self.p > torch.rand(1):
+ effects = [["pitch", f"{self.shift}"], ["rate", f"{sr}"]]
+ wave = wave.unsqueeze(0)
+ wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, sr, effects)
+ wave = wave.squeeze(0)
+ return wave
+
+ def apply_to_each_clip(self, waves, sr):
+ for i, wave in enumerate(waves):
+ waves[i] = self.apply_to_single_clip(wave, sr)
+ return waves
+
+ def forward(self, item):
+ has_batch_dim = len(item["audio"].shape) == 2
+ sr = int(item["meta"]["audio"]["framerate"][0])
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["audio"] = fn(item["audio"], sr)
+ return item
+
+
+class AudioRandomReverb(torch.nn.Module):
+
+ def __init__(self, p: float) -> None:
+ super().__init__()
+ self.p = p
+ self.effects = [["reverb", "-w"]]
+
+ def apply_to_single_clip(self, wave, fps):
+ if self.p > torch.rand(1):
+ wave = wave.unsqueeze(0)
+ wave, _ = torchaudio.sox_effects.apply_effects_tensor(wave, fps, self.effects)
+ wave = wave.mean(dim=0)
+ return wave
+
+ def apply_to_each_clip(self, waves, fps):
+ for i, wave in enumerate(waves):
+ waves[i] = self.apply_to_single_clip(wave, fps)
+ return waves
+
+ def forward(self, item):
+ has_batch_dim = len(item["audio"].shape) == 2
+ sr = int(item["meta"]["audio"]["framerate"][0])
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["audio"] = fn(item["audio"], sr)
+ return item
+
+
+class AudioRandomGaussNoise(torch.nn.Module):
+
+ def __init__(self, p: float, amplitude=0.01) -> None:
+ super().__init__()
+ self.p = p
+ self.amplitude = amplitude
+
+ def apply_to_single_clip(self, wave):
+ if self.p > torch.rand(1):
+ noise = torch.randn_like(wave, dtype=wave.dtype)
+ wave = wave + self.amplitude * noise
+ return wave
+
+ def apply_to_each_clip(self, waves):
+ for i, wave in enumerate(waves):
+ waves[i] = self.apply_to_single_clip(wave)
+ return waves
+
+ def forward(self, item):
+ has_batch_dim = len(item["audio"].shape) == 2
+ if has_batch_dim:
+ fn = self.apply_to_each_clip
+ else:
+ fn = self.apply_to_single_clip
+ item["audio"] = fn(item["audio"])
+ return item
+
+
+class AudioMelSpectrogram(torch.nn.Module):
+
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.spec = torchaudio.transforms.MelSpectrogram(**kwargs)
+
+ def forward(self, item):
+ item["audio"] = self.spec(item["audio"]) # safe for batched input
+ return item
+
+
+class AudioLog(torch.nn.Module):
+
+ def __init__(self, eps=1e-6) -> None:
+ super().__init__()
+ self.eps = eps
+
+ def forward(self, item):
+ item["audio"] = torch.log(item["audio"] + self.eps)
+ return item
+
+
+class PadOrTruncate(torch.nn.Module):
+
+ def __init__(self, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0):
+ super().__init__()
+ self.max_spec_t = max_spec_t
+ self.pad_mode = pad_mode
+ self.pad_value = pad_value
+
+ def forward(self, item):
+ item["audio"] = self.pad_or_truncate(item["audio"])
+ return item
+
+ def pad_or_truncate(self, audio):
+ difference = self.max_spec_t - audio.shape[-1] # safe for batched input
+ # pad or truncate, depending on difference
+ if difference > 0:
+ # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
+ pad_dims = (0, difference)
+ audio = torch.nn.functional.pad(audio, pad_dims, self.pad_mode, self.pad_value)
+ elif difference < 0:
+ logging.warning(f"Truncating spec ({audio.shape}) to max_spec_t ({self.max_spec_t}).")
+ audio = audio[..., : self.max_spec_t] # safe for batched input
+ return audio
+
+
+class AudioNormalizeAST(torch.nn.Module):
+ """Normalization is done with two specified mean and std (half)"""
+
+ def __init__(self, mean: float, std: float) -> None:
+ super().__init__()
+ self.mean = mean
+ self.std = std
+
+ def forward(self, item):
+ item["audio"] = (item["audio"] - self.mean) / (2 * self.std)
+ item["meta"]["audio"]["norm_stats"] = {"mean": self.mean, "std": self.std}
+ return item
+
+
+class PermuteStreams(torch.nn.Module):
+
+ def __init__(self, einops_order_audio: str, einops_order_rgb: str) -> None:
+ '''For example:
+ einops_order_audio: "S F T -> S T F"
+ einops_order_rgb: "S T C H W -> S C T H W"'''
+ super().__init__()
+ self.einops_order_audio = einops_order_audio
+ self.einops_order_rgb = einops_order_rgb
+
+ def forward(self, item):
+ if self.einops_order_audio is not None:
+ item["audio"] = einops.rearrange(item["audio"], self.einops_order_audio).contiguous()
+ if self.einops_order_rgb is not None:
+ item["video"] = einops.rearrange(item["video"], self.einops_order_rgb).contiguous()
+ return item
+
+
+class ResampleAudio(torch.nn.Module):
+
+ def __init__(self, new_fps: int):
+ super().__init__()
+ self.new_fps = new_fps
+
+ def forward(self, item):
+ orig_fps = int(item["meta"]["audio"]["framerate"][0])
+ item["meta"]["audio"]["orig_shape"] = item["audio"].shape
+ if orig_fps != self.new_fps:
+ item["audio"] = torchaudio.functional.resample(item["audio"], orig_fps, self.new_fps)
+ item["meta"]["audio"]["framerate"][0] = self.new_fps
+ return item
+
+
+class ResampleRGB(torch.nn.Module):
+
+ def __init__(self, new_fps: int) -> None:
+ super().__init__()
+ self.new_fps = new_fps
+
+ def forward(self, item):
+ orig_fps = float(item["meta"]["video"]["fps"][0])
+ item["meta"]["video"]["orig_shape"] = item["video"].shape
+ if orig_fps != self.new_fps:
+ duration_sec = item["video"].shape[0] / orig_fps
+ indices = torch.arange(0, orig_fps * duration_sec - 1e-9, orig_fps / self.new_fps)
+ # basically, rounding
+ indices = indices.to(dtype=torch.long)
+ item["video"] = item["video"][indices]
+ item["meta"]["video"]["fps"][0] = self.new_fps
+ return item
+
+
+class ResizeAndLetterboxPad(torch.nn.Module):
+ """Adapted from WACV24 Amazon`s challenge"""
+
+ def __init__(self, new_h, new_w):
+ super().__init__()
+ self.new_h = new_h
+ self.new_w = new_w
+ self.aspect_ratio = new_w / new_h
+
+ def forward(self, item):
+ item["video"] = self.resize_and_pad(item["video"])
+ return item
+
+ def resize_and_pad(self, rgb: torch.Tensor):
+ _, _, height, width = rgb.shape
+ current_aspect_ratio = width / height
+ if current_aspect_ratio > self.aspect_ratio:
+ scaled_height = round(self.new_w / current_aspect_ratio)
+ rgb = torchvision.transforms.functional.resize(rgb, (scaled_height, self.new_w), antialias=None)
+ top = (self.new_h - scaled_height) // 2
+ bottom = self.new_h - (scaled_height + top)
+ rgb = torch.nn.ConstantPad2d((0, 0, top, bottom), 0)(rgb)
+ elif current_aspect_ratio < self.aspect_ratio:
+ scaled_width = round(self.new_h * current_aspect_ratio)
+ rgb = torchvision.transforms.functional.resize(rgb, (self.new_h, scaled_width), antialias=None)
+ left = (self.new_w - scaled_width) // 2
+ right = self.new_w - (scaled_width + left)
+ rgb = torch.nn.ConstantPad2d((left, right, 0, 0), 0)(rgb)
+ return rgb
+
+
+class ResampleResizeLetterboxPad(torch.nn.Module):
+
+ def __init__(self, afps, vfps, new_h, new_w) -> None:
+ super().__init__()
+ self.transforms = torchvision.transforms.Compose(
+ [ResampleAudio(new_fps=afps), ResampleRGB(new_fps=vfps), ResizeAndLetterboxPad(new_h=new_h, new_w=new_w)]
+ )
+
+ def forward(self, x: dict) -> dict:
+ return self.transforms(x)
+
+
+class DoNothing(torch.nn.Module):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__()
+
+ def forward(self, x: dict) -> dict:
+ return x
+
+
+if __name__ == "__main__":
+ grid = make_class_grid(-1, 1, 21)
+ grid = make_class_grid(-2, 2, 41)
+ print("grid:", grid)
+ print("value quantization:", quantize_offset(grid, 0.06))
+ v_fps = 25.0
+ duration = 10.0
+
+ input = {
+ "video": torch.randint(0, 256, (int(duration * v_fps), 3, 720 // 2, 1280 // 2), dtype=torch.uint8),
+ "audio": torch.arange(221184 - 1).float(),
+ "targets": {},
+ "meta": {
+ "video": {"duration": [duration], "fps": [v_fps]},
+ "audio": {"duration": [duration], "framerate": [22050.0]},
+ "subtitles": {"duration": []},
+ "cc": {"duration": []},
+ },
+ "path": "/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4",
+ "split": "train",
+ }
+
+ print(input["audio"].shape, input["video"].shape)
+
+ fn = EqualifyFromRight(clip_max_len_sec=10)
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape)
+
+ fn = RGBSpatialCrop((224, 224), is_random=True)
+ # fn = RGBSpatialCrop((112, 112), is_random=True)
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = Resize((224, 224))
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = GenerateMultipleSegments(
+ segment_size_vframes=16, n_segments=14, is_start_random=False, audio_jitter_sec=0.05, step_size_seg=0.5
+ )
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = RandomApplyColorDistortion(p_gray_scale=0.5, p_color_jitter=0.5, s=1.0)
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = RGBToFloatToZeroOne()
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+ print(input["meta"])
+
+ fn = RGBNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+ print(input["video"].mean(dim=(0, 2, 3)))
+ print(input["meta"])
+
+ fn = AudioRandomReverb(p=1.0)
+ input = fn(input)
+
+ fn = AudioRandomVolume(p=1.0, gain=2.0, gain_type="amplitude")
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = AudioRandomPitchShift(p=1.0, shift=1000)
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = AudioRandomLowpassFilter(p=1.0, cutoff_freq=100)
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = AudioRandomGaussNoise(p=1.0, amplitude=0.01)
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ fn = AudioLog()
+ input = fn(input)
+ print(input["audio"].shape, input["video"].shape, input["meta"]["audio"])
+
+ # audio only
+ input = {
+ "audio": torch.arange(221184).float(),
+ "meta": {
+ "video": {"duration": [10.0], "fps": [10.0]},
+ "audio": {"duration": [11.0], "framerate": [22050.0]},
+ "subtitles": {"duration": []},
+ "cc": {"duration": []},
+ },
+ "path": "/home/nvme/data/vggsound/video/-5cWCaoEDlE_261000_271000.mp4",
+ }
+
+ print(input["audio"].shape)
+
+ fn = AudioLog()
+ input = fn(input)
+ print(input["audio"].shape, input["meta"]["audio"])
+ print(input["meta"])
+ print(input["audio"].min(), input["audio"].max())
diff --git a/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml b/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f9d20b76302a8af7928391643bd4b2d184e970aa
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/divided_224_16x4.yaml
@@ -0,0 +1,84 @@
+TRAIN:
+ ENABLE: True
+ DATASET: Ssv2
+ BATCH_SIZE: 32
+ EVAL_PERIOD: 5
+ CHECKPOINT_PERIOD: 5
+ AUTO_RESUME: True
+ CHECKPOINT_EPOCH_RESET: True
+ CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
+DATA:
+ NUM_FRAMES: 16
+ SAMPLING_RATE: 4
+ TRAIN_JITTER_SCALES: [256, 320]
+ TRAIN_CROP_SIZE: 224
+ TEST_CROP_SIZE: 224
+ INPUT_CHANNEL_NUM: [3]
+ MEAN: [0.5, 0.5, 0.5]
+ STD: [0.5, 0.5, 0.5]
+ PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
+ PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
+ INV_UNIFORM_SAMPLE: True
+ RANDOM_FLIP: False
+ REVERSE_INPUT_CHANNEL: True
+ USE_RAND_AUGMENT: True
+ RE_PROB: 0.0
+ USE_REPEATED_AUG: False
+ USE_RANDOM_RESIZE_CROPS: False
+ COLORJITTER: False
+ GRAYSCALE: False
+ GAUSSIAN: False
+SOLVER:
+ BASE_LR: 1e-4
+ LR_POLICY: steps_with_relative_lrs
+ LRS: [1, 0.1, 0.01]
+ STEPS: [0, 20, 30]
+ MAX_EPOCH: 35
+ MOMENTUM: 0.9
+ WEIGHT_DECAY: 5e-2
+ WARMUP_EPOCHS: 0.0
+ OPTIMIZING_METHOD: adamw
+ USE_MIXED_PRECISION: True
+ SMOOTHING: 0.2
+SLOWFAST:
+ ALPHA: 8
+VIT:
+ PATCH_SIZE: 16
+ PATCH_SIZE_TEMP: 2
+ CHANNELS: 3
+ EMBED_DIM: 768
+ DEPTH: 12
+ NUM_HEADS: 12
+ MLP_RATIO: 4
+ QKV_BIAS: True
+ VIDEO_INPUT: True
+ TEMPORAL_RESOLUTION: 8
+ USE_MLP: True
+ DROP: 0.0
+ POS_DROPOUT: 0.0
+ DROP_PATH: 0.2
+ IM_PRETRAINED: True
+ HEAD_DROPOUT: 0.0
+ HEAD_ACT: tanh
+ PRETRAINED_WEIGHTS: vit_1k
+ ATTN_LAYER: divided
+MODEL:
+ NUM_CLASSES: 174
+ ARCH: slow
+ MODEL_NAME: VisionTransformer
+ LOSS_FUNC: cross_entropy
+TEST:
+ ENABLE: True
+ DATASET: Ssv2
+ BATCH_SIZE: 64
+ NUM_ENSEMBLE_VIEWS: 1
+ NUM_SPATIAL_CROPS: 3
+DATA_LOADER:
+ NUM_WORKERS: 4
+ PIN_MEMORY: True
+NUM_GPUS: 8
+NUM_SHARDS: 4
+RNG_SEED: 0
+OUTPUT_DIR: .
+TENSORBOARD:
+ ENABLE: True
diff --git a/hunyuanvideo_foley/models/synchformer/modeling_ast.py b/hunyuanvideo_foley/models/synchformer/modeling_ast.py
new file mode 100644
index 0000000000000000000000000000000000000000..f456753ecfff180dd36a3d2ff3e50a47ab735d52
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/modeling_ast.py
@@ -0,0 +1,673 @@
+# coding=utf-8
+# Copyright 2022 MIT and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Modified by v-iashin to support token masking
+
+"""PyTorch Audio Spectrogram Transformer (AST) model."""
+
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from transformers.models.audio_spectrogram_transformer.modeling_audio_spectrogram_transformer import ASTConfig
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "ASTConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "MIT/ast-finetuned-audioset-10-10-0.4593"
+_EXPECTED_OUTPUT_SHAPE = [1, 1214, 768]
+
+# Audio classification docstring
+_SEQ_CLASS_CHECKPOINT = "MIT/ast-finetuned-audioset-10-10-0.4593"
+_SEQ_CLASS_EXPECTED_OUTPUT = "'Speech'"
+_SEQ_CLASS_EXPECTED_LOSS = 0.17
+
+
+AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "MIT/ast-finetuned-audioset-10-10-0.4593",
+ # See all Audio Spectrogram Transformer models at https://huggingface.co/models?filter=ast
+]
+
+
+class ASTEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings.
+ """
+
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.patch_embeddings = ASTPatchEmbeddings(config)
+
+ frequency_out_dimension, time_out_dimension = self.get_shape(config)
+ num_patches = frequency_out_dimension * time_out_dimension
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def get_shape(self, config):
+ # see Karpathy's cs231n blog on how to calculate the output dimensions
+ # https://cs231n.github.io/convolutional-networks/#conv
+ frequency_out_dimension = (config.num_mel_bins - config.patch_size) // config.frequency_stride + 1
+ time_out_dimension = (config.max_length - config.patch_size) // config.time_stride + 1
+
+ return frequency_out_dimension, time_out_dimension
+
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
+ batch_size = input_values.shape[0]
+ embeddings = self.patch_embeddings(input_values)
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
+ print(self.position_embeddings.shape)
+ embeddings = embeddings + self.position_embeddings
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class ASTPatchEmbeddings(nn.Module):
+ """
+ This class turns `input_values` into the initial `hidden_states` (patch embeddings) of shape `(batch_size,
+ seq_length, hidden_size)` to be consumed by a Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ patch_size = config.patch_size
+ frequency_stride = config.frequency_stride
+ time_stride = config.time_stride
+
+ self.projection = nn.Conv2d(
+ 1, config.hidden_size, kernel_size=(patch_size, patch_size), stride=(frequency_stride, time_stride)
+ )
+
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
+ input_values = input_values.unsqueeze(1)
+ input_values = input_values.transpose(2, 3)
+ embeddings = self.projection(input_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->AST
+class ASTSelfAttention(nn.Module):
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ tok_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # apply masking if provided, tok_mask is (BS, N): 1s - keep; attention_scores is (BS, H, N, N)
+ if tok_mask is not None:
+ BS, N = tok_mask.shape
+ attention_scores = attention_scores.masked_fill(tok_mask.view(BS, 1, 1, N) == 0, float("-inf"))
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
+class ASTSelfOutput(nn.Module):
+ """
+ The residual connection is defined in ASTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->AST
+class ASTAttention(nn.Module):
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+ self.attention = ASTSelfAttention(config)
+ self.output = ASTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ tok_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, tok_mask, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
+class ASTIntermediate(nn.Module):
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->AST
+class ASTOutput(nn.Module):
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST
+class ASTLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = ASTAttention(config)
+ self.intermediate = ASTIntermediate(config)
+ self.output = ASTOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ tok_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in AST, layernorm is applied before self-attention
+ tok_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in AST, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->AST
+class ASTEncoder(nn.Module):
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([ASTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ tok_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ tok_mask,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, tok_mask, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ASTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ASTConfig
+ base_model_prefix = "audio_spectrogram_transformer"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+
+ # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
+ def _set_gradient_checkpointing(self, module: ASTEncoder, value: bool = False) -> None:
+ if isinstance(module, ASTEncoder):
+ module.gradient_checkpointing = value
+
+
+AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ASTConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
+ [`ASTFeatureExtractor.__call__`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare AST Model transformer outputting raw hidden-states without any specific head on top.",
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
+)
+class ASTModel(ASTPreTrainedModel):
+ def __init__(self, config: ASTConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ASTEmbeddings(config)
+ self.encoder = ASTEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> ASTPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor] = None,
+ cont_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_values is None:
+ raise ValueError("You have to specify input_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(input_values)
+
+ # transforms the mask that has spectrogram dims to the token masking which is obtained after patching.
+ # Due to the ovelap in patching, getting the token mask from spectrogram mask is not straightforward,
+ # because one 16x16 content patch is encoded in two tokens if stride is <16. So, to get the mask for
+ # tokens I will apply the patching func (self.embeddings) to the tensor with infinities at the masked
+ # content position. For infs, the patching fn will return nans, which I'll use to get the token mask.
+ if cont_mask is not None:
+ indicator = torch.ones_like(input_values).to(input_values.dtype)
+ # replace content mask (0s) with infs
+ indicator[~cont_mask] = torch.inf
+ # apply patching; now nans are where the content mask was
+ with torch.no_grad():
+ indicator = self.embeddings(indicator) # BS, N, D
+ # replace nans with 0s; these are the tokens that correspond to the masked content
+ tok_mask = ~torch.isnan(indicator)
+ # since all values in the D-dimension (latent) will also be nans, we can just use the first el
+ tok_mask = tok_mask[:, :, 0] # (BS, 2+num_patches) -- 2 is from CLS and DISTIL tokens
+ else:
+ tok_mask = None
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ tok_mask=tok_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ pooled_output = (sequence_output[:, 0] + sequence_output[:, 1]) / 2
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return (
+ BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ ),
+ tok_mask,
+ )
+
+
+class ASTMLPHead(nn.Module):
+ def __init__(self, config: ASTConfig):
+ super().__init__()
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dense = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ def forward(self, hidden_state):
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = self.dense(hidden_state)
+ return hidden_state
+
+
+@add_start_docstrings(
+ """
+ Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled
+ output) e.g. for datasets like AudioSet, Speech Commands v2.
+ """,
+ AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING,
+)
+class ASTForAudioClassification(ASTPreTrainedModel):
+ def __init__(self, config: ASTConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.audio_spectrogram_transformer = ASTModel(config)
+
+ # Classifier head
+ self.classifier = ASTMLPHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(AUDIO_SPECTROGRAM_TRANSFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_SEQ_CLASS_CHECKPOINT,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor] = None,
+ cont_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the audio classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.audio_spectrogram_transformer(
+ input_values,
+ cont_mask=cont_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/hunyuanvideo_foley/models/synchformer/motionformer.py b/hunyuanvideo_foley/models/synchformer/motionformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9980a6f6d667699a275b10c6f613a30493566713
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/motionformer.py
@@ -0,0 +1,397 @@
+import logging
+from pathlib import Path
+
+import einops
+import torch
+from omegaconf import OmegaConf
+from timm.layers import trunc_normal_
+from torch import nn
+
+from .utils import check_if_file_exists_else_download
+from .video_model_builder import VisionTransformer
+
+
+FILE2URL = {
+ # cfg
+ "motionformer_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml",
+ "joint_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml",
+ "divided_224_16x4.yaml": "https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml",
+ # ckpt
+ "ssv2_motionformer_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth",
+ "ssv2_joint_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth",
+ "ssv2_divided_224_16x4.pyth": "https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth",
+}
+
+
+class MotionFormer(VisionTransformer):
+ """This class serves three puposes:
+ 1. Renames the class to MotionFormer.
+ 2. Downloads the cfg from the original repo and patches it if needed.
+ 3. Takes care of feature extraction by redefining .forward()
+ - if `extract_features=True` and `factorize_space_time=False`,
+ the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
+ - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
+ and spatial and temporal transformer encoder layers are used.
+ - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
+ the output is of shape (B, D) and spatial and temporal transformer encoder layers
+ are used as well as the global representation is extracted from segments (extra pos emb
+ is added).
+ """
+
+ def __init__(
+ self,
+ extract_features: bool = False,
+ ckpt_path: str = None,
+ factorize_space_time: bool = None,
+ agg_space_module: str = None,
+ agg_time_module: str = None,
+ add_global_repr: bool = True,
+ agg_segments_module: str = None,
+ max_segments: int = None,
+ ):
+ self.extract_features = extract_features
+ self.ckpt_path = ckpt_path
+ self.factorize_space_time = factorize_space_time
+
+ if self.ckpt_path is not None:
+ check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
+ ckpt = torch.load(self.ckpt_path, map_location="cpu")
+ mformer_ckpt2cfg = {
+ "ssv2_motionformer_224_16x4.pyth": "motionformer_224_16x4.yaml",
+ "ssv2_joint_224_16x4.pyth": "joint_224_16x4.yaml",
+ "ssv2_divided_224_16x4.pyth": "divided_224_16x4.yaml",
+ }
+ # init from motionformer ckpt or from our Stage I ckpt
+ # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
+ # load the state dict differently
+ was_pt_on_avclip = self.ckpt_path.endswith(".pt") # checks if it is a stage I ckpt (FIXME: a bit generic)
+ if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
+ cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
+ elif was_pt_on_avclip:
+ # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
+ s1_cfg = ckpt.get("args", None) # Stage I cfg
+ if s1_cfg is not None:
+ s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
+ # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
+ if s1_vfeat_extractor_ckpt_path is not None:
+ cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
+ else:
+ cfg_fname = "divided_224_16x4.yaml"
+ else:
+ cfg_fname = "divided_224_16x4.yaml"
+ else:
+ raise ValueError(f"ckpt_path {self.ckpt_path} is not supported.")
+ else:
+ was_pt_on_avclip = False
+ cfg_fname = "divided_224_16x4.yaml"
+ # logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
+
+ if cfg_fname in ["motionformer_224_16x4.yaml", "divided_224_16x4.yaml"]:
+ pos_emb_type = "separate"
+ elif cfg_fname == "joint_224_16x4.yaml":
+ pos_emb_type = "joint"
+
+ self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
+
+ check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
+ mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
+ logging.info(f"Loading MotionFormer config from {self.mformer_cfg_path.absolute()}")
+
+ # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
+ mformer_cfg.VIT.ATTN_DROPOUT = 0.0
+ mformer_cfg.VIT.POS_EMBED = pos_emb_type
+ mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
+ mformer_cfg.VIT.APPROX_ATTN_TYPE = "none" # guessing
+ mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
+
+ # finally init VisionTransformer with the cfg
+ super().__init__(mformer_cfg)
+
+ # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
+ if (self.ckpt_path is not None) and (not was_pt_on_avclip):
+ _ckpt_load_status = self.load_state_dict(ckpt["model_state"], strict=False)
+ if len(_ckpt_load_status.missing_keys) > 0 or len(_ckpt_load_status.unexpected_keys) > 0:
+ logging.warning(
+ f"Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed."
+ f"Missing keys: {_ckpt_load_status.missing_keys}, "
+ f"Unexpected keys: {_ckpt_load_status.unexpected_keys}"
+ )
+ else:
+ logging.info(f"Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.")
+
+ if self.extract_features:
+ assert isinstance(self.norm, nn.LayerNorm), "early x[:, 1:, :] may not be safe for per-tr weights"
+ # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
+ self.pre_logits = nn.Identity()
+ # we don't need the classification head (saving memory)
+ self.head = nn.Identity()
+ self.head_drop = nn.Identity()
+ # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
+ transf_enc_layer_kwargs = dict(
+ d_model=self.embed_dim,
+ nhead=self.num_heads,
+ activation=nn.GELU(),
+ batch_first=True,
+ dim_feedforward=self.mlp_ratio * self.embed_dim,
+ dropout=self.drop_rate,
+ layer_norm_eps=1e-6,
+ norm_first=True,
+ )
+ # define adapters if needed
+ if self.factorize_space_time:
+ if agg_space_module == "TransformerEncoderLayer":
+ self.spatial_attn_agg = SpatialTransformerEncoderLayer(**transf_enc_layer_kwargs)
+ elif agg_space_module == "AveragePooling":
+ self.spatial_attn_agg = AveragePooling(
+ avg_pattern="BS D t h w -> BS D t", then_permute_pattern="BS D t -> BS t D"
+ )
+ if agg_time_module == "TransformerEncoderLayer":
+ self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
+ elif agg_time_module == "AveragePooling":
+ self.temp_attn_agg = AveragePooling(avg_pattern="BS t D -> BS D")
+ elif "Identity" in agg_time_module:
+ self.temp_attn_agg = nn.Identity()
+ # define a global aggregation layer (aggregarate over segments)
+ self.add_global_repr = add_global_repr
+ if add_global_repr:
+ if agg_segments_module == "TransformerEncoderLayer":
+ # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
+ # we need to add pos emb (PE) because previously we added the same PE for each segment
+ pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
+ self.global_attn_agg = TemporalTransformerEncoderLayer(
+ add_pos_emb=True,
+ pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
+ pos_max_len=pos_max_len,
+ **transf_enc_layer_kwargs,
+ )
+ elif agg_segments_module == "AveragePooling":
+ self.global_attn_agg = AveragePooling(avg_pattern="B S D -> B D")
+
+ if was_pt_on_avclip:
+ # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
+ # and keep only the state_dict of the feat extractor
+ ckpt_weights = dict()
+ for k, v in ckpt["state_dict"].items():
+ if k.startswith(("module.v_encoder.", "v_encoder.")):
+ k = k.replace("module.", "").replace("v_encoder.", "")
+ ckpt_weights[k] = v
+ _load_status = self.load_state_dict(ckpt_weights, strict=False)
+ if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
+ logging.warning(
+ f"Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n"
+ f"Missing keys ({len(_load_status.missing_keys)}): "
+ f"{_load_status.missing_keys}, \n"
+ f"Unexpected keys ({len(_load_status.unexpected_keys)}): "
+ f"{_load_status.unexpected_keys} \n"
+ f"temp_attn_agg are expected to be missing if ckpt was pt contrastively."
+ )
+ else:
+ logging.info(f"Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.")
+
+ # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
+ # but it used to calculate the number of patches, so we need to set keep it
+ self.patch_embed.requires_grad_(False)
+
+ def forward(self, x):
+ """
+ x is of shape (B, S, C, T, H, W) where S is the number of segments.
+ """
+ # Batch, Segments, Channels, T=frames, Height, Width
+ B, S, C, T, H, W = x.shape
+ # Motionformer expects a tensor of shape (1, B, C, T, H, W).
+ # The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
+ # see `video_model_builder.video_input`.
+ # x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
+
+ orig_shape = (B, S, C, T, H, W)
+ x = x.view(B * S, C, T, H, W) # flatten batch and segments
+ x = self.forward_segments(x, orig_shape=orig_shape)
+ # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
+ x = x.view(B, S, *x.shape[1:])
+ # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
+
+ return x # x is (B, S, ...)
+
+ def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
+ """x is of shape (1, BS, C, T, H, W) where S is the number of segments."""
+ x, x_mask = self.forward_features(x)
+
+ assert self.extract_features
+
+ # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
+ x = x[:, 1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
+ x = self.norm(x)
+ x = self.pre_logits(x)
+ if self.factorize_space_time:
+ x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
+
+ x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
+ x = self.temp_attn_agg(x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
+
+ return x
+
+ def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
+ """
+ feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
+ Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
+ From `self.patch_embed_3d`, it follows that we could reshape feats with:
+ `feats.transpose(1, 2).view(B*S, D, t, h, w)`
+ """
+ B, S, C, T, H, W = orig_shape
+ D = self.embed_dim
+
+ # num patches in each dimension
+ t = T // self.patch_embed_3d.z_block_size
+ h = self.patch_embed_3d.height
+ w = self.patch_embed_3d.width
+
+ feats = feats.permute(0, 2, 1) # (B*S, D, T)
+ feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
+
+ return feats
+
+
+class BaseEncoderLayer(nn.TransformerEncoderLayer):
+ """
+ This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
+ to the sequence and outputs the CLS token's representation.
+ This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
+ and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
+ We also, optionally, add a positional embedding to the input sequence which
+ allows to reuse it for global aggregation (of segments) for both streams.
+ """
+
+ def __init__(
+ self,
+ add_pos_emb: bool = False,
+ pos_emb_drop: float = None,
+ pos_max_len: int = None,
+ *args_transformer_enc,
+ **kwargs_transformer_enc,
+ ):
+ super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
+ trunc_normal_(self.cls_token, std=0.02)
+
+ # add positional embedding
+ self.add_pos_emb = add_pos_emb
+ if add_pos_emb:
+ self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
+ self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
+ self.pos_drop = nn.Dropout(pos_emb_drop)
+ trunc_normal_(self.pos_emb, std=0.02)
+
+ self.apply(self._init_weights)
+
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
+ """x is of shape (B, N, D); if provided x_mask is of shape (B, N)"""
+ batch_dim = x.shape[0]
+
+ # add CLS token
+ cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
+ x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
+ if x_mask is not None:
+ cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, device=x_mask.device) # 1=keep; 0=mask
+ x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
+ B, N = x_mask_w_cls.shape
+ # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
+ x_mask_w_cls = (
+ x_mask_w_cls.reshape(B, 1, 1, N)
+ .expand(-1, self.self_attn.num_heads, N, -1)
+ .reshape(B * self.self_attn.num_heads, N, N)
+ )
+ assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, "x_mask_w_cls.dtype != bool"
+ x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
+ else:
+ x_mask_w_cls = None
+
+ # add positional embedding
+ if self.add_pos_emb:
+ seq_len = x.shape[1] # (don't even think about moving it before the CLS token concatenation)
+ assert seq_len <= self.pos_max_len, f"Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})"
+ x = x + self.pos_emb[:, :seq_len, :]
+ x = self.pos_drop(x)
+
+ # apply encoder layer (calls nn.TransformerEncoderLayer.forward);
+ x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
+
+ # CLS token is expected to hold spatial information for each frame
+ x = x[:, 0, :] # (batch_dim, D)
+
+ return x
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"cls_token", "pos_emb"}
+
+
+class SpatialTransformerEncoderLayer(BaseEncoderLayer):
+ """Aggregates spatial dimensions by applying attention individually to each frame."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
+ """x is of shape (B*S, D, t, h, w) where S is the number of segments.
+ if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
+ Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame."""
+ BS, D, t, h, w = x.shape
+
+ # time as a batch dimension and flatten spatial dimensions as sequence
+ x = einops.rearrange(x, "BS D t h w -> (BS t) (h w) D")
+ # similar to mask
+ if x_mask is not None:
+ x_mask = einops.rearrange(x_mask, "BS t h w -> (BS t) (h w)")
+
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
+ x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
+
+ # reshape back to (B*S, t, D)
+ x = einops.rearrange(x, "(BS t) D -> BS t D", BS=BS, t=t)
+
+ # (B*S, t, D)
+ return x
+
+
+class TemporalTransformerEncoderLayer(BaseEncoderLayer):
+ """Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
+ in both streams."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ """x is of shape (B*S, t, D) where S is the number of segments.
+ Returns a tensor of shape (B*S, D) pooling temporal information."""
+ BS, t, D = x.shape
+
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
+ x = super().forward(x) # (B*S, D)
+
+ return x # (B*S, D)
+
+
+class AveragePooling(nn.Module):
+
+ def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
+ """patterns are e.g. "bs t d -> bs d" """
+ super().__init__()
+ # TODO: need to register them as buffers (but fails because these are strings)
+ self.reduce_fn = "mean"
+ self.avg_pattern = avg_pattern
+ self.then_permute_pattern = then_permute_pattern
+
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
+ x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
+ if self.then_permute_pattern is not None:
+ x = einops.rearrange(x, self.then_permute_pattern)
+ return x
diff --git a/hunyuanvideo_foley/models/synchformer/synchformer.py b/hunyuanvideo_foley/models/synchformer/synchformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1238bbc68dc451a56121ba7ab1fc00aa290420c3
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/synchformer.py
@@ -0,0 +1,355 @@
+import logging
+import math
+from typing import Any, Mapping
+
+import einops
+import numpy as np
+import torch
+import torchaudio
+from torch import nn
+from torch.nn import functional as F
+
+from .motionformer import MotionFormer
+from .ast_model import AST
+from .utils import Config
+
+
+class Synchformer(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ self.vfeat_extractor = MotionFormer(
+ extract_features=True,
+ factorize_space_time=True,
+ agg_space_module="TransformerEncoderLayer",
+ agg_time_module="torch.nn.Identity",
+ add_global_repr=False,
+ )
+ self.afeat_extractor = AST(
+ extract_features=True,
+ max_spec_t=66,
+ factorize_freq_time=True,
+ agg_freq_module="TransformerEncoderLayer",
+ agg_time_module="torch.nn.Identity",
+ add_global_repr=False,
+ )
+
+ # # bridging the s3d latent dim (1024) into what is specified in the config
+ # # to match e.g. the transformer dim
+ self.vproj = nn.Linear(in_features=768, out_features=768)
+ self.aproj = nn.Linear(in_features=768, out_features=768)
+ self.transformer = GlobalTransformer(
+ tok_pdrop=0.0, embd_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, n_layer=3, n_head=8, n_embd=768
+ )
+
+ def forward(self, vis):
+ B, S, Tv, C, H, W = vis.shape
+ vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
+ # feat extractors return a tuple of segment-level and global features (ignored for sync)
+ # (B, S, tv, D), e.g. (B, 7, 8, 768)
+ vis = self.vfeat_extractor(vis)
+ return vis
+
+ def compare_v_a(self, vis: torch.Tensor, aud: torch.Tensor):
+ vis = self.vproj(vis)
+ aud = self.aproj(aud)
+
+ B, S, tv, D = vis.shape
+ B, S, ta, D = aud.shape
+ vis = vis.view(B, S * tv, D) # (B, S*tv, D)
+ aud = aud.view(B, S * ta, D) # (B, S*ta, D)
+ # print(vis.shape, aud.shape)
+
+ # self.transformer will concatenate the vis and aud in one sequence with aux tokens,
+ # ie `CvvvvMaaaaaa`, and will return the logits for the CLS tokens
+ logits = self.transformer(vis, aud) # (B, cls); or (B, cls) and (B, 2) if DoubtingTransformer
+
+ return logits
+
+ def extract_vfeats(self, vis):
+ B, S, Tv, C, H, W = vis.shape
+ vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
+ # feat extractors return a tuple of segment-level and global features (ignored for sync)
+ # (B, S, tv, D), e.g. (B, 7, 8, 768)
+ vis = self.vfeat_extractor(vis)
+ return vis
+
+ def extract_afeats(self, aud):
+ B, S, _, Fa, Ta = aud.shape
+ aud = aud.view(B, S, Fa, Ta).permute(0, 1, 3, 2) # (B, S, Ta, F)
+ # (B, S, ta, D), e.g. (B, 7, 6, 768)
+ aud, _ = self.afeat_extractor(aud)
+ return aud
+
+ def compute_loss(self, logits, targets, loss_fn: str = None):
+ loss = None
+ if targets is not None:
+ if loss_fn is None or loss_fn == "cross_entropy":
+ # logits: (B, cls) and targets: (B,)
+ loss = F.cross_entropy(logits, targets)
+ else:
+ raise NotImplementedError(f"Loss {loss_fn} not implemented")
+ return loss
+
+ def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
+ # discard all entries except vfeat_extractor
+ # sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
+
+ return super().load_state_dict(sd, strict)
+
+
+class RandInitPositionalEncoding(nn.Module):
+ """Random inited trainable pos embedding. It is just applied on the sequence, thus respects no priors."""
+
+ def __init__(self, block_shape: list, n_embd: int):
+ super().__init__()
+ self.block_shape = block_shape
+ self.n_embd = n_embd
+ self.pos_emb = nn.Parameter(torch.randn(1, *block_shape, n_embd))
+
+ def forward(self, token_embeddings):
+ return token_embeddings + self.pos_emb
+
+
+class GlobalTransformer(torch.nn.Module):
+ """Same as in SparseSync but without the selector transformers and the head"""
+
+ def __init__(
+ self,
+ tok_pdrop=0.0,
+ embd_pdrop=0.1,
+ resid_pdrop=0.1,
+ attn_pdrop=0.1,
+ n_layer=3,
+ n_head=8,
+ n_embd=768,
+ pos_emb_block_shape=[
+ 198,
+ ],
+ n_off_head_out=21,
+ ) -> None:
+ super().__init__()
+ self.config = Config(
+ embd_pdrop=embd_pdrop,
+ resid_pdrop=resid_pdrop,
+ attn_pdrop=attn_pdrop,
+ n_layer=n_layer,
+ n_head=n_head,
+ n_embd=n_embd,
+ )
+ # input norm
+ self.vis_in_lnorm = torch.nn.LayerNorm(n_embd)
+ self.aud_in_lnorm = torch.nn.LayerNorm(n_embd)
+ # aux tokens
+ self.OFF_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
+ self.MOD_tok = torch.nn.Parameter(torch.randn(1, 1, n_embd))
+ # whole token dropout
+ self.tok_pdrop = tok_pdrop
+ self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop)
+ self.tok_drop_aud = torch.nn.Dropout1d(tok_pdrop)
+ # maybe add pos emb
+ self.pos_emb_cfg = RandInitPositionalEncoding(
+ block_shape=pos_emb_block_shape,
+ n_embd=n_embd,
+ )
+ # the stem
+ self.drop = torch.nn.Dropout(embd_pdrop)
+ self.blocks = torch.nn.Sequential(*[Block(self.config) for _ in range(n_layer)])
+ # pre-output norm
+ self.ln_f = torch.nn.LayerNorm(n_embd)
+ # maybe add a head
+ self.off_head = torch.nn.Linear(in_features=n_embd, out_features=n_off_head_out)
+
+ def forward(self, v: torch.Tensor, a: torch.Tensor, targets=None, attempt_to_apply_heads=True):
+ B, Sv, D = v.shape
+ B, Sa, D = a.shape
+ # broadcasting special tokens to the batch size
+ off_tok = einops.repeat(self.OFF_tok, "1 1 d -> b 1 d", b=B)
+ mod_tok = einops.repeat(self.MOD_tok, "1 1 d -> b 1 d", b=B)
+ # norm
+ v, a = self.vis_in_lnorm(v), self.aud_in_lnorm(a)
+ # maybe whole token dropout
+ if self.tok_pdrop > 0:
+ v, a = self.tok_drop_vis(v), self.tok_drop_aud(a)
+ # (B, 1+Sv+1+Sa, D)
+ x = torch.cat((off_tok, v, mod_tok, a), dim=1)
+ # maybe add pos emb
+ if hasattr(self, "pos_emb_cfg"):
+ x = self.pos_emb_cfg(x)
+ # dropout -> stem -> norm
+ x = self.drop(x)
+ x = self.blocks(x)
+ x = self.ln_f(x)
+ # maybe add heads
+ if attempt_to_apply_heads and hasattr(self, "off_head"):
+ x = self.off_head(x[:, 0, :])
+ return x
+
+
+class SelfAttention(nn.Module):
+ """
+ A vanilla multi-head masked self-attention layer with a projection at the end.
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
+ explicit implementation here to show that there is nothing too scary here.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+ # # causal mask to ensure that attention is only applied to the left in the input sequence
+ # mask = torch.tril(torch.ones(config.block_size,
+ # config.block_size))
+ # if hasattr(config, "n_unmasked"):
+ # mask[:config.n_unmasked, :config.n_unmasked] = 1
+ # self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
+ self.n_head = config.n_head
+
+ def forward(self, x):
+ B, T, C = x.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+ # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+ # att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
+ att = F.softmax(att, dim=-1)
+ y = self.attn_drop(att) @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+
+ # output projection
+ y = self.resid_drop(self.proj(y))
+
+ return y
+
+
+class Block(nn.Module):
+ """an unassuming Transformer block"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = SelfAttention(config)
+ self.mlp = nn.Sequential(
+ nn.Linear(config.n_embd, 4 * config.n_embd),
+ nn.GELU(), # nice
+ nn.Linear(4 * config.n_embd, config.n_embd),
+ nn.Dropout(config.resid_pdrop),
+ )
+
+ def forward(self, x):
+ x = x + self.attn(self.ln1(x))
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+def make_class_grid(
+ leftmost_val,
+ rightmost_val,
+ grid_size,
+ add_extreme_offset: bool = False,
+ seg_size_vframes: int = None,
+ nseg: int = None,
+ step_size_seg: float = None,
+ vfps: float = None,
+):
+ assert grid_size >= 3, f"grid_size: {grid_size} doesnot make sense. If =2 -> (-1,1); =1 -> (-1); =0 -> ()"
+ grid = torch.from_numpy(np.linspace(leftmost_val, rightmost_val, grid_size)).float()
+ if add_extreme_offset:
+ assert all([seg_size_vframes, nseg, step_size_seg]), f"{seg_size_vframes} {nseg} {step_size_seg}"
+ seg_size_sec = seg_size_vframes / vfps
+ trim_size_in_seg = nseg - (1 - step_size_seg) * (nseg - 1)
+ extreme_value = trim_size_in_seg * seg_size_sec
+ grid = torch.cat([grid, torch.tensor([extreme_value])]) # adding extreme offset to the class grid
+ return grid
+
+
+# from synchformer
+def pad_or_truncate(audio: torch.Tensor, max_spec_t: int, pad_mode: str = "constant", pad_value: float = 0.0):
+ difference = max_spec_t - audio.shape[-1] # safe for batched input
+ # pad or truncate, depending on difference
+ if difference > 0:
+ # pad the last dim (time) -> (..., n_mels, 0+time+difference) # safe for batched input
+ pad_dims = (0, difference)
+ audio = torch.nn.functional.pad(audio, pad_dims, pad_mode, pad_value)
+ elif difference < 0:
+ print(f"Truncating spec ({audio.shape}) to max_spec_t ({max_spec_t}).")
+ audio = audio[..., :max_spec_t] # safe for batched input
+ return audio
+
+
+def encode_audio_with_sync(
+ synchformer: Synchformer, x: torch.Tensor, mel: torchaudio.transforms.MelSpectrogram
+) -> torch.Tensor:
+ b, t = x.shape
+
+ # partition the video
+ segment_size = 10240
+ step_size = 10240 // 2
+ num_segments = (t - segment_size) // step_size + 1
+ segments = []
+ for i in range(num_segments):
+ segments.append(x[:, i * step_size : i * step_size + segment_size])
+ x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
+
+ x = mel(x)
+ x = torch.log(x + 1e-6)
+ x = pad_or_truncate(x, 66)
+
+ mean = -4.2677393
+ std = 4.5689974
+ x = (x - mean) / (2 * std)
+ # x: B * S * 128 * 66
+ x = synchformer.extract_afeats(x.unsqueeze(2))
+ return x
+
+
+def read_audio(filename, expected_length=int(16000 * 4)):
+ waveform, sr = torchaudio.load(filename)
+ waveform = waveform.mean(dim=0)
+
+ if sr != 16000:
+ resampler = torchaudio.transforms.Resample(sr, 16000)
+ waveform = resampler[sr](waveform)
+
+ waveform = waveform[:expected_length]
+ if waveform.shape[0] != expected_length:
+ raise ValueError(f"Audio {filename} is too short")
+
+ waveform = waveform.squeeze()
+
+ return waveform
+
+
+if __name__ == "__main__":
+ synchformer = Synchformer().cuda().eval()
+
+ # mmaudio provided synchformer ckpt
+ synchformer.load_state_dict(
+ torch.load(
+ os.environ.get("SYNCHFORMER_WEIGHTS", f"weights/synchformer.pth"),
+ weights_only=True,
+ map_location="cpu",
+ )
+ )
+
+ sync_mel_spectrogram = torchaudio.transforms.MelSpectrogram(
+ sample_rate=16000,
+ win_length=400,
+ hop_length=160,
+ n_fft=1024,
+ n_mels=128,
+ )
diff --git a/hunyuanvideo_foley/models/synchformer/utils.py b/hunyuanvideo_foley/models/synchformer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..05595cc15b925f52ccd07fea8f131ec810f56bd7
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/utils.py
@@ -0,0 +1,87 @@
+from hashlib import md5
+from pathlib import Path
+import subprocess
+
+import requests
+from tqdm import tqdm
+
+PARENT_LINK = "https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a"
+FNAME2LINK = {
+ # S3: Synchability: AudioSet (run 2)
+ "24-01-22T20-34-52.pt": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt",
+ "cfg-24-01-22T20-34-52.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml",
+ # S2: Synchformer: AudioSet (run 2)
+ "24-01-04T16-39-21.pt": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt",
+ "cfg-24-01-04T16-39-21.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml",
+ # S2: Synchformer: AudioSet (run 1)
+ "23-08-28T11-23-23.pt": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt",
+ "cfg-23-08-28T11-23-23.yaml": f"{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml",
+ # S2: Synchformer: LRS3 (run 2)
+ "23-12-23T18-33-57.pt": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt",
+ "cfg-23-12-23T18-33-57.yaml": f"{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml",
+ # S2: Synchformer: VGS (run 2)
+ "24-01-02T10-00-53.pt": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt",
+ "cfg-24-01-02T10-00-53.yaml": f"{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml",
+ # SparseSync: ft VGGSound-Full
+ "22-09-21T21-00-52.pt": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt",
+ "cfg-22-09-21T21-00-52.yaml": f"{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml",
+ # SparseSync: ft VGGSound-Sparse
+ "22-07-28T15-49-45.pt": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt",
+ "cfg-22-07-28T15-49-45.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml",
+ # SparseSync: only pt on LRS3
+ "22-07-13T22-25-49.pt": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt",
+ "cfg-22-07-13T22-25-49.yaml": f"{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml",
+ # SparseSync: feature extractors
+ "ResNetAudio-22-08-04T09-51-04.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt", # 2s
+ "ResNetAudio-22-08-03T23-14-49.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt", # 3s
+ "ResNetAudio-22-08-03T23-14-28.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt", # 4s
+ "ResNetAudio-22-06-24T08-10-33.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt", # 5s
+ "ResNetAudio-22-06-24T17-31-07.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt", # 6s
+ "ResNetAudio-22-06-24T23-57-11.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt", # 7s
+ "ResNetAudio-22-06-25T04-35-42.pt": f"{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt", # 8s
+}
+
+
+def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
+ """Checks if file exists, if not downloads it from the link to the path"""
+ path = Path(path)
+ if not path.exists():
+ path.parent.mkdir(exist_ok=True, parents=True)
+ link = fname2link.get(path.name, None)
+ if link is None:
+ raise ValueError(
+ f"Cant find the checkpoint file: {path}.", f"Please download it manually and ensure the path exists."
+ )
+ with requests.get(fname2link[path.name], stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def which_ffmpeg() -> str:
+ """Determines the path to ffmpeg library
+ Returns:
+ str -- path to the library
+ """
+ result = subprocess.run(["which", "ffmpeg"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ ffmpeg_path = result.stdout.decode("utf-8").replace("\n", "")
+ return ffmpeg_path
+
+
+def get_md5sum(path):
+ hash_md5 = md5()
+ with open(path, "rb") as f:
+ for chunk in iter(lambda: f.read(4096 * 8), b""):
+ hash_md5.update(chunk)
+ md5sum = hash_md5.hexdigest()
+ return md5sum
+
+
+class Config:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
diff --git a/hunyuanvideo_foley/models/synchformer/video_model_builder.py b/hunyuanvideo_foley/models/synchformer/video_model_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..190df1a5f066c2c06ab41178fc1174c7956bc599
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/video_model_builder.py
@@ -0,0 +1,270 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# Copyright 2020 Ross Wightman
+# Modified Model definition
+
+from collections import OrderedDict
+from functools import partial
+
+import torch
+import torch.nn as nn
+from timm.layers import trunc_normal_
+
+from .vit_helper import PatchEmbed, PatchEmbed3D, DividedSpaceTimeBlock
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
+
+ def __init__(self, cfg):
+ super().__init__()
+ self.img_size = cfg.DATA.TRAIN_CROP_SIZE
+ self.patch_size = cfg.VIT.PATCH_SIZE
+ self.in_chans = cfg.VIT.CHANNELS
+ if cfg.TRAIN.DATASET == "Epickitchens":
+ self.num_classes = [97, 300]
+ else:
+ self.num_classes = cfg.MODEL.NUM_CLASSES
+ self.embed_dim = cfg.VIT.EMBED_DIM
+ self.depth = cfg.VIT.DEPTH
+ self.num_heads = cfg.VIT.NUM_HEADS
+ self.mlp_ratio = cfg.VIT.MLP_RATIO
+ self.qkv_bias = cfg.VIT.QKV_BIAS
+ self.drop_rate = cfg.VIT.DROP
+ self.drop_path_rate = cfg.VIT.DROP_PATH
+ self.head_dropout = cfg.VIT.HEAD_DROPOUT
+ self.video_input = cfg.VIT.VIDEO_INPUT
+ self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
+ self.use_mlp = cfg.VIT.USE_MLP
+ self.num_features = self.embed_dim
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
+ self.head_act = cfg.VIT.HEAD_ACT
+ self.cfg = cfg
+
+ # Patch Embedding
+ self.patch_embed = PatchEmbed(
+ img_size=224, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim
+ )
+
+ # 3D Patch Embedding
+ self.patch_embed_3d = PatchEmbed3D(
+ img_size=self.img_size,
+ temporal_resolution=self.temporal_resolution,
+ patch_size=self.patch_size,
+ in_chans=self.in_chans,
+ embed_dim=self.embed_dim,
+ z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP,
+ )
+ self.patch_embed_3d.proj.weight.data = torch.zeros_like(self.patch_embed_3d.proj.weight.data)
+
+ # Number of patches
+ if self.video_input:
+ num_patches = self.patch_embed.num_patches * self.temporal_resolution
+ else:
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+
+ # CLS token
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ trunc_normal_(self.cls_token, std=0.02)
+
+ # Positional embedding
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
+ self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ if self.cfg.VIT.POS_EMBED == "joint":
+ self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
+ trunc_normal_(self.st_embed, std=0.02)
+ elif self.cfg.VIT.POS_EMBED == "separate":
+ self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
+
+ # Layer Blocks
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
+ if self.cfg.VIT.ATTN_LAYER == "divided":
+ self.blocks = nn.ModuleList(
+ [
+ DividedSpaceTimeBlock(
+ attn_type=cfg.VIT.ATTN_LAYER,
+ dim=self.embed_dim,
+ num_heads=self.num_heads,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ drop=self.drop_rate,
+ attn_drop=self.attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ )
+ for i in range(self.depth)
+ ]
+ )
+
+ self.norm = norm_layer(self.embed_dim)
+
+ # MLP head
+ if self.use_mlp:
+ hidden_dim = self.embed_dim
+ if self.head_act == "tanh":
+ # logging.info("Using TanH activation in MLP")
+ act = nn.Tanh()
+ elif self.head_act == "gelu":
+ # logging.info("Using GELU activation in MLP")
+ act = nn.GELU()
+ else:
+ # logging.info("Using ReLU activation in MLP")
+ act = nn.ReLU()
+ self.pre_logits = nn.Sequential(
+ OrderedDict(
+ [
+ ("fc", nn.Linear(self.embed_dim, hidden_dim)),
+ ("act", act),
+ ]
+ )
+ )
+ else:
+ self.pre_logits = nn.Identity()
+
+ # Classifier Head
+ self.head_drop = nn.Dropout(p=self.head_dropout)
+ if isinstance(self.num_classes, (list,)) and len(self.num_classes) > 1:
+ for a, i in enumerate(range(len(self.num_classes))):
+ setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
+ else:
+ self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
+
+ # Initialize weights
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ if self.cfg.VIT.POS_EMBED == "joint":
+ return {"pos_embed", "cls_token", "st_embed"}
+ else:
+ return {"pos_embed", "cls_token", "temp_embed"}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=""):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ # if self.video_input:
+ # x = x[0]
+ B = x.shape[0]
+
+ # Tokenize input
+ # if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
+ # for simplicity of mapping between content dimensions (input x) and token dims (after patching)
+ # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
+
+ # apply patching on input
+ x = self.patch_embed_3d(x)
+ tok_mask = None
+
+ # else:
+ # tok_mask = None
+ # # 2D tokenization
+ # if self.video_input:
+ # x = x.permute(0, 2, 1, 3, 4)
+ # (B, T, C, H, W) = x.shape
+ # x = x.reshape(B * T, C, H, W)
+
+ # x = self.patch_embed(x)
+
+ # if self.video_input:
+ # (B2, T2, D2) = x.shape
+ # x = x.reshape(B, T * T2, D2)
+
+ # Append CLS token
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ # if tok_mask is not None:
+ # # prepend 1(=keep) to the mask to account for the CLS token as well
+ # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
+
+ # Interpolate positinoal embeddings
+ # if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
+ # pos_embed = self.pos_embed
+ # N = pos_embed.shape[1] - 1
+ # npatch = int((x.size(1) - 1) / self.temporal_resolution)
+ # class_emb = pos_embed[:, 0]
+ # pos_embed = pos_embed[:, 1:]
+ # dim = x.shape[-1]
+ # pos_embed = torch.nn.functional.interpolate(
+ # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ # scale_factor=math.sqrt(npatch / N),
+ # mode='bicubic',
+ # )
+ # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
+ # else:
+ new_pos_embed = self.pos_embed
+ npatch = self.patch_embed.num_patches
+
+ # Add positional embeddings to input
+ if self.video_input:
+ if self.cfg.VIT.POS_EMBED == "separate":
+ cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
+ tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
+ tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
+ total_pos_embed = tile_pos_embed + tile_temporal_embed
+ total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
+ x = x + total_pos_embed
+ elif self.cfg.VIT.POS_EMBED == "joint":
+ x = x + self.st_embed
+ else:
+ # image input
+ x = x + new_pos_embed
+
+ # Apply positional dropout
+ x = self.pos_drop(x)
+
+ # Encoding using transformer layers
+ for i, blk in enumerate(self.blocks):
+ x = blk(
+ x,
+ seq_len=npatch,
+ num_frames=self.temporal_resolution,
+ approx=self.cfg.VIT.APPROX_ATTN_TYPE,
+ num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
+ tok_mask=tok_mask,
+ )
+
+ ### v-iashin: I moved it to the forward pass
+ # x = self.norm(x)[:, 0]
+ # x = self.pre_logits(x)
+ ###
+ return x, tok_mask
+
+ # def forward(self, x):
+ # x = self.forward_features(x)
+ # ### v-iashin: here. This should leave the same forward output as before
+ # x = self.norm(x)[:, 0]
+ # x = self.pre_logits(x)
+ # ###
+ # x = self.head_drop(x)
+ # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
+ # output = []
+ # for head in range(len(self.num_classes)):
+ # x_out = getattr(self, "head%d" % head)(x)
+ # if not self.training:
+ # x_out = torch.nn.functional.softmax(x_out, dim=-1)
+ # output.append(x_out)
+ # return output
+ # else:
+ # x = self.head(x)
+ # if not self.training:
+ # x = torch.nn.functional.softmax(x, dim=-1)
+ # return x
diff --git a/hunyuanvideo_foley/models/synchformer/vit_helper.py b/hunyuanvideo_foley/models/synchformer/vit_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..29739530ce8692f3124b3ae748f11a4a06aa5fc8
--- /dev/null
+++ b/hunyuanvideo_foley/models/synchformer/vit_helper.py
@@ -0,0 +1,364 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# Copyright 2020 Ross Wightman
+# Modified Model definition
+"""Video models."""
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from timm.layers import to_2tuple
+from torch import einsum
+from torch.nn import functional as F
+
+default_cfgs = {
+ "vit_1k": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth",
+ "vit_1k_large": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth",
+}
+
+
+def qkv_attn(q, k, v, tok_mask: torch.Tensor = None):
+ sim = einsum("b i d, b j d -> b i j", q, k)
+ # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N)
+ if tok_mask is not None:
+ BSH, N = tok_mask.shape
+ sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, float("-inf")) # 1 - broadcasts across N
+ attn = sim.softmax(dim=-1)
+ out = einsum("b i j, b j d -> b i d", attn, v)
+ return out
+
+
+class DividedAttention(nn.Module):
+
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ # init to zeros
+ self.qkv.weight.data.fill_(0)
+ self.qkv.bias.data.fill_(0)
+ self.proj.weight.data.fill_(1)
+ self.proj.bias.data.fill_(0)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
+ # num of heads variable
+ h = self.num_heads
+
+ # project x to q, k, v vaalues
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
+ if tok_mask is not None:
+ # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d
+ assert len(tok_mask.shape) == 2
+ tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1])
+
+ # Scale q
+ q *= self.scale
+
+ # Take out cls_q, cls_k, cls_v
+ (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
+ # the same for masking
+ if tok_mask is not None:
+ cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:]
+ else:
+ cls_mask, mask_ = None, None
+
+ # let CLS token attend to key / values of all patches across time and space
+ cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask)
+
+ # rearrange across time or space
+ q_, k_, v_ = map(lambda t: rearrange(t, f"{einops_from} -> {einops_to}", **einops_dims), (q_, k_, v_))
+
+ # expand CLS token keys and values across time or space and concat
+ r = q_.shape[0] // cls_k.shape[0]
+ cls_k, cls_v = map(lambda t: repeat(t, "b () d -> (b r) () d", r=r), (cls_k, cls_v))
+
+ k_ = torch.cat((cls_k, k_), dim=1)
+ v_ = torch.cat((cls_v, v_), dim=1)
+
+ # the same for masking (if provided)
+ if tok_mask is not None:
+ # since mask does not have the latent dim (d), we need to remove it from einops dims
+ mask_ = rearrange(mask_, f"{einops_from} -> {einops_to}".replace(" d", ""), **einops_dims)
+ cls_mask = repeat(cls_mask, "b () -> (b r) ()", r=r) # expand cls_mask across time or space
+ mask_ = torch.cat((cls_mask, mask_), dim=1)
+
+ # attention
+ out = qkv_attn(q_, k_, v_, tok_mask=mask_)
+
+ # merge back time or space
+ out = rearrange(out, f"{einops_to} -> {einops_from}", **einops_dims)
+
+ # concat back the cls token
+ out = torch.cat((cls_out, out), dim=1)
+
+ # merge back the heads
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
+
+ ## to out
+ x = self.proj(out)
+ x = self.proj_drop(x)
+ return x
+
+
+class DividedSpaceTimeBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim=768,
+ num_heads=12,
+ attn_type="divided",
+ mlp_ratio=4.0,
+ qkv_bias=False,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+
+ self.einops_from_space = "b (f n) d"
+ self.einops_to_space = "(b f) n d"
+ self.einops_from_time = "b (f n) d"
+ self.einops_to_time = "(b n) f d"
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = DividedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+
+ self.timeattn = DividedAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
+ )
+
+ # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.drop_path = nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.norm3 = norm_layer(dim)
+
+ def forward(self, x, seq_len=196, num_frames=8, approx="none", num_landmarks=128, tok_mask: torch.Tensor = None):
+ time_output = self.timeattn(
+ self.norm3(x), self.einops_from_time, self.einops_to_time, n=seq_len, tok_mask=tok_mask
+ )
+ time_residual = x + time_output
+
+ space_output = self.attn(
+ self.norm1(time_residual), self.einops_from_space, self.einops_to_space, f=num_frames, tok_mask=tok_mask
+ )
+ space_residual = time_residual + self.drop_path(space_output)
+
+ x = space_residual
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Mlp(nn.Module):
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = img_size if type(img_size) is tuple else to_2tuple(img_size)
+ patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class PatchEmbed3D(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ img_size=224,
+ temporal_resolution=4,
+ in_chans=3,
+ patch_size=16,
+ z_block_size=2,
+ embed_dim=768,
+ flatten=True,
+ ):
+ super().__init__()
+ self.height = img_size // patch_size
+ self.width = img_size // patch_size
+ ### v-iashin: these two are incorrect
+ # self.frames = (temporal_resolution // z_block_size)
+ # self.num_patches = self.height * self.width * self.frames
+ self.z_block_size = z_block_size
+ ###
+ self.proj = nn.Conv3d(
+ in_chans,
+ embed_dim,
+ kernel_size=(z_block_size, patch_size, patch_size),
+ stride=(z_block_size, patch_size, patch_size),
+ )
+ self.flatten = flatten
+
+ def forward(self, x):
+ B, C, T, H, W = x.shape
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class HeadMLP(nn.Module):
+
+ def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
+ super(HeadMLP, self).__init__()
+ self.n_input = n_input
+ self.n_classes = n_classes
+ self.n_hidden = n_hidden
+ if n_hidden is None:
+ # use linear classifier
+ self.block_forward = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True))
+ else:
+ # use simple MLP classifier
+ self.block_forward = nn.Sequential(
+ nn.Dropout(p=p),
+ nn.Linear(n_input, n_hidden, bias=True),
+ nn.BatchNorm1d(n_hidden),
+ nn.ReLU(inplace=True),
+ nn.Dropout(p=p),
+ nn.Linear(n_hidden, n_classes, bias=True),
+ )
+ print(f"Dropout-NLP: {p}")
+
+ def forward(self, x):
+ return self.block_forward(x)
+
+
+def _conv_filter(state_dict, patch_size=16):
+ """convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ for k, v in state_dict.items():
+ if "patch_embed.proj.weight" in k:
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
+ out_dict[k] = v
+ return out_dict
+
+
+def adapt_input_conv(in_chans, conv_weight, agg="sum"):
+ conv_type = conv_weight.dtype
+ conv_weight = conv_weight.float()
+ O, I, J, K = conv_weight.shape
+ if in_chans == 1:
+ if I > 3:
+ assert conv_weight.shape[1] % 3 == 0
+ # For models with space2depth stems
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
+ else:
+ if agg == "sum":
+ print("Summing conv1 weights")
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
+ else:
+ print("Averaging conv1 weights")
+ conv_weight = conv_weight.mean(dim=1, keepdim=True)
+ elif in_chans != 3:
+ if I != 3:
+ raise NotImplementedError("Weight format not supported by conversion.")
+ else:
+ if agg == "sum":
+ print("Summing conv1 weights")
+ repeat = int(math.ceil(in_chans / 3))
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
+ conv_weight *= 3 / float(in_chans)
+ else:
+ print("Averaging conv1 weights")
+ conv_weight = conv_weight.mean(dim=1, keepdim=True)
+ conv_weight = conv_weight.repeat(1, in_chans, 1, 1)
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
+
+
+def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
+ # Load state dict
+ assert f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]"
+ state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS])
+
+ if filter_fn is not None:
+ state_dict = filter_fn(state_dict)
+
+ input_convs = "patch_embed.proj"
+ if input_convs is not None and in_chans != 3:
+ if isinstance(input_convs, str):
+ input_convs = (input_convs,)
+ for input_conv_name in input_convs:
+ weight_name = input_conv_name + ".weight"
+ try:
+ state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name], agg="avg")
+ print(f"Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)")
+ except NotImplementedError as e:
+ del state_dict[weight_name]
+ strict = False
+ print(f"Unable to convert pretrained {input_conv_name} weights, using random init for this layer.")
+
+ classifier_name = "head"
+ label_offset = cfg.get("label_offset", 0)
+ pretrain_classes = 1000
+ if num_classes != pretrain_classes:
+ # completely discard fully connected if model num_classes doesn't match pretrained weights
+ del state_dict[classifier_name + ".weight"]
+ del state_dict[classifier_name + ".bias"]
+ strict = False
+ elif label_offset > 0:
+ # special case for pretrained weights with an extra background class in pretrained weights
+ classifier_weight = state_dict[classifier_name + ".weight"]
+ state_dict[classifier_name + ".weight"] = classifier_weight[label_offset:]
+ classifier_bias = state_dict[classifier_name + ".bias"]
+ state_dict[classifier_name + ".bias"] = classifier_bias[label_offset:]
+
+ loaded_state = state_dict
+ self_state = model.state_dict()
+ all_names = set(self_state.keys())
+ saved_names = set([])
+ for name, param in loaded_state.items():
+ param = param
+ if "module." in name:
+ name = name.replace("module.", "")
+ if name in self_state.keys() and param.shape == self_state[name].shape:
+ saved_names.add(name)
+ self_state[name].copy_(param)
+ else:
+ print(f"didnt load: {name} of shape: {param.shape}")
+ print("Missing Keys:")
+ print(all_names - saved_names)
diff --git a/hunyuanvideo_foley/utils/__init__.py b/hunyuanvideo_foley/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ca77588c6a0f723bb634860c77fcb4cdcfd1ead
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3eb0c8455ff9e2fbd8c0b379a70280b0e5242439
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c5f88e71027e453b930a83e2ae0aadf23d3a057
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87840261b58868ed5c287cf825b36421a42ff394
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/config_utils.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e85af7b40960c6204595bf1fd83da5f2643f415e
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f50455f2bb1971c21da20d1979c966b8060d2d5
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/feature_utils.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/helper.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03bd596162c662e3b5e539562110e67f32ab1cfb
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7985f406d537b40f2a34c1318d08997eae4b060
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/helper.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92561d8239b90bcb3fdbd5e3bc7df74352d94bc6
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5d902564adfd96af1ac11ea48f17cbb184422b4
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/media_utils.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b04d192c9e90ea387d0404756e7a8bf131e24ef
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd5172c98502b6572776c6f4dbd4f6c90ab071dd
Binary files /dev/null and b/hunyuanvideo_foley/utils/__pycache__/model_utils.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/config_utils.py b/hunyuanvideo_foley/utils/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1bdf9108dfebf68c0a27cfdd675c8cc5e4c1d88
--- /dev/null
+++ b/hunyuanvideo_foley/utils/config_utils.py
@@ -0,0 +1,109 @@
+"""Configuration utilities for the HunyuanVideo-Foley project."""
+
+import yaml
+from pathlib import Path
+from typing import Any, Dict, List, Union
+
+class AttributeDict:
+
+ def __init__(self, data: Union[Dict, List, Any]):
+ if isinstance(data, dict):
+ for key, value in data.items():
+ if isinstance(value, (dict, list)):
+ value = AttributeDict(value)
+ setattr(self, self._sanitize_key(key), value)
+ elif isinstance(data, list):
+ self._list = [AttributeDict(item) if isinstance(item, (dict, list)) else item
+ for item in data]
+ else:
+ self._value = data
+
+ def _sanitize_key(self, key: str) -> str:
+ import re
+ sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', str(key))
+ if sanitized[0].isdigit():
+ sanitized = f'_{sanitized}'
+ return sanitized
+
+ def __getitem__(self, key):
+ if hasattr(self, '_list'):
+ return self._list[key]
+ return getattr(self, self._sanitize_key(key))
+
+ def __setitem__(self, key, value):
+ if hasattr(self, '_list'):
+ self._list[key] = value
+ else:
+ setattr(self, self._sanitize_key(key), value)
+
+ def __iter__(self):
+ if hasattr(self, '_list'):
+ return iter(self._list)
+ return iter(self.__dict__.keys())
+
+ def __len__(self):
+ if hasattr(self, '_list'):
+ return len(self._list)
+ return len(self.__dict__)
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except (KeyError, AttributeError, IndexError):
+ return default
+
+ def keys(self):
+ if hasattr(self, '_list'):
+ return range(len(self._list))
+ elif hasattr(self, '_value'):
+ return []
+ else:
+ return [key for key in self.__dict__.keys() if not key.startswith('_')]
+
+ def values(self):
+ if hasattr(self, '_list'):
+ return self._list
+ elif hasattr(self, '_value'):
+ return [self._value]
+ else:
+ return [value for key, value in self.__dict__.items() if not key.startswith('_')]
+
+ def items(self):
+ if hasattr(self, '_list'):
+ return enumerate(self._list)
+ elif hasattr(self, '_value'):
+ return []
+ else:
+ return [(key, value) for key, value in self.__dict__.items() if not key.startswith('_')]
+
+ def __repr__(self):
+ if hasattr(self, '_list'):
+ return f"AttributeDict({self._list})"
+ elif hasattr(self, '_value'):
+ return f"AttributeDict({self._value})"
+ return f"AttributeDict({dict(self.__dict__)})"
+
+ def to_dict(self) -> Union[Dict, List, Any]:
+ if hasattr(self, '_list'):
+ return [item.to_dict() if isinstance(item, AttributeDict) else item
+ for item in self._list]
+ elif hasattr(self, '_value'):
+ return self._value
+ else:
+ result = {}
+ for key, value in self.__dict__.items():
+ if isinstance(value, AttributeDict):
+ result[key] = value.to_dict()
+ else:
+ result[key] = value
+ return result
+
+def load_yaml(file_path: str, encoding: str = 'utf-8') -> AttributeDict:
+ try:
+ with open(file_path, 'r', encoding=encoding) as file:
+ data = yaml.safe_load(file)
+ return AttributeDict(data)
+ except FileNotFoundError:
+ raise FileNotFoundError(f"YAML file not found: {file_path}")
+ except yaml.YAMLError as e:
+ raise yaml.YAMLError(f"YAML format error: {e}")
diff --git a/hunyuanvideo_foley/utils/feature_utils.py b/hunyuanvideo_foley/utils/feature_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..39f474479fdac18000e89d4db0b5211e35dffe23
--- /dev/null
+++ b/hunyuanvideo_foley/utils/feature_utils.py
@@ -0,0 +1,156 @@
+"""Feature extraction utilities for video and text processing."""
+
+import os
+import numpy as np
+import torch
+import av
+from PIL import Image
+from einops import rearrange
+from typing import Any, Dict, List, Union, Tuple
+from loguru import logger
+
+from .config_utils import AttributeDict
+from ..constants import FPS_VISUAL, MAX_VIDEO_DURATION_SECONDS
+
+
+class FeatureExtractionError(Exception):
+ """Exception raised for feature extraction errors."""
+ pass
+
+def get_frames_av(
+ video_path: str,
+ fps: float,
+ max_length: float = None,
+) -> Tuple[np.ndarray, float]:
+ end_sec = max_length if max_length is not None else 15
+ next_frame_time_for_each_fps = 0.0
+ time_delta_for_each_fps = 1 / fps
+
+ all_frames = []
+ output_frames = []
+
+ with av.open(video_path) as container:
+ stream = container.streams.video[0]
+ ori_fps = stream.guessed_rate
+ stream.thread_type = "AUTO"
+ for packet in container.demux(stream):
+ for frame in packet.decode():
+ frame_time = frame.time
+ if frame_time < 0:
+ continue
+ if frame_time > end_sec:
+ break
+
+ frame_np = None
+
+ this_time = frame_time
+ while this_time >= next_frame_time_for_each_fps:
+ if frame_np is None:
+ frame_np = frame.to_ndarray(format="rgb24")
+
+ output_frames.append(frame_np)
+ next_frame_time_for_each_fps += time_delta_for_each_fps
+
+ output_frames = np.stack(output_frames)
+
+ vid_len_in_s = len(output_frames) / fps
+ if max_length is not None and len(output_frames) > int(max_length * fps):
+ output_frames = output_frames[: int(max_length * fps)]
+ vid_len_in_s = max_length
+
+ return output_frames, vid_len_in_s
+
+@torch.inference_mode()
+def encode_video_with_siglip2(x: torch.Tensor, model_dict, batch_size: int = -1):
+ b, t, c, h, w = x.shape
+ if batch_size < 0:
+ batch_size = b * t
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+ outputs = []
+ for i in range(0, b * t, batch_size):
+ outputs.append(model_dict.siglip2_model.get_image_features(pixel_values=x[i : i + batch_size]))
+ res = torch.cat(outputs, dim=0)
+ res = rearrange(res, "(b t) d -> b t d", b=b)
+ return res
+
+@torch.inference_mode()
+def encode_video_with_sync(x: torch.Tensor, model_dict, batch_size: int = -1):
+ """
+ The input video of x is best to be in fps of 24 of greater than 24.
+ Input:
+ x: tensor in shape of [B, T, C, H, W]
+ batch_size: the batch_size for synchformer inference
+ """
+ b, t, c, h, w = x.shape
+ assert c == 3 and h == 224 and w == 224
+
+ segment_size = 16
+ step_size = 8
+ num_segments = (t - segment_size) // step_size + 1
+ segments = []
+ for i in range(num_segments):
+ segments.append(x[:, i * step_size : i * step_size + segment_size])
+ x = torch.stack(segments, dim=1).cuda() # (B, num_segments, segment_size, 3, 224, 224)
+
+ outputs = []
+ if batch_size < 0:
+ batch_size = b * num_segments
+ x = rearrange(x, "b s t c h w -> (b s) 1 t c h w")
+ for i in range(0, b * num_segments, batch_size):
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.half):
+ outputs.append(model_dict.syncformer_model(x[i : i + batch_size]))
+ x = torch.cat(outputs, dim=0) # [b * num_segments, 1, 8, 768]
+ x = rearrange(x, "(b s) 1 t d -> b (s t) d", b=b)
+ return x
+
+
+@torch.inference_mode()
+def encode_video_features(video_path, model_dict):
+ visual_features = {}
+ # siglip2 visual features
+ frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["siglip2"])
+ images = [Image.fromarray(frame).convert('RGB') for frame in frames]
+ images = [model_dict.siglip2_preprocess(image) for image in images] # [T, C, H, W]
+ clip_frames = torch.stack(images).to(model_dict.device).unsqueeze(0)
+ visual_features['siglip2_feat'] = encode_video_with_siglip2(clip_frames, model_dict).to(model_dict.device)
+
+ # synchformer visual features
+ frames, ori_vid_len_in_s = get_frames_av(video_path, FPS_VISUAL["synchformer"])
+ images = torch.from_numpy(frames).permute(0, 3, 1, 2) # [T, C, H, W]
+ sync_frames = model_dict.syncformer_preprocess(images).unsqueeze(0) # [1, T, 3, 224, 224]
+ # [1, num_segments * 8, channel_dim], e.g. [1, 240, 768] for 10s video
+ visual_features['syncformer_feat'] = encode_video_with_sync(sync_frames, model_dict)
+
+ vid_len_in_s = sync_frames.shape[1] / FPS_VISUAL["synchformer"]
+ visual_features = AttributeDict(visual_features)
+
+ return visual_features, vid_len_in_s
+
+@torch.inference_mode()
+def encode_text_feat(text: List[str], model_dict):
+ # x: (B, L)
+ inputs = model_dict.clap_tokenizer(text, padding=True, return_tensors="pt").to(model_dict.device)
+ outputs = model_dict.clap_model(**inputs, output_hidden_states=True, return_dict=True)
+ return outputs.last_hidden_state, outputs.attentions
+
+
+def feature_process(video_path, prompt, model_dict, cfg):
+ visual_feats, audio_len_in_s = encode_video_features(video_path, model_dict)
+ neg_prompt = "noisy, harsh"
+ prompts = [neg_prompt, prompt]
+ text_feat_res, text_feat_mask = encode_text_feat(prompts, model_dict)
+
+ text_feat = text_feat_res[1:]
+ uncond_text_feat = text_feat_res[:1]
+
+ if cfg.model_config.model_kwargs.text_length < text_feat.shape[1]:
+ text_seq_length = cfg.model_config.model_kwargs.text_length
+ text_feat = text_feat[:, :text_seq_length]
+ uncond_text_feat = uncond_text_feat[:, :text_seq_length]
+
+ text_feats = AttributeDict({
+ 'text_feat': text_feat,
+ 'uncond_text_feat': uncond_text_feat,
+ })
+
+ return visual_feats, text_feats, audio_len_in_s
diff --git a/hunyuanvideo_foley/utils/helper.py b/hunyuanvideo_foley/utils/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..04840dfcdf847e3f61ec327ebc691b0dbd139da4
--- /dev/null
+++ b/hunyuanvideo_foley/utils/helper.py
@@ -0,0 +1,134 @@
+import collections.abc
+from itertools import repeat
+import importlib
+import yaml
+import time
+
+def default(value, default_val):
+ return default_val if value is None else value
+
+
+def default_dtype(value, default_val):
+ if value is not None:
+ assert isinstance(value, type(default_val)), f"Expect {type(default_val)}, got {type(value)}."
+ return value
+ return default_val
+
+
+def repeat_interleave(lst, num_repeats):
+ return [item for item in lst for _ in range(num_repeats)]
+
+
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ x = tuple(x)
+ if len(x) == 1:
+ x = tuple(repeat(x[0], n))
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+
+
+def as_tuple(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ if x is None or isinstance(x, (int, float, str)):
+ return (x,)
+ else:
+ raise ValueError(f"Unknown type {type(x)}")
+
+
+def as_list_of_2tuple(x):
+ x = as_tuple(x)
+ if len(x) == 1:
+ x = (x[0], x[0])
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
+ lst = []
+ for i in range(0, len(x), 2):
+ lst.append((x[i], x[i + 1]))
+ return lst
+
+
+def find_multiple(n: int, k: int) -> int:
+ assert k > 0
+ if n % k == 0:
+ return n
+ return n - (n % k) + k
+
+
+def merge_dicts(dict1, dict2):
+ for key, value in dict2.items():
+ if key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict):
+ merge_dicts(dict1[key], value)
+ else:
+ dict1[key] = value
+ return dict1
+
+
+def merge_yaml_files(file_list):
+ merged_config = {}
+
+ for file in file_list:
+ with open(file, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+ if config:
+ # Remove the first level
+ for key, value in config.items():
+ if isinstance(value, dict):
+ merged_config = merge_dicts(merged_config, value)
+ else:
+ merged_config[key] = value
+
+ return merged_config
+
+
+def merge_dict(file_list):
+ merged_config = {}
+
+ for file in file_list:
+ with open(file, "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+ if config:
+ merged_config = merge_dicts(merged_config, config)
+
+ return merged_config
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def readable_time(seconds):
+ """ Convert time seconds to a readable format: DD Days, HH Hours, MM Minutes, SS Seconds """
+ seconds = int(seconds)
+ days, seconds = divmod(seconds, 86400)
+ hours, seconds = divmod(seconds, 3600)
+ minutes, seconds = divmod(seconds, 60)
+ if days > 0:
+ return f"{days} Days, {hours} Hours, {minutes} Minutes, {seconds} Seconds"
+ if hours > 0:
+ return f"{hours} Hours, {minutes} Minutes, {seconds} Seconds"
+ if minutes > 0:
+ return f"{minutes} Minutes, {seconds} Seconds"
+ return f"{seconds} Seconds"
+
+
+def get_obj_from_cfg(cfg, reload=False):
+ if isinstance(cfg, str):
+ return get_obj_from_str(cfg, reload)
+ elif isinstance(cfg, (list, tuple,)):
+ return tuple([get_obj_from_str(c, reload) for c in cfg])
+ else:
+ raise NotImplementedError(f"Not implemented for {type(cfg)}.")
diff --git a/hunyuanvideo_foley/utils/media_utils.py b/hunyuanvideo_foley/utils/media_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d80e6cbfe5164caeb7340fa2fbbcb8a11ddd9c13
--- /dev/null
+++ b/hunyuanvideo_foley/utils/media_utils.py
@@ -0,0 +1,101 @@
+"""Media utilities for audio/video processing."""
+
+import os
+import subprocess
+from pathlib import Path
+from typing import Optional
+
+from loguru import logger
+
+
+class MediaProcessingError(Exception):
+ """Exception raised for media processing errors."""
+ pass
+
+
+def merge_audio_video(
+ audio_path: str,
+ video_path: str,
+ output_path: str,
+ overwrite: bool = True,
+ quality: str = "high"
+) -> str:
+ """
+ Merge audio and video files using ffmpeg.
+
+ Args:
+ audio_path: Path to input audio file
+ video_path: Path to input video file
+ output_path: Path for output video file
+ overwrite: Whether to overwrite existing output file
+ quality: Quality setting ('high', 'medium', 'low')
+
+ Returns:
+ Path to the output file
+
+ Raises:
+ MediaProcessingError: If input files don't exist or ffmpeg fails
+ FileNotFoundError: If ffmpeg is not installed
+ """
+ # Validate input files
+ if not os.path.exists(audio_path):
+ raise MediaProcessingError(f"Audio file not found: {audio_path}")
+ if not os.path.exists(video_path):
+ raise MediaProcessingError(f"Video file not found: {video_path}")
+
+ # Create output directory if needed
+ output_dir = Path(output_path).parent
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Quality settings
+ quality_settings = {
+ "high": ["-b:a", "192k"],
+ "medium": ["-b:a", "128k"],
+ "low": ["-b:a", "96k"]
+ }
+
+ # Build ffmpeg command
+ ffmpeg_command = [
+ "ffmpeg",
+ "-i", video_path,
+ "-i", audio_path,
+ "-c:v", "copy",
+ "-c:a", "aac",
+ "-ac", "2",
+ "-af", "pan=stereo|c0=c0|c1=c0",
+ "-map", "0:v:0",
+ "-map", "1:a:0",
+ *quality_settings.get(quality, quality_settings["high"]),
+ ]
+
+ if overwrite:
+ ffmpeg_command.append("-y")
+
+ ffmpeg_command.append(output_path)
+
+ try:
+ logger.info(f"Merging audio '{audio_path}' with video '{video_path}'")
+ process = subprocess.Popen(
+ ffmpeg_command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True
+ )
+ stdout, stderr = process.communicate()
+
+ if process.returncode != 0:
+ error_msg = f"FFmpeg failed with return code {process.returncode}: {stderr}"
+ logger.error(error_msg)
+ raise MediaProcessingError(error_msg)
+ else:
+ logger.info(f"Successfully merged video saved to: {output_path}")
+
+ except FileNotFoundError:
+ raise FileNotFoundError(
+ "ffmpeg not found. Please install ffmpeg: "
+ "https://ffmpeg.org/download.html"
+ )
+ except Exception as e:
+ raise MediaProcessingError(f"Unexpected error during media processing: {e}")
+
+ return output_path
diff --git a/hunyuanvideo_foley/utils/model_utils.py b/hunyuanvideo_foley/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..354ea223c7ef8ae4f8759f6763028d27e6ad13bc
--- /dev/null
+++ b/hunyuanvideo_foley/utils/model_utils.py
@@ -0,0 +1,241 @@
+import torch
+import os
+from loguru import logger
+from torchvision import transforms
+from torchvision.transforms import v2
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import AutoTokenizer, AutoModel, ClapTextModelWithProjection
+from ..models.dac_vae.model.dac import DAC
+from ..models.synchformer import Synchformer
+from ..models.hifi_foley import HunyuanVideoFoley
+from .config_utils import load_yaml, AttributeDict
+from .schedulers import FlowMatchDiscreteScheduler
+from tqdm import tqdm
+
+def load_state_dict(model, model_path):
+ logger.info(f"Loading model state dict from: {model_path}")
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False)
+
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+
+ if missing_keys:
+ logger.warning(f"Missing keys in state dict ({len(missing_keys)} keys):")
+ for key in missing_keys:
+ logger.warning(f" - {key}")
+ else:
+ logger.info("No missing keys found")
+
+ if unexpected_keys:
+ logger.warning(f"Unexpected keys in state dict ({len(unexpected_keys)} keys):")
+ for key in unexpected_keys:
+ logger.warning(f" - {key}")
+ else:
+ logger.info("No unexpected keys found")
+
+ logger.info("Model state dict loaded successfully")
+ return model
+
+def load_model(model_path, config_path, device):
+ logger.info("Starting model loading process...")
+ logger.info(f"Configuration file: {config_path}")
+ logger.info(f"Model weights dir: {model_path}")
+ logger.info(f"Target device: {device}")
+
+ cfg = load_yaml(config_path)
+ logger.info("Configuration loaded successfully")
+
+ # HunyuanVideoFoley
+ logger.info("Loading HunyuanVideoFoley main model...")
+ foley_model = HunyuanVideoFoley(cfg, dtype=torch.bfloat16, device=device).to(device=device, dtype=torch.bfloat16)
+ foley_model = load_state_dict(foley_model, os.path.join(model_path, "hunyuanvideo_foley.pth"))
+ foley_model.eval()
+ logger.info("HunyuanVideoFoley model loaded and set to evaluation mode")
+
+ # DAC-VAE
+ dac_path = os.path.join(model_path, "vae_128d_48k.pth")
+ logger.info(f"Loading DAC VAE model from: {dac_path}")
+ dac_model = DAC.load(dac_path)
+ dac_model = dac_model.to(device)
+ dac_model.requires_grad_(False)
+ dac_model.eval()
+ logger.info("DAC VAE model loaded successfully")
+
+ # Siglip2 visual-encoder
+ logger.info("Loading SigLIP2 visual encoder...")
+ siglip2_preprocess = transforms.Compose([
+ transforms.Resize((512, 512)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ])
+ siglip2_model = AutoModel.from_pretrained("google/siglip2-base-patch16-512").to(device).eval()
+ logger.info("SigLIP2 model and preprocessing pipeline loaded successfully")
+
+ # clap text-encoder
+ logger.info("Loading CLAP text encoder...")
+ clap_tokenizer = AutoTokenizer.from_pretrained("laion/larger_clap_general")
+ clap_model = ClapTextModelWithProjection.from_pretrained("laion/larger_clap_general").to(device)
+ logger.info("CLAP tokenizer and model loaded successfully")
+
+ # syncformer
+ syncformer_path = os.path.join(model_path, "synchformer_state_dict.pth")
+ logger.info(f"Loading Synchformer model from: {syncformer_path}")
+ syncformer_preprocess = v2.Compose(
+ [
+ v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
+ v2.CenterCrop(224),
+ v2.ToImage(),
+ v2.ToDtype(torch.float32, scale=True),
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ]
+ )
+
+ syncformer_model = Synchformer()
+ syncformer_model.load_state_dict(torch.load(syncformer_path, weights_only=False, map_location="cpu"))
+ syncformer_model = syncformer_model.to(device).eval()
+ logger.info("Synchformer model and preprocessing pipeline loaded successfully")
+
+
+ logger.info("Creating model dictionary with attribute access...")
+ model_dict = AttributeDict({
+ 'foley_model': foley_model,
+ 'dac_model': dac_model,
+ 'siglip2_preprocess': siglip2_preprocess,
+ 'siglip2_model': siglip2_model,
+ 'clap_tokenizer': clap_tokenizer,
+ 'clap_model': clap_model,
+ 'syncformer_preprocess': syncformer_preprocess,
+ 'syncformer_model': syncformer_model,
+ 'device': device,
+ })
+
+ logger.info("All models loaded successfully!")
+ logger.info("Available model components:")
+ for key in model_dict.keys():
+ logger.info(f" - {key}")
+ logger.info("Models can be accessed via attribute notation (e.g., models.foley_model)")
+
+ return model_dict, cfg
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps,
+ device,
+ **kwargs,
+):
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+def prepare_latents(scheduler, batch_size, num_channels_latents, length, dtype, device):
+ shape = (batch_size, num_channels_latents, int(length))
+ latents = randn_tensor(shape, device=device, dtype=dtype)
+
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
+ if hasattr(scheduler, "init_noise_sigma"):
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * scheduler.init_noise_sigma
+
+ return latents
+
+
+@torch.no_grad()
+def denoise_process(visual_feats, text_feats, audio_len_in_s, model_dict, cfg, guidance_scale=4.5, num_inference_steps=50, batch_size=1):
+
+ target_dtype = model_dict.foley_model.dtype
+ autocast_enabled = target_dtype != torch.float32
+ device = model_dict.device
+
+ scheduler = FlowMatchDiscreteScheduler(
+ shift=cfg.diffusion_config.sample_flow_shift,
+ reverse=cfg.diffusion_config.flow_reverse,
+ solver=cfg.diffusion_config.flow_solver,
+ use_flux_shift=cfg.diffusion_config.sample_use_flux_shift,
+ flux_base_shift=cfg.diffusion_config.flux_base_shift,
+ flux_max_shift=cfg.diffusion_config.flux_max_shift,
+ )
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ scheduler,
+ num_inference_steps,
+ device,
+ )
+
+ latents = prepare_latents(
+ scheduler,
+ batch_size=batch_size,
+ num_channels_latents=cfg.model_config.model_kwargs.audio_vae_latent_dim,
+ length=audio_len_in_s * cfg.model_config.model_kwargs.audio_frame_rate,
+ dtype=target_dtype,
+ device=device,
+ )
+
+ # Denoise loop
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Denoising steps"):
+ # noise latents
+ latent_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
+ latent_input = scheduler.scale_model_input(latent_input, t)
+
+ t_expand = t.repeat(latent_input.shape[0])
+
+ # siglip2 features
+ siglip2_feat = visual_feats.siglip2_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
+ uncond_siglip2_feat = model_dict.foley_model.get_empty_clip_sequence(
+ bs=batch_size, len=siglip2_feat.shape[1]
+ ).to(device)
+
+ if guidance_scale is not None and guidance_scale > 1.0:
+ siglip2_feat_input = torch.cat([uncond_siglip2_feat, siglip2_feat], dim=0)
+ else:
+ siglip2_feat_input = siglip2_feat
+
+ # syncformer features
+ syncformer_feat = visual_feats.syncformer_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
+ uncond_syncformer_feat = model_dict.foley_model.get_empty_sync_sequence(
+ bs=batch_size, len=syncformer_feat.shape[1]
+ ).to(device)
+ if guidance_scale is not None and guidance_scale > 1.0:
+ syncformer_feat_input = torch.cat([uncond_syncformer_feat, syncformer_feat], dim=0)
+ else:
+ syncformer_feat_input = syncformer_feat
+
+ # text features
+ text_feat_repeated = text_feats.text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
+ uncond_text_feat_repeated = text_feats.uncond_text_feat.repeat(batch_size, 1, 1) # Repeat for batch_size
+ if guidance_scale is not None and guidance_scale > 1.0:
+ text_feat_input = torch.cat([uncond_text_feat_repeated, text_feat_repeated], dim=0)
+ else:
+ text_feat_input = text_feat_repeated
+
+ with torch.autocast(device_type=device.type, enabled=autocast_enabled, dtype=target_dtype):
+ # Predict the noise residual
+ noise_pred = model_dict.foley_model(
+ x=latent_input,
+ t=t_expand,
+ cond=text_feat_input,
+ clip_feat=siglip2_feat_input,
+ sync_feat=syncformer_feat_input,
+ return_dict=True,
+ )["x"]
+
+ noise_pred = noise_pred.to(dtype=torch.float32)
+
+ if guidance_scale is not None and guidance_scale > 1.0:
+ # Perform classifier-free guidance
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # Compute the previous noisy sample x_t -> x_t-1
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # Post-process the latents to audio
+
+ with torch.no_grad():
+ audio = model_dict.dac_model.decode(latents)
+ audio = audio.float().cpu()
+
+ audio = audio[:, :int(audio_len_in_s*model_dict.dac_model.sample_rate)]
+
+ return audio, model_dict.dac_model.sample_rate
+
+
diff --git a/hunyuanvideo_foley/utils/schedulers/__init__.py b/hunyuanvideo_foley/utils/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fea4433560c92fc9d8569993447d0bdb456dc9e
--- /dev/null
+++ b/hunyuanvideo_foley/utils/schedulers/__init__.py
@@ -0,0 +1,2 @@
+from diffusers.schedulers import DDPMScheduler, EulerDiscreteScheduler
+from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
\ No newline at end of file
diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-312.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf9cbc935de5371a4cb668e5303dd33e845f8368
Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..381e78dd35cabf6f12fbf6da052bcf4562808b45
Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/__init__.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-312.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96c4539046ffa747401e185d9fda122792d003a6
Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-312.pyc differ
diff --git a/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a2af07df956649977e8a493af347471de260eab
Binary files /dev/null and b/hunyuanvideo_foley/utils/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-313.pyc differ
diff --git a/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py b/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f58814479643b42eb18158db2fcc2a29544424e
--- /dev/null
+++ b/hunyuanvideo_foley/utils/schedulers/scheduling_flow_match_discrete.py
@@ -0,0 +1,376 @@
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FlowMatchDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ shift (`float`, defaults to 1.0):
+ The shift value for the timestep schedule.
+ reverse (`bool`, defaults to `True`):
+ Whether to reverse the timestep schedule.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ reverse: bool = True,
+ solver: str = "euler",
+ use_flux_shift: bool = False,
+ flux_base_shift: float = 0.5,
+ flux_max_shift: float = 1.15,
+ n_tokens: Optional[int] = None,
+ ):
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
+
+ if not reverse:
+ sigmas = sigmas.flip(0)
+
+ self.sigmas = sigmas
+ # the value fed to model
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
+ self.timesteps_full = (sigmas * num_train_timesteps).to(dtype=torch.float32)
+
+ self._step_index = None
+ self._begin_index = None
+
+ self.supported_solver = [
+ "euler",
+ "heun-2", "midpoint-2",
+ "kutta-4",
+ ]
+ if solver not in self.supported_solver:
+ raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
+
+ # empty dt and derivative (for heun)
+ self.derivative_1 = None
+ self.derivative_2 = None
+ self.derivative_3 = None
+ self.dt = None
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ @property
+ def state_in_first_order(self):
+ return self.derivative_1 is None
+
+ @property
+ def state_in_second_order(self):
+ return self.derivative_2 is None
+
+ @property
+ def state_in_third_order(self):
+ return self.derivative_3 is None
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
+ n_tokens: int = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ n_tokens (`int`, *optional*):
+ Number of tokens in the input sequence.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
+
+ # Apply timestep shift
+ if self.config.use_flux_shift:
+ assert isinstance(n_tokens, int), "n_tokens should be provided for flux shift"
+ mu = self.get_lin_function(y1=self.config.flux_base_shift, y2=self.config.flux_max_shift)(n_tokens)
+ sigmas = self.flux_time_shift(mu, 1.0, sigmas)
+ elif self.config.shift != 1.:
+ sigmas = self.sd3_time_shift(sigmas)
+
+ if not self.config.reverse:
+ sigmas = 1 - sigmas
+
+ self.sigmas = sigmas
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
+ self.timesteps_full = (sigmas * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
+
+ # empty dt and derivative (for kutta)
+ self.derivative_1 = None
+ self.derivative_2 = None
+ self.derivative_3 = None
+ self.dt = None
+
+ # Reset step index
+ self._step_index = None
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
+ return sample
+
+ @staticmethod
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
+ m = (y2 - y1) / (x2 - x1)
+ b = y1 - m * x1
+ return lambda x: m * x + b
+
+ @staticmethod
+ def flux_time_shift(mu: float, sigma: float, t: torch.Tensor):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def sd3_time_shift(self, t: torch.Tensor):
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ pred_uncond: torch.FloatTensor = None,
+ generator: Optional[torch.Generator] = None,
+ n_tokens: Optional[int] = None,
+ return_dict: bool = True,
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ n_tokens (`int`, *optional*):
+ Number of tokens in the input sequence.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+ model_output = model_output.to(torch.float32)
+ pred_uncond = pred_uncond.to(torch.float32) if pred_uncond is not None else None
+
+ # dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
+ sigma = self.sigmas[self.step_index]
+ sigma_next = self.sigmas[self.step_index + 1]
+
+ last_inner_step = True
+ if self.config.solver == "euler":
+ derivative, dt, sample, last_inner_step = self.first_order_method(model_output, sigma, sigma_next, sample)
+ elif self.config.solver in ["heun-2", "midpoint-2"]:
+ derivative, dt, sample, last_inner_step = self.second_order_method(model_output, sigma, sigma_next, sample)
+ elif self.config.solver == "kutta-4":
+ derivative, dt, sample, last_inner_step = self.fourth_order_method(model_output, sigma, sigma_next, sample)
+ else:
+ raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
+
+ prev_sample = sample + derivative * dt
+
+ # Cast sample back to model compatible dtype
+ # prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ if last_inner_step:
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
+
+ def first_order_method(self, model_output, sigma, sigma_next, sample):
+ derivative = model_output.float()
+ dt = sigma_next - sigma
+ return derivative, dt, sample, True
+
+ def second_order_method(self, model_output, sigma, sigma_next, sample):
+ if self.state_in_first_order:
+ # store for 2nd order step
+ self.derivative_1 = model_output
+ self.dt = sigma_next - sigma
+ self.sample = sample
+
+ derivative = model_output
+ if self.config.solver == 'heun-2':
+ dt = self.dt
+ elif self.config.solver == 'midpoint-2':
+ dt = self.dt / 2
+ else:
+ raise NotImplementedError(f"Solver {self.config.solver} not supported.")
+ last_inner_step = False
+
+ else:
+ if self.config.solver == 'heun-2':
+ derivative = 0.5 * (self.derivative_1 + model_output)
+ elif self.config.solver == 'midpoint-2':
+ derivative = model_output
+ else:
+ raise NotImplementedError(f"Solver {self.config.solver} not supported.")
+
+ # 3. take prev timestep & sample
+ dt = self.dt
+ sample = self.sample
+ last_inner_step = True
+
+ # free dt and derivative
+ # Note, this puts the scheduler in "first order mode"
+ self.derivative_1 = None
+ self.dt = None
+ self.sample = None
+
+ return derivative, dt, sample, last_inner_step
+
+ def fourth_order_method(self, model_output, sigma, sigma_next, sample):
+ if self.state_in_first_order:
+ self.derivative_1 = model_output
+ self.dt = sigma_next - sigma
+ self.sample = sample
+ derivative = model_output
+ dt = self.dt / 2
+ last_inner_step = False
+
+ elif self.state_in_second_order:
+ self.derivative_2 = model_output
+ derivative = model_output
+ dt = self.dt / 2
+ last_inner_step = False
+
+ elif self.state_in_third_order:
+ self.derivative_3 = model_output
+ derivative = model_output
+ dt = self.dt
+ last_inner_step = False
+
+ else:
+ derivative = 1/6 * self.derivative_1 + 1/3 * self.derivative_2 + 1/3 * self.derivative_3 + 1/6 * model_output
+
+ # 3. take prev timestep & sample
+ dt = self.dt
+ sample = self.sample
+ last_inner_step = True
+
+ # free dt and derivative
+ # Note, this puts the scheduler in "first order mode"
+ self.derivative_1 = None
+ self.derivative_2 = None
+ self.derivative_3 = None
+ self.dt = None
+ self.sample = None
+
+ return derivative, dt, sample, last_inner_step
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f95a2ef41a38066e9991861ef6c8e9a299fea561
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,52 @@
+# Core ML dependencies - CPU optimized
+torch>=2.0.0,<2.3.0
+torchvision>=0.15.0,<0.18.0
+torchaudio>=2.0.0,<2.3.0
+numpy>=1.21.0,<1.27.0
+scipy>=1.7.0
+
+# Deep Learning frameworks
+diffusers>=0.24.0
+timm>=0.9.0
+accelerate>=0.20.0
+
+# Transformers and NLP - use stable version
+transformers>=4.35.0,<4.50.0
+sentencepiece>=0.1.99
+
+# Audio processing - use HTTPS version for HF Spaces
+audiotools @ git+https://github.com/descriptinc/audiotools.git
+
+# Video/Image processing
+pillow>=8.3.0
+av>=10.0.0
+einops>=0.6.0
+
+# Configuration and utilities
+pyyaml>=6.0
+omegaconf>=2.3.0
+easydict>=1.9.0
+loguru>=0.6.0
+tqdm>=4.64.0
+setuptools>=65.0.0
+
+# Data handling
+pandas>=1.3.0
+pyarrow>=10.0.0
+
+# Web interface - compatible version for HF Spaces
+gradio>=4.0.0,<5.0.0
+
+# Network
+urllib3>=1.26.0,<3.0.0
+
+# Hugging Face integration
+huggingface_hub>=0.16.0
+datasets>=2.14.0
+
+# Additional dependencies for stability
+packaging>=21.0
+typing-extensions>=4.0.0
+
+# Optional: reduce memory usage
+psutil>=5.8.0
\ No newline at end of file