Spaces:
Running on Zero
Running on Zero
Upload 100 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- packages/ltx-core/README.md +280 -0
- packages/ltx-core/pyproject.toml +37 -0
- packages/ltx-core/src/ltx_core/__init__.py +0 -0
- packages/ltx-core/src/ltx_core/components/__init__.py +10 -0
- packages/ltx-core/src/ltx_core/components/diffusion_steps.py +22 -0
- packages/ltx-core/src/ltx_core/components/guiders.py +198 -0
- packages/ltx-core/src/ltx_core/components/noisers.py +35 -0
- packages/ltx-core/src/ltx_core/components/patchifiers.py +348 -0
- packages/ltx-core/src/ltx_core/components/protocols.py +101 -0
- packages/ltx-core/src/ltx_core/components/schedulers.py +129 -0
- packages/ltx-core/src/ltx_core/conditioning/__init__.py +12 -0
- packages/ltx-core/src/ltx_core/conditioning/exceptions.py +4 -0
- packages/ltx-core/src/ltx_core/conditioning/item.py +20 -0
- packages/ltx-core/src/ltx_core/conditioning/types/__init__.py +9 -0
- packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py +53 -0
- packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py +44 -0
- packages/ltx-core/src/ltx_core/guidance/__init__.py +15 -0
- packages/ltx-core/src/ltx_core/guidance/perturbations.py +79 -0
- packages/ltx-core/src/ltx_core/loader/__init__.py +48 -0
- packages/ltx-core/src/ltx_core/loader/fuse_loras.py +100 -0
- packages/ltx-core/src/ltx_core/loader/kernels.py +72 -0
- packages/ltx-core/src/ltx_core/loader/module_ops.py +14 -0
- packages/ltx-core/src/ltx_core/loader/primitives.py +109 -0
- packages/ltx-core/src/ltx_core/loader/registry.py +84 -0
- packages/ltx-core/src/ltx_core/loader/sd_ops.py +127 -0
- packages/ltx-core/src/ltx_core/loader/sft_loader.py +63 -0
- packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +101 -0
- packages/ltx-core/src/ltx_core/model/__init__.py +8 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py +27 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/attention.py +71 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py +480 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py +110 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py +10 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py +110 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py +123 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/ops.py +76 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py +176 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py +106 -0
- packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py +123 -0
- packages/ltx-core/src/ltx_core/model/common/__init__.py +9 -0
- packages/ltx-core/src/ltx_core/model/common/normalization.py +59 -0
- packages/ltx-core/src/ltx_core/model/model_protocol.py +10 -0
- packages/ltx-core/src/ltx_core/model/transformer/__init__.py +24 -0
- packages/ltx-core/src/ltx_core/model/transformer/adaln.py +34 -0
- packages/ltx-core/src/ltx_core/model/transformer/attention.py +185 -0
- packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py +15 -0
- packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py +10 -0
- packages/ltx-core/src/ltx_core/model/transformer/modality.py +23 -0
- packages/ltx-core/src/ltx_core/model/transformer/model.py +468 -0
- packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py +237 -0
packages/ltx-core/README.md
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LTX-Core
|
| 2 |
+
|
| 3 |
+
The foundational library for the LTX-2 Audio-Video generation model. This package contains the raw model definitions, component implementations, and loading logic used by `ltx-pipelines` and `ltx-trainer`.
|
| 4 |
+
|
| 5 |
+
## 📦 What's Inside?
|
| 6 |
+
|
| 7 |
+
- **`components/`**: Modular diffusion components (Schedulers, Guiders, Noisers, Patchifiers) following standard protocols
|
| 8 |
+
- **`conditioning/`**: Tools for preparing latent states and applying conditioning (image, video, keyframes)
|
| 9 |
+
- **`guidance/`**: Perturbation system for fine-grained control over attention mechanisms
|
| 10 |
+
- **`loader/`**: Utilities for loading weights from `.safetensors`, fusing LoRAs, and managing memory
|
| 11 |
+
- **`model/`**: PyTorch implementations of the LTX-2 Transformer, Video VAE, Audio VAE, Vocoder and Upscaler
|
| 12 |
+
- **`text_encoders/gemma`**: Gemma text encoder implementation with tokenizers, feature extractors, and separate encoders for audio-video and video-only generation
|
| 13 |
+
|
| 14 |
+
## 🚀 Quick Start
|
| 15 |
+
|
| 16 |
+
`ltx-core` provides the building blocks (models, components, and utilities) needed to construct inference flows. For ready-made inference pipelines use [`ltx-pipelines`](../ltx-pipelines/) or [`ltx-trainer`](../ltx-trainer/) for training.
|
| 17 |
+
|
| 18 |
+
## 🔧 Installation
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
# From the repository root
|
| 22 |
+
uv sync --frozen
|
| 23 |
+
|
| 24 |
+
# Or install as a package
|
| 25 |
+
pip install -e packages/ltx-core
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Building Blocks Overview
|
| 29 |
+
|
| 30 |
+
`ltx-core` provides modular components that can be combined to build custom inference flows:
|
| 31 |
+
|
| 32 |
+
### Core Models
|
| 33 |
+
|
| 34 |
+
- **Transformer** ([`model/transformer/`](src/ltx_core/model/transformer/)): The 48-layer LTX-2 transformer with cross-modal attention for joint audio-video processing. Expects inputs in [`Modality`](src/ltx_core/model/transformer/modality.py) format
|
| 35 |
+
- **Video VAE** ([`model/video_vae/`](src/ltx_core/model/video_vae/)): Encodes/decodes video pixels to/from latent space with temporal and spatial compression
|
| 36 |
+
- **Audio VAE** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Encodes/decodes audio spectrograms to/from latent space
|
| 37 |
+
- **Vocoder** ([`model/audio_vae/`](src/ltx_core/model/audio_vae/)): Neural vocoder that converts mel spectrograms to audio waveforms
|
| 38 |
+
- **Text Encoder** ([`text_encoders/`](src/ltx_core/text_encoders/)): Gemma-based encoder that produces separate embeddings for video and audio conditioning
|
| 39 |
+
- **Spatial Upscaler** ([`model/upsampler/`](src/ltx_core/model/upsampler/)): Upsamples latent representations for higher-resolution generation
|
| 40 |
+
|
| 41 |
+
### Diffusion Components
|
| 42 |
+
|
| 43 |
+
- **Schedulers** ([`components/schedulers.py`](src/ltx_core/components/schedulers.py)): Noise schedules (LTX2Scheduler, LinearQuadratic, Beta) that control the denoising process
|
| 44 |
+
- **Guiders** ([`components/guiders.py`](src/ltx_core/components/guiders.py)): Guidance strategies (CFG, STG, APG) for controlling generation quality and adherence to prompts
|
| 45 |
+
- **Noisers** ([`components/noisers.py`](src/ltx_core/components/noisers.py)): Add noise to latents according to the diffusion schedule
|
| 46 |
+
- **Patchifiers** ([`components/patchifiers.py`](src/ltx_core/components/patchifiers.py)): Convert between spatial latents `[B, C, F, H, W]` and sequence format `[B, seq_len, dim]` for transformer processing
|
| 47 |
+
|
| 48 |
+
### Conditioning & Control
|
| 49 |
+
|
| 50 |
+
- **Conditioning** ([`conditioning/`](src/ltx_core/conditioning/)): Tools for preparing and applying various conditioning types (image, video, keyframes)
|
| 51 |
+
- **Guidance** ([`guidance/`](src/ltx_core/guidance/)): Perturbation system for fine-grained control over attention mechanisms (e.g., skipping specific attention layers)
|
| 52 |
+
|
| 53 |
+
### Utilities
|
| 54 |
+
|
| 55 |
+
- **Loader** ([`loader/`](src/ltx_core/loader/)): Model loading from `.safetensors`, LoRA fusion, weight remapping, and memory management
|
| 56 |
+
|
| 57 |
+
For complete, production-ready pipeline implementations that combine these building blocks, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
# Architecture Overview
|
| 62 |
+
|
| 63 |
+
This section provides a deep dive into the internal architecture of the LTX-2 Audio-Video generation model.
|
| 64 |
+
|
| 65 |
+
## Table of Contents
|
| 66 |
+
|
| 67 |
+
1. [High-Level Architecture](#high-level-architecture)
|
| 68 |
+
2. [The Transformer](#the-transformer)
|
| 69 |
+
3. [Video VAE](#video-vae)
|
| 70 |
+
4. [Audio VAE](#audio-vae)
|
| 71 |
+
5. [Text Encoding (Gemma)](#text-encoding-gemma)
|
| 72 |
+
6. [Spatial Upscaler](#spatial-upsampler)
|
| 73 |
+
7. [Data Flow](#data-flow)
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## High-Level Architecture
|
| 78 |
+
|
| 79 |
+
LTX-2 is a **joint Audio-Video diffusion transformer** that processes both modalities simultaneously in a unified architecture. Unlike traditional models that handle video and audio separately, LTX-2 uses cross-modal attention to enable natural synchronization.
|
| 80 |
+
|
| 81 |
+
```text
|
| 82 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 83 |
+
│ INPUT PREPARATION │
|
| 84 |
+
│ │
|
| 85 |
+
│ Video Pixels → Video VAE Encoder → Video Latents │
|
| 86 |
+
│ Audio Waveform → Audio VAE Encoder → Audio Latents │
|
| 87 |
+
│ Text Prompt → Gemma Encoder → Text Embeddings │
|
| 88 |
+
└─────────────────────────────────────────────────────────────┘
|
| 89 |
+
↓
|
| 90 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 91 |
+
│ LTX-2 TRANSFORMER (48 Blocks) │
|
| 92 |
+
│ │
|
| 93 |
+
│ ┌──────────────┐ ┌──────────────┐ │
|
| 94 |
+
│ │ Video Stream │ │ Audio Stream │ │
|
| 95 |
+
│ │ │ │ │ │
|
| 96 |
+
│ │ Self-Attn │ │ Self-Attn │ │
|
| 97 |
+
│ │ Cross-Attn │ │ Cross-Attn │ │
|
| 98 |
+
│ │ │◄────────────►│ │ │
|
| 99 |
+
│ │ A↔V Cross │ │ A↔V Cross │ │
|
| 100 |
+
│ │ Feed-Forward │ │ Feed-Forward │ │
|
| 101 |
+
│ └──────────────┘ └──────────────┘ │
|
| 102 |
+
└─────────────────────────────────────────────────────────────┘
|
| 103 |
+
↓
|
| 104 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 105 |
+
│ OUTPUT DECODING │
|
| 106 |
+
│ │
|
| 107 |
+
│ Video Latents → Video VAE Decoder → Video Pixels │
|
| 108 |
+
│ Audio Latents → Audio VAE Decoder → Mel Spectrogram │
|
| 109 |
+
│ Mel Spectrogram → Vocoder → Audio Waveform │
|
| 110 |
+
└─────────────────────────────────────────────────────────────┘
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## The Transformer
|
| 116 |
+
|
| 117 |
+
The core of LTX-2 is a 48-layer transformer that processes both video and audio tokens simultaneously.
|
| 118 |
+
|
| 119 |
+
### Model Structure
|
| 120 |
+
|
| 121 |
+
**Source**: [`src/ltx_core/model/transformer/model.py`](src/ltx_core/model/transformer/model.py)
|
| 122 |
+
|
| 123 |
+
The `LTXModel` class implements the transformer. It supports both video-only and audio-video generation modes. For actual usage, see the [`ltx-pipelines`](../ltx-pipelines/) package which handles model loading and initialization.
|
| 124 |
+
|
| 125 |
+
### Transformer Block Architecture
|
| 126 |
+
|
| 127 |
+
**Source**: [`src/ltx_core/model/transformer/transformer.py`](src/ltx_core/model/transformer/transformer.py)
|
| 128 |
+
|
| 129 |
+
```text
|
| 130 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 131 |
+
│ TRANSFORMER BLOCK │
|
| 132 |
+
│ │
|
| 133 |
+
│ VIDEO PATH: │
|
| 134 |
+
│ Input → RMSNorm → AdaLN → Self-Attn (attn1) │
|
| 135 |
+
│ → RMSNorm → Cross-Attn (attn2, text) │
|
| 136 |
+
│ → RMSNorm → AdaLN → A↔V Cross-Attn │
|
| 137 |
+
│ → RMSNorm → AdaLN → Feed-Forward (ff) → Output │
|
| 138 |
+
│ │
|
| 139 |
+
│ AUDIO PATH: │
|
| 140 |
+
│ Input → RMSNorm → AdaLN → Self-Attn (audio_attn1) │
|
| 141 |
+
│ → RMSNorm → Cross-Attn (audio_attn2, text) │
|
| 142 |
+
│ → RMSNorm → AdaLN → A↔V Cross-Attn │
|
| 143 |
+
│ → RMSNorm → AdaLN → Feed-Forward (audio_ff) │
|
| 144 |
+
│ │
|
| 145 |
+
│ AdaLN (Adaptive Layer Normalization): │
|
| 146 |
+
│ - Uses scale_shift_table (6 params) for video/audio │
|
| 147 |
+
│ - Uses scale_shift_table_a2v_ca (5 params) for A↔V CA │
|
| 148 |
+
│ - Conditioned on per-token timestep embeddings │
|
| 149 |
+
└─────────────────────────────────────────────────────────────┘
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
### Perturbations
|
| 153 |
+
|
| 154 |
+
The transformer supports [**perturbations**](src/ltx_core/guidance/perturbations.py) that selectively skip attention operations.
|
| 155 |
+
|
| 156 |
+
Perturbations allow you to disable specific attention mechanisms during inference, which is useful for guidance techniques like STG (Spatio-Temporal Guidance).
|
| 157 |
+
|
| 158 |
+
**Supported Perturbation Types**:
|
| 159 |
+
|
| 160 |
+
- `SKIP_VIDEO_SELF_ATTN`: Skip video self-attention
|
| 161 |
+
- `SKIP_AUDIO_SELF_ATTN`: Skip audio self-attention
|
| 162 |
+
- `SKIP_A2V_CROSS_ATTN`: Skip audio-to-video cross-attention
|
| 163 |
+
- `SKIP_V2A_CROSS_ATTN`: Skip video-to-audio cross-attention
|
| 164 |
+
|
| 165 |
+
Perturbations are used internally by guidance mechanisms like STG (Spatio-Temporal Guidance). For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
## Video VAE
|
| 170 |
+
|
| 171 |
+
The Video VAE ([`src/ltx_core/model/video_vae/`](src/ltx_core/model/video_vae/)) encodes video pixels into latent representations and decodes them back.
|
| 172 |
+
|
| 173 |
+
### Architecture
|
| 174 |
+
|
| 175 |
+
- **Encoder**: Compresses `[B, 3, F, H, W]` pixels → `[B, 128, F', H/32, W/32]` latents
|
| 176 |
+
- Where `F' = 1 + (F-1)/8` (frame count must satisfy `(F-1) % 8 == 0`)
|
| 177 |
+
- Example: `[B, 3, 33, 512, 512]` → `[B, 128, 5, 16, 16]`
|
| 178 |
+
- **Decoder**: Expands `[B, 128, F, H, W]` latents → `[B, 3, F', H*32, W*32]` pixels
|
| 179 |
+
- Where `F' = 1 + (F-1)*8`
|
| 180 |
+
- Example: `[B, 128, 5, 16, 16]` → `[B, 3, 33, 512, 512]`
|
| 181 |
+
|
| 182 |
+
The Video VAE is used internally by pipelines for encoding video pixels to latents and decoding latents back to pixels. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
|
| 186 |
+
## Audio VAE
|
| 187 |
+
|
| 188 |
+
The Audio VAE ([`src/ltx_core/model/audio_vae/`](src/ltx_core/model/audio_vae/)) processes audio spectrograms.
|
| 189 |
+
|
| 190 |
+
### Audio VAE Architecture
|
| 191 |
+
|
| 192 |
+
- **Encoder**: Compresses mel spectrogram `[B, mel_bins, T]` → `[B, 8, T/4, 16]` latents
|
| 193 |
+
- Temporal downsampling: 4× (`LATENT_DOWNSAMPLE_FACTOR = 4`)
|
| 194 |
+
- Frequency bins: Fixed 16 mel bins in latent space
|
| 195 |
+
- Latent channels: 8
|
| 196 |
+
- **Decoder**: Expands `[B, 8, T, 16]` latents → mel spectrogram `[B, mel_bins, T*4]`
|
| 197 |
+
- **Vocoder**: Converts mel spectrogram → audio waveform
|
| 198 |
+
|
| 199 |
+
**Downsampling**:
|
| 200 |
+
|
| 201 |
+
- Temporal: 4× (time steps)
|
| 202 |
+
- Frequency: Variable (input mel_bins → fixed 16 in latent space)
|
| 203 |
+
|
| 204 |
+
The Audio VAE is used internally by pipelines for encoding mel spectrograms to latents and decoding latents back to mel spectrograms. The vocoder converts mel spectrograms to audio waveforms. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 205 |
+
|
| 206 |
+
---
|
| 207 |
+
|
| 208 |
+
## Text Encoding (Gemma)
|
| 209 |
+
|
| 210 |
+
LTX-2 uses **Gemma** (Google's open LLM) as the text encoder, located in [`src/ltx_core/text_encoders/gemma/`](src/ltx_core/text_encoders/gemma/).
|
| 211 |
+
|
| 212 |
+
### Text Encoder Architecture
|
| 213 |
+
|
| 214 |
+
- **Tokenizer**: Converts text → token IDs
|
| 215 |
+
- **Gemma Model**: Processes tokens → embeddings
|
| 216 |
+
- **Text Projection**: Uses `PixArtAlphaTextProjection` to project caption embeddings
|
| 217 |
+
- Two-layer MLP with GELU (tanh approximation) or SiLU activation
|
| 218 |
+
- Projects from caption channels (3840) to model dimensions
|
| 219 |
+
- **Feature Extractor**: Extracts video/audio-specific embeddings
|
| 220 |
+
- **Separate Encoders**:
|
| 221 |
+
- `AVEncoder`: For audio-video generation (outputs separate video and audio contexts)
|
| 222 |
+
- `VideoOnlyEncoder`: For video-only generation
|
| 223 |
+
|
| 224 |
+
### System Prompts
|
| 225 |
+
|
| 226 |
+
System prompts are also used to enhance user's prompts.
|
| 227 |
+
|
| 228 |
+
- **Text-to-Video**: [`gemma_t2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_t2v_system_prompt.txt)
|
| 229 |
+
- **Image-to-Video**: [`gemma_i2v_system_prompt.txt`](src/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt)
|
| 230 |
+
|
| 231 |
+
**Important**: Video and audio receive **different** context embeddings, even from the same prompt. This allows better modality-specific conditioning.
|
| 232 |
+
|
| 233 |
+
**Output Format**:
|
| 234 |
+
|
| 235 |
+
- Video context: `[B, seq_len, 4096]` - Video-specific text embeddings
|
| 236 |
+
- Audio context: `[B, seq_len, 2048]` - Audio-specific text embeddings
|
| 237 |
+
|
| 238 |
+
The text encoder is used internally by pipelines. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 239 |
+
|
| 240 |
+
---
|
| 241 |
+
|
| 242 |
+
## Upscaler
|
| 243 |
+
|
| 244 |
+
The Upscaler ([`src/ltx_core/model/upsampler/`](src/ltx_core/model/upsampler/)) upsamples latent representations for higher-resolution output.
|
| 245 |
+
|
| 246 |
+
The spatial upsampler is used internally by two-stage pipelines (e.g., [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py), [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py)) to upsample low-resolution latents before final VAE decoding. For usage examples, see the [`ltx-pipelines`](../ltx-pipelines/) package.
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
## Data Flow
|
| 251 |
+
|
| 252 |
+
### Complete Generation Pipeline
|
| 253 |
+
|
| 254 |
+
Here's how all the components work together conceptually ([`src/ltx_core/components/`](src/ltx_core/components/)):
|
| 255 |
+
|
| 256 |
+
**Pipeline Steps**:
|
| 257 |
+
|
| 258 |
+
1. **Text Encoding**: Text prompt → Gemma encoder → separate video/audio embeddings
|
| 259 |
+
2. **Latent Initialization**: Initialize noise latents in spatial format `[B, C, F, H, W]`
|
| 260 |
+
3. **Patchification**: Convert spatial latents to sequence format `[B, seq_len, dim]` for transformer
|
| 261 |
+
4. **Sigma Schedule**: Generate noise schedule (adapts to token count)
|
| 262 |
+
5. **Denoising Loop**: Iteratively denoise using transformer predictions
|
| 263 |
+
- Create Modality inputs with per-token timesteps and RoPE positions
|
| 264 |
+
- Forward pass through transformer (conditional and unconditional for CFG)
|
| 265 |
+
- Apply guidance (CFG, STG, etc.)
|
| 266 |
+
- Update latents using diffusion step (Euler, etc.)
|
| 267 |
+
6. **Unpatchification**: Convert sequence back to spatial format
|
| 268 |
+
7. **VAE Decoding**: Decode latents to pixel space (with optional upsampling for two-stage)
|
| 269 |
+
|
| 270 |
+
- [`TI2VidTwoStagesPipeline`](../ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py) - Two-stage text-to-video (recommended)
|
| 271 |
+
- [`ICLoraPipeline`](../ltx-pipelines/src/ltx_pipelines/ic_lora.py) - Video-to-video with IC-LoRA control
|
| 272 |
+
- [`DistilledPipeline`](../ltx-pipelines/src/ltx_pipelines/distilled.py) - Fast inference with distilled model
|
| 273 |
+
- [`KeyframeInterpolationPipeline`](../ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py) - Keyframe-based interpolation
|
| 274 |
+
|
| 275 |
+
See the [ltx-pipelines README](../ltx-pipelines/README.md) for usage examples.
|
| 276 |
+
|
| 277 |
+
## 🔗 Related Projects
|
| 278 |
+
|
| 279 |
+
- **[ltx-pipelines](../ltx-pipelines/)** - High-level pipeline implementations for text-to-video, image-to-video, and video-to-video
|
| 280 |
+
- **[ltx-trainer](../ltx-trainer/)** - Training and fine-tuning tools
|
packages/ltx-core/pyproject.toml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "ltx-core"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "Core implementation of Lightricks' LTX-2 model"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"torch~=2.7",
|
| 9 |
+
"torchaudio",
|
| 10 |
+
"einops",
|
| 11 |
+
"numpy",
|
| 12 |
+
"transformers",
|
| 13 |
+
"safetensors",
|
| 14 |
+
"accelerate",
|
| 15 |
+
"scipy>=1.14",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[project.optional-dependencies]
|
| 19 |
+
xformers = ["xformers"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
[tool.uv.sources]
|
| 23 |
+
xformers = { index = "pytorch" }
|
| 24 |
+
|
| 25 |
+
[[tool.uv.index]]
|
| 26 |
+
name = "pytorch"
|
| 27 |
+
url = "https://download.pytorch.org/whl/cu129"
|
| 28 |
+
explicit = true
|
| 29 |
+
|
| 30 |
+
[build-system]
|
| 31 |
+
requires = ["uv_build>=0.9.8,<0.10.0"]
|
| 32 |
+
build-backend = "uv_build"
|
| 33 |
+
|
| 34 |
+
[dependency-groups]
|
| 35 |
+
dev = [
|
| 36 |
+
"scikit-image>=0.25.2",
|
| 37 |
+
]
|
packages/ltx-core/src/ltx_core/__init__.py
ADDED
|
File without changes
|
packages/ltx-core/src/ltx_core/components/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Diffusion pipeline components.
|
| 3 |
+
Submodules:
|
| 4 |
+
diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep)
|
| 5 |
+
guiders - Guidance strategies (CFGGuider, STGGuider, APG variants)
|
| 6 |
+
noisers - Noise samplers (GaussianNoiser)
|
| 7 |
+
patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier)
|
| 8 |
+
protocols - Protocol definitions (Patchifier, etc.)
|
| 9 |
+
schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler)
|
| 10 |
+
"""
|
packages/ltx-core/src/ltx_core/components/diffusion_steps.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.components.protocols import DiffusionStepProtocol
|
| 4 |
+
from ltx_core.utils import to_velocity
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EulerDiffusionStep(DiffusionStepProtocol):
|
| 8 |
+
"""
|
| 9 |
+
First-order Euler method for diffusion sampling.
|
| 10 |
+
Takes a single step from the current noise level (sigma) to the next by
|
| 11 |
+
computing velocity from the denoised prediction and applying: sample + velocity * dt.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def step(
|
| 15 |
+
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
sigma = sigmas[step_index]
|
| 18 |
+
sigma_next = sigmas[step_index + 1]
|
| 19 |
+
dt = sigma_next - sigma
|
| 20 |
+
velocity = to_velocity(sample, sigma, denoised_sample)
|
| 21 |
+
|
| 22 |
+
return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype)
|
packages/ltx-core/src/ltx_core/components/guiders.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.components.protocols import GuiderProtocol
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class CFGGuider(GuiderProtocol):
|
| 10 |
+
"""
|
| 11 |
+
Classifier-free guidance (CFG) guider.
|
| 12 |
+
Computes the guidance delta as (scale - 1) * (cond - uncond), steering the
|
| 13 |
+
denoising process toward the conditioned prediction.
|
| 14 |
+
Attributes:
|
| 15 |
+
scale: Guidance strength. 1.0 means no guidance, higher values increase
|
| 16 |
+
adherence to the conditioning.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
scale: float
|
| 20 |
+
|
| 21 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
return (self.scale - 1) * (cond - uncond)
|
| 23 |
+
|
| 24 |
+
def enabled(self) -> bool:
|
| 25 |
+
return self.scale != 1.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class CFGStarRescalingGuider(GuiderProtocol):
|
| 30 |
+
"""
|
| 31 |
+
Calculates the CFG delta between conditioned and unconditioned samples.
|
| 32 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 33 |
+
conditioning axis within the distribution, the unconditioned sample is
|
| 34 |
+
rescaled in accordance with the norm of the conditioned sample.
|
| 35 |
+
Attributes:
|
| 36 |
+
scale (float):
|
| 37 |
+
Global guidance strength. A value of 1.0 corresponds to no extra
|
| 38 |
+
guidance beyond the base model prediction. Values > 1.0 increase
|
| 39 |
+
the influence of the conditioned sample relative to the
|
| 40 |
+
unconditioned one.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
scale: float
|
| 44 |
+
|
| 45 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
rescaled_neg = projection_coef(cond, uncond) * uncond
|
| 47 |
+
return (self.scale - 1) * (cond - rescaled_neg)
|
| 48 |
+
|
| 49 |
+
def enabled(self) -> bool:
|
| 50 |
+
return self.scale != 1.0
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass(frozen=True)
|
| 54 |
+
class STGGuider(GuiderProtocol):
|
| 55 |
+
"""
|
| 56 |
+
Calculates the STG delta between conditioned and perturbed denoised samples.
|
| 57 |
+
Perturbed samples are the result of the denoising process with perturbations,
|
| 58 |
+
e.g. attentions acting as passthrough for certain layers and modalities.
|
| 59 |
+
Attributes:
|
| 60 |
+
scale (float):
|
| 61 |
+
Global strength of the STG guidance. A value of 0.0 disables the
|
| 62 |
+
guidance. Larger values increase the correction applied in the
|
| 63 |
+
direction of (pos_denoised - perturbed_denoised).
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
scale: float
|
| 67 |
+
|
| 68 |
+
def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
return self.scale * (pos_denoised - perturbed_denoised)
|
| 70 |
+
|
| 71 |
+
def enabled(self) -> bool:
|
| 72 |
+
return self.scale != 0.0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass(frozen=True)
|
| 76 |
+
class LtxAPGGuider(GuiderProtocol):
|
| 77 |
+
"""
|
| 78 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 79 |
+
and unconditioned samples.
|
| 80 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 81 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 82 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 83 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 84 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 85 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 86 |
+
Attributes:
|
| 87 |
+
scale (float):
|
| 88 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 89 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 90 |
+
directions that change semantics but stay consistent with the
|
| 91 |
+
conditioning manifold.
|
| 92 |
+
eta (float):
|
| 93 |
+
Weight of the component of the guidance that is parallel to the
|
| 94 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 95 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 96 |
+
motion along the conditioning direction.
|
| 97 |
+
norm_threshold (float):
|
| 98 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 99 |
+
can be reduced or ignored (depending on implementation).
|
| 100 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 101 |
+
guidance signal is very small.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
scale: float
|
| 105 |
+
eta: float = 1.0
|
| 106 |
+
norm_threshold: float = 0.0
|
| 107 |
+
|
| 108 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
guidance = cond - uncond
|
| 110 |
+
if self.norm_threshold > 0:
|
| 111 |
+
ones = torch.ones_like(guidance)
|
| 112 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 113 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 114 |
+
guidance = guidance * scale_factor
|
| 115 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 116 |
+
g_parallel = proj_coeff * cond
|
| 117 |
+
g_orth = guidance - g_parallel
|
| 118 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 119 |
+
|
| 120 |
+
return g_apg * (self.scale - 1)
|
| 121 |
+
|
| 122 |
+
def enabled(self) -> bool:
|
| 123 |
+
return self.scale != 1.0
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@dataclass(frozen=False)
|
| 127 |
+
class LegacyStatefulAPGGuider(GuiderProtocol):
|
| 128 |
+
"""
|
| 129 |
+
Calculates the APG (adaptive projected guidance) delta between conditioned
|
| 130 |
+
and unconditioned samples.
|
| 131 |
+
To minimize offset in the denoising direction and move mostly along the
|
| 132 |
+
conditioning axis within the distribution, the (cond - uncond) delta is
|
| 133 |
+
decomposed into components parallel and orthogonal to the conditioned
|
| 134 |
+
sample. The `eta` parameter weights the parallel component, while `scale`
|
| 135 |
+
is applied to the orthogonal component. Optionally, a norm threshold can
|
| 136 |
+
be used to suppress guidance when the magnitude of the correction is small.
|
| 137 |
+
Attributes:
|
| 138 |
+
scale (float):
|
| 139 |
+
Strength applied to the component of the guidance that is orthogonal
|
| 140 |
+
to the conditioned sample. Controls how aggressively we move in
|
| 141 |
+
directions that change semantics but stay consistent with the
|
| 142 |
+
conditioning manifold.
|
| 143 |
+
eta (float):
|
| 144 |
+
Weight of the component of the guidance that is parallel to the
|
| 145 |
+
conditioned sample. A value of 1.0 keeps the full parallel
|
| 146 |
+
component; values in [0, 1] attenuate it, and values > 1.0 amplify
|
| 147 |
+
motion along the conditioning direction.
|
| 148 |
+
norm_threshold (float):
|
| 149 |
+
Minimum L2 norm of the guidance delta below which the guidance
|
| 150 |
+
can be reduced or ignored (depending on implementation).
|
| 151 |
+
This is useful for avoiding noisy or unstable updates when the
|
| 152 |
+
guidance signal is very small.
|
| 153 |
+
momentum (float):
|
| 154 |
+
Exponential moving-average coefficient for accumulating guidance
|
| 155 |
+
over time. running_avg = momentum * running_avg + guidance
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
scale: float
|
| 159 |
+
eta: float
|
| 160 |
+
norm_threshold: float = 5.0
|
| 161 |
+
momentum: float = 0.0
|
| 162 |
+
# it is user's responsibility not to use same APGGuider for several denoisings or different modalities
|
| 163 |
+
# in order not to share accumulated average across different denoisings or modalities
|
| 164 |
+
running_avg: torch.Tensor | None = None
|
| 165 |
+
|
| 166 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
guidance = cond - uncond
|
| 168 |
+
if self.momentum != 0:
|
| 169 |
+
if self.running_avg is None:
|
| 170 |
+
self.running_avg = guidance.clone()
|
| 171 |
+
else:
|
| 172 |
+
self.running_avg = self.momentum * self.running_avg + guidance
|
| 173 |
+
guidance = self.running_avg
|
| 174 |
+
|
| 175 |
+
if self.norm_threshold > 0:
|
| 176 |
+
ones = torch.ones_like(guidance)
|
| 177 |
+
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
| 178 |
+
scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm)
|
| 179 |
+
guidance = guidance * scale_factor
|
| 180 |
+
|
| 181 |
+
proj_coeff = projection_coef(guidance, cond)
|
| 182 |
+
g_parallel = proj_coeff * cond
|
| 183 |
+
g_orth = guidance - g_parallel
|
| 184 |
+
g_apg = g_parallel * self.eta + g_orth
|
| 185 |
+
|
| 186 |
+
return g_apg * self.scale
|
| 187 |
+
|
| 188 |
+
def enabled(self) -> bool:
|
| 189 |
+
return self.scale != 0.0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor:
|
| 193 |
+
batch_size = to_project.shape[0]
|
| 194 |
+
positive_flat = to_project.reshape(batch_size, -1)
|
| 195 |
+
negative_flat = project_onto.reshape(batch_size, -1)
|
| 196 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 197 |
+
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
| 198 |
+
return dot_product / squared_norm
|
packages/ltx-core/src/ltx_core/components/noisers.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
+
from typing import Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.types import LatentState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Noiser(Protocol):
|
| 10 |
+
"""Protocol for adding noise to a latent state during diffusion."""
|
| 11 |
+
|
| 12 |
+
def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ...
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GaussianNoiser(Noiser):
|
| 16 |
+
"""Adds Gaussian noise to a latent state, scaled by the denoise mask."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator: torch.Generator):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.generator = generator
|
| 22 |
+
|
| 23 |
+
def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState:
|
| 24 |
+
noise = torch.randn(
|
| 25 |
+
*latent_state.latent.shape,
|
| 26 |
+
device=latent_state.latent.device,
|
| 27 |
+
dtype=latent_state.latent.dtype,
|
| 28 |
+
generator=self.generator,
|
| 29 |
+
)
|
| 30 |
+
scaled_mask = latent_state.denoise_mask * noise_scale
|
| 31 |
+
latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask)
|
| 32 |
+
return replace(
|
| 33 |
+
latent_state,
|
| 34 |
+
latent=latent.to(latent_state.latent.dtype),
|
| 35 |
+
)
|
packages/ltx-core/src/ltx_core/components/patchifiers.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.components.protocols import Patchifier
|
| 8 |
+
from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VideoLatentPatchifier(Patchifier):
|
| 12 |
+
def __init__(self, patch_size: int):
|
| 13 |
+
# Patch sizes for video latents.
|
| 14 |
+
self._patch_size = (
|
| 15 |
+
1, # temporal dimension
|
| 16 |
+
patch_size, # height dimension
|
| 17 |
+
patch_size, # width dimension
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 22 |
+
return self._patch_size
|
| 23 |
+
|
| 24 |
+
def get_token_count(self, tgt_shape: VideoLatentShape) -> int:
|
| 25 |
+
return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size)
|
| 26 |
+
|
| 27 |
+
def patchify(
|
| 28 |
+
self,
|
| 29 |
+
latents: torch.Tensor,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
latents = einops.rearrange(
|
| 32 |
+
latents,
|
| 33 |
+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
| 34 |
+
p1=self._patch_size[0],
|
| 35 |
+
p2=self._patch_size[1],
|
| 36 |
+
p3=self._patch_size[2],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return latents
|
| 40 |
+
|
| 41 |
+
def unpatchify(
|
| 42 |
+
self,
|
| 43 |
+
latents: torch.Tensor,
|
| 44 |
+
output_shape: VideoLatentShape,
|
| 45 |
+
) -> torch.Tensor:
|
| 46 |
+
assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier"
|
| 47 |
+
|
| 48 |
+
patch_grid_frames = output_shape.frames // self._patch_size[0]
|
| 49 |
+
patch_grid_height = output_shape.height // self._patch_size[1]
|
| 50 |
+
patch_grid_width = output_shape.width // self._patch_size[2]
|
| 51 |
+
|
| 52 |
+
latents = einops.rearrange(
|
| 53 |
+
latents,
|
| 54 |
+
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
| 55 |
+
f=patch_grid_frames,
|
| 56 |
+
h=patch_grid_height,
|
| 57 |
+
w=patch_grid_width,
|
| 58 |
+
p=self._patch_size[1],
|
| 59 |
+
q=self._patch_size[2],
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return latents
|
| 63 |
+
|
| 64 |
+
def get_patch_grid_bounds(
|
| 65 |
+
self,
|
| 66 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 67 |
+
device: Optional[torch.device] = None,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Return the per-dimension bounds [inclusive start, exclusive end) for every
|
| 71 |
+
patch produced by `patchify`. The bounds are expressed in the original
|
| 72 |
+
video grid coordinates: frame/time, height, and width.
|
| 73 |
+
The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where:
|
| 74 |
+
- axis 1 (size 3) enumerates (frame/time, height, width) dimensions
|
| 75 |
+
- axis 3 (size 2) stores `[start, end)` indices within each dimension
|
| 76 |
+
Args:
|
| 77 |
+
output_shape: Video grid description containing frames, height, and width.
|
| 78 |
+
device: Device of the latent tensor.
|
| 79 |
+
"""
|
| 80 |
+
if not isinstance(output_shape, VideoLatentShape):
|
| 81 |
+
raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates")
|
| 82 |
+
|
| 83 |
+
frames = output_shape.frames
|
| 84 |
+
height = output_shape.height
|
| 85 |
+
width = output_shape.width
|
| 86 |
+
batch_size = output_shape.batch
|
| 87 |
+
|
| 88 |
+
# Validate inputs to ensure positive dimensions
|
| 89 |
+
assert frames > 0, f"frames must be positive, got {frames}"
|
| 90 |
+
assert height > 0, f"height must be positive, got {height}"
|
| 91 |
+
assert width > 0, f"width must be positive, got {width}"
|
| 92 |
+
assert batch_size > 0, f"batch_size must be positive, got {batch_size}"
|
| 93 |
+
|
| 94 |
+
# Generate grid coordinates for each dimension (frame, height, width)
|
| 95 |
+
# We use torch.arange to create the starting coordinates for each patch.
|
| 96 |
+
# indexing='ij' ensures the dimensions are in the order (frame, height, width).
|
| 97 |
+
grid_coords = torch.meshgrid(
|
| 98 |
+
torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
|
| 99 |
+
torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
|
| 100 |
+
torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
|
| 101 |
+
indexing="ij",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Stack the grid coordinates to create the start coordinates tensor.
|
| 105 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 106 |
+
patch_starts = torch.stack(grid_coords, dim=0)
|
| 107 |
+
|
| 108 |
+
# Create a tensor containing the size of a single patch:
|
| 109 |
+
# (frame_patch_size, height_patch_size, width_patch_size).
|
| 110 |
+
# Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates.
|
| 111 |
+
patch_size_delta = torch.tensor(
|
| 112 |
+
self._patch_size,
|
| 113 |
+
device=patch_starts.device,
|
| 114 |
+
dtype=patch_starts.dtype,
|
| 115 |
+
).view(3, 1, 1, 1)
|
| 116 |
+
|
| 117 |
+
# Calculate end coordinates: start + patch_size
|
| 118 |
+
# Shape becomes (3, grid_f, grid_h, grid_w)
|
| 119 |
+
patch_ends = patch_starts + patch_size_delta
|
| 120 |
+
|
| 121 |
+
# Stack start and end coordinates together along the last dimension
|
| 122 |
+
# Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end]
|
| 123 |
+
latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)
|
| 124 |
+
|
| 125 |
+
# Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence.
|
| 126 |
+
# Final Shape: (batch_size, 3, num_patches, 2)
|
| 127 |
+
latent_coords = einops.repeat(
|
| 128 |
+
latent_coords,
|
| 129 |
+
"c f h w bounds -> b c (f h w) bounds",
|
| 130 |
+
b=batch_size,
|
| 131 |
+
bounds=2,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return latent_coords
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_pixel_coords(
|
| 138 |
+
latent_coords: torch.Tensor,
|
| 139 |
+
scale_factors: SpatioTemporalScaleFactors,
|
| 140 |
+
causal_fix: bool = False,
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
"""
|
| 143 |
+
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
| 144 |
+
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
| 145 |
+
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
| 146 |
+
Args:
|
| 147 |
+
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
| 148 |
+
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
| 149 |
+
per axis.
|
| 150 |
+
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
| 151 |
+
that treat frame zero differently still yield non-negative timestamps.
|
| 152 |
+
"""
|
| 153 |
+
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
| 154 |
+
broadcast_shape = [1] * latent_coords.ndim
|
| 155 |
+
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
| 156 |
+
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
| 157 |
+
|
| 158 |
+
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
| 159 |
+
pixel_coords = latent_coords * scale_tensor
|
| 160 |
+
|
| 161 |
+
if causal_fix:
|
| 162 |
+
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
| 163 |
+
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
| 164 |
+
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
| 165 |
+
|
| 166 |
+
return pixel_coords
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class AudioPatchifier(Patchifier):
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
patch_size: int,
|
| 173 |
+
sample_rate: int = 16000,
|
| 174 |
+
hop_length: int = 160,
|
| 175 |
+
audio_latent_downsample_factor: int = 4,
|
| 176 |
+
is_causal: bool = True,
|
| 177 |
+
shift: int = 0,
|
| 178 |
+
):
|
| 179 |
+
"""
|
| 180 |
+
Patchifier tailored for spectrogram/audio latents.
|
| 181 |
+
Args:
|
| 182 |
+
patch_size: Number of mel bins combined into a single patch. This
|
| 183 |
+
controls the resolution along the frequency axis.
|
| 184 |
+
sample_rate: Original waveform sampling rate. Used to map latent
|
| 185 |
+
indices back to seconds so downstream consumers can align audio
|
| 186 |
+
and video cues.
|
| 187 |
+
hop_length: Window hop length used for the spectrogram. Determines
|
| 188 |
+
how many real-time samples separate two consecutive latent frames.
|
| 189 |
+
audio_latent_downsample_factor: Ratio between spectrogram frames and
|
| 190 |
+
latent frames; compensates for additional downsampling inside the
|
| 191 |
+
VAE encoder.
|
| 192 |
+
is_causal: When True, timing is shifted to account for causal
|
| 193 |
+
receptive fields so timestamps do not peek into the future.
|
| 194 |
+
shift: Integer offset applied to the latent indices. Enables
|
| 195 |
+
constructing overlapping windows from the same latent sequence.
|
| 196 |
+
"""
|
| 197 |
+
self.hop_length = hop_length
|
| 198 |
+
self.sample_rate = sample_rate
|
| 199 |
+
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
| 200 |
+
self.is_causal = is_causal
|
| 201 |
+
self.shift = shift
|
| 202 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 206 |
+
return self._patch_size
|
| 207 |
+
|
| 208 |
+
def get_token_count(self, tgt_shape: AudioLatentShape) -> int:
|
| 209 |
+
return tgt_shape.frames
|
| 210 |
+
|
| 211 |
+
def _get_audio_latent_time_in_sec(
|
| 212 |
+
self,
|
| 213 |
+
start_latent: int,
|
| 214 |
+
end_latent: int,
|
| 215 |
+
dtype: torch.dtype,
|
| 216 |
+
device: Optional[torch.device] = None,
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
"""
|
| 219 |
+
Converts latent indices into real-time seconds while honoring causal
|
| 220 |
+
offsets and the configured hop length.
|
| 221 |
+
Args:
|
| 222 |
+
start_latent: Inclusive start index inside the latent sequence. This
|
| 223 |
+
sets the first timestamp returned.
|
| 224 |
+
end_latent: Exclusive end index. Determines how many timestamps get
|
| 225 |
+
generated.
|
| 226 |
+
dtype: Floating-point dtype used for the returned tensor, allowing
|
| 227 |
+
callers to control precision.
|
| 228 |
+
device: Target device for the timestamp tensor. When omitted the
|
| 229 |
+
computation occurs on CPU to avoid surprising GPU allocations.
|
| 230 |
+
"""
|
| 231 |
+
if device is None:
|
| 232 |
+
device = torch.device("cpu")
|
| 233 |
+
|
| 234 |
+
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
|
| 235 |
+
|
| 236 |
+
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
|
| 237 |
+
|
| 238 |
+
if self.is_causal:
|
| 239 |
+
# Frame offset for causal alignment.
|
| 240 |
+
# The "+1" ensures the timestamp corresponds to the first sample that is fully available.
|
| 241 |
+
causal_offset = 1
|
| 242 |
+
audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)
|
| 243 |
+
|
| 244 |
+
return audio_mel_frame * self.hop_length / self.sample_rate
|
| 245 |
+
|
| 246 |
+
def _compute_audio_timings(
|
| 247 |
+
self,
|
| 248 |
+
batch_size: int,
|
| 249 |
+
num_steps: int,
|
| 250 |
+
device: Optional[torch.device] = None,
|
| 251 |
+
) -> torch.Tensor:
|
| 252 |
+
"""
|
| 253 |
+
Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame.
|
| 254 |
+
This helper method underpins `get_patch_grid_bounds` for the audio patchifier.
|
| 255 |
+
Args:
|
| 256 |
+
batch_size: Number of sequences to broadcast the timings over.
|
| 257 |
+
num_steps: Number of latent frames (time steps) to convert into timestamps.
|
| 258 |
+
device: Device on which the resulting tensor should reside.
|
| 259 |
+
"""
|
| 260 |
+
resolved_device = device
|
| 261 |
+
if resolved_device is None:
|
| 262 |
+
resolved_device = torch.device("cpu")
|
| 263 |
+
|
| 264 |
+
start_timings = self._get_audio_latent_time_in_sec(
|
| 265 |
+
self.shift,
|
| 266 |
+
num_steps + self.shift,
|
| 267 |
+
torch.float32,
|
| 268 |
+
resolved_device,
|
| 269 |
+
)
|
| 270 |
+
start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 271 |
+
|
| 272 |
+
end_timings = self._get_audio_latent_time_in_sec(
|
| 273 |
+
self.shift + 1,
|
| 274 |
+
num_steps + self.shift + 1,
|
| 275 |
+
torch.float32,
|
| 276 |
+
resolved_device,
|
| 277 |
+
)
|
| 278 |
+
end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
|
| 279 |
+
|
| 280 |
+
return torch.stack([start_timings, end_timings], dim=-1)
|
| 281 |
+
|
| 282 |
+
def patchify(
|
| 283 |
+
self,
|
| 284 |
+
audio_latents: torch.Tensor,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
"""
|
| 287 |
+
Flattens the audio latent tensor along time. Use `get_patch_grid_bounds`
|
| 288 |
+
to derive timestamps for each latent frame based on the configured hop
|
| 289 |
+
length and downsampling.
|
| 290 |
+
Args:
|
| 291 |
+
audio_latents: Latent tensor to patchify.
|
| 292 |
+
Returns:
|
| 293 |
+
Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the
|
| 294 |
+
corresponding timing metadata when needed.
|
| 295 |
+
"""
|
| 296 |
+
audio_latents = einops.rearrange(
|
| 297 |
+
audio_latents,
|
| 298 |
+
"b c t f -> b t (c f)",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return audio_latents
|
| 302 |
+
|
| 303 |
+
def unpatchify(
|
| 304 |
+
self,
|
| 305 |
+
audio_latents: torch.Tensor,
|
| 306 |
+
output_shape: AudioLatentShape,
|
| 307 |
+
) -> torch.Tensor:
|
| 308 |
+
"""
|
| 309 |
+
Restores the `(B, C, T, F)` spectrogram tensor from flattened patches.
|
| 310 |
+
Use `get_patch_grid_bounds` to recompute the timestamps that describe each
|
| 311 |
+
frame's position in real time.
|
| 312 |
+
Args:
|
| 313 |
+
audio_latents: Latent tensor to unpatchify.
|
| 314 |
+
output_shape: Shape of the unpatched output tensor.
|
| 315 |
+
Returns:
|
| 316 |
+
Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing
|
| 317 |
+
metadata associated with the restored latents.
|
| 318 |
+
"""
|
| 319 |
+
# audio_latents shape: (batch, time, freq * channels)
|
| 320 |
+
audio_latents = einops.rearrange(
|
| 321 |
+
audio_latents,
|
| 322 |
+
"b t (c f) -> b c t f",
|
| 323 |
+
c=output_shape.channels,
|
| 324 |
+
f=output_shape.mel_bins,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
return audio_latents
|
| 328 |
+
|
| 329 |
+
def get_patch_grid_bounds(
|
| 330 |
+
self,
|
| 331 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 332 |
+
device: Optional[torch.device] = None,
|
| 333 |
+
) -> torch.Tensor:
|
| 334 |
+
"""
|
| 335 |
+
Return the temporal bounds `[inclusive start, exclusive end)` for every
|
| 336 |
+
patch emitted by `patchify`. For audio this corresponds to timestamps in
|
| 337 |
+
seconds aligned with the original spectrogram grid.
|
| 338 |
+
The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where:
|
| 339 |
+
- axis 1 (size 1) represents the temporal dimension
|
| 340 |
+
- axis 3 (size 2) stores the `[start, end)` timestamps per patch
|
| 341 |
+
Args:
|
| 342 |
+
output_shape: Audio grid specification describing the number of time steps.
|
| 343 |
+
device: Target device for the returned tensor.
|
| 344 |
+
"""
|
| 345 |
+
if not isinstance(output_shape, AudioLatentShape):
|
| 346 |
+
raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates")
|
| 347 |
+
|
| 348 |
+
return self._compute_audio_timings(output_shape.batch, output_shape.frames, device)
|
packages/ltx-core/src/ltx_core/components/protocols.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.types import AudioLatentShape, VideoLatentShape
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Patchifier(Protocol):
|
| 9 |
+
"""
|
| 10 |
+
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def patchify(
|
| 14 |
+
self,
|
| 15 |
+
latents: torch.Tensor,
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
...
|
| 18 |
+
"""
|
| 19 |
+
Convert latent tensors into flattened patch tokens.
|
| 20 |
+
Args:
|
| 21 |
+
latents: Latent tensor to patchify.
|
| 22 |
+
Returns:
|
| 23 |
+
Flattened patch tokens tensor.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def unpatchify(
|
| 27 |
+
self,
|
| 28 |
+
latents: torch.Tensor,
|
| 29 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
| 33 |
+
Args:
|
| 34 |
+
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
| 35 |
+
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
| 36 |
+
VideoLatentShape.
|
| 37 |
+
Returns:
|
| 38 |
+
Dense latent tensor restored from the flattened representation.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def patch_size(self) -> Tuple[int, int, int]:
|
| 43 |
+
...
|
| 44 |
+
"""
|
| 45 |
+
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def get_patch_grid_bounds(
|
| 49 |
+
self,
|
| 50 |
+
output_shape: AudioLatentShape | VideoLatentShape,
|
| 51 |
+
device: torch.device | None = None,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
...
|
| 54 |
+
"""
|
| 55 |
+
Compute metadata describing where each latent patch resides within the
|
| 56 |
+
grid specified by `output_shape`.
|
| 57 |
+
Args:
|
| 58 |
+
output_shape: Target grid layout for the patches.
|
| 59 |
+
device: Target device for the returned tensor.
|
| 60 |
+
Returns:
|
| 61 |
+
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SchedulerProtocol(Protocol):
|
| 66 |
+
"""
|
| 67 |
+
Protocol for schedulers that provide a sigmas schedule tensor for a
|
| 68 |
+
given number of steps. Device is cpu.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ...
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class GuiderProtocol(Protocol):
|
| 75 |
+
"""
|
| 76 |
+
Protocol for guiders that compute a delta tensor given conditioning inputs.
|
| 77 |
+
The returned delta should be added to the conditional output (cond), enabling
|
| 78 |
+
multiple guiders to be chained together by accumulating their deltas.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
scale: float
|
| 82 |
+
|
| 83 |
+
def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ...
|
| 84 |
+
|
| 85 |
+
def enabled(self) -> bool:
|
| 86 |
+
"""
|
| 87 |
+
Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale
|
| 88 |
+
is 1.0.
|
| 89 |
+
"""
|
| 90 |
+
...
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class DiffusionStepProtocol(Protocol):
|
| 94 |
+
"""
|
| 95 |
+
Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor,
|
| 96 |
+
current denoised sample tensor, and sigmas tensor.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def step(
|
| 100 |
+
self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int
|
| 101 |
+
) -> torch.Tensor: ...
|
packages/ltx-core/src/ltx_core/components/schedulers.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
import numpy
|
| 5 |
+
import scipy
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ltx_core.components.protocols import SchedulerProtocol
|
| 9 |
+
|
| 10 |
+
BASE_SHIFT_ANCHOR = 1024
|
| 11 |
+
MAX_SHIFT_ANCHOR = 4096
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LTX2Scheduler(SchedulerProtocol):
|
| 15 |
+
"""
|
| 16 |
+
Default scheduler for LTX-2 diffusion sampling.
|
| 17 |
+
Generates a sigma schedule with token-count-dependent shifting and optional
|
| 18 |
+
stretching to a terminal value.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def execute(
|
| 22 |
+
self,
|
| 23 |
+
steps: int,
|
| 24 |
+
latent: torch.Tensor | None = None,
|
| 25 |
+
max_shift: float = 2.05,
|
| 26 |
+
base_shift: float = 0.95,
|
| 27 |
+
stretch: bool = True,
|
| 28 |
+
terminal: float = 0.1,
|
| 29 |
+
**_kwargs,
|
| 30 |
+
) -> torch.FloatTensor:
|
| 31 |
+
tokens = math.prod(latent.shape[2:]) if latent is not None else MAX_SHIFT_ANCHOR
|
| 32 |
+
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
| 33 |
+
|
| 34 |
+
x1 = BASE_SHIFT_ANCHOR
|
| 35 |
+
x2 = MAX_SHIFT_ANCHOR
|
| 36 |
+
mm = (max_shift - base_shift) / (x2 - x1)
|
| 37 |
+
b = base_shift - mm * x1
|
| 38 |
+
sigma_shift = (tokens) * mm + b
|
| 39 |
+
|
| 40 |
+
power = 1
|
| 41 |
+
sigmas = torch.where(
|
| 42 |
+
sigmas != 0,
|
| 43 |
+
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
| 44 |
+
0,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Stretch sigmas so that its final value matches the given terminal value.
|
| 48 |
+
if stretch:
|
| 49 |
+
non_zero_mask = sigmas != 0
|
| 50 |
+
non_zero_sigmas = sigmas[non_zero_mask]
|
| 51 |
+
one_minus_z = 1.0 - non_zero_sigmas
|
| 52 |
+
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
| 53 |
+
stretched = 1.0 - (one_minus_z / scale_factor)
|
| 54 |
+
sigmas[non_zero_mask] = stretched
|
| 55 |
+
|
| 56 |
+
return sigmas.to(torch.float32)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class LinearQuadraticScheduler(SchedulerProtocol):
|
| 60 |
+
"""
|
| 61 |
+
Scheduler with linear steps followed by quadratic steps.
|
| 62 |
+
Produces a sigma schedule that transitions linearly up to a threshold,
|
| 63 |
+
then follows a quadratic curve for the remaining steps.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def execute(
|
| 67 |
+
self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs
|
| 68 |
+
) -> torch.FloatTensor:
|
| 69 |
+
if steps == 1:
|
| 70 |
+
return torch.FloatTensor([1.0, 0.0])
|
| 71 |
+
|
| 72 |
+
if linear_steps is None:
|
| 73 |
+
linear_steps = steps // 2
|
| 74 |
+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
| 75 |
+
threshold_noise_step_diff = linear_steps - threshold_noise * steps
|
| 76 |
+
quadratic_steps = steps - linear_steps
|
| 77 |
+
quadratic_sigma_schedule = []
|
| 78 |
+
if quadratic_steps > 0:
|
| 79 |
+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
| 80 |
+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
| 81 |
+
const = quadratic_coef * (linear_steps**2)
|
| 82 |
+
quadratic_sigma_schedule = [
|
| 83 |
+
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps)
|
| 84 |
+
]
|
| 85 |
+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
| 86 |
+
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
| 87 |
+
return torch.FloatTensor(sigma_schedule)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class BetaScheduler(SchedulerProtocol):
|
| 91 |
+
"""
|
| 92 |
+
Scheduler using a beta distribution to sample timesteps.
|
| 93 |
+
Based on: https://arxiv.org/abs/2407.12173
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
shift = 2.37
|
| 97 |
+
timesteps_length = 10000
|
| 98 |
+
|
| 99 |
+
def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor:
|
| 100 |
+
"""
|
| 101 |
+
Execute the beta scheduler.
|
| 102 |
+
Args:
|
| 103 |
+
steps: The number of steps to execute the scheduler for.
|
| 104 |
+
alpha: The alpha parameter for the beta distribution.
|
| 105 |
+
beta: The beta parameter for the beta distribution.
|
| 106 |
+
Warnings:
|
| 107 |
+
The number of steps within `sigmas` theoretically might be less than `steps+1`,
|
| 108 |
+
because of the deduplication of the identical timesteps
|
| 109 |
+
Returns:
|
| 110 |
+
A tensor of sigmas.
|
| 111 |
+
"""
|
| 112 |
+
model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length)
|
| 113 |
+
total_timesteps = len(model_sampling_sigmas) - 1
|
| 114 |
+
ts = 1 - numpy.linspace(0, 1, steps, endpoint=False)
|
| 115 |
+
ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist()
|
| 116 |
+
ts = list(dict.fromkeys(ts))
|
| 117 |
+
|
| 118 |
+
sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0]
|
| 119 |
+
return torch.FloatTensor(sigmas)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@lru_cache(maxsize=5)
|
| 123 |
+
def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor:
|
| 124 |
+
timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length
|
| 125 |
+
return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps])
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def flux_time_shift(mu: float, sigma: float, t: float) -> float:
|
| 129 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
packages/ltx-core/src/ltx_core/conditioning/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conditioning utilities: latent state, tools, and conditioning types."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.exceptions import ConditioningError
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.conditioning.types import VideoConditionByKeyframeIndex, VideoConditionByLatentIndex
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"ConditioningError",
|
| 9 |
+
"ConditioningItem",
|
| 10 |
+
"VideoConditionByKeyframeIndex",
|
| 11 |
+
"VideoConditionByLatentIndex",
|
| 12 |
+
]
|
packages/ltx-core/src/ltx_core/conditioning/exceptions.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ConditioningError(Exception):
|
| 2 |
+
"""
|
| 3 |
+
Class for conditioning-related errors.
|
| 4 |
+
"""
|
packages/ltx-core/src/ltx_core/conditioning/item.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol
|
| 2 |
+
|
| 3 |
+
from ltx_core.tools import LatentTools
|
| 4 |
+
from ltx_core.types import LatentState
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConditioningItem(Protocol):
|
| 8 |
+
"""Protocol for conditioning items that modify latent state during diffusion."""
|
| 9 |
+
|
| 10 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 11 |
+
"""
|
| 12 |
+
Apply the conditioning to the latent state.
|
| 13 |
+
Args:
|
| 14 |
+
latent_state: The latent state to apply the conditioning to. This is state always patchified.
|
| 15 |
+
Returns:
|
| 16 |
+
The latent state after the conditioning has been applied.
|
| 17 |
+
IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the
|
| 18 |
+
latent.
|
| 19 |
+
"""
|
| 20 |
+
...
|
packages/ltx-core/src/ltx_core/conditioning/types/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conditioning type implementations."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex
|
| 4 |
+
from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"VideoConditionByKeyframeIndex",
|
| 8 |
+
"VideoConditionByLatentIndex",
|
| 9 |
+
]
|
packages/ltx-core/src/ltx_core/conditioning/types/keyframe_cond.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.components.patchifiers import get_pixel_coords
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.tools import VideoLatentTools
|
| 6 |
+
from ltx_core.types import LatentState, VideoLatentShape
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VideoConditionByKeyframeIndex(ConditioningItem):
|
| 10 |
+
"""
|
| 11 |
+
Conditions video generation on keyframe latents at a specific frame index.
|
| 12 |
+
Appends keyframe tokens to the latent state with positions offset by frame_idx,
|
| 13 |
+
and sets denoise strength according to the strength parameter.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float):
|
| 17 |
+
self.keyframes = keyframes
|
| 18 |
+
self.frame_idx = frame_idx
|
| 19 |
+
self.strength = strength
|
| 20 |
+
|
| 21 |
+
def apply_to(
|
| 22 |
+
self,
|
| 23 |
+
latent_state: LatentState,
|
| 24 |
+
latent_tools: VideoLatentTools,
|
| 25 |
+
) -> LatentState:
|
| 26 |
+
tokens = latent_tools.patchifier.patchify(self.keyframes)
|
| 27 |
+
latent_coords = latent_tools.patchifier.get_patch_grid_bounds(
|
| 28 |
+
output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape),
|
| 29 |
+
device=self.keyframes.device,
|
| 30 |
+
)
|
| 31 |
+
positions = get_pixel_coords(
|
| 32 |
+
latent_coords=latent_coords,
|
| 33 |
+
scale_factors=latent_tools.scale_factors,
|
| 34 |
+
causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
positions[:, 0, ...] += self.frame_idx
|
| 38 |
+
positions = positions.to(dtype=torch.float32)
|
| 39 |
+
positions[:, 0, ...] /= latent_tools.fps
|
| 40 |
+
|
| 41 |
+
denoise_mask = torch.full(
|
| 42 |
+
size=(*tokens.shape[:2], 1),
|
| 43 |
+
fill_value=1.0 - self.strength,
|
| 44 |
+
device=self.keyframes.device,
|
| 45 |
+
dtype=self.keyframes.dtype,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return LatentState(
|
| 49 |
+
latent=torch.cat([latent_state.latent, tokens], dim=1),
|
| 50 |
+
denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1),
|
| 51 |
+
positions=torch.cat([latent_state.positions, positions], dim=2),
|
| 52 |
+
clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1),
|
| 53 |
+
)
|
packages/ltx-core/src/ltx_core/conditioning/types/latent_cond.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.conditioning.exceptions import ConditioningError
|
| 4 |
+
from ltx_core.conditioning.item import ConditioningItem
|
| 5 |
+
from ltx_core.tools import LatentTools
|
| 6 |
+
from ltx_core.types import LatentState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VideoConditionByLatentIndex(ConditioningItem):
|
| 10 |
+
"""
|
| 11 |
+
Conditions video generation by injecting latents at a specific latent frame index.
|
| 12 |
+
Replaces tokens in the latent state at positions corresponding to latent_idx,
|
| 13 |
+
and sets denoise strength according to the strength parameter.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int):
|
| 17 |
+
self.latent = latent
|
| 18 |
+
self.strength = strength
|
| 19 |
+
self.latent_idx = latent_idx
|
| 20 |
+
|
| 21 |
+
def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState:
|
| 22 |
+
cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape
|
| 23 |
+
tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape()
|
| 24 |
+
|
| 25 |
+
if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width):
|
| 26 |
+
raise ConditioningError(
|
| 27 |
+
f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected "
|
| 28 |
+
f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure "
|
| 29 |
+
"the image and latent have the same spatial shape."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
tokens = latent_tools.patchifier.patchify(self.latent)
|
| 33 |
+
start_token = latent_tools.patchifier.get_token_count(
|
| 34 |
+
latent_tools.target_shape._replace(frames=self.latent_idx)
|
| 35 |
+
)
|
| 36 |
+
stop_token = start_token + tokens.shape[1]
|
| 37 |
+
|
| 38 |
+
latent_state = latent_state.clone()
|
| 39 |
+
|
| 40 |
+
latent_state.latent[:, start_token:stop_token] = tokens
|
| 41 |
+
latent_state.clean_latent[:, start_token:stop_token] = tokens
|
| 42 |
+
latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength
|
| 43 |
+
|
| 44 |
+
return latent_state
|
packages/ltx-core/src/ltx_core/guidance/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Guidance and perturbation utilities for attention manipulation."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.guidance.perturbations import (
|
| 4 |
+
BatchedPerturbationConfig,
|
| 5 |
+
Perturbation,
|
| 6 |
+
PerturbationConfig,
|
| 7 |
+
PerturbationType,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"BatchedPerturbationConfig",
|
| 12 |
+
"Perturbation",
|
| 13 |
+
"PerturbationConfig",
|
| 14 |
+
"PerturbationType",
|
| 15 |
+
]
|
packages/ltx-core/src/ltx_core/guidance/perturbations.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch._prims_common import DeviceLikeType
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PerturbationType(Enum):
|
| 9 |
+
"""Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
|
| 10 |
+
|
| 11 |
+
SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
|
| 12 |
+
SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
|
| 13 |
+
SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
|
| 14 |
+
SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class Perturbation:
|
| 19 |
+
"""A single perturbation specifying which attention type to skip and in which blocks."""
|
| 20 |
+
|
| 21 |
+
type: PerturbationType
|
| 22 |
+
blocks: list[int] | None # None means all blocks
|
| 23 |
+
|
| 24 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 25 |
+
if self.type != perturbation_type:
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
if self.blocks is None:
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
return block in self.blocks
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass(frozen=True)
|
| 35 |
+
class PerturbationConfig:
|
| 36 |
+
"""Configuration holding a list of perturbations for a single sample."""
|
| 37 |
+
|
| 38 |
+
perturbations: list[Perturbation] | None
|
| 39 |
+
|
| 40 |
+
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 41 |
+
if self.perturbations is None:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def empty() -> "PerturbationConfig":
|
| 48 |
+
return PerturbationConfig([])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass(frozen=True)
|
| 52 |
+
class BatchedPerturbationConfig:
|
| 53 |
+
"""Perturbation configurations for a batch, with utilities for generating attention masks."""
|
| 54 |
+
|
| 55 |
+
perturbations: list[PerturbationConfig]
|
| 56 |
+
|
| 57 |
+
def mask(
|
| 58 |
+
self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype)
|
| 61 |
+
for batch_idx, perturbation in enumerate(self.perturbations):
|
| 62 |
+
if perturbation.is_perturbed(perturbation_type, block):
|
| 63 |
+
mask[batch_idx] = 0
|
| 64 |
+
|
| 65 |
+
return mask
|
| 66 |
+
|
| 67 |
+
def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
mask = self.mask(perturbation_type, block, values.device, values.dtype)
|
| 69 |
+
return mask.view(mask.numel(), *([1] * len(values.shape[1:])))
|
| 70 |
+
|
| 71 |
+
def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 72 |
+
return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 73 |
+
|
| 74 |
+
def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool:
|
| 75 |
+
return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations)
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def empty(batch_size: int) -> "BatchedPerturbationConfig":
|
| 79 |
+
return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)])
|
packages/ltx-core/src/ltx_core/loader/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loader utilities for model weights, LoRAs, and safetensor operations."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
| 4 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 5 |
+
from ltx_core.loader.primitives import (
|
| 6 |
+
LoRAAdaptableProtocol,
|
| 7 |
+
LoraPathStrengthAndSDOps,
|
| 8 |
+
LoraStateDictWithStrength,
|
| 9 |
+
ModelBuilderProtocol,
|
| 10 |
+
StateDict,
|
| 11 |
+
StateDictLoader,
|
| 12 |
+
)
|
| 13 |
+
from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry
|
| 14 |
+
from ltx_core.loader.sd_ops import (
|
| 15 |
+
LTXV_LORA_COMFY_RENAMING_MAP,
|
| 16 |
+
ContentMatching,
|
| 17 |
+
ContentReplacement,
|
| 18 |
+
KeyValueOperation,
|
| 19 |
+
KeyValueOperationResult,
|
| 20 |
+
SDKeyValueOperation,
|
| 21 |
+
SDOps,
|
| 22 |
+
)
|
| 23 |
+
from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader
|
| 24 |
+
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"LTXV_LORA_COMFY_RENAMING_MAP",
|
| 28 |
+
"ContentMatching",
|
| 29 |
+
"ContentReplacement",
|
| 30 |
+
"DummyRegistry",
|
| 31 |
+
"KeyValueOperation",
|
| 32 |
+
"KeyValueOperationResult",
|
| 33 |
+
"LoRAAdaptableProtocol",
|
| 34 |
+
"LoraPathStrengthAndSDOps",
|
| 35 |
+
"LoraStateDictWithStrength",
|
| 36 |
+
"ModelBuilderProtocol",
|
| 37 |
+
"ModuleOps",
|
| 38 |
+
"Registry",
|
| 39 |
+
"SDKeyValueOperation",
|
| 40 |
+
"SDOps",
|
| 41 |
+
"SafetensorsModelStateDictLoader",
|
| 42 |
+
"SafetensorsStateDictLoader",
|
| 43 |
+
"SingleGPUModelBuilder",
|
| 44 |
+
"StateDict",
|
| 45 |
+
"StateDictLoader",
|
| 46 |
+
"StateDictRegistry",
|
| 47 |
+
"apply_loras",
|
| 48 |
+
]
|
packages/ltx-core/src/ltx_core/loader/fuse_loras.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import triton
|
| 3 |
+
|
| 4 |
+
from ltx_core.loader.kernels import fused_add_round_kernel
|
| 5 |
+
from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
|
| 6 |
+
|
| 7 |
+
BLOCK_SIZE = 1024
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor:
|
| 11 |
+
if original_weight.dtype == torch.float8_e4m3fn:
|
| 12 |
+
exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7
|
| 13 |
+
elif original_weight.dtype == torch.float8_e5m2:
|
| 14 |
+
exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841
|
| 15 |
+
else:
|
| 16 |
+
raise ValueError("Unsupported dtype")
|
| 17 |
+
|
| 18 |
+
if target_weight.dtype != torch.bfloat16:
|
| 19 |
+
raise ValueError("target_weight dtype must be bfloat16")
|
| 20 |
+
|
| 21 |
+
# Calculate grid and block sizes
|
| 22 |
+
n_elements = original_weight.numel()
|
| 23 |
+
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
|
| 24 |
+
|
| 25 |
+
# Launch kernel
|
| 26 |
+
fused_add_round_kernel[grid](
|
| 27 |
+
original_weight,
|
| 28 |
+
target_weight,
|
| 29 |
+
seed,
|
| 30 |
+
n_elements,
|
| 31 |
+
exponent_bias,
|
| 32 |
+
mantissa_bits,
|
| 33 |
+
BLOCK_SIZE,
|
| 34 |
+
)
|
| 35 |
+
return target_weight
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def calculate_weight_float8_(target_weights: torch.Tensor, original_weights: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
result = fused_add_round_launch(target_weights, original_weights, seed=0).to(target_weights.dtype)
|
| 40 |
+
target_weights.copy_(result, non_blocking=True)
|
| 41 |
+
return target_weights
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _prepare_deltas(
|
| 45 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device
|
| 46 |
+
) -> torch.Tensor | None:
|
| 47 |
+
deltas = []
|
| 48 |
+
prefix = key[: -len(".weight")]
|
| 49 |
+
key_a = f"{prefix}.lora_A.weight"
|
| 50 |
+
key_b = f"{prefix}.lora_B.weight"
|
| 51 |
+
for lsd, coef in lora_sd_and_strengths:
|
| 52 |
+
if key_a not in lsd.sd or key_b not in lsd.sd:
|
| 53 |
+
continue
|
| 54 |
+
product = torch.matmul(lsd.sd[key_b] * coef, lsd.sd[key_a])
|
| 55 |
+
deltas.append(product.to(dtype=dtype, device=device))
|
| 56 |
+
if len(deltas) == 0:
|
| 57 |
+
return None
|
| 58 |
+
elif len(deltas) == 1:
|
| 59 |
+
return deltas[0]
|
| 60 |
+
return torch.sum(torch.stack(deltas, dim=0), dim=0)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def apply_loras(
|
| 64 |
+
model_sd: StateDict,
|
| 65 |
+
lora_sd_and_strengths: list[LoraStateDictWithStrength],
|
| 66 |
+
dtype: torch.dtype,
|
| 67 |
+
destination_sd: StateDict | None = None,
|
| 68 |
+
) -> StateDict:
|
| 69 |
+
sd = {}
|
| 70 |
+
if destination_sd is not None:
|
| 71 |
+
sd = destination_sd.sd
|
| 72 |
+
size = 0
|
| 73 |
+
device = torch.device("meta")
|
| 74 |
+
inner_dtypes = set()
|
| 75 |
+
for key, weight in model_sd.sd.items():
|
| 76 |
+
if weight is None:
|
| 77 |
+
continue
|
| 78 |
+
device = weight.device
|
| 79 |
+
target_dtype = dtype if dtype is not None else weight.dtype
|
| 80 |
+
deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16
|
| 81 |
+
deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
|
| 82 |
+
if deltas is None:
|
| 83 |
+
if key in sd:
|
| 84 |
+
continue
|
| 85 |
+
deltas = weight.clone().to(dtype=target_dtype, device=device)
|
| 86 |
+
elif weight.dtype == torch.float8_e4m3fn:
|
| 87 |
+
if str(device).startswith("cuda"):
|
| 88 |
+
deltas = calculate_weight_float8_(deltas, weight)
|
| 89 |
+
else:
|
| 90 |
+
deltas.add_(weight.to(dtype=deltas.dtype, device=device))
|
| 91 |
+
elif weight.dtype == torch.bfloat16:
|
| 92 |
+
deltas.add_(weight)
|
| 93 |
+
else:
|
| 94 |
+
raise ValueError(f"Unsupported dtype: {weight.dtype}")
|
| 95 |
+
sd[key] = deltas.to(dtype=target_dtype)
|
| 96 |
+
inner_dtypes.add(target_dtype)
|
| 97 |
+
size += deltas.nbytes
|
| 98 |
+
if destination_sd is not None:
|
| 99 |
+
return destination_sd
|
| 100 |
+
return StateDict(sd, device, size, inner_dtypes)
|
packages/ltx-core/src/ltx_core/loader/kernels.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ruff: noqa: ANN001, ANN201, ERA001, N803, N806
|
| 2 |
+
import triton
|
| 3 |
+
import triton.language as tl
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@triton.jit
|
| 7 |
+
def fused_add_round_kernel(
|
| 8 |
+
x_ptr,
|
| 9 |
+
output_ptr, # contents will be added to the output
|
| 10 |
+
seed,
|
| 11 |
+
n_elements,
|
| 12 |
+
EXPONENT_BIAS,
|
| 13 |
+
MANTISSA_BITS,
|
| 14 |
+
BLOCK_SIZE: tl.constexpr,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding
|
| 18 |
+
and add them to bfloat16 output weights. Might be used to upcast original model weights
|
| 19 |
+
and to further add them to precalculated deltas coming from LoRAs.
|
| 20 |
+
"""
|
| 21 |
+
# Get program ID and compute offsets
|
| 22 |
+
pid = tl.program_id(axis=0)
|
| 23 |
+
block_start = pid * BLOCK_SIZE
|
| 24 |
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
| 25 |
+
mask = offsets < n_elements
|
| 26 |
+
|
| 27 |
+
# Load data
|
| 28 |
+
x = tl.load(x_ptr + offsets, mask=mask)
|
| 29 |
+
rand_vals = tl.rand(seed, offsets) - 0.5
|
| 30 |
+
|
| 31 |
+
x = tl.cast(x, tl.float16)
|
| 32 |
+
delta = tl.load(output_ptr + offsets, mask=mask)
|
| 33 |
+
delta = tl.cast(delta, tl.float16)
|
| 34 |
+
x = x + delta
|
| 35 |
+
|
| 36 |
+
x_bits = tl.cast(x, tl.int16, bitcast=True)
|
| 37 |
+
|
| 38 |
+
# Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for
|
| 39 |
+
# normal numbers and -14 for subnormals.
|
| 40 |
+
fp16_exponent_bits = (x_bits & 0x7C00) >> 10
|
| 41 |
+
fp16_normals = fp16_exponent_bits > 0
|
| 42 |
+
fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14)
|
| 43 |
+
|
| 44 |
+
# Add the target dtype's exponent bias and clamp to the target dtype's exponent range.
|
| 45 |
+
exponent = fp16_exponent + EXPONENT_BIAS
|
| 46 |
+
MAX_EXPONENT = 2 * EXPONENT_BIAS + 1
|
| 47 |
+
exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent)
|
| 48 |
+
exponent = tl.where(exponent < 0, 0, exponent)
|
| 49 |
+
|
| 50 |
+
# Normal ULP exponent, expressed as an fp16 exponent field:
|
| 51 |
+
# (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15
|
| 52 |
+
# Simplifies to: fp16_exponent - MANTISSA_BITS + 15
|
| 53 |
+
# See https://en.wikipedia.org/wiki/Unit_in_the_last_place
|
| 54 |
+
eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15))
|
| 55 |
+
|
| 56 |
+
# Calculate epsilon in the target dtype
|
| 57 |
+
eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True)
|
| 58 |
+
|
| 59 |
+
# Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) ->
|
| 60 |
+
# fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 =
|
| 61 |
+
# 16 - EXPONENT_BIAS - MANTISSA_BITS
|
| 62 |
+
eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True)
|
| 63 |
+
eps = tl.where(exponent > 0, eps_normal, eps_subnormal)
|
| 64 |
+
|
| 65 |
+
# Apply zero mask to epsilon
|
| 66 |
+
eps = tl.where(x == 0, 0.0, eps)
|
| 67 |
+
|
| 68 |
+
# Apply stochastic rounding
|
| 69 |
+
output = tl.cast(x + rand_vals * eps, tl.bfloat16)
|
| 70 |
+
|
| 71 |
+
# Store the result
|
| 72 |
+
tl.store(output_ptr + offsets, output, mask=mask)
|
packages/ltx-core/src/ltx_core/loader/module_ops.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, NamedTuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModuleOps(NamedTuple):
|
| 7 |
+
"""
|
| 8 |
+
Defines a named operation for matching and mutating PyTorch modules.
|
| 9 |
+
Used to selectively transform modules in a model (e.g., replacing layers with quantized versions).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
name: str
|
| 13 |
+
matcher: Callable[[torch.nn.Module], bool]
|
| 14 |
+
mutator: Callable[[torch.nn.Module], torch.nn.Module]
|
packages/ltx-core/src/ltx_core/loader/primitives.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import NamedTuple, Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 7 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 8 |
+
from ltx_core.model.model_protocol import ModelType
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class StateDict:
|
| 13 |
+
"""
|
| 14 |
+
Immutable container for a PyTorch state dictionary.
|
| 15 |
+
Contains:
|
| 16 |
+
- sd: Dictionary of tensors (weights, buffers, etc.)
|
| 17 |
+
- device: Device where tensors are stored
|
| 18 |
+
- size: Total memory footprint in bytes
|
| 19 |
+
- dtype: Set of tensor dtypes present
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
sd: dict
|
| 23 |
+
device: torch.device
|
| 24 |
+
size: int
|
| 25 |
+
dtype: set[torch.dtype]
|
| 26 |
+
|
| 27 |
+
def footprint(self) -> tuple[int, torch.device]:
|
| 28 |
+
return self.size, self.device
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class StateDictLoader(Protocol):
|
| 32 |
+
"""
|
| 33 |
+
Protocol for loading state dictionaries from various sources.
|
| 34 |
+
Implementations must provide:
|
| 35 |
+
- metadata: Extract model metadata from a single path
|
| 36 |
+
- load: Load state dict from path(s) and apply SDOps transformations
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def metadata(self, path: str) -> dict:
|
| 40 |
+
"""
|
| 41 |
+
Load metadata from path
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 45 |
+
"""
|
| 46 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ModelBuilderProtocol(Protocol[ModelType]):
|
| 51 |
+
"""
|
| 52 |
+
Protocol for building PyTorch models from configuration dictionaries.
|
| 53 |
+
Implementations must provide:
|
| 54 |
+
- meta_model: Create a model from configuration dictionary and apply module operations
|
| 55 |
+
- build: Create and initialize a model from state dictionary and apply dtype transformations
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType:
|
| 59 |
+
"""
|
| 60 |
+
Create a model on the meta device from a configuration dictionary.
|
| 61 |
+
This decouples model creation from weight loading, allowing the model
|
| 62 |
+
architecture to be instantiated without allocating memory for parameters.
|
| 63 |
+
Args:
|
| 64 |
+
config: Model configuration dictionary.
|
| 65 |
+
module_ops: Optional list of module operations to apply (e.g., quantization).
|
| 66 |
+
Returns:
|
| 67 |
+
Model instance on meta device (no actual memory allocated for parameters).
|
| 68 |
+
"""
|
| 69 |
+
...
|
| 70 |
+
|
| 71 |
+
def build(self, dtype: torch.dtype | None = None) -> ModelType:
|
| 72 |
+
"""
|
| 73 |
+
Build the model
|
| 74 |
+
Args:
|
| 75 |
+
dtype: Target dtype for the model, if None, uses the dtype of the model_path model
|
| 76 |
+
Returns:
|
| 77 |
+
Model instance
|
| 78 |
+
"""
|
| 79 |
+
...
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LoRAAdaptableProtocol(Protocol):
|
| 83 |
+
"""
|
| 84 |
+
Protocol for models that can be adapted with LoRAs.
|
| 85 |
+
Implementations must provide:
|
| 86 |
+
- lora: Add a LoRA to the model
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol":
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class LoraPathStrengthAndSDOps(NamedTuple):
|
| 94 |
+
"""
|
| 95 |
+
Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
path: str
|
| 99 |
+
strength: float
|
| 100 |
+
sd_ops: SDOps
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class LoraStateDictWithStrength(NamedTuple):
|
| 104 |
+
"""
|
| 105 |
+
Tuple containing a LoRA state dict and strength for applying to the model.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
state_dict: StateDict
|
| 109 |
+
strength: float
|
packages/ltx-core/src/ltx_core/loader/registry.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import threading
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Protocol
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.primitives import StateDict
|
| 8 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Registry(Protocol):
|
| 12 |
+
"""
|
| 13 |
+
Protocol for managing state dictionaries in a registry.
|
| 14 |
+
It is used to store state dictionaries and reuse them later without loading them again.
|
| 15 |
+
Implementations must provide:
|
| 16 |
+
- add: Add a state dictionary to the registry
|
| 17 |
+
- pop: Remove a state dictionary from the registry
|
| 18 |
+
- get: Retrieve a state dictionary from the registry
|
| 19 |
+
- clear: Clear all state dictionaries from the registry
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ...
|
| 23 |
+
|
| 24 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 25 |
+
|
| 26 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ...
|
| 27 |
+
|
| 28 |
+
def clear(self) -> None: ...
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DummyRegistry(Registry):
|
| 32 |
+
"""
|
| 33 |
+
Dummy registry that does not store state dictionaries.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def clear(self) -> None:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class StateDictRegistry(Registry):
|
| 51 |
+
"""
|
| 52 |
+
Registry that stores state dictionaries in a dictionary.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
_state_dicts: dict[str, StateDict] = field(default_factory=dict)
|
| 56 |
+
_lock: threading.Lock = field(default_factory=threading.Lock)
|
| 57 |
+
|
| 58 |
+
def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str:
|
| 59 |
+
m = hashlib.sha256()
|
| 60 |
+
parts = [str(Path(p).resolve()) for p in paths]
|
| 61 |
+
if sd_ops is not None:
|
| 62 |
+
parts.append(sd_ops.name)
|
| 63 |
+
m.update("\0".join(parts).encode("utf-8"))
|
| 64 |
+
return m.hexdigest()
|
| 65 |
+
|
| 66 |
+
def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str:
|
| 67 |
+
sd_id = self._generate_id(paths, sd_ops)
|
| 68 |
+
with self._lock:
|
| 69 |
+
if sd_id in self._state_dicts:
|
| 70 |
+
raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first")
|
| 71 |
+
self._state_dicts[sd_id] = state_dict
|
| 72 |
+
return sd_id
|
| 73 |
+
|
| 74 |
+
def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 75 |
+
with self._lock:
|
| 76 |
+
return self._state_dicts.pop(self._generate_id(paths, sd_ops), None)
|
| 77 |
+
|
| 78 |
+
def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None:
|
| 79 |
+
with self._lock:
|
| 80 |
+
return self._state_dicts.get(self._generate_id(paths, sd_ops), None)
|
| 81 |
+
|
| 82 |
+
def clear(self) -> None:
|
| 83 |
+
with self._lock:
|
| 84 |
+
self._state_dicts.clear()
|
packages/ltx-core/src/ltx_core/loader/sd_ops.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, replace
|
| 2 |
+
from typing import NamedTuple, Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass(frozen=True, slots=True)
|
| 8 |
+
class ContentReplacement:
|
| 9 |
+
"""
|
| 10 |
+
Represents a content replacement operation.
|
| 11 |
+
Used to replace a specific content with a replacement in a state dict key.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
content: str
|
| 15 |
+
replacement: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True, slots=True)
|
| 19 |
+
class ContentMatching:
|
| 20 |
+
"""
|
| 21 |
+
Represents a content matching operation.
|
| 22 |
+
Used to match a specific prefix and suffix in a state dict key.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
prefix: str = ""
|
| 26 |
+
suffix: str = ""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class KeyValueOperationResult(NamedTuple):
|
| 30 |
+
"""
|
| 31 |
+
Represents the result of a key-value operation.
|
| 32 |
+
Contains the new key and value after the operation has been applied.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
new_key: str
|
| 36 |
+
new_value: torch.Tensor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class KeyValueOperation(Protocol):
|
| 40 |
+
"""
|
| 41 |
+
Protocol for key-value operations.
|
| 42 |
+
Used to apply operations to a specific key and value in a state dict.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ...
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass(frozen=True, slots=True)
|
| 49 |
+
class SDKeyValueOperation:
|
| 50 |
+
"""
|
| 51 |
+
Represents a key-value operation.
|
| 52 |
+
Used to apply operations to a specific key and value in a state dict.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
key_matcher: ContentMatching
|
| 56 |
+
kv_operation: KeyValueOperation
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(frozen=True, slots=True)
|
| 60 |
+
class SDOps:
|
| 61 |
+
"""Immutable class representing state dict key operations."""
|
| 62 |
+
|
| 63 |
+
name: str
|
| 64 |
+
mapping: tuple[
|
| 65 |
+
ContentReplacement | ContentMatching | SDKeyValueOperation, ...
|
| 66 |
+
] = () # Immutable tuple of (key, value) pairs
|
| 67 |
+
|
| 68 |
+
def with_replacement(self, content: str, replacement: str) -> "SDOps":
|
| 69 |
+
"""Create a new SDOps instance with the specified replacement added to the mapping."""
|
| 70 |
+
|
| 71 |
+
new_mapping = (*self.mapping, ContentReplacement(content, replacement))
|
| 72 |
+
return replace(self, mapping=new_mapping)
|
| 73 |
+
|
| 74 |
+
def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps":
|
| 75 |
+
"""Create a new SDOps instance with the specified prefix and suffix matching added to the mapping."""
|
| 76 |
+
|
| 77 |
+
new_mapping = (*self.mapping, ContentMatching(prefix, suffix))
|
| 78 |
+
return replace(self, mapping=new_mapping)
|
| 79 |
+
|
| 80 |
+
def with_kv_operation(
|
| 81 |
+
self,
|
| 82 |
+
operation: KeyValueOperation,
|
| 83 |
+
key_prefix: str = "",
|
| 84 |
+
key_suffix: str = "",
|
| 85 |
+
) -> "SDOps":
|
| 86 |
+
"""Create a new SDOps instance with the specified value operation added to the mapping."""
|
| 87 |
+
key_matcher = ContentMatching(key_prefix, key_suffix)
|
| 88 |
+
sd_kv_operation = SDKeyValueOperation(key_matcher, operation)
|
| 89 |
+
new_mapping = (*self.mapping, sd_kv_operation)
|
| 90 |
+
return replace(self, mapping=new_mapping)
|
| 91 |
+
|
| 92 |
+
def apply_to_key(self, key: str) -> str | None:
|
| 93 |
+
"""Apply the mapping to the given name."""
|
| 94 |
+
matchers = [content for content in self.mapping if isinstance(content, ContentMatching)]
|
| 95 |
+
valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers)
|
| 96 |
+
if not valid:
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
for replacement in self.mapping:
|
| 100 |
+
if not isinstance(replacement, ContentReplacement):
|
| 101 |
+
continue
|
| 102 |
+
if replacement.content in key:
|
| 103 |
+
key = key.replace(replacement.content, replacement.replacement)
|
| 104 |
+
return key
|
| 105 |
+
|
| 106 |
+
def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 107 |
+
"""Apply the value operation to the given name and associated value."""
|
| 108 |
+
for operation in self.mapping:
|
| 109 |
+
if not isinstance(operation, SDKeyValueOperation):
|
| 110 |
+
continue
|
| 111 |
+
if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix):
|
| 112 |
+
return operation.kv_operation(key, value)
|
| 113 |
+
return [KeyValueOperationResult(key, value)]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Predefined SDOps instances
|
| 117 |
+
LTXV_LORA_COMFY_RENAMING_MAP = (
|
| 118 |
+
SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "")
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
LTXV_LORA_COMFY_TARGET_MAP = (
|
| 122 |
+
SDOps("LTXV_LORA_COMFY_TARGET_MAP")
|
| 123 |
+
.with_matching()
|
| 124 |
+
.with_replacement("diffusion_model.", "")
|
| 125 |
+
.with_replacement(".lora_A.weight", ".weight")
|
| 126 |
+
.with_replacement(".lora_B.weight", ".weight")
|
| 127 |
+
)
|
packages/ltx-core/src/ltx_core/loader/sft_loader.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import safetensors
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.loader.primitives import StateDict, StateDictLoader
|
| 7 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafetensorsStateDictLoader(StateDictLoader):
|
| 11 |
+
"""
|
| 12 |
+
Loads weights from safetensors files without metadata support.
|
| 13 |
+
Use this for loading raw weight files. For model files that include
|
| 14 |
+
configuration metadata, use SafetensorsModelStateDictLoader instead.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def metadata(self, path: str) -> dict:
|
| 18 |
+
raise NotImplementedError("Not implemented")
|
| 19 |
+
|
| 20 |
+
def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict:
|
| 21 |
+
"""
|
| 22 |
+
Load state dict from path or paths (for sharded model storage) and apply sd_ops
|
| 23 |
+
"""
|
| 24 |
+
sd = {}
|
| 25 |
+
size = 0
|
| 26 |
+
dtype = set()
|
| 27 |
+
device = device or torch.device("cpu")
|
| 28 |
+
model_paths = path if isinstance(path, list) else [path]
|
| 29 |
+
for shard_path in model_paths:
|
| 30 |
+
with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f:
|
| 31 |
+
safetensor_keys = f.keys()
|
| 32 |
+
for name in safetensor_keys:
|
| 33 |
+
expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
|
| 34 |
+
if expected_name is None:
|
| 35 |
+
continue
|
| 36 |
+
value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False)
|
| 37 |
+
key_value_pairs = ((expected_name, value),)
|
| 38 |
+
if sd_ops is not None:
|
| 39 |
+
key_value_pairs = sd_ops.apply_to_key_value(expected_name, value)
|
| 40 |
+
for key, value in key_value_pairs:
|
| 41 |
+
size += value.nbytes
|
| 42 |
+
dtype.add(value.dtype)
|
| 43 |
+
sd[key] = value
|
| 44 |
+
|
| 45 |
+
return StateDict(sd=sd, device=device, size=size, dtype=dtype)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SafetensorsModelStateDictLoader(StateDictLoader):
|
| 49 |
+
"""
|
| 50 |
+
Loads weights and configuration metadata from safetensors model files.
|
| 51 |
+
Unlike SafetensorsStateDictLoader, this loader can read model configuration
|
| 52 |
+
from the safetensors file metadata via the metadata() method.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None):
|
| 56 |
+
self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader()
|
| 57 |
+
|
| 58 |
+
def metadata(self, path: str) -> dict:
|
| 59 |
+
with safetensors.safe_open(path, framework="pt") as f:
|
| 60 |
+
return json.loads(f.metadata()["config"])
|
| 61 |
+
|
| 62 |
+
def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict:
|
| 63 |
+
return self.weight_loader.load(path, sd_ops, device)
|
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from dataclasses import dataclass, field, replace
|
| 3 |
+
from typing import Generic
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ltx_core.loader.fuse_loras import apply_loras
|
| 8 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 9 |
+
from ltx_core.loader.primitives import (
|
| 10 |
+
LoRAAdaptableProtocol,
|
| 11 |
+
LoraPathStrengthAndSDOps,
|
| 12 |
+
LoraStateDictWithStrength,
|
| 13 |
+
ModelBuilderProtocol,
|
| 14 |
+
StateDict,
|
| 15 |
+
StateDictLoader,
|
| 16 |
+
)
|
| 17 |
+
from ltx_core.loader.registry import DummyRegistry, Registry
|
| 18 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 19 |
+
from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader
|
| 20 |
+
from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
| 21 |
+
|
| 22 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
|
| 27 |
+
"""
|
| 28 |
+
Builder for PyTorch models residing on a single GPU.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
model_class_configurator: type[ModelConfigurator[ModelType]]
|
| 32 |
+
model_path: str | tuple[str, ...]
|
| 33 |
+
model_sd_ops: SDOps | None = None
|
| 34 |
+
module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple)
|
| 35 |
+
loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple)
|
| 36 |
+
model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader)
|
| 37 |
+
registry: Registry = field(default_factory=DummyRegistry)
|
| 38 |
+
|
| 39 |
+
def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder":
|
| 40 |
+
return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops)))
|
| 41 |
+
|
| 42 |
+
def model_config(self) -> dict:
|
| 43 |
+
first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path
|
| 44 |
+
return self.model_loader.metadata(first_shard_path)
|
| 45 |
+
|
| 46 |
+
def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType:
|
| 47 |
+
with torch.device("meta"):
|
| 48 |
+
model = self.model_class_configurator.from_config(config)
|
| 49 |
+
for module_op in module_ops:
|
| 50 |
+
if module_op.matcher(model):
|
| 51 |
+
model = module_op.mutator(model)
|
| 52 |
+
return model
|
| 53 |
+
|
| 54 |
+
def load_sd(
|
| 55 |
+
self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None
|
| 56 |
+
) -> StateDict:
|
| 57 |
+
state_dict = registry.get(paths, sd_ops)
|
| 58 |
+
if state_dict is None:
|
| 59 |
+
state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device)
|
| 60 |
+
registry.add(paths, sd_ops=sd_ops, state_dict=state_dict)
|
| 61 |
+
return state_dict
|
| 62 |
+
|
| 63 |
+
def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType:
|
| 64 |
+
uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"]
|
| 65 |
+
uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"]
|
| 66 |
+
if uninitialized_params or uninitialized_buffers:
|
| 67 |
+
logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}")
|
| 68 |
+
return meta_model
|
| 69 |
+
retval = meta_model.to(device)
|
| 70 |
+
return retval
|
| 71 |
+
|
| 72 |
+
def build(self, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ModelType:
|
| 73 |
+
device = torch.device("cuda") if device is None else device
|
| 74 |
+
config = self.model_config()
|
| 75 |
+
meta_model = self.meta_model(config, self.module_ops)
|
| 76 |
+
model_paths = self.model_path if isinstance(self.model_path, tuple) else [self.model_path]
|
| 77 |
+
model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)
|
| 78 |
+
|
| 79 |
+
lora_strengths = [lora.strength for lora in self.loras]
|
| 80 |
+
if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
|
| 81 |
+
sd = model_state_dict.sd
|
| 82 |
+
if dtype is not None:
|
| 83 |
+
sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()}
|
| 84 |
+
meta_model.load_state_dict(sd, strict=False, assign=True)
|
| 85 |
+
return self._return_model(meta_model, device)
|
| 86 |
+
|
| 87 |
+
lora_state_dicts = [
|
| 88 |
+
self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=device) for lora in self.loras
|
| 89 |
+
]
|
| 90 |
+
lora_sd_and_strengths = [
|
| 91 |
+
LoraStateDictWithStrength(sd, strength)
|
| 92 |
+
for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True)
|
| 93 |
+
]
|
| 94 |
+
final_sd = apply_loras(
|
| 95 |
+
model_sd=model_state_dict,
|
| 96 |
+
lora_sd_and_strengths=lora_sd_and_strengths,
|
| 97 |
+
dtype=dtype,
|
| 98 |
+
destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
|
| 99 |
+
)
|
| 100 |
+
meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
|
| 101 |
+
return self._return_model(meta_model, device)
|
packages/ltx-core/src/ltx_core/model/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model definitions for LTX-2."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"ModelConfigurator",
|
| 7 |
+
"ModelType",
|
| 8 |
+
]
|
packages/ltx-core/src/ltx_core/model/audio_vae/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio VAE model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
| 4 |
+
from ltx_core.model.audio_vae.model_configurator import (
|
| 5 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
|
| 6 |
+
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
|
| 7 |
+
VOCODER_COMFY_KEYS_FILTER,
|
| 8 |
+
AudioDecoderConfigurator,
|
| 9 |
+
AudioEncoderConfigurator,
|
| 10 |
+
VocoderConfigurator,
|
| 11 |
+
)
|
| 12 |
+
from ltx_core.model.audio_vae.ops import AudioProcessor
|
| 13 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"AUDIO_VAE_DECODER_COMFY_KEYS_FILTER",
|
| 17 |
+
"AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER",
|
| 18 |
+
"VOCODER_COMFY_KEYS_FILTER",
|
| 19 |
+
"AudioDecoder",
|
| 20 |
+
"AudioDecoderConfigurator",
|
| 21 |
+
"AudioEncoder",
|
| 22 |
+
"AudioEncoderConfigurator",
|
| 23 |
+
"AudioProcessor",
|
| 24 |
+
"Vocoder",
|
| 25 |
+
"VocoderConfigurator",
|
| 26 |
+
"decode_audio",
|
| 27 |
+
]
|
packages/ltx-core/src/ltx_core/model/audio_vae/attention.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AttentionType(Enum):
|
| 9 |
+
"""Enum for specifying the attention mechanism type."""
|
| 10 |
+
|
| 11 |
+
VANILLA = "vanilla"
|
| 12 |
+
LINEAR = "linear"
|
| 13 |
+
NONE = "none"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AttnBlock(torch.nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_channels: int,
|
| 20 |
+
norm_type: NormType = NormType.GROUP,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
|
| 25 |
+
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
| 26 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 27 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 28 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 29 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
h_ = x
|
| 33 |
+
h_ = self.norm(h_)
|
| 34 |
+
q = self.q(h_)
|
| 35 |
+
k = self.k(h_)
|
| 36 |
+
v = self.v(h_)
|
| 37 |
+
|
| 38 |
+
# compute attention
|
| 39 |
+
b, c, h, w = q.shape
|
| 40 |
+
q = q.reshape(b, c, h * w).contiguous()
|
| 41 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
| 42 |
+
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
| 43 |
+
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 44 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 45 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 46 |
+
|
| 47 |
+
# attend to values
|
| 48 |
+
v = v.reshape(b, c, h * w).contiguous()
|
| 49 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
| 50 |
+
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 51 |
+
h_ = h_.reshape(b, c, h, w).contiguous()
|
| 52 |
+
|
| 53 |
+
h_ = self.proj_out(h_)
|
| 54 |
+
|
| 55 |
+
return x + h_
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def make_attn(
|
| 59 |
+
in_channels: int,
|
| 60 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 61 |
+
norm_type: NormType = NormType.GROUP,
|
| 62 |
+
) -> torch.nn.Module:
|
| 63 |
+
match attn_type:
|
| 64 |
+
case AttentionType.VANILLA:
|
| 65 |
+
return AttnBlock(in_channels, norm_type=norm_type)
|
| 66 |
+
case AttentionType.NONE:
|
| 67 |
+
return torch.nn.Identity()
|
| 68 |
+
case AttentionType.LINEAR:
|
| 69 |
+
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
| 70 |
+
case _:
|
| 71 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
packages/ltx-core/src/ltx_core/model/audio_vae/audio_vae.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from ltx_core.components.patchifiers import AudioPatchifier
|
| 7 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 8 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 9 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 10 |
+
from ltx_core.model.audio_vae.downsample import build_downsampling_path
|
| 11 |
+
from ltx_core.model.audio_vae.ops import PerChannelStatistics
|
| 12 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 13 |
+
from ltx_core.model.audio_vae.upsample import build_upsampling_path
|
| 14 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder
|
| 15 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 16 |
+
from ltx_core.types import AudioLatentShape
|
| 17 |
+
|
| 18 |
+
LATENT_DOWNSAMPLE_FACTOR = 4
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_mid_block(
|
| 22 |
+
channels: int,
|
| 23 |
+
temb_channels: int,
|
| 24 |
+
dropout: float,
|
| 25 |
+
norm_type: NormType,
|
| 26 |
+
causality_axis: CausalityAxis,
|
| 27 |
+
attn_type: AttentionType,
|
| 28 |
+
add_attention: bool,
|
| 29 |
+
) -> torch.nn.Module:
|
| 30 |
+
"""Build the middle block with two ResNet blocks and optional attention."""
|
| 31 |
+
mid = torch.nn.Module()
|
| 32 |
+
mid.block_1 = ResnetBlock(
|
| 33 |
+
in_channels=channels,
|
| 34 |
+
out_channels=channels,
|
| 35 |
+
temb_channels=temb_channels,
|
| 36 |
+
dropout=dropout,
|
| 37 |
+
norm_type=norm_type,
|
| 38 |
+
causality_axis=causality_axis,
|
| 39 |
+
)
|
| 40 |
+
mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity()
|
| 41 |
+
mid.block_2 = ResnetBlock(
|
| 42 |
+
in_channels=channels,
|
| 43 |
+
out_channels=channels,
|
| 44 |
+
temb_channels=temb_channels,
|
| 45 |
+
dropout=dropout,
|
| 46 |
+
norm_type=norm_type,
|
| 47 |
+
causality_axis=causality_axis,
|
| 48 |
+
)
|
| 49 |
+
return mid
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Run features through the middle block."""
|
| 54 |
+
features = mid.block_1(features, temb=None)
|
| 55 |
+
features = mid.attn_1(features)
|
| 56 |
+
return mid.block_2(features, temb=None)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class AudioEncoder(torch.nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Encoder that compresses audio spectrograms into latent representations.
|
| 62 |
+
The encoder uses a series of downsampling blocks with residual connections,
|
| 63 |
+
attention mechanisms, and configurable causal convolutions.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__( # noqa: PLR0913
|
| 67 |
+
self,
|
| 68 |
+
*,
|
| 69 |
+
ch: int,
|
| 70 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 71 |
+
num_res_blocks: int,
|
| 72 |
+
attn_resolutions: Set[int],
|
| 73 |
+
dropout: float = 0.0,
|
| 74 |
+
resamp_with_conv: bool = True,
|
| 75 |
+
in_channels: int,
|
| 76 |
+
resolution: int,
|
| 77 |
+
z_channels: int,
|
| 78 |
+
double_z: bool = True,
|
| 79 |
+
attn_type: AttentionType = AttentionType.VANILLA,
|
| 80 |
+
mid_block_add_attention: bool = True,
|
| 81 |
+
norm_type: NormType = NormType.GROUP,
|
| 82 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 83 |
+
sample_rate: int = 16000,
|
| 84 |
+
mel_hop_length: int = 160,
|
| 85 |
+
n_fft: int = 1024,
|
| 86 |
+
is_causal: bool = True,
|
| 87 |
+
mel_bins: int = 64,
|
| 88 |
+
**_ignore_kwargs,
|
| 89 |
+
) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Initialize the Encoder.
|
| 92 |
+
Args:
|
| 93 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 94 |
+
(audio_vae.model.params.ddconfig):
|
| 95 |
+
ch: Base number of feature channels used in the first convolution layer.
|
| 96 |
+
ch_mult: Multiplicative factors for the number of channels at each resolution level.
|
| 97 |
+
num_res_blocks: Number of residual blocks to use at each resolution level.
|
| 98 |
+
attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention.
|
| 99 |
+
resolution: Input spatial resolution of the spectrogram (height, width).
|
| 100 |
+
z_channels: Number of channels in the latent representation.
|
| 101 |
+
norm_type: Normalization layer type to use within the network (e.g., group, batch).
|
| 102 |
+
causality_axis: Axis along which convolutions should be causal (e.g., time axis).
|
| 103 |
+
sample_rate: Audio sample rate in Hz for the input signals.
|
| 104 |
+
mel_hop_length: Hop length used when computing the mel spectrogram.
|
| 105 |
+
n_fft: FFT size used to compute the spectrogram.
|
| 106 |
+
mel_bins: Number of mel-frequency bins in the input spectrogram.
|
| 107 |
+
in_channels: Number of channels in the input spectrogram tensor.
|
| 108 |
+
double_z: If True, predict both mean and log-variance (doubling latent channels).
|
| 109 |
+
is_causal: If True, use causal convolutions suitable for streaming setups.
|
| 110 |
+
dropout: Dropout probability used in residual and mid blocks.
|
| 111 |
+
attn_type: Type of attention mechanism to use in attention blocks.
|
| 112 |
+
resamp_with_conv: If True, perform resolution changes using strided convolutions.
|
| 113 |
+
mid_block_add_attention: If True, add an attention block in the mid-level of the encoder.
|
| 114 |
+
"""
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 118 |
+
self.sample_rate = sample_rate
|
| 119 |
+
self.mel_hop_length = mel_hop_length
|
| 120 |
+
self.n_fft = n_fft
|
| 121 |
+
self.is_causal = is_causal
|
| 122 |
+
self.mel_bins = mel_bins
|
| 123 |
+
|
| 124 |
+
self.patchifier = AudioPatchifier(
|
| 125 |
+
patch_size=1,
|
| 126 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 127 |
+
sample_rate=sample_rate,
|
| 128 |
+
hop_length=mel_hop_length,
|
| 129 |
+
is_causal=is_causal,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.ch = ch
|
| 133 |
+
self.temb_ch = 0
|
| 134 |
+
self.num_resolutions = len(ch_mult)
|
| 135 |
+
self.num_res_blocks = num_res_blocks
|
| 136 |
+
self.resolution = resolution
|
| 137 |
+
self.in_channels = in_channels
|
| 138 |
+
self.z_channels = z_channels
|
| 139 |
+
self.double_z = double_z
|
| 140 |
+
self.norm_type = norm_type
|
| 141 |
+
self.causality_axis = causality_axis
|
| 142 |
+
self.attn_type = attn_type
|
| 143 |
+
|
| 144 |
+
# downsampling
|
| 145 |
+
self.conv_in = make_conv2d(
|
| 146 |
+
in_channels,
|
| 147 |
+
self.ch,
|
| 148 |
+
kernel_size=3,
|
| 149 |
+
stride=1,
|
| 150 |
+
causality_axis=self.causality_axis,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.non_linearity = torch.nn.SiLU()
|
| 154 |
+
|
| 155 |
+
self.down, block_in = build_downsampling_path(
|
| 156 |
+
ch=ch,
|
| 157 |
+
ch_mult=ch_mult,
|
| 158 |
+
num_resolutions=self.num_resolutions,
|
| 159 |
+
num_res_blocks=num_res_blocks,
|
| 160 |
+
resolution=resolution,
|
| 161 |
+
temb_channels=self.temb_ch,
|
| 162 |
+
dropout=dropout,
|
| 163 |
+
norm_type=self.norm_type,
|
| 164 |
+
causality_axis=self.causality_axis,
|
| 165 |
+
attn_type=self.attn_type,
|
| 166 |
+
attn_resolutions=attn_resolutions,
|
| 167 |
+
resamp_with_conv=resamp_with_conv,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.mid = build_mid_block(
|
| 171 |
+
channels=block_in,
|
| 172 |
+
temb_channels=self.temb_ch,
|
| 173 |
+
dropout=dropout,
|
| 174 |
+
norm_type=self.norm_type,
|
| 175 |
+
causality_axis=self.causality_axis,
|
| 176 |
+
attn_type=self.attn_type,
|
| 177 |
+
add_attention=mid_block_add_attention,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
| 181 |
+
self.conv_out = make_conv2d(
|
| 182 |
+
block_in,
|
| 183 |
+
2 * z_channels if double_z else z_channels,
|
| 184 |
+
kernel_size=3,
|
| 185 |
+
stride=1,
|
| 186 |
+
causality_axis=self.causality_axis,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
"""
|
| 191 |
+
Encode audio spectrogram into latent representations.
|
| 192 |
+
Args:
|
| 193 |
+
spectrogram: Input spectrogram of shape (batch, channels, time, frequency)
|
| 194 |
+
Returns:
|
| 195 |
+
Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 196 |
+
"""
|
| 197 |
+
h = self.conv_in(spectrogram)
|
| 198 |
+
h = self._run_downsampling_path(h)
|
| 199 |
+
h = run_mid_block(self.mid, h)
|
| 200 |
+
h = self._finalize_output(h)
|
| 201 |
+
|
| 202 |
+
return self._normalize_latents(h)
|
| 203 |
+
|
| 204 |
+
def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 205 |
+
for level in range(self.num_resolutions):
|
| 206 |
+
stage = self.down[level]
|
| 207 |
+
for block_idx in range(self.num_res_blocks):
|
| 208 |
+
h = stage.block[block_idx](h, temb=None)
|
| 209 |
+
if stage.attn:
|
| 210 |
+
h = stage.attn[block_idx](h)
|
| 211 |
+
|
| 212 |
+
if level != self.num_resolutions - 1:
|
| 213 |
+
h = stage.downsample(h)
|
| 214 |
+
|
| 215 |
+
return h
|
| 216 |
+
|
| 217 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 218 |
+
h = self.norm_out(h)
|
| 219 |
+
h = self.non_linearity(h)
|
| 220 |
+
return self.conv_out(h)
|
| 221 |
+
|
| 222 |
+
def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor:
|
| 223 |
+
"""
|
| 224 |
+
Normalize encoder latents using per-channel statistics.
|
| 225 |
+
When the encoder is configured with ``double_z=True``, the final
|
| 226 |
+
convolution produces twice the number of latent channels, typically
|
| 227 |
+
interpreted as two concatenated tensors along the channel dimension
|
| 228 |
+
(e.g., mean and variance or other auxiliary parameters).
|
| 229 |
+
This method intentionally uses only the first half of the channels
|
| 230 |
+
(the "mean" component) as input to the patchifier and normalization
|
| 231 |
+
logic. The remaining channels are left unchanged by this method and
|
| 232 |
+
are expected to be consumed elsewhere in the VAE pipeline.
|
| 233 |
+
If ``double_z=False``, the encoder output already contains only the
|
| 234 |
+
mean latents and the chunking operation simply returns that tensor.
|
| 235 |
+
"""
|
| 236 |
+
means = torch.chunk(latent_output, 2, dim=1)[0]
|
| 237 |
+
latent_shape = AudioLatentShape(
|
| 238 |
+
batch=means.shape[0],
|
| 239 |
+
channels=means.shape[1],
|
| 240 |
+
frames=means.shape[2],
|
| 241 |
+
mel_bins=means.shape[3],
|
| 242 |
+
)
|
| 243 |
+
latent_patched = self.patchifier.patchify(means)
|
| 244 |
+
latent_normalized = self.per_channel_statistics.normalize(latent_patched)
|
| 245 |
+
return self.patchifier.unpatchify(latent_normalized, latent_shape)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class AudioDecoder(torch.nn.Module):
|
| 249 |
+
"""
|
| 250 |
+
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
| 251 |
+
The decoder mirrors the encoder structure with configurable channel multipliers,
|
| 252 |
+
attention resolutions, and causal convolutions.
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__( # noqa: PLR0913
|
| 256 |
+
self,
|
| 257 |
+
*,
|
| 258 |
+
ch: int,
|
| 259 |
+
out_ch: int,
|
| 260 |
+
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
| 261 |
+
num_res_blocks: int,
|
| 262 |
+
attn_resolutions: Set[int],
|
| 263 |
+
resolution: int,
|
| 264 |
+
z_channels: int,
|
| 265 |
+
norm_type: NormType = NormType.GROUP,
|
| 266 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 267 |
+
dropout: float = 0.0,
|
| 268 |
+
mid_block_add_attention: bool = True,
|
| 269 |
+
sample_rate: int = 16000,
|
| 270 |
+
mel_hop_length: int = 160,
|
| 271 |
+
is_causal: bool = True,
|
| 272 |
+
mel_bins: int | None = None,
|
| 273 |
+
) -> None:
|
| 274 |
+
"""
|
| 275 |
+
Initialize the Decoder.
|
| 276 |
+
Args:
|
| 277 |
+
Arguments are configuration parameters, loaded from the audio VAE checkpoint config
|
| 278 |
+
(audio_vae.model.params.ddconfig):
|
| 279 |
+
- ch, out_ch, ch_mult, num_res_blocks, attn_resolutions
|
| 280 |
+
- resolution, z_channels
|
| 281 |
+
- norm_type, causality_axis
|
| 282 |
+
"""
|
| 283 |
+
super().__init__()
|
| 284 |
+
|
| 285 |
+
# Internal behavioural defaults that are not driven by the checkpoint.
|
| 286 |
+
resamp_with_conv = True
|
| 287 |
+
attn_type = AttentionType.VANILLA
|
| 288 |
+
|
| 289 |
+
# Per-channel statistics for denormalizing latents
|
| 290 |
+
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
| 291 |
+
self.sample_rate = sample_rate
|
| 292 |
+
self.mel_hop_length = mel_hop_length
|
| 293 |
+
self.is_causal = is_causal
|
| 294 |
+
self.mel_bins = mel_bins
|
| 295 |
+
self.patchifier = AudioPatchifier(
|
| 296 |
+
patch_size=1,
|
| 297 |
+
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
| 298 |
+
sample_rate=sample_rate,
|
| 299 |
+
hop_length=mel_hop_length,
|
| 300 |
+
is_causal=is_causal,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self.ch = ch
|
| 304 |
+
self.temb_ch = 0
|
| 305 |
+
self.num_resolutions = len(ch_mult)
|
| 306 |
+
self.num_res_blocks = num_res_blocks
|
| 307 |
+
self.resolution = resolution
|
| 308 |
+
self.out_ch = out_ch
|
| 309 |
+
self.give_pre_end = False
|
| 310 |
+
self.tanh_out = False
|
| 311 |
+
self.norm_type = norm_type
|
| 312 |
+
self.z_channels = z_channels
|
| 313 |
+
self.channel_multipliers = ch_mult
|
| 314 |
+
self.attn_resolutions = attn_resolutions
|
| 315 |
+
self.causality_axis = causality_axis
|
| 316 |
+
self.attn_type = attn_type
|
| 317 |
+
|
| 318 |
+
base_block_channels = ch * self.channel_multipliers[-1]
|
| 319 |
+
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
| 320 |
+
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
| 321 |
+
|
| 322 |
+
self.conv_in = make_conv2d(
|
| 323 |
+
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 324 |
+
)
|
| 325 |
+
self.non_linearity = torch.nn.SiLU()
|
| 326 |
+
self.mid = build_mid_block(
|
| 327 |
+
channels=base_block_channels,
|
| 328 |
+
temb_channels=self.temb_ch,
|
| 329 |
+
dropout=dropout,
|
| 330 |
+
norm_type=self.norm_type,
|
| 331 |
+
causality_axis=self.causality_axis,
|
| 332 |
+
attn_type=self.attn_type,
|
| 333 |
+
add_attention=mid_block_add_attention,
|
| 334 |
+
)
|
| 335 |
+
self.up, final_block_channels = build_upsampling_path(
|
| 336 |
+
ch=ch,
|
| 337 |
+
ch_mult=ch_mult,
|
| 338 |
+
num_resolutions=self.num_resolutions,
|
| 339 |
+
num_res_blocks=num_res_blocks,
|
| 340 |
+
resolution=resolution,
|
| 341 |
+
temb_channels=self.temb_ch,
|
| 342 |
+
dropout=dropout,
|
| 343 |
+
norm_type=self.norm_type,
|
| 344 |
+
causality_axis=self.causality_axis,
|
| 345 |
+
attn_type=self.attn_type,
|
| 346 |
+
attn_resolutions=attn_resolutions,
|
| 347 |
+
resamp_with_conv=resamp_with_conv,
|
| 348 |
+
initial_block_channels=base_block_channels,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
| 352 |
+
self.conv_out = make_conv2d(
|
| 353 |
+
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
"""
|
| 358 |
+
Decode latent features back to audio spectrograms.
|
| 359 |
+
Args:
|
| 360 |
+
sample: Encoded latent representation of shape (batch, channels, frames, mel_bins)
|
| 361 |
+
Returns:
|
| 362 |
+
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
|
| 363 |
+
"""
|
| 364 |
+
sample, target_shape = self._denormalize_latents(sample)
|
| 365 |
+
|
| 366 |
+
h = self.conv_in(sample)
|
| 367 |
+
h = run_mid_block(self.mid, h)
|
| 368 |
+
h = self._run_upsampling_path(h)
|
| 369 |
+
h = self._finalize_output(h)
|
| 370 |
+
|
| 371 |
+
return self._adjust_output_shape(h, target_shape)
|
| 372 |
+
|
| 373 |
+
def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]:
|
| 374 |
+
latent_shape = AudioLatentShape(
|
| 375 |
+
batch=sample.shape[0],
|
| 376 |
+
channels=sample.shape[1],
|
| 377 |
+
frames=sample.shape[2],
|
| 378 |
+
mel_bins=sample.shape[3],
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
sample_patched = self.patchifier.patchify(sample)
|
| 382 |
+
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
| 383 |
+
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
| 384 |
+
|
| 385 |
+
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
| 386 |
+
if self.causality_axis != CausalityAxis.NONE:
|
| 387 |
+
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
| 388 |
+
|
| 389 |
+
target_shape = AudioLatentShape(
|
| 390 |
+
batch=latent_shape.batch,
|
| 391 |
+
channels=self.out_ch,
|
| 392 |
+
frames=target_frames,
|
| 393 |
+
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
return sample, target_shape
|
| 397 |
+
|
| 398 |
+
def _adjust_output_shape(
|
| 399 |
+
self,
|
| 400 |
+
decoded_output: torch.Tensor,
|
| 401 |
+
target_shape: AudioLatentShape,
|
| 402 |
+
) -> torch.Tensor:
|
| 403 |
+
"""
|
| 404 |
+
Adjust output shape to match target dimensions for variable-length audio.
|
| 405 |
+
This function handles the common case where decoded audio spectrograms need to be
|
| 406 |
+
resized to match a specific target shape.
|
| 407 |
+
Args:
|
| 408 |
+
decoded_output: Tensor of shape (batch, channels, time, frequency)
|
| 409 |
+
target_shape: AudioLatentShape describing (batch, channels, time, mel bins)
|
| 410 |
+
Returns:
|
| 411 |
+
Tensor adjusted to match target_shape exactly
|
| 412 |
+
"""
|
| 413 |
+
# Current output shape: (batch, channels, time, frequency)
|
| 414 |
+
_, _, current_time, current_freq = decoded_output.shape
|
| 415 |
+
target_channels = target_shape.channels
|
| 416 |
+
target_time = target_shape.frames
|
| 417 |
+
target_freq = target_shape.mel_bins
|
| 418 |
+
|
| 419 |
+
# Step 1: Crop first to avoid exceeding target dimensions
|
| 420 |
+
decoded_output = decoded_output[
|
| 421 |
+
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
# Step 2: Calculate padding needed for time and frequency dimensions
|
| 425 |
+
time_padding_needed = target_time - decoded_output.shape[2]
|
| 426 |
+
freq_padding_needed = target_freq - decoded_output.shape[3]
|
| 427 |
+
|
| 428 |
+
# Step 3: Apply padding if needed
|
| 429 |
+
if time_padding_needed > 0 or freq_padding_needed > 0:
|
| 430 |
+
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
|
| 431 |
+
# For audio: pad_left/right = frequency, pad_top/bottom = time
|
| 432 |
+
padding = (
|
| 433 |
+
0,
|
| 434 |
+
max(freq_padding_needed, 0), # frequency padding (left, right)
|
| 435 |
+
0,
|
| 436 |
+
max(time_padding_needed, 0), # time padding (top, bottom)
|
| 437 |
+
)
|
| 438 |
+
decoded_output = F.pad(decoded_output, padding)
|
| 439 |
+
|
| 440 |
+
# Step 4: Final safety crop to ensure exact target shape
|
| 441 |
+
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
| 442 |
+
|
| 443 |
+
return decoded_output
|
| 444 |
+
|
| 445 |
+
def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor:
|
| 446 |
+
for level in reversed(range(self.num_resolutions)):
|
| 447 |
+
stage = self.up[level]
|
| 448 |
+
for block_idx, block in enumerate(stage.block):
|
| 449 |
+
h = block(h, temb=None)
|
| 450 |
+
if stage.attn:
|
| 451 |
+
h = stage.attn[block_idx](h)
|
| 452 |
+
|
| 453 |
+
if level != 0 and hasattr(stage, "upsample"):
|
| 454 |
+
h = stage.upsample(h)
|
| 455 |
+
|
| 456 |
+
return h
|
| 457 |
+
|
| 458 |
+
def _finalize_output(self, h: torch.Tensor) -> torch.Tensor:
|
| 459 |
+
if self.give_pre_end:
|
| 460 |
+
return h
|
| 461 |
+
|
| 462 |
+
h = self.norm_out(h)
|
| 463 |
+
h = self.non_linearity(h)
|
| 464 |
+
h = self.conv_out(h)
|
| 465 |
+
return torch.tanh(h) if self.tanh_out else h
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> torch.Tensor:
|
| 469 |
+
"""
|
| 470 |
+
Decode an audio latent representation using the provided audio decoder and vocoder.
|
| 471 |
+
Args:
|
| 472 |
+
latent: Input audio latent tensor.
|
| 473 |
+
audio_decoder: Model to decode the latent to waveform features.
|
| 474 |
+
vocoder: Model to convert decoded features to audio waveform.
|
| 475 |
+
Returns:
|
| 476 |
+
Decoded audio as a float tensor.
|
| 477 |
+
"""
|
| 478 |
+
decoded_audio = audio_decoder(latent)
|
| 479 |
+
decoded_audio = vocoder(decoded_audio).squeeze(0).float()
|
| 480 |
+
return decoded_audio
|
packages/ltx-core/src/ltx_core/model/audio_vae/causal_conv_2d.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CausalConv2d(torch.nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
A causal 2D convolution.
|
| 10 |
+
This layer ensures that the output at time `t` only depends on inputs
|
| 11 |
+
at time `t` and earlier. It achieves this by applying asymmetric padding
|
| 12 |
+
to the time dimension (width) before the convolution.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels: int,
|
| 18 |
+
out_channels: int,
|
| 19 |
+
kernel_size: int | tuple[int, int],
|
| 20 |
+
stride: int = 1,
|
| 21 |
+
dilation: int | tuple[int, int] = 1,
|
| 22 |
+
groups: int = 1,
|
| 23 |
+
bias: bool = True,
|
| 24 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.causality_axis = causality_axis
|
| 29 |
+
|
| 30 |
+
# Ensure kernel_size and dilation are tuples
|
| 31 |
+
kernel_size = torch.nn.modules.utils._pair(kernel_size)
|
| 32 |
+
dilation = torch.nn.modules.utils._pair(dilation)
|
| 33 |
+
|
| 34 |
+
# Calculate padding dimensions
|
| 35 |
+
pad_h = (kernel_size[0] - 1) * dilation[0]
|
| 36 |
+
pad_w = (kernel_size[1] - 1) * dilation[1]
|
| 37 |
+
|
| 38 |
+
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
|
| 39 |
+
match self.causality_axis:
|
| 40 |
+
case CausalityAxis.NONE:
|
| 41 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
| 42 |
+
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
|
| 43 |
+
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
| 44 |
+
case CausalityAxis.HEIGHT:
|
| 45 |
+
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
| 46 |
+
case _:
|
| 47 |
+
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
| 48 |
+
|
| 49 |
+
# The internal convolution layer uses no padding, as we handle it manually
|
| 50 |
+
self.conv = torch.nn.Conv2d(
|
| 51 |
+
in_channels,
|
| 52 |
+
out_channels,
|
| 53 |
+
kernel_size,
|
| 54 |
+
stride=stride,
|
| 55 |
+
padding=0,
|
| 56 |
+
dilation=dilation,
|
| 57 |
+
groups=groups,
|
| 58 |
+
bias=bias,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
# Apply causal padding before convolution
|
| 63 |
+
x = F.pad(x, self.padding)
|
| 64 |
+
return self.conv(x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_conv2d(
|
| 68 |
+
in_channels: int,
|
| 69 |
+
out_channels: int,
|
| 70 |
+
kernel_size: int | tuple[int, int],
|
| 71 |
+
stride: int = 1,
|
| 72 |
+
padding: tuple[int, int, int, int] | None = None,
|
| 73 |
+
dilation: int = 1,
|
| 74 |
+
groups: int = 1,
|
| 75 |
+
bias: bool = True,
|
| 76 |
+
causality_axis: CausalityAxis | None = None,
|
| 77 |
+
) -> torch.nn.Module:
|
| 78 |
+
"""
|
| 79 |
+
Create a 2D convolution layer that can be either causal or non-causal.
|
| 80 |
+
Args:
|
| 81 |
+
in_channels: Number of input channels
|
| 82 |
+
out_channels: Number of output channels
|
| 83 |
+
kernel_size: Size of the convolution kernel
|
| 84 |
+
stride: Convolution stride
|
| 85 |
+
padding: Padding (if None, will be calculated based on causal flag)
|
| 86 |
+
dilation: Dilation rate
|
| 87 |
+
groups: Number of groups for grouped convolution
|
| 88 |
+
bias: Whether to use bias
|
| 89 |
+
causality_axis: Dimension along which to apply causality.
|
| 90 |
+
Returns:
|
| 91 |
+
Either a regular Conv2d or CausalConv2d layer
|
| 92 |
+
"""
|
| 93 |
+
if causality_axis is not None:
|
| 94 |
+
# For causal convolution, padding is handled internally by CausalConv2d
|
| 95 |
+
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
|
| 96 |
+
else:
|
| 97 |
+
# For non-causal convolution, use symmetric padding if not specified
|
| 98 |
+
if padding is None:
|
| 99 |
+
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
|
| 100 |
+
|
| 101 |
+
return torch.nn.Conv2d(
|
| 102 |
+
in_channels,
|
| 103 |
+
out_channels,
|
| 104 |
+
kernel_size,
|
| 105 |
+
stride,
|
| 106 |
+
padding,
|
| 107 |
+
dilation,
|
| 108 |
+
groups,
|
| 109 |
+
bias,
|
| 110 |
+
)
|
packages/ltx-core/src/ltx_core/model/audio_vae/causality_axis.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CausalityAxis(Enum):
|
| 5 |
+
"""Enum for specifying the causality axis in causal convolutions."""
|
| 6 |
+
|
| 7 |
+
NONE = None
|
| 8 |
+
WIDTH = "width"
|
| 9 |
+
HEIGHT = "height"
|
| 10 |
+
WIDTH_COMPATIBILITY = "width-compatibility"
|
packages/ltx-core/src/ltx_core/model/audio_vae/downsample.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 8 |
+
from ltx_core.model.common.normalization import NormType
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Downsample(torch.nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
A downsampling layer that can use either a strided convolution
|
| 14 |
+
or average pooling. Supports standard and causal padding for the
|
| 15 |
+
convolutional mode.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
with_conv: bool,
|
| 22 |
+
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.with_conv = with_conv
|
| 26 |
+
self.causality_axis = causality_axis
|
| 27 |
+
|
| 28 |
+
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
| 29 |
+
raise ValueError("causality is only supported when `with_conv=True`.")
|
| 30 |
+
|
| 31 |
+
if self.with_conv:
|
| 32 |
+
# Do time downsampling here
|
| 33 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 34 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
if self.with_conv:
|
| 38 |
+
# Padding tuple is in the order: (left, right, top, bottom).
|
| 39 |
+
match self.causality_axis:
|
| 40 |
+
case CausalityAxis.NONE:
|
| 41 |
+
pad = (0, 1, 0, 1)
|
| 42 |
+
case CausalityAxis.WIDTH:
|
| 43 |
+
pad = (2, 0, 0, 1)
|
| 44 |
+
case CausalityAxis.HEIGHT:
|
| 45 |
+
pad = (0, 1, 2, 0)
|
| 46 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 47 |
+
pad = (1, 0, 0, 1)
|
| 48 |
+
case _:
|
| 49 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 50 |
+
|
| 51 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 52 |
+
x = self.conv(x)
|
| 53 |
+
else:
|
| 54 |
+
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
|
| 55 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 56 |
+
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def build_downsampling_path( # noqa: PLR0913
|
| 61 |
+
*,
|
| 62 |
+
ch: int,
|
| 63 |
+
ch_mult: Tuple[int, ...],
|
| 64 |
+
num_resolutions: int,
|
| 65 |
+
num_res_blocks: int,
|
| 66 |
+
resolution: int,
|
| 67 |
+
temb_channels: int,
|
| 68 |
+
dropout: float,
|
| 69 |
+
norm_type: NormType,
|
| 70 |
+
causality_axis: CausalityAxis,
|
| 71 |
+
attn_type: AttentionType,
|
| 72 |
+
attn_resolutions: Set[int],
|
| 73 |
+
resamp_with_conv: bool,
|
| 74 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 75 |
+
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
|
| 76 |
+
down_modules = torch.nn.ModuleList()
|
| 77 |
+
curr_res = resolution
|
| 78 |
+
in_ch_mult = (1, *tuple(ch_mult))
|
| 79 |
+
block_in = ch
|
| 80 |
+
|
| 81 |
+
for i_level in range(num_resolutions):
|
| 82 |
+
block = torch.nn.ModuleList()
|
| 83 |
+
attn = torch.nn.ModuleList()
|
| 84 |
+
block_in = ch * in_ch_mult[i_level]
|
| 85 |
+
block_out = ch * ch_mult[i_level]
|
| 86 |
+
|
| 87 |
+
for _ in range(num_res_blocks):
|
| 88 |
+
block.append(
|
| 89 |
+
ResnetBlock(
|
| 90 |
+
in_channels=block_in,
|
| 91 |
+
out_channels=block_out,
|
| 92 |
+
temb_channels=temb_channels,
|
| 93 |
+
dropout=dropout,
|
| 94 |
+
norm_type=norm_type,
|
| 95 |
+
causality_axis=causality_axis,
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
block_in = block_out
|
| 99 |
+
if curr_res in attn_resolutions:
|
| 100 |
+
attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 101 |
+
|
| 102 |
+
down = torch.nn.Module()
|
| 103 |
+
down.block = block
|
| 104 |
+
down.attn = attn
|
| 105 |
+
if i_level != num_resolutions - 1:
|
| 106 |
+
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 107 |
+
curr_res = curr_res // 2
|
| 108 |
+
down_modules.append(down)
|
| 109 |
+
|
| 110 |
+
return down_modules, block_in
|
packages/ltx-core/src/ltx_core/model/audio_vae/model_configurator.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ltx_core.loader.sd_ops import SDOps
|
| 2 |
+
from ltx_core.model.audio_vae.attention import AttentionType
|
| 3 |
+
from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder
|
| 4 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 5 |
+
from ltx_core.model.audio_vae.vocoder import Vocoder
|
| 6 |
+
from ltx_core.model.common.normalization import NormType
|
| 7 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VocoderConfigurator(ModelConfigurator[Vocoder]):
|
| 11 |
+
@classmethod
|
| 12 |
+
def from_config(cls: type[Vocoder], config: dict) -> Vocoder:
|
| 13 |
+
config = config.get("vocoder", {})
|
| 14 |
+
return Vocoder(
|
| 15 |
+
resblock_kernel_sizes=config.get("resblock_kernel_sizes", [3, 7, 11]),
|
| 16 |
+
upsample_rates=config.get("upsample_rates", [6, 5, 2, 2, 2]),
|
| 17 |
+
upsample_kernel_sizes=config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]),
|
| 18 |
+
resblock_dilation_sizes=config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]),
|
| 19 |
+
upsample_initial_channel=config.get("upsample_initial_channel", 1024),
|
| 20 |
+
stereo=config.get("stereo", True),
|
| 21 |
+
resblock=config.get("resblock", "1"),
|
| 22 |
+
output_sample_rate=config.get("output_sample_rate", 24000),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
VOCODER_COMFY_KEYS_FILTER = (
|
| 27 |
+
SDOps("VOCODER_COMFY_KEYS_FILTER").with_matching(prefix="vocoder.").with_replacement("vocoder.", "")
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]):
|
| 32 |
+
@classmethod
|
| 33 |
+
def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder:
|
| 34 |
+
audio_vae_cfg = config.get("audio_vae", {})
|
| 35 |
+
model_cfg = audio_vae_cfg.get("model", {})
|
| 36 |
+
model_params = model_cfg.get("params", {})
|
| 37 |
+
ddconfig = model_params.get("ddconfig", {})
|
| 38 |
+
preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
|
| 39 |
+
stft_cfg = preprocessing_cfg.get("stft", {})
|
| 40 |
+
mel_cfg = preprocessing_cfg.get("mel", {})
|
| 41 |
+
variables_cfg = audio_vae_cfg.get("variables", {})
|
| 42 |
+
|
| 43 |
+
sample_rate = model_params.get("sampling_rate", 16000)
|
| 44 |
+
mel_hop_length = stft_cfg.get("hop_length", 160)
|
| 45 |
+
is_causal = stft_cfg.get("causal", True)
|
| 46 |
+
mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
|
| 47 |
+
|
| 48 |
+
return AudioDecoder(
|
| 49 |
+
ch=ddconfig.get("ch", 128),
|
| 50 |
+
out_ch=ddconfig.get("out_ch", 2),
|
| 51 |
+
ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
|
| 52 |
+
num_res_blocks=ddconfig.get("num_res_blocks", 2),
|
| 53 |
+
attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
|
| 54 |
+
resolution=ddconfig.get("resolution", 256),
|
| 55 |
+
z_channels=ddconfig.get("z_channels", 8),
|
| 56 |
+
norm_type=NormType(ddconfig.get("norm_type", "pixel")),
|
| 57 |
+
causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
|
| 58 |
+
dropout=ddconfig.get("dropout", 0.0),
|
| 59 |
+
mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
|
| 60 |
+
sample_rate=sample_rate,
|
| 61 |
+
mel_hop_length=mel_hop_length,
|
| 62 |
+
is_causal=is_causal,
|
| 63 |
+
mel_bins=mel_bins,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]):
|
| 68 |
+
@classmethod
|
| 69 |
+
def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder:
|
| 70 |
+
audio_vae_cfg = config.get("audio_vae", {})
|
| 71 |
+
model_cfg = audio_vae_cfg.get("model", {})
|
| 72 |
+
model_params = model_cfg.get("params", {})
|
| 73 |
+
ddconfig = model_params.get("ddconfig", {})
|
| 74 |
+
preprocessing_cfg = audio_vae_cfg.get("preprocessing", {})
|
| 75 |
+
stft_cfg = preprocessing_cfg.get("stft", {})
|
| 76 |
+
mel_cfg = preprocessing_cfg.get("mel", {})
|
| 77 |
+
variables_cfg = audio_vae_cfg.get("variables", {})
|
| 78 |
+
|
| 79 |
+
sample_rate = model_params.get("sampling_rate", 16000)
|
| 80 |
+
mel_hop_length = stft_cfg.get("hop_length", 160)
|
| 81 |
+
n_fft = stft_cfg.get("filter_length", 1024)
|
| 82 |
+
is_causal = stft_cfg.get("causal", True)
|
| 83 |
+
mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins")
|
| 84 |
+
|
| 85 |
+
return AudioEncoder(
|
| 86 |
+
ch=ddconfig.get("ch", 128),
|
| 87 |
+
ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))),
|
| 88 |
+
num_res_blocks=ddconfig.get("num_res_blocks", 2),
|
| 89 |
+
attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}),
|
| 90 |
+
resolution=ddconfig.get("resolution", 256),
|
| 91 |
+
z_channels=ddconfig.get("z_channels", 8),
|
| 92 |
+
double_z=ddconfig.get("double_z", True),
|
| 93 |
+
dropout=ddconfig.get("dropout", 0.0),
|
| 94 |
+
resamp_with_conv=ddconfig.get("resamp_with_conv", True),
|
| 95 |
+
in_channels=ddconfig.get("in_channels", 2),
|
| 96 |
+
attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")),
|
| 97 |
+
mid_block_add_attention=ddconfig.get("mid_block_add_attention", True),
|
| 98 |
+
norm_type=NormType(ddconfig.get("norm_type", "pixel")),
|
| 99 |
+
causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")),
|
| 100 |
+
sample_rate=sample_rate,
|
| 101 |
+
mel_hop_length=mel_hop_length,
|
| 102 |
+
n_fft=n_fft,
|
| 103 |
+
is_causal=is_causal,
|
| 104 |
+
mel_bins=mel_bins,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = (
|
| 109 |
+
SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER")
|
| 110 |
+
.with_matching(prefix="audio_vae.decoder.")
|
| 111 |
+
.with_matching(prefix="audio_vae.per_channel_statistics.")
|
| 112 |
+
.with_replacement("audio_vae.decoder.", "")
|
| 113 |
+
.with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = (
|
| 118 |
+
SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER")
|
| 119 |
+
.with_matching(prefix="audio_vae.encoder.")
|
| 120 |
+
.with_matching(prefix="audio_vae.per_channel_statistics.")
|
| 121 |
+
.with_replacement("audio_vae.encoder.", "")
|
| 122 |
+
.with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
| 123 |
+
)
|
packages/ltx-core/src/ltx_core/model/audio_vae/ops.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AudioProcessor(nn.Module):
|
| 7 |
+
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
sample_rate: int,
|
| 12 |
+
mel_bins: int,
|
| 13 |
+
mel_hop_length: int,
|
| 14 |
+
n_fft: int,
|
| 15 |
+
) -> None:
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.sample_rate = sample_rate
|
| 18 |
+
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
| 19 |
+
sample_rate=sample_rate,
|
| 20 |
+
n_fft=n_fft,
|
| 21 |
+
win_length=n_fft,
|
| 22 |
+
hop_length=mel_hop_length,
|
| 23 |
+
f_min=0.0,
|
| 24 |
+
f_max=sample_rate / 2.0,
|
| 25 |
+
n_mels=mel_bins,
|
| 26 |
+
window_fn=torch.hann_window,
|
| 27 |
+
center=True,
|
| 28 |
+
pad_mode="reflect",
|
| 29 |
+
power=1.0,
|
| 30 |
+
mel_scale="slaney",
|
| 31 |
+
norm="slaney",
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def resample_waveform(
|
| 35 |
+
self,
|
| 36 |
+
waveform: torch.Tensor,
|
| 37 |
+
source_rate: int,
|
| 38 |
+
target_rate: int,
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
"""Resample waveform to target sample rate if needed."""
|
| 41 |
+
if source_rate == target_rate:
|
| 42 |
+
return waveform
|
| 43 |
+
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
| 44 |
+
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
| 45 |
+
|
| 46 |
+
def waveform_to_mel(
|
| 47 |
+
self,
|
| 48 |
+
waveform: torch.Tensor,
|
| 49 |
+
waveform_sample_rate: int,
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
| 52 |
+
waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
|
| 53 |
+
|
| 54 |
+
mel = self.mel_transform(waveform)
|
| 55 |
+
mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 56 |
+
|
| 57 |
+
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
| 58 |
+
return mel.permute(0, 1, 3, 2).contiguous()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class PerChannelStatistics(nn.Module):
|
| 62 |
+
"""
|
| 63 |
+
Per-channel statistics for normalizing and denormalizing the latent representation.
|
| 64 |
+
This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, latent_channels: int = 128) -> None:
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
| 70 |
+
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
| 71 |
+
|
| 72 |
+
def un_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
| 74 |
+
|
| 75 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
packages/ltx-core/src/ltx_core/model/audio_vae/resnet.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 6 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 7 |
+
from ltx_core.model.common.normalization import NormType, build_normalization_layer
|
| 8 |
+
|
| 9 |
+
LRELU_SLOPE = 0.1
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ResBlock1(torch.nn.Module):
|
| 13 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)):
|
| 14 |
+
super(ResBlock1, self).__init__()
|
| 15 |
+
self.convs1 = torch.nn.ModuleList(
|
| 16 |
+
[
|
| 17 |
+
torch.nn.Conv1d(
|
| 18 |
+
channels,
|
| 19 |
+
channels,
|
| 20 |
+
kernel_size,
|
| 21 |
+
1,
|
| 22 |
+
dilation=dilation[0],
|
| 23 |
+
padding="same",
|
| 24 |
+
),
|
| 25 |
+
torch.nn.Conv1d(
|
| 26 |
+
channels,
|
| 27 |
+
channels,
|
| 28 |
+
kernel_size,
|
| 29 |
+
1,
|
| 30 |
+
dilation=dilation[1],
|
| 31 |
+
padding="same",
|
| 32 |
+
),
|
| 33 |
+
torch.nn.Conv1d(
|
| 34 |
+
channels,
|
| 35 |
+
channels,
|
| 36 |
+
kernel_size,
|
| 37 |
+
1,
|
| 38 |
+
dilation=dilation[2],
|
| 39 |
+
padding="same",
|
| 40 |
+
),
|
| 41 |
+
]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.convs2 = torch.nn.ModuleList(
|
| 45 |
+
[
|
| 46 |
+
torch.nn.Conv1d(
|
| 47 |
+
channels,
|
| 48 |
+
channels,
|
| 49 |
+
kernel_size,
|
| 50 |
+
1,
|
| 51 |
+
dilation=1,
|
| 52 |
+
padding="same",
|
| 53 |
+
),
|
| 54 |
+
torch.nn.Conv1d(
|
| 55 |
+
channels,
|
| 56 |
+
channels,
|
| 57 |
+
kernel_size,
|
| 58 |
+
1,
|
| 59 |
+
dilation=1,
|
| 60 |
+
padding="same",
|
| 61 |
+
),
|
| 62 |
+
torch.nn.Conv1d(
|
| 63 |
+
channels,
|
| 64 |
+
channels,
|
| 65 |
+
kernel_size,
|
| 66 |
+
1,
|
| 67 |
+
dilation=1,
|
| 68 |
+
padding="same",
|
| 69 |
+
),
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
for conv1, conv2 in zip(self.convs1, self.convs2, strict=True):
|
| 75 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 76 |
+
xt = conv1(xt)
|
| 77 |
+
xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE)
|
| 78 |
+
xt = conv2(xt)
|
| 79 |
+
x = xt + x
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ResBlock2(torch.nn.Module):
|
| 84 |
+
def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)):
|
| 85 |
+
super(ResBlock2, self).__init__()
|
| 86 |
+
self.convs = torch.nn.ModuleList(
|
| 87 |
+
[
|
| 88 |
+
torch.nn.Conv1d(
|
| 89 |
+
channels,
|
| 90 |
+
channels,
|
| 91 |
+
kernel_size,
|
| 92 |
+
1,
|
| 93 |
+
dilation=dilation[0],
|
| 94 |
+
padding="same",
|
| 95 |
+
),
|
| 96 |
+
torch.nn.Conv1d(
|
| 97 |
+
channels,
|
| 98 |
+
channels,
|
| 99 |
+
kernel_size,
|
| 100 |
+
1,
|
| 101 |
+
dilation=dilation[1],
|
| 102 |
+
padding="same",
|
| 103 |
+
),
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
for conv in self.convs:
|
| 109 |
+
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
| 110 |
+
xt = conv(xt)
|
| 111 |
+
x = xt + x
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ResnetBlock(torch.nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
*,
|
| 119 |
+
in_channels: int,
|
| 120 |
+
out_channels: int | None = None,
|
| 121 |
+
conv_shortcut: bool = False,
|
| 122 |
+
dropout: float = 0.0,
|
| 123 |
+
temb_channels: int = 512,
|
| 124 |
+
norm_type: NormType = NormType.GROUP,
|
| 125 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 126 |
+
) -> None:
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.causality_axis = causality_axis
|
| 129 |
+
|
| 130 |
+
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
| 131 |
+
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
| 132 |
+
self.in_channels = in_channels
|
| 133 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 134 |
+
self.out_channels = out_channels
|
| 135 |
+
self.use_conv_shortcut = conv_shortcut
|
| 136 |
+
|
| 137 |
+
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
| 138 |
+
self.non_linearity = torch.nn.SiLU()
|
| 139 |
+
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 140 |
+
if temb_channels > 0:
|
| 141 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 142 |
+
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
| 143 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 144 |
+
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 145 |
+
if self.in_channels != self.out_channels:
|
| 146 |
+
if self.use_conv_shortcut:
|
| 147 |
+
self.conv_shortcut = make_conv2d(
|
| 148 |
+
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
self.nin_shortcut = make_conv2d(
|
| 152 |
+
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
temb: torch.Tensor | None = None,
|
| 159 |
+
) -> torch.Tensor:
|
| 160 |
+
h = x
|
| 161 |
+
h = self.norm1(h)
|
| 162 |
+
h = self.non_linearity(h)
|
| 163 |
+
h = self.conv1(h)
|
| 164 |
+
|
| 165 |
+
if temb is not None:
|
| 166 |
+
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
| 167 |
+
|
| 168 |
+
h = self.norm2(h)
|
| 169 |
+
h = self.non_linearity(h)
|
| 170 |
+
h = self.dropout(h)
|
| 171 |
+
h = self.conv2(h)
|
| 172 |
+
|
| 173 |
+
if self.in_channels != self.out_channels:
|
| 174 |
+
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
| 175 |
+
|
| 176 |
+
return x + h
|
packages/ltx-core/src/ltx_core/model/audio_vae/upsample.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Set, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.audio_vae.attention import AttentionType, make_attn
|
| 6 |
+
from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d
|
| 7 |
+
from ltx_core.model.audio_vae.causality_axis import CausalityAxis
|
| 8 |
+
from ltx_core.model.audio_vae.resnet import ResnetBlock
|
| 9 |
+
from ltx_core.model.common.normalization import NormType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Upsample(torch.nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_channels: int,
|
| 16 |
+
with_conv: bool,
|
| 17 |
+
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
| 18 |
+
) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.with_conv = with_conv
|
| 21 |
+
self.causality_axis = causality_axis
|
| 22 |
+
if self.with_conv:
|
| 23 |
+
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
| 24 |
+
|
| 25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 27 |
+
if self.with_conv:
|
| 28 |
+
x = self.conv(x)
|
| 29 |
+
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
| 30 |
+
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
| 31 |
+
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
| 32 |
+
# So the output elements rely on the following windows:
|
| 33 |
+
# 0: [-,-,0]
|
| 34 |
+
# 1: [-,0,0]
|
| 35 |
+
# 2: [0,0,1]
|
| 36 |
+
# 3: [0,1,1]
|
| 37 |
+
# 4: [1,1,2]
|
| 38 |
+
# 5: [1,2,2]
|
| 39 |
+
# Notice that the first and second elements in the output rely only on the first element in the input,
|
| 40 |
+
# while all other elements rely on two elements in the input.
|
| 41 |
+
# So we can drop the first element to undo the padding (rather than the last element).
|
| 42 |
+
# This is a no-op for non-causal convolutions.
|
| 43 |
+
match self.causality_axis:
|
| 44 |
+
case CausalityAxis.NONE:
|
| 45 |
+
pass # x remains unchanged
|
| 46 |
+
case CausalityAxis.HEIGHT:
|
| 47 |
+
x = x[:, :, 1:, :]
|
| 48 |
+
case CausalityAxis.WIDTH:
|
| 49 |
+
x = x[:, :, :, 1:]
|
| 50 |
+
case CausalityAxis.WIDTH_COMPATIBILITY:
|
| 51 |
+
pass # x remains unchanged
|
| 52 |
+
case _:
|
| 53 |
+
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
| 54 |
+
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_upsampling_path( # noqa: PLR0913
|
| 59 |
+
*,
|
| 60 |
+
ch: int,
|
| 61 |
+
ch_mult: Tuple[int, ...],
|
| 62 |
+
num_resolutions: int,
|
| 63 |
+
num_res_blocks: int,
|
| 64 |
+
resolution: int,
|
| 65 |
+
temb_channels: int,
|
| 66 |
+
dropout: float,
|
| 67 |
+
norm_type: NormType,
|
| 68 |
+
causality_axis: CausalityAxis,
|
| 69 |
+
attn_type: AttentionType,
|
| 70 |
+
attn_resolutions: Set[int],
|
| 71 |
+
resamp_with_conv: bool,
|
| 72 |
+
initial_block_channels: int,
|
| 73 |
+
) -> tuple[torch.nn.ModuleList, int]:
|
| 74 |
+
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
|
| 75 |
+
up_modules = torch.nn.ModuleList()
|
| 76 |
+
block_in = initial_block_channels
|
| 77 |
+
curr_res = resolution // (2 ** (num_resolutions - 1))
|
| 78 |
+
|
| 79 |
+
for level in reversed(range(num_resolutions)):
|
| 80 |
+
stage = torch.nn.Module()
|
| 81 |
+
stage.block = torch.nn.ModuleList()
|
| 82 |
+
stage.attn = torch.nn.ModuleList()
|
| 83 |
+
block_out = ch * ch_mult[level]
|
| 84 |
+
|
| 85 |
+
for _ in range(num_res_blocks + 1):
|
| 86 |
+
stage.block.append(
|
| 87 |
+
ResnetBlock(
|
| 88 |
+
in_channels=block_in,
|
| 89 |
+
out_channels=block_out,
|
| 90 |
+
temb_channels=temb_channels,
|
| 91 |
+
dropout=dropout,
|
| 92 |
+
norm_type=norm_type,
|
| 93 |
+
causality_axis=causality_axis,
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
block_in = block_out
|
| 97 |
+
if curr_res in attn_resolutions:
|
| 98 |
+
stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type))
|
| 99 |
+
|
| 100 |
+
if level != 0:
|
| 101 |
+
stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
| 102 |
+
curr_res *= 2
|
| 103 |
+
|
| 104 |
+
up_modules.insert(0, stage)
|
| 105 |
+
|
| 106 |
+
return up_modules, block_in
|
packages/ltx-core/src/ltx_core/model/audio_vae/vocoder.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import einops
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1, ResBlock2
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Vocoder(torch.nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Vocoder model for synthesizing audio from Mel spectrograms.
|
| 15 |
+
Args:
|
| 16 |
+
resblock_kernel_sizes: List of kernel sizes for the residual blocks.
|
| 17 |
+
This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`.
|
| 18 |
+
upsample_rates: List of upsampling rates.
|
| 19 |
+
This value is read from the checkpoint at `config.vocoder.upsample_rates`.
|
| 20 |
+
upsample_kernel_sizes: List of kernel sizes for the upsampling layers.
|
| 21 |
+
This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`.
|
| 22 |
+
resblock_dilation_sizes: List of dilation sizes for the residual blocks.
|
| 23 |
+
This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`.
|
| 24 |
+
upsample_initial_channel: Initial number of channels for the upsampling layers.
|
| 25 |
+
This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`.
|
| 26 |
+
stereo: Whether to use stereo output.
|
| 27 |
+
This value is read from the checkpoint at `config.vocoder.stereo`.
|
| 28 |
+
resblock: Type of residual block to use.
|
| 29 |
+
This value is read from the checkpoint at `config.vocoder.resblock`.
|
| 30 |
+
output_sample_rate: Waveform sample rate.
|
| 31 |
+
This value is read from the checkpoint at `config.vocoder.output_sample_rate`.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
resblock_kernel_sizes: List[int] | None = None,
|
| 37 |
+
upsample_rates: List[int] | None = None,
|
| 38 |
+
upsample_kernel_sizes: List[int] | None = None,
|
| 39 |
+
resblock_dilation_sizes: List[List[int]] | None = None,
|
| 40 |
+
upsample_initial_channel: int = 1024,
|
| 41 |
+
stereo: bool = True,
|
| 42 |
+
resblock: str = "1",
|
| 43 |
+
output_sample_rate: int = 24000,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
# Initialize default values if not provided. Note that mutable default values are not supported.
|
| 48 |
+
if resblock_kernel_sizes is None:
|
| 49 |
+
resblock_kernel_sizes = [3, 7, 11]
|
| 50 |
+
if upsample_rates is None:
|
| 51 |
+
upsample_rates = [6, 5, 2, 2, 2]
|
| 52 |
+
if upsample_kernel_sizes is None:
|
| 53 |
+
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
| 54 |
+
if resblock_dilation_sizes is None:
|
| 55 |
+
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
| 56 |
+
|
| 57 |
+
self.output_sample_rate = output_sample_rate
|
| 58 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 59 |
+
self.num_upsamples = len(upsample_rates)
|
| 60 |
+
in_channels = 128 if stereo else 64
|
| 61 |
+
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
| 62 |
+
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
| 63 |
+
|
| 64 |
+
self.ups = nn.ModuleList()
|
| 65 |
+
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)):
|
| 66 |
+
self.ups.append(
|
| 67 |
+
nn.ConvTranspose1d(
|
| 68 |
+
upsample_initial_channel // (2**i),
|
| 69 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 70 |
+
kernel_size,
|
| 71 |
+
stride,
|
| 72 |
+
padding=(kernel_size - stride) // 2,
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.resblocks = nn.ModuleList()
|
| 77 |
+
for i, _ in enumerate(self.ups):
|
| 78 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 79 |
+
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True):
|
| 80 |
+
self.resblocks.append(resblock_class(ch, kernel_size, dilations))
|
| 81 |
+
|
| 82 |
+
out_channels = 2 if stereo else 1
|
| 83 |
+
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
| 84 |
+
self.conv_post = nn.Conv1d(final_channels, out_channels, 7, 1, padding=3)
|
| 85 |
+
|
| 86 |
+
self.upsample_factor = math.prod(layer.stride[0] for layer in self.ups)
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
"""
|
| 90 |
+
Forward pass of the vocoder.
|
| 91 |
+
Args:
|
| 92 |
+
x: Input Mel spectrogram tensor. Can be either:
|
| 93 |
+
- 3D: (batch_size, time, mel_bins) for mono
|
| 94 |
+
- 4D: (batch_size, 2, time, mel_bins) for stereo
|
| 95 |
+
Returns:
|
| 96 |
+
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
| 97 |
+
"""
|
| 98 |
+
x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time)
|
| 99 |
+
|
| 100 |
+
if x.dim() == 4: # stereo
|
| 101 |
+
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
| 102 |
+
x = einops.rearrange(x, "b s c t -> b (s c) t")
|
| 103 |
+
|
| 104 |
+
x = self.conv_pre(x)
|
| 105 |
+
|
| 106 |
+
for i in range(self.num_upsamples):
|
| 107 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 108 |
+
x = self.ups[i](x)
|
| 109 |
+
start = i * self.num_kernels
|
| 110 |
+
end = start + self.num_kernels
|
| 111 |
+
|
| 112 |
+
# Evaluate all resblocks with the same input tensor so they can run
|
| 113 |
+
# independently (and thus in parallel on accelerator hardware) before
|
| 114 |
+
# aggregating their outputs via mean.
|
| 115 |
+
block_outputs = torch.stack(
|
| 116 |
+
[self.resblocks[idx](x) for idx in range(start, end)],
|
| 117 |
+
dim=0,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
x = block_outputs.mean(dim=0)
|
| 121 |
+
|
| 122 |
+
x = self.conv_post(F.leaky_relu(x))
|
| 123 |
+
return torch.tanh(x)
|
packages/ltx-core/src/ltx_core/model/common/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Common model utilities."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"NormType",
|
| 7 |
+
"PixelNorm",
|
| 8 |
+
"build_normalization_layer",
|
| 9 |
+
]
|
packages/ltx-core/src/ltx_core/model/common/normalization.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NormType(Enum):
|
| 8 |
+
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
| 9 |
+
|
| 10 |
+
GROUP = "group"
|
| 11 |
+
PIXEL = "pixel"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class PixelNorm(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
Per-pixel (per-location) RMS normalization layer.
|
| 17 |
+
For each element along the chosen dimension, this layer normalizes the tensor
|
| 18 |
+
by the root-mean-square of its values across that dimension:
|
| 19 |
+
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Args:
|
| 25 |
+
dim: Dimension along which to compute the RMS (typically channels).
|
| 26 |
+
eps: Small constant added for numerical stability.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.dim = dim
|
| 30 |
+
self.eps = eps
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Apply RMS normalization along the configured dimension.
|
| 35 |
+
"""
|
| 36 |
+
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
| 37 |
+
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
| 38 |
+
# Normalize by the root-mean-square (RMS).
|
| 39 |
+
rms = torch.sqrt(mean_sq + self.eps)
|
| 40 |
+
return x / rms
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_normalization_layer(
|
| 44 |
+
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
| 45 |
+
) -> nn.Module:
|
| 46 |
+
"""
|
| 47 |
+
Create a normalization layer based on the normalization type.
|
| 48 |
+
Args:
|
| 49 |
+
in_channels: Number of input channels
|
| 50 |
+
num_groups: Number of groups for group normalization
|
| 51 |
+
normtype: Type of normalization: "group" or "pixel"
|
| 52 |
+
Returns:
|
| 53 |
+
A normalization layer
|
| 54 |
+
"""
|
| 55 |
+
if normtype == NormType.GROUP:
|
| 56 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 57 |
+
if normtype == NormType.PIXEL:
|
| 58 |
+
return PixelNorm(dim=1, eps=1e-6)
|
| 59 |
+
raise ValueError(f"Invalid normalization type: {normtype}")
|
packages/ltx-core/src/ltx_core/model/model_protocol.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol, TypeVar
|
| 2 |
+
|
| 3 |
+
ModelType = TypeVar("ModelType")
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModelConfigurator(Protocol[ModelType]):
|
| 7 |
+
"""Protocol for model loader classes that instantiates models from a configuration dictionary."""
|
| 8 |
+
|
| 9 |
+
@classmethod
|
| 10 |
+
def from_config(cls, config: dict) -> ModelType: ...
|
packages/ltx-core/src/ltx_core/model/transformer/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.transformer.modality import Modality
|
| 4 |
+
from ltx_core.model.transformer.model import LTXModel, X0Model
|
| 5 |
+
from ltx_core.model.transformer.model_configurator import (
|
| 6 |
+
LTXV_MODEL_COMFY_RENAMING_MAP,
|
| 7 |
+
LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP,
|
| 8 |
+
UPCAST_DURING_INFERENCE,
|
| 9 |
+
LTXModelConfigurator,
|
| 10 |
+
LTXVideoOnlyModelConfigurator,
|
| 11 |
+
UpcastWithStochasticRounding,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"LTXV_MODEL_COMFY_RENAMING_MAP",
|
| 16 |
+
"LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP",
|
| 17 |
+
"UPCAST_DURING_INFERENCE",
|
| 18 |
+
"LTXModel",
|
| 19 |
+
"LTXModelConfigurator",
|
| 20 |
+
"LTXVideoOnlyModelConfigurator",
|
| 21 |
+
"Modality",
|
| 22 |
+
"UpcastWithStochasticRounding",
|
| 23 |
+
"X0Model",
|
| 24 |
+
]
|
packages/ltx-core/src/ltx_core/model/transformer/adaln.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AdaLayerNormSingle(torch.nn.Module):
|
| 9 |
+
r"""
|
| 10 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
| 11 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
| 12 |
+
Parameters:
|
| 13 |
+
embedding_dim (`int`): The size of each embedding vector.
|
| 14 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, embedding_dim: int, embedding_coefficient: int = 6):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 21 |
+
embedding_dim,
|
| 22 |
+
size_emb_dim=embedding_dim // 3,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
self.silu = torch.nn.SiLU()
|
| 26 |
+
self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
| 27 |
+
|
| 28 |
+
def forward(
|
| 29 |
+
self,
|
| 30 |
+
timestep: torch.Tensor,
|
| 31 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
| 32 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 33 |
+
embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype)
|
| 34 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
packages/ltx-core/src/ltx_core/model/transformer/attention.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Protocol
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
memory_efficient_attention = None
|
| 9 |
+
flash_attn_interface = None
|
| 10 |
+
try:
|
| 11 |
+
from xformers.ops import memory_efficient_attention
|
| 12 |
+
except ImportError:
|
| 13 |
+
memory_efficient_attention = None
|
| 14 |
+
|
| 15 |
+
import flash_attn_interface
|
| 16 |
+
|
| 17 |
+
class AttentionCallable(Protocol):
|
| 18 |
+
def __call__(
|
| 19 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 20 |
+
) -> torch.Tensor: ...
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PytorchAttention(AttentionCallable):
|
| 24 |
+
def __call__(
|
| 25 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 26 |
+
) -> torch.Tensor:
|
| 27 |
+
b, _, dim_head = q.shape
|
| 28 |
+
dim_head //= heads
|
| 29 |
+
q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v))
|
| 30 |
+
|
| 31 |
+
if mask is not None:
|
| 32 |
+
# add a batch dimension if there isn't already one
|
| 33 |
+
if mask.ndim == 2:
|
| 34 |
+
mask = mask.unsqueeze(0)
|
| 35 |
+
# add a heads dimension if there isn't already one
|
| 36 |
+
if mask.ndim == 3:
|
| 37 |
+
mask = mask.unsqueeze(1)
|
| 38 |
+
|
| 39 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
| 40 |
+
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class XFormersAttention(AttentionCallable):
|
| 45 |
+
def __call__(
|
| 46 |
+
self,
|
| 47 |
+
q: torch.Tensor,
|
| 48 |
+
k: torch.Tensor,
|
| 49 |
+
v: torch.Tensor,
|
| 50 |
+
heads: int,
|
| 51 |
+
mask: torch.Tensor | None = None,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
if memory_efficient_attention is None:
|
| 54 |
+
raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.")
|
| 55 |
+
|
| 56 |
+
b, _, dim_head = q.shape
|
| 57 |
+
dim_head //= heads
|
| 58 |
+
|
| 59 |
+
# xformers expects [B, M, H, K]
|
| 60 |
+
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
|
| 61 |
+
|
| 62 |
+
if mask is not None:
|
| 63 |
+
# add a singleton batch dimension
|
| 64 |
+
if mask.ndim == 2:
|
| 65 |
+
mask = mask.unsqueeze(0)
|
| 66 |
+
# add a singleton heads dimension
|
| 67 |
+
if mask.ndim == 3:
|
| 68 |
+
mask = mask.unsqueeze(1)
|
| 69 |
+
# pad to a multiple of 8
|
| 70 |
+
pad = 8 - mask.shape[-1] % 8
|
| 71 |
+
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
|
| 72 |
+
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
|
| 73 |
+
# in flux, this matrix ends up being over 1GB
|
| 74 |
+
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
|
| 75 |
+
mask_out = torch.empty(
|
| 76 |
+
[mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
mask_out[..., : mask.shape[-1]] = mask
|
| 80 |
+
# doesn't this remove the padding again??
|
| 81 |
+
mask = mask_out[..., : mask.shape[-1]]
|
| 82 |
+
mask = mask.expand(b, heads, -1, -1)
|
| 83 |
+
|
| 84 |
+
out = memory_efficient_attention(q.to(v.dtype), k.to(v.dtype), v, attn_bias=mask, p=0.0)
|
| 85 |
+
out = out.reshape(b, -1, heads * dim_head)
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class FlashAttention3(AttentionCallable):
|
| 90 |
+
def __call__(
|
| 91 |
+
self,
|
| 92 |
+
q: torch.Tensor,
|
| 93 |
+
k: torch.Tensor,
|
| 94 |
+
v: torch.Tensor,
|
| 95 |
+
heads: int,
|
| 96 |
+
mask: torch.Tensor | None = None,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
if flash_attn_interface is None:
|
| 99 |
+
raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.")
|
| 100 |
+
|
| 101 |
+
b, _, dim_head = q.shape
|
| 102 |
+
dim_head //= heads
|
| 103 |
+
|
| 104 |
+
q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))
|
| 105 |
+
|
| 106 |
+
if mask is not None:
|
| 107 |
+
raise NotImplementedError("Mask is not supported for FlashAttention3")
|
| 108 |
+
|
| 109 |
+
out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v)
|
| 110 |
+
out = out.reshape(b, -1, heads * dim_head)
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class AttentionFunction(Enum):
|
| 115 |
+
PYTORCH = "pytorch"
|
| 116 |
+
XFORMERS = "xformers"
|
| 117 |
+
FLASH_ATTENTION_3 = "flash_attention_3"
|
| 118 |
+
DEFAULT = "default"
|
| 119 |
+
|
| 120 |
+
def __call__(
|
| 121 |
+
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None
|
| 122 |
+
) -> torch.Tensor:
|
| 123 |
+
if mask is None:
|
| 124 |
+
return FlashAttention3()(q, k, v, heads, mask)
|
| 125 |
+
else:
|
| 126 |
+
return (
|
| 127 |
+
XFormersAttention()(q, k, v, heads, mask)
|
| 128 |
+
if memory_efficient_attention is not None
|
| 129 |
+
else PytorchAttention()(q, k, v, heads, mask)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class Attention(torch.nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
query_dim: int,
|
| 137 |
+
context_dim: int | None = None,
|
| 138 |
+
heads: int = 8,
|
| 139 |
+
dim_head: int = 64,
|
| 140 |
+
norm_eps: float = 1e-6,
|
| 141 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 142 |
+
attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT,
|
| 143 |
+
) -> None:
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.rope_type = rope_type
|
| 146 |
+
self.attention_function = attention_function
|
| 147 |
+
|
| 148 |
+
inner_dim = dim_head * heads
|
| 149 |
+
context_dim = query_dim if context_dim is None else context_dim
|
| 150 |
+
|
| 151 |
+
self.heads = heads
|
| 152 |
+
self.dim_head = dim_head
|
| 153 |
+
|
| 154 |
+
self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
|
| 155 |
+
self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps)
|
| 156 |
+
|
| 157 |
+
self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True)
|
| 158 |
+
self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
| 159 |
+
self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True)
|
| 160 |
+
|
| 161 |
+
self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity())
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
x: torch.Tensor,
|
| 166 |
+
context: torch.Tensor | None = None,
|
| 167 |
+
mask: torch.Tensor | None = None,
|
| 168 |
+
pe: torch.Tensor | None = None,
|
| 169 |
+
k_pe: torch.Tensor | None = None,
|
| 170 |
+
) -> torch.Tensor:
|
| 171 |
+
q = self.to_q(x)
|
| 172 |
+
context = x if context is None else context
|
| 173 |
+
k = self.to_k(context)
|
| 174 |
+
v = self.to_v(context)
|
| 175 |
+
|
| 176 |
+
q = self.q_norm(q)
|
| 177 |
+
k = self.k_norm(k)
|
| 178 |
+
|
| 179 |
+
if pe is not None:
|
| 180 |
+
q = apply_rotary_emb(q, pe, self.rope_type)
|
| 181 |
+
k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
|
| 182 |
+
|
| 183 |
+
# attention_function can be an enum *or* a custom callable
|
| 184 |
+
out = self.attention_function(q, k, v, self.heads, mask)
|
| 185 |
+
return self.to_out(out)
|
packages/ltx-core/src/ltx_core/model/transformer/feed_forward.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.transformer.gelu_approx import GELUApprox
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FeedForward(torch.nn.Module):
|
| 7 |
+
def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
inner_dim = int(dim * mult)
|
| 10 |
+
project_in = GELUApprox(dim, inner_dim)
|
| 11 |
+
|
| 12 |
+
self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out))
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return self.net(x)
|
packages/ltx-core/src/ltx_core/model/transformer/gelu_approx.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GELUApprox(torch.nn.Module):
|
| 5 |
+
def __init__(self, dim_in: int, dim_out: int) -> None:
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.proj = torch.nn.Linear(dim_in, dim_out)
|
| 8 |
+
|
| 9 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 10 |
+
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
packages/ltx-core/src/ltx_core/model/transformer/modality.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass(frozen=True)
|
| 7 |
+
class Modality:
|
| 8 |
+
"""
|
| 9 |
+
Input data for a single modality (video or audio) in the transformer.
|
| 10 |
+
Bundles the latent tokens, timestep embeddings, positional information,
|
| 11 |
+
and text conditioning context for processing by the diffusion transformer.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
latent: (
|
| 15 |
+
torch.Tensor
|
| 16 |
+
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
| 17 |
+
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
| 18 |
+
positions: (
|
| 19 |
+
torch.Tensor
|
| 20 |
+
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
| 21 |
+
context: torch.Tensor
|
| 22 |
+
enabled: bool = True
|
| 23 |
+
context_mask: torch.Tensor | None = None
|
packages/ltx-core/src/ltx_core/model/transformer/model.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_core.guidance.perturbations import BatchedPerturbationConfig
|
| 6 |
+
from ltx_core.model.transformer.adaln import AdaLayerNormSingle
|
| 7 |
+
from ltx_core.model.transformer.attention import AttentionCallable, AttentionFunction
|
| 8 |
+
from ltx_core.model.transformer.modality import Modality
|
| 9 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 10 |
+
from ltx_core.model.transformer.text_projection import PixArtAlphaTextProjection
|
| 11 |
+
from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, TransformerConfig
|
| 12 |
+
from ltx_core.model.transformer.transformer_args import (
|
| 13 |
+
MultiModalTransformerArgsPreprocessor,
|
| 14 |
+
TransformerArgs,
|
| 15 |
+
TransformerArgsPreprocessor,
|
| 16 |
+
)
|
| 17 |
+
from ltx_core.utils import to_denoised
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LTXModelType(Enum):
|
| 21 |
+
AudioVideo = "ltx av model"
|
| 22 |
+
VideoOnly = "ltx video only model"
|
| 23 |
+
AudioOnly = "ltx audio only model"
|
| 24 |
+
|
| 25 |
+
def is_video_enabled(self) -> bool:
|
| 26 |
+
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
|
| 27 |
+
|
| 28 |
+
def is_audio_enabled(self) -> bool:
|
| 29 |
+
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LTXModel(torch.nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
LTX model transformer implementation.
|
| 35 |
+
This class implements the transformer blocks for the LTX model.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__( # noqa: PLR0913
|
| 39 |
+
self,
|
| 40 |
+
*,
|
| 41 |
+
model_type: LTXModelType = LTXModelType.AudioVideo,
|
| 42 |
+
num_attention_heads: int = 32,
|
| 43 |
+
attention_head_dim: int = 128,
|
| 44 |
+
in_channels: int = 128,
|
| 45 |
+
out_channels: int = 128,
|
| 46 |
+
num_layers: int = 48,
|
| 47 |
+
cross_attention_dim: int = 4096,
|
| 48 |
+
norm_eps: float = 1e-06,
|
| 49 |
+
attention_type: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT,
|
| 50 |
+
caption_channels: int = 3840,
|
| 51 |
+
positional_embedding_theta: float = 10000.0,
|
| 52 |
+
positional_embedding_max_pos: list[int] | None = None,
|
| 53 |
+
timestep_scale_multiplier: int = 1000,
|
| 54 |
+
use_middle_indices_grid: bool = True,
|
| 55 |
+
audio_num_attention_heads: int = 32,
|
| 56 |
+
audio_attention_head_dim: int = 64,
|
| 57 |
+
audio_in_channels: int = 128,
|
| 58 |
+
audio_out_channels: int = 128,
|
| 59 |
+
audio_cross_attention_dim: int = 2048,
|
| 60 |
+
audio_positional_embedding_max_pos: list[int] | None = None,
|
| 61 |
+
av_ca_timestep_scale_multiplier: int = 1,
|
| 62 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 63 |
+
double_precision_rope: bool = False,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self._enable_gradient_checkpointing = False
|
| 67 |
+
self.use_middle_indices_grid = use_middle_indices_grid
|
| 68 |
+
self.rope_type = rope_type
|
| 69 |
+
self.double_precision_rope = double_precision_rope
|
| 70 |
+
self.timestep_scale_multiplier = timestep_scale_multiplier
|
| 71 |
+
self.positional_embedding_theta = positional_embedding_theta
|
| 72 |
+
self.model_type = model_type
|
| 73 |
+
cross_pe_max_pos = None
|
| 74 |
+
if model_type.is_video_enabled():
|
| 75 |
+
if positional_embedding_max_pos is None:
|
| 76 |
+
positional_embedding_max_pos = [20, 2048, 2048]
|
| 77 |
+
self.positional_embedding_max_pos = positional_embedding_max_pos
|
| 78 |
+
self.num_attention_heads = num_attention_heads
|
| 79 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 80 |
+
self._init_video(
|
| 81 |
+
in_channels=in_channels,
|
| 82 |
+
out_channels=out_channels,
|
| 83 |
+
caption_channels=caption_channels,
|
| 84 |
+
norm_eps=norm_eps,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if model_type.is_audio_enabled():
|
| 88 |
+
if audio_positional_embedding_max_pos is None:
|
| 89 |
+
audio_positional_embedding_max_pos = [20]
|
| 90 |
+
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
| 91 |
+
self.audio_num_attention_heads = audio_num_attention_heads
|
| 92 |
+
self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim
|
| 93 |
+
self._init_audio(
|
| 94 |
+
in_channels=audio_in_channels,
|
| 95 |
+
out_channels=audio_out_channels,
|
| 96 |
+
caption_channels=caption_channels,
|
| 97 |
+
norm_eps=norm_eps,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if model_type.is_video_enabled() and model_type.is_audio_enabled():
|
| 101 |
+
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
| 102 |
+
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
| 103 |
+
self.audio_cross_attention_dim = audio_cross_attention_dim
|
| 104 |
+
self._init_audio_video(num_scale_shift_values=4)
|
| 105 |
+
|
| 106 |
+
self._init_preprocessors(cross_pe_max_pos)
|
| 107 |
+
# Initialize transformer blocks
|
| 108 |
+
self._init_transformer_blocks(
|
| 109 |
+
num_layers=num_layers,
|
| 110 |
+
attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0,
|
| 111 |
+
cross_attention_dim=cross_attention_dim,
|
| 112 |
+
audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0,
|
| 113 |
+
audio_cross_attention_dim=audio_cross_attention_dim,
|
| 114 |
+
norm_eps=norm_eps,
|
| 115 |
+
attention_type=attention_type,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def _init_video(
|
| 119 |
+
self,
|
| 120 |
+
in_channels: int,
|
| 121 |
+
out_channels: int,
|
| 122 |
+
caption_channels: int,
|
| 123 |
+
norm_eps: float,
|
| 124 |
+
) -> None:
|
| 125 |
+
"""Initialize video-specific components."""
|
| 126 |
+
# Video input components
|
| 127 |
+
self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True)
|
| 128 |
+
|
| 129 |
+
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
|
| 130 |
+
|
| 131 |
+
# Video caption projection
|
| 132 |
+
self.caption_projection = PixArtAlphaTextProjection(
|
| 133 |
+
in_features=caption_channels,
|
| 134 |
+
hidden_size=self.inner_dim,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Video output components
|
| 138 |
+
self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim))
|
| 139 |
+
self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps)
|
| 140 |
+
self.proj_out = torch.nn.Linear(self.inner_dim, out_channels)
|
| 141 |
+
|
| 142 |
+
def _init_audio(
|
| 143 |
+
self,
|
| 144 |
+
in_channels: int,
|
| 145 |
+
out_channels: int,
|
| 146 |
+
caption_channels: int,
|
| 147 |
+
norm_eps: float,
|
| 148 |
+
) -> None:
|
| 149 |
+
"""Initialize audio-specific components."""
|
| 150 |
+
|
| 151 |
+
# Audio input components
|
| 152 |
+
self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True)
|
| 153 |
+
|
| 154 |
+
self.audio_adaln_single = AdaLayerNormSingle(
|
| 155 |
+
self.audio_inner_dim,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Audio caption projection
|
| 159 |
+
self.audio_caption_projection = PixArtAlphaTextProjection(
|
| 160 |
+
in_features=caption_channels,
|
| 161 |
+
hidden_size=self.audio_inner_dim,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Audio output components
|
| 165 |
+
self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim))
|
| 166 |
+
self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps)
|
| 167 |
+
self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels)
|
| 168 |
+
|
| 169 |
+
def _init_audio_video(
|
| 170 |
+
self,
|
| 171 |
+
num_scale_shift_values: int,
|
| 172 |
+
) -> None:
|
| 173 |
+
"""Initialize audio-video cross-attention components."""
|
| 174 |
+
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
| 175 |
+
self.inner_dim,
|
| 176 |
+
embedding_coefficient=num_scale_shift_values,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
| 180 |
+
self.audio_inner_dim,
|
| 181 |
+
embedding_coefficient=num_scale_shift_values,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
| 185 |
+
self.inner_dim,
|
| 186 |
+
embedding_coefficient=1,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
| 190 |
+
self.audio_inner_dim,
|
| 191 |
+
embedding_coefficient=1,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def _init_preprocessors(
|
| 195 |
+
self,
|
| 196 |
+
cross_pe_max_pos: int | None = None,
|
| 197 |
+
) -> None:
|
| 198 |
+
"""Initialize preprocessors for LTX."""
|
| 199 |
+
|
| 200 |
+
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
|
| 201 |
+
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
| 202 |
+
patchify_proj=self.patchify_proj,
|
| 203 |
+
adaln=self.adaln_single,
|
| 204 |
+
caption_projection=self.caption_projection,
|
| 205 |
+
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
| 206 |
+
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
| 207 |
+
inner_dim=self.inner_dim,
|
| 208 |
+
max_pos=self.positional_embedding_max_pos,
|
| 209 |
+
num_attention_heads=self.num_attention_heads,
|
| 210 |
+
cross_pe_max_pos=cross_pe_max_pos,
|
| 211 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 212 |
+
audio_cross_attention_dim=self.audio_cross_attention_dim,
|
| 213 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 214 |
+
double_precision_rope=self.double_precision_rope,
|
| 215 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 216 |
+
rope_type=self.rope_type,
|
| 217 |
+
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
| 218 |
+
)
|
| 219 |
+
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
| 220 |
+
patchify_proj=self.audio_patchify_proj,
|
| 221 |
+
adaln=self.audio_adaln_single,
|
| 222 |
+
caption_projection=self.audio_caption_projection,
|
| 223 |
+
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
| 224 |
+
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
| 225 |
+
inner_dim=self.audio_inner_dim,
|
| 226 |
+
max_pos=self.audio_positional_embedding_max_pos,
|
| 227 |
+
num_attention_heads=self.audio_num_attention_heads,
|
| 228 |
+
cross_pe_max_pos=cross_pe_max_pos,
|
| 229 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 230 |
+
audio_cross_attention_dim=self.audio_cross_attention_dim,
|
| 231 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 232 |
+
double_precision_rope=self.double_precision_rope,
|
| 233 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 234 |
+
rope_type=self.rope_type,
|
| 235 |
+
av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier,
|
| 236 |
+
)
|
| 237 |
+
elif self.model_type.is_video_enabled():
|
| 238 |
+
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
| 239 |
+
patchify_proj=self.patchify_proj,
|
| 240 |
+
adaln=self.adaln_single,
|
| 241 |
+
caption_projection=self.caption_projection,
|
| 242 |
+
inner_dim=self.inner_dim,
|
| 243 |
+
max_pos=self.positional_embedding_max_pos,
|
| 244 |
+
num_attention_heads=self.num_attention_heads,
|
| 245 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 246 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 247 |
+
double_precision_rope=self.double_precision_rope,
|
| 248 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 249 |
+
rope_type=self.rope_type,
|
| 250 |
+
)
|
| 251 |
+
elif self.model_type.is_audio_enabled():
|
| 252 |
+
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
| 253 |
+
patchify_proj=self.audio_patchify_proj,
|
| 254 |
+
adaln=self.audio_adaln_single,
|
| 255 |
+
caption_projection=self.audio_caption_projection,
|
| 256 |
+
inner_dim=self.audio_inner_dim,
|
| 257 |
+
max_pos=self.audio_positional_embedding_max_pos,
|
| 258 |
+
num_attention_heads=self.audio_num_attention_heads,
|
| 259 |
+
use_middle_indices_grid=self.use_middle_indices_grid,
|
| 260 |
+
timestep_scale_multiplier=self.timestep_scale_multiplier,
|
| 261 |
+
double_precision_rope=self.double_precision_rope,
|
| 262 |
+
positional_embedding_theta=self.positional_embedding_theta,
|
| 263 |
+
rope_type=self.rope_type,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
def _init_transformer_blocks(
|
| 267 |
+
self,
|
| 268 |
+
num_layers: int,
|
| 269 |
+
attention_head_dim: int,
|
| 270 |
+
cross_attention_dim: int,
|
| 271 |
+
audio_attention_head_dim: int,
|
| 272 |
+
audio_cross_attention_dim: int,
|
| 273 |
+
norm_eps: float,
|
| 274 |
+
attention_type: AttentionFunction | AttentionCallable,
|
| 275 |
+
) -> None:
|
| 276 |
+
"""Initialize transformer blocks for LTX."""
|
| 277 |
+
video_config = (
|
| 278 |
+
TransformerConfig(
|
| 279 |
+
dim=self.inner_dim,
|
| 280 |
+
heads=self.num_attention_heads,
|
| 281 |
+
d_head=attention_head_dim,
|
| 282 |
+
context_dim=cross_attention_dim,
|
| 283 |
+
)
|
| 284 |
+
if self.model_type.is_video_enabled()
|
| 285 |
+
else None
|
| 286 |
+
)
|
| 287 |
+
audio_config = (
|
| 288 |
+
TransformerConfig(
|
| 289 |
+
dim=self.audio_inner_dim,
|
| 290 |
+
heads=self.audio_num_attention_heads,
|
| 291 |
+
d_head=audio_attention_head_dim,
|
| 292 |
+
context_dim=audio_cross_attention_dim,
|
| 293 |
+
)
|
| 294 |
+
if self.model_type.is_audio_enabled()
|
| 295 |
+
else None
|
| 296 |
+
)
|
| 297 |
+
self.transformer_blocks = torch.nn.ModuleList(
|
| 298 |
+
[
|
| 299 |
+
BasicAVTransformerBlock(
|
| 300 |
+
idx=idx,
|
| 301 |
+
video=video_config,
|
| 302 |
+
audio=audio_config,
|
| 303 |
+
rope_type=self.rope_type,
|
| 304 |
+
norm_eps=norm_eps,
|
| 305 |
+
attention_function=attention_type,
|
| 306 |
+
)
|
| 307 |
+
for idx in range(num_layers)
|
| 308 |
+
]
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def set_gradient_checkpointing(self, enable: bool) -> None:
|
| 312 |
+
"""Enable or disable gradient checkpointing for transformer blocks.
|
| 313 |
+
Gradient checkpointing trades compute for memory by recomputing activations
|
| 314 |
+
during the backward pass instead of storing them. This can significantly
|
| 315 |
+
reduce memory usage at the cost of ~20-30% slower training.
|
| 316 |
+
Args:
|
| 317 |
+
enable: Whether to enable gradient checkpointing
|
| 318 |
+
"""
|
| 319 |
+
self._enable_gradient_checkpointing = enable
|
| 320 |
+
|
| 321 |
+
def _process_transformer_blocks(
|
| 322 |
+
self,
|
| 323 |
+
video: TransformerArgs | None,
|
| 324 |
+
audio: TransformerArgs | None,
|
| 325 |
+
perturbations: BatchedPerturbationConfig,
|
| 326 |
+
) -> tuple[TransformerArgs, TransformerArgs]:
|
| 327 |
+
"""Process transformer blocks for LTXAV."""
|
| 328 |
+
|
| 329 |
+
# Process transformer blocks
|
| 330 |
+
for block in self.transformer_blocks:
|
| 331 |
+
if self._enable_gradient_checkpointing and self.training:
|
| 332 |
+
# Use gradient checkpointing to save memory during training.
|
| 333 |
+
# With use_reentrant=False, we can pass dataclasses directly -
|
| 334 |
+
# PyTorch will track all tensor leaves in the computation graph.
|
| 335 |
+
video, audio = torch.utils.checkpoint.checkpoint(
|
| 336 |
+
block,
|
| 337 |
+
video,
|
| 338 |
+
audio,
|
| 339 |
+
perturbations,
|
| 340 |
+
use_reentrant=False,
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
video, audio = block(
|
| 344 |
+
video=video,
|
| 345 |
+
audio=audio,
|
| 346 |
+
perturbations=perturbations,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return video, audio
|
| 350 |
+
|
| 351 |
+
def _process_output(
|
| 352 |
+
self,
|
| 353 |
+
scale_shift_table: torch.Tensor,
|
| 354 |
+
norm_out: torch.nn.LayerNorm,
|
| 355 |
+
proj_out: torch.nn.Linear,
|
| 356 |
+
x: torch.Tensor,
|
| 357 |
+
embedded_timestep: torch.Tensor,
|
| 358 |
+
) -> torch.Tensor:
|
| 359 |
+
"""Process output for LTXV."""
|
| 360 |
+
# Apply scale-shift modulation
|
| 361 |
+
scale_shift_values = (
|
| 362 |
+
scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
| 363 |
+
)
|
| 364 |
+
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 365 |
+
|
| 366 |
+
x = norm_out(x)
|
| 367 |
+
x = x * (1 + scale) + shift
|
| 368 |
+
x = proj_out(x)
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
def forward(
|
| 372 |
+
self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig
|
| 373 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 374 |
+
"""
|
| 375 |
+
Forward pass for LTX models.
|
| 376 |
+
Returns:
|
| 377 |
+
Processed output tensors
|
| 378 |
+
"""
|
| 379 |
+
if not self.model_type.is_video_enabled() and video is not None:
|
| 380 |
+
raise ValueError("Video is not enabled for this model")
|
| 381 |
+
if not self.model_type.is_audio_enabled() and audio is not None:
|
| 382 |
+
raise ValueError("Audio is not enabled for this model")
|
| 383 |
+
|
| 384 |
+
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
| 385 |
+
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
| 386 |
+
# Process transformer blocks
|
| 387 |
+
video_out, audio_out = self._process_transformer_blocks(
|
| 388 |
+
video=video_args,
|
| 389 |
+
audio=audio_args,
|
| 390 |
+
perturbations=perturbations,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Process output
|
| 394 |
+
vx = (
|
| 395 |
+
self._process_output(
|
| 396 |
+
self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep
|
| 397 |
+
)
|
| 398 |
+
if video_out is not None
|
| 399 |
+
else None
|
| 400 |
+
)
|
| 401 |
+
ax = (
|
| 402 |
+
self._process_output(
|
| 403 |
+
self.audio_scale_shift_table,
|
| 404 |
+
self.audio_norm_out,
|
| 405 |
+
self.audio_proj_out,
|
| 406 |
+
audio_out.x,
|
| 407 |
+
audio_out.embedded_timestep,
|
| 408 |
+
)
|
| 409 |
+
if audio_out is not None
|
| 410 |
+
else None
|
| 411 |
+
)
|
| 412 |
+
return vx, ax
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class LegacyX0Model(torch.nn.Module):
|
| 416 |
+
"""
|
| 417 |
+
Legacy X0 model implementation.
|
| 418 |
+
Returns fully denoised output based on the velocities produced by the base model.
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
def __init__(self, velocity_model: LTXModel):
|
| 422 |
+
super().__init__()
|
| 423 |
+
self.velocity_model = velocity_model
|
| 424 |
+
|
| 425 |
+
def forward(
|
| 426 |
+
self,
|
| 427 |
+
video: Modality | None,
|
| 428 |
+
audio: Modality | None,
|
| 429 |
+
perturbations: BatchedPerturbationConfig,
|
| 430 |
+
sigma: float,
|
| 431 |
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 432 |
+
"""
|
| 433 |
+
Denoise the video and audio according to the sigma.
|
| 434 |
+
Returns:
|
| 435 |
+
Denoised video and audio
|
| 436 |
+
"""
|
| 437 |
+
vx, ax = self.velocity_model(video, audio, perturbations)
|
| 438 |
+
denoised_video = to_denoised(video.latent, vx, sigma) if vx is not None else None
|
| 439 |
+
denoised_audio = to_denoised(audio.latent, ax, sigma) if ax is not None else None
|
| 440 |
+
return denoised_video, denoised_audio
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class X0Model(torch.nn.Module):
|
| 444 |
+
"""
|
| 445 |
+
X0 model implementation.
|
| 446 |
+
Returns fully denoised outputs based on the velocities produced by the base model.
|
| 447 |
+
Applies scaled denoising to the video and audio according to the timesteps = sigma * denoising_mask.
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
def __init__(self, velocity_model: LTXModel):
|
| 451 |
+
super().__init__()
|
| 452 |
+
self.velocity_model = velocity_model
|
| 453 |
+
|
| 454 |
+
def forward(
|
| 455 |
+
self,
|
| 456 |
+
video: Modality | None,
|
| 457 |
+
audio: Modality | None,
|
| 458 |
+
perturbations: BatchedPerturbationConfig,
|
| 459 |
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 460 |
+
"""
|
| 461 |
+
Denoise the video and audio according to the sigma.
|
| 462 |
+
Returns:
|
| 463 |
+
Denoised video and audio
|
| 464 |
+
"""
|
| 465 |
+
vx, ax = self.velocity_model(video, audio, perturbations)
|
| 466 |
+
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
| 467 |
+
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
| 468 |
+
return denoised_video, denoised_audio
|
packages/ltx-core/src/ltx_core/model/transformer/model_configurator.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.loader.fuse_loras import fused_add_round_launch
|
| 4 |
+
from ltx_core.loader.module_ops import ModuleOps
|
| 5 |
+
from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps
|
| 6 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 7 |
+
from ltx_core.model.transformer.attention import AttentionFunction
|
| 8 |
+
from ltx_core.model.transformer.model import LTXModel, LTXModelType
|
| 9 |
+
from ltx_core.model.transformer.rope import LTXRopeType
|
| 10 |
+
from ltx_core.utils import check_config_value
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LTXModelConfigurator(ModelConfigurator[LTXModel]):
|
| 14 |
+
"""
|
| 15 |
+
Configurator for LTX model.
|
| 16 |
+
Used to create an LTX model from a configuration dictionary.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 21 |
+
config = config.get("transformer", {})
|
| 22 |
+
|
| 23 |
+
check_config_value(config, "dropout", 0.0)
|
| 24 |
+
check_config_value(config, "attention_bias", True)
|
| 25 |
+
check_config_value(config, "num_vector_embeds", None)
|
| 26 |
+
check_config_value(config, "activation_fn", "gelu-approximate")
|
| 27 |
+
check_config_value(config, "num_embeds_ada_norm", 1000)
|
| 28 |
+
check_config_value(config, "use_linear_projection", False)
|
| 29 |
+
check_config_value(config, "only_cross_attention", False)
|
| 30 |
+
check_config_value(config, "cross_attention_norm", True)
|
| 31 |
+
check_config_value(config, "double_self_attention", False)
|
| 32 |
+
check_config_value(config, "upcast_attention", False)
|
| 33 |
+
check_config_value(config, "standardization_norm", "rms_norm")
|
| 34 |
+
check_config_value(config, "norm_elementwise_affine", False)
|
| 35 |
+
check_config_value(config, "qk_norm", "rms_norm")
|
| 36 |
+
check_config_value(config, "positional_embedding_type", "rope")
|
| 37 |
+
check_config_value(config, "use_audio_video_cross_attention", True)
|
| 38 |
+
check_config_value(config, "share_ff", False)
|
| 39 |
+
check_config_value(config, "av_cross_ada_norm", True)
|
| 40 |
+
check_config_value(config, "use_middle_indices_grid", True)
|
| 41 |
+
|
| 42 |
+
return LTXModel(
|
| 43 |
+
model_type=LTXModelType.AudioVideo,
|
| 44 |
+
num_attention_heads=config.get("num_attention_heads", 32),
|
| 45 |
+
attention_head_dim=config.get("attention_head_dim", 128),
|
| 46 |
+
in_channels=config.get("in_channels", 128),
|
| 47 |
+
out_channels=config.get("out_channels", 128),
|
| 48 |
+
num_layers=config.get("num_layers", 48),
|
| 49 |
+
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
| 50 |
+
norm_eps=config.get("norm_eps", 1e-06),
|
| 51 |
+
attention_type=AttentionFunction(config.get("attention_type", "default")),
|
| 52 |
+
caption_channels=config.get("caption_channels", 3840),
|
| 53 |
+
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
| 54 |
+
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
| 55 |
+
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
| 56 |
+
use_middle_indices_grid=config.get("use_middle_indices_grid", True),
|
| 57 |
+
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
|
| 58 |
+
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
|
| 59 |
+
audio_in_channels=config.get("audio_in_channels", 128),
|
| 60 |
+
audio_out_channels=config.get("audio_out_channels", 128),
|
| 61 |
+
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
|
| 62 |
+
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
|
| 63 |
+
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1),
|
| 64 |
+
rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
|
| 65 |
+
double_precision_rope=config.get("frequencies_precision", False) == "float64",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]):
|
| 70 |
+
"""
|
| 71 |
+
Configurator for LTX video only model.
|
| 72 |
+
Used to create an LTX video only model from a configuration dictionary.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def from_config(cls: type[LTXModel], config: dict) -> LTXModel:
|
| 77 |
+
config = config.get("transformer", {})
|
| 78 |
+
|
| 79 |
+
check_config_value(config, "dropout", 0.0)
|
| 80 |
+
check_config_value(config, "attention_bias", True)
|
| 81 |
+
check_config_value(config, "num_vector_embeds", None)
|
| 82 |
+
check_config_value(config, "activation_fn", "gelu-approximate")
|
| 83 |
+
check_config_value(config, "num_embeds_ada_norm", 1000)
|
| 84 |
+
check_config_value(config, "use_linear_projection", False)
|
| 85 |
+
check_config_value(config, "only_cross_attention", False)
|
| 86 |
+
check_config_value(config, "cross_attention_norm", True)
|
| 87 |
+
check_config_value(config, "double_self_attention", False)
|
| 88 |
+
check_config_value(config, "upcast_attention", False)
|
| 89 |
+
check_config_value(config, "standardization_norm", "rms_norm")
|
| 90 |
+
check_config_value(config, "norm_elementwise_affine", False)
|
| 91 |
+
check_config_value(config, "qk_norm", "rms_norm")
|
| 92 |
+
check_config_value(config, "positional_embedding_type", "rope")
|
| 93 |
+
check_config_value(config, "use_middle_indices_grid", True)
|
| 94 |
+
|
| 95 |
+
return LTXModel(
|
| 96 |
+
model_type=LTXModelType.VideoOnly,
|
| 97 |
+
num_attention_heads=config.get("num_attention_heads", 32),
|
| 98 |
+
attention_head_dim=config.get("attention_head_dim", 128),
|
| 99 |
+
in_channels=config.get("in_channels", 128),
|
| 100 |
+
out_channels=config.get("out_channels", 128),
|
| 101 |
+
num_layers=config.get("num_layers", 48),
|
| 102 |
+
cross_attention_dim=config.get("cross_attention_dim", 4096),
|
| 103 |
+
norm_eps=config.get("norm_eps", 1e-06),
|
| 104 |
+
attention_type=AttentionFunction(config.get("attention_type", "default")),
|
| 105 |
+
caption_channels=config.get("caption_channels", 3840),
|
| 106 |
+
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
|
| 107 |
+
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
|
| 108 |
+
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
|
| 109 |
+
use_middle_indices_grid=config.get("use_middle_indices_grid", True),
|
| 110 |
+
rope_type=LTXRopeType(config.get("rope_type", "interleaved")),
|
| 111 |
+
double_precision_rope=config.get("frequencies_precision", False) == "float64",
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _naive_weight_or_bias_downcast(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]:
|
| 116 |
+
"""
|
| 117 |
+
Downcast the weight or bias to the float8_e4m3fn dtype.
|
| 118 |
+
"""
|
| 119 |
+
return [KeyValueOperationResult(key, value.to(dtype=torch.float8_e4m3fn))]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _upcast_and_round(
|
| 123 |
+
weight: torch.Tensor, dtype: torch.dtype, with_stochastic_rounding: bool = False, seed: int = 0
|
| 124 |
+
) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
Upcast the weight to the given dtype and optionally apply stochastic rounding.
|
| 127 |
+
Input weight needs to have float8_e4m3fn or float8_e5m2 dtype.
|
| 128 |
+
"""
|
| 129 |
+
if not with_stochastic_rounding:
|
| 130 |
+
return weight.to(dtype)
|
| 131 |
+
return fused_add_round_launch(torch.zeros_like(weight, dtype=dtype), weight, seed)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def replace_fwd_with_upcast(layer: torch.nn.Linear, with_stochastic_rounding: bool = False, seed: int = 0) -> None:
|
| 135 |
+
"""
|
| 136 |
+
Replace linear.forward and rms_norm.forward with a version that:
|
| 137 |
+
- upcasts weight and bias to input's dtype
|
| 138 |
+
- returns F.linear or F.rms_norm calculated in that dtype
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
layer.original_forward = layer.forward
|
| 142 |
+
|
| 143 |
+
def new_linear_forward(*args, **_kwargs) -> torch.Tensor:
|
| 144 |
+
# assume first arg is the input tensor
|
| 145 |
+
x = args[0]
|
| 146 |
+
w_up = _upcast_and_round(layer.weight, x.dtype, with_stochastic_rounding, seed)
|
| 147 |
+
b_up = None
|
| 148 |
+
|
| 149 |
+
if layer.bias is not None:
|
| 150 |
+
b_up = _upcast_and_round(layer.bias, x.dtype, with_stochastic_rounding, seed)
|
| 151 |
+
|
| 152 |
+
return torch.nn.functional.linear(x, w_up, b_up)
|
| 153 |
+
|
| 154 |
+
layer.forward = new_linear_forward
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def amend_forward_with_upcast(
|
| 158 |
+
model: torch.nn.Module, with_stochastic_rounding: bool = False, seed: int = 0
|
| 159 |
+
) -> torch.nn.Module:
|
| 160 |
+
"""
|
| 161 |
+
Replace the forward method of the model's Linear and RMSNorm layers to forward
|
| 162 |
+
with upcast and optional stochastic rounding.
|
| 163 |
+
"""
|
| 164 |
+
for m in model.modules():
|
| 165 |
+
if isinstance(m, (torch.nn.Linear)):
|
| 166 |
+
replace_fwd_with_upcast(m, with_stochastic_rounding, seed)
|
| 167 |
+
return model
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
LTXV_MODEL_COMFY_RENAMING_MAP = (
|
| 171 |
+
SDOps("LTXV_MODEL_COMFY_PREFIX_MAP")
|
| 172 |
+
.with_matching(prefix="model.diffusion_model.")
|
| 173 |
+
.with_replacement("model.diffusion_model.", "")
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
LTXV_MODEL_COMFY_RENAMING_WITH_TRANSFORMER_LINEAR_DOWNCAST_MAP = (
|
| 177 |
+
SDOps("LTXV_MODEL_COMFY_PREFIX_MAP")
|
| 178 |
+
.with_matching(prefix="model.diffusion_model.")
|
| 179 |
+
.with_replacement("model.diffusion_model.", "")
|
| 180 |
+
.with_kv_operation(
|
| 181 |
+
key_prefix="transformer_blocks.", key_suffix=".to_q.weight", operation=_naive_weight_or_bias_downcast
|
| 182 |
+
)
|
| 183 |
+
.with_kv_operation(
|
| 184 |
+
key_prefix="transformer_blocks.", key_suffix=".to_q.bias", operation=_naive_weight_or_bias_downcast
|
| 185 |
+
)
|
| 186 |
+
.with_kv_operation(
|
| 187 |
+
key_prefix="transformer_blocks.", key_suffix=".to_k.weight", operation=_naive_weight_or_bias_downcast
|
| 188 |
+
)
|
| 189 |
+
.with_kv_operation(
|
| 190 |
+
key_prefix="transformer_blocks.", key_suffix=".to_k.bias", operation=_naive_weight_or_bias_downcast
|
| 191 |
+
)
|
| 192 |
+
.with_kv_operation(
|
| 193 |
+
key_prefix="transformer_blocks.", key_suffix=".to_v.weight", operation=_naive_weight_or_bias_downcast
|
| 194 |
+
)
|
| 195 |
+
.with_kv_operation(
|
| 196 |
+
key_prefix="transformer_blocks.", key_suffix=".to_v.bias", operation=_naive_weight_or_bias_downcast
|
| 197 |
+
)
|
| 198 |
+
.with_kv_operation(
|
| 199 |
+
key_prefix="transformer_blocks.", key_suffix=".to_out.0.weight", operation=_naive_weight_or_bias_downcast
|
| 200 |
+
)
|
| 201 |
+
.with_kv_operation(
|
| 202 |
+
key_prefix="transformer_blocks.", key_suffix=".to_out.0.bias", operation=_naive_weight_or_bias_downcast
|
| 203 |
+
)
|
| 204 |
+
.with_kv_operation(
|
| 205 |
+
key_prefix="transformer_blocks.", key_suffix=".ff.net.0.proj.weight", operation=_naive_weight_or_bias_downcast
|
| 206 |
+
)
|
| 207 |
+
.with_kv_operation(
|
| 208 |
+
key_prefix="transformer_blocks.", key_suffix=".ff.net.0.proj.bias", operation=_naive_weight_or_bias_downcast
|
| 209 |
+
)
|
| 210 |
+
.with_kv_operation(
|
| 211 |
+
key_prefix="transformer_blocks.", key_suffix=".ff.net.2.weight", operation=_naive_weight_or_bias_downcast
|
| 212 |
+
)
|
| 213 |
+
.with_kv_operation(
|
| 214 |
+
key_prefix="transformer_blocks.", key_suffix=".ff.net.2.bias", operation=_naive_weight_or_bias_downcast
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
UPCAST_DURING_INFERENCE = ModuleOps(
|
| 219 |
+
name="upcast_fp8_during_linear_forward",
|
| 220 |
+
matcher=lambda model: isinstance(model, LTXModel),
|
| 221 |
+
mutator=lambda model: amend_forward_with_upcast(model, False),
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class UpcastWithStochasticRounding(ModuleOps):
|
| 226 |
+
"""
|
| 227 |
+
ModuleOps for upcasting the model's float8_e4m3fn weights and biases to the bfloat16 dtype
|
| 228 |
+
and applying stochastic rounding during linear forward.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __new__(cls, seed: int = 0):
|
| 232 |
+
return super().__new__(
|
| 233 |
+
cls,
|
| 234 |
+
name="upcast_fp8_during_linear_forward_with_stochastic_rounding",
|
| 235 |
+
matcher=lambda model: isinstance(model, LTXModel),
|
| 236 |
+
mutator=lambda model: amend_forward_with_upcast(model, True, seed),
|
| 237 |
+
)
|