Fabrice-TIERCELIN commited on
Commit
8788929
·
verified ·
1 Parent(s): f267cb0

Upload 4 files

Browse files
packages/ltx-trainer/AGENTS.md ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AGENTS.md
2
+
3
+ This file provides guidance to AI coding assistants (Claude, Cursor, etc.) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ **LTX-2 Trainer** is a training toolkit for fine-tuning the Lightricks LTX-2 audio-video generation model. It supports:
8
+
9
+ - **LoRA training** - Efficient fine-tuning with adapters
10
+ - **Full fine-tuning** - Complete model training
11
+ - **Audio-video training** - Joint audio and video generation
12
+ - **IC-LoRA training** - In-context control adapters for video-to-video transformations
13
+
14
+ **Key Dependencies:**
15
+
16
+ - **[`ltx-core`](../ltx-core/)** - Core model implementations (transformer, VAE, text encoder)
17
+ - **[`ltx-pipelines`](../ltx-pipelines/)** - Inference pipeline components
18
+
19
+ > **Important:** This trainer only supports **LTX-2** (the audio-video model). The older LTXV models are not supported.
20
+
21
+ ## Architecture Overview
22
+
23
+ ### Package Structure
24
+
25
+ ```
26
+ packages/ltx-trainer/
27
+ ├── src/ltx_trainer/ # Main training module
28
+ │ ├── config.py # Pydantic configuration models
29
+ │ ├── trainer.py # Main training orchestration with Accelerate
30
+ │ ├── model_loader.py # Model loading using ltx-core
31
+ │ ├── validation_sampler.py # Inference for validation samples
32
+ │ ├── datasets.py # PrecomputedDataset for latent-based training
33
+ │ ├── training_strategies/ # Strategy pattern for different training modes
34
+ │ │ ├── __init__.py # Factory function: get_training_strategy()
35
+ │ │ ├── base_strategy.py # TrainingStrategy ABC, ModelInputs, TrainingStrategyConfigBase
36
+ │ │ ├── text_to_video.py # TextToVideoStrategy, TextToVideoConfig
37
+ │ │ └── video_to_video.py # VideoToVideoStrategy, VideoToVideoConfig
38
+ │ ├── timestep_samplers.py # Flow matching timestep sampling
39
+ │ ├── captioning.py # Video captioning utilities
40
+ │ ├── video_utils.py # Video processing utilities
41
+ │ └── hf_hub_utils.py # HuggingFace Hub integration
42
+ ├── scripts/ # User-facing CLI tools
43
+ │ ├── train.py # Main training script
44
+ │ ├── process_dataset.py # Dataset preprocessing
45
+ │ ├── process_videos.py # Video latent encoding
46
+ │ ├── process_captions.py # Text embedding computation
47
+ │ ├── caption_videos.py # Automatic video captioning
48
+ │ ├── decode_latents.py # Latent decoding for debugging
49
+ │ ├── inference.py # Inference with trained models
50
+ │ ├── compute_reference.py # Generate IC-LoRA reference videos
51
+ │ └── split_scenes.py # Scene detection and splitting
52
+ ├── configs/ # Example training configurations
53
+ │ ├── ltx2_av_lora.yaml # Audio-video LoRA training
54
+ │ ├── ltx2_v2v_ic_lora.yaml # IC-LoRA video-to-video
55
+ │ └── accelerate/ # Accelerate configs for distributed training
56
+ └── docs/ # Documentation
57
+ ```
58
+
59
+ ### Key Architectural Patterns
60
+
61
+ **Model Loading:**
62
+
63
+ - `ltx_trainer.model_loader` provides component loaders using `ltx-core`
64
+ - Individual loaders: `load_transformer()`, `load_video_vae_encoder()`, `load_video_vae_decoder()`, `load_text_encoder()`, etc.
65
+ - Combined loader: `load_model()` returns `LtxModelComponents` dataclass
66
+ - Uses `SingleGPUModelBuilder` from ltx-core internally
67
+
68
+ **Training Flow:**
69
+
70
+ 1. Configuration loaded via Pydantic models in `config.py`
71
+ 2. `Trainer` class orchestrates the training loop
72
+ 3. Training strategies (`TextToVideoStrategy`, `VideoToVideoStrategy`) prepare inputs and compute loss
73
+ 4. Accelerate handles distributed training and device placement
74
+ 5. Data flows as precomputed latents through `PrecomputedDataset`
75
+
76
+ **Model Interface (Modality-based):**
77
+
78
+ ```python
79
+ from ltx_core.model.transformer.modality import Modality
80
+
81
+ # Create modality objects for video and audio
82
+ video = Modality(
83
+ enabled=True,
84
+ latent=video_latents, # [B, seq_len, 128]
85
+ timesteps=video_timesteps, # [B, seq_len] per-token
86
+ positions=video_positions, # [B, 3, seq_len, 2]
87
+ context=video_embeds,
88
+ context_mask=None,
89
+ )
90
+ audio = Modality(
91
+ enabled=True,
92
+ latent=audio_latents,
93
+ timesteps=audio_timesteps,
94
+ positions=audio_positions, # [B, 1, seq_len, 2]
95
+ context=audio_embeds,
96
+ context_mask=None,
97
+ )
98
+
99
+ # Forward pass returns predictions for both modalities
100
+ video_pred, audio_pred = model(video=video, audio=audio, perturbations=None)
101
+ ```
102
+
103
+ > **Note:** `Modality` is immutable (frozen dataclass). Use `dataclasses.replace()` to modify.
104
+
105
+ **Configuration System:**
106
+
107
+ - All config in `src/ltx_trainer/config.py`
108
+ - Main class: `LtxTrainerConfig`
109
+ - Training strategy configs: `TextToVideoConfig`, `VideoToVideoConfig`
110
+ - Uses Pydantic field validators and model validators
111
+ - Config files in `configs/` directory
112
+
113
+ ## Development Commands
114
+
115
+ ### Setup and Installation
116
+
117
+ ```bash
118
+ # From the repository root
119
+ uv sync
120
+ cd packages/ltx-trainer
121
+ ```
122
+
123
+ ### Code Quality
124
+
125
+ ```bash
126
+ # Run ruff linting and formatting
127
+ uv run ruff check .
128
+ uv run ruff format .
129
+
130
+ # Run pre-commit checks
131
+ uv run pre-commit run --all-files
132
+ ```
133
+
134
+ ### Running Tests
135
+
136
+ ```bash
137
+ cd packages/ltx-trainer
138
+ uv run pytest
139
+ ```
140
+
141
+ ### Running Training
142
+
143
+ ```bash
144
+ # Single GPU
145
+ uv run python scripts/train.py configs/ltx2_av_lora.yaml
146
+
147
+ # Multi-GPU with Accelerate
148
+ uv run accelerate launch scripts/train.py configs/ltx2_av_lora.yaml
149
+ ```
150
+
151
+ ## Code Standards
152
+
153
+ ### Type Hints
154
+
155
+ - **Always use type hints** for all function arguments and return values
156
+ - Use Python 3.10+ syntax: `list[str]` not `List[str]`, `str | Path` not `Union[str, Path]`
157
+ - Use `pathlib.Path` for file operations
158
+
159
+ ### Class Methods
160
+
161
+ - Mark methods as `@staticmethod` if they don't access instance or class state
162
+ - Use `@classmethod` for alternative constructors
163
+
164
+ ### AI/ML Specific
165
+
166
+ - Use `@torch.inference_mode()` for inference (prefer over `@torch.no_grad()`)
167
+ - Use `accelerator.device` for distributed compatibility
168
+ - Support mixed precision (`bfloat16` via dtype parameters)
169
+ - Use gradient checkpointing for memory-intensive training
170
+
171
+ ### Logging
172
+
173
+ - Use `from ltx_trainer import logger` for all messages
174
+ - Avoid print statements in production code
175
+
176
+ ## Important Files & Modules
177
+
178
+ ### Configuration (CRITICAL)
179
+
180
+ **`src/ltx_trainer/config.py`** - Master config definitions
181
+
182
+ Key classes:
183
+ - `LtxTrainerConfig` - Main configuration container
184
+ - `ModelConfig` - Model paths and training mode
185
+ - `TrainingStrategyConfig` - Union of `TextToVideoConfig` | `VideoToVideoConfig`
186
+ - `LoraConfig` - LoRA hyperparameters
187
+ - `OptimizationConfig` - Learning rate, batch size, etc.
188
+ - `ValidationConfig` - Validation settings
189
+ - `WandbConfig` - W&B logging settings
190
+
191
+ **⚠️ When modifying config.py:**
192
+ 1. Update ALL config files in `configs/`
193
+ 2. Update `docs/configuration-reference.md`
194
+ 3. Test that all configs remain valid
195
+
196
+ ### Training Core
197
+
198
+ **`src/ltx_trainer/trainer.py`** - Main training loop
199
+
200
+ - Implements distributed training with Accelerate
201
+ - Handles mixed precision, gradient accumulation, checkpointing
202
+ - Uses training strategies for mode-specific logic
203
+
204
+ **`src/ltx_trainer/training_strategies/`** - Strategy pattern
205
+
206
+ - `base_strategy.py`: `TrainingStrategy` ABC, `ModelInputs` dataclass
207
+ - `text_to_video.py`: Standard text-to-video (with optional audio)
208
+ - `video_to_video.py`: IC-LoRA video-to-video transformations
209
+
210
+ Key methods each strategy implements:
211
+ - `get_data_sources()` - Required data directories
212
+ - `prepare_training_inputs()` - Convert batch to `ModelInputs`
213
+ - `compute_loss()` - Calculate training loss
214
+ - `requires_audio` property - Whether audio components needed
215
+
216
+ **`src/ltx_trainer/model_loader.py`** - Model loading
217
+
218
+ Component loaders:
219
+ - `load_transformer()` → `LTXModel`
220
+ - `load_video_vae_encoder()` → `VideoVAEEncoder`
221
+ - `load_video_vae_decoder()` → `VideoVAEDecoder`
222
+ - `load_audio_vae_decoder()` → `AudioVAEDecoder`
223
+ - `load_vocoder()` → `Vocoder`
224
+ - `load_text_encoder()` → `AVGemmaTextEncoderModel`
225
+ - `load_model()` → `LtxModelComponents` (convenience wrapper)
226
+
227
+ **`src/ltx_trainer/validation_sampler.py`** - Inference for validation
228
+
229
+ Uses ltx-core components for denoising:
230
+ - `LTX2Scheduler` for sigma scheduling
231
+ - `EulerDiffusionStep` for diffusion steps
232
+ - `CFGGuider` for classifier-free guidance
233
+
234
+ ### Data
235
+
236
+ **`src/ltx_trainer/datasets.py`** - Dataset handling
237
+
238
+ - `PrecomputedDataset` loads pre-computed VAE latents
239
+ - Supports video latents, audio latents, text embeddings, reference latents
240
+
241
+ ## Common Development Tasks
242
+
243
+ ### Adding a New Configuration Parameter
244
+
245
+ 1. Add field to appropriate config class in `src/ltx_trainer/config.py`
246
+ 2. Add validator if needed
247
+ 3. Update ALL config files in `configs/`
248
+ 4. Update `docs/configuration-reference.md`
249
+
250
+ ### Implementing a New Training Strategy
251
+
252
+ 1. Create new file in `src/ltx_trainer/training_strategies/`
253
+ 2. Create config class inheriting `TrainingStrategyConfigBase`
254
+ 3. Create strategy class inheriting `TrainingStrategy`
255
+ 4. Implement: `get_data_sources()`, `prepare_training_inputs()`, `compute_loss()`
256
+ 5. Add to `__init__.py`: import, add to `TrainingStrategyConfig` union, update factory
257
+ 6. Add discriminator tag to config.py's `TrainingStrategyConfig`
258
+ 7. Create example config file in `configs/`
259
+
260
+ ### Working with Modalities
261
+
262
+ ```python
263
+ from dataclasses import replace
264
+ from ltx_core.model.transformer.modality import Modality
265
+
266
+ # Create modality
267
+ video = Modality(
268
+ enabled=True,
269
+ latent=latents,
270
+ timesteps=timesteps,
271
+ positions=positions,
272
+ context=context,
273
+ context_mask=None,
274
+ )
275
+
276
+ # Update (immutable - must use replace)
277
+ video = replace(video, latent=new_latent, timesteps=new_timesteps)
278
+
279
+ # Disable a modality
280
+ audio = replace(audio, enabled=False)
281
+ ```
282
+
283
+ ## Debugging Tips
284
+
285
+ **Training Issues:**
286
+
287
+ - Check logs first (rich logger provides context)
288
+ - GPU memory: Look for OOM errors, enable `enable_gradient_checkpointing: true`
289
+ - Distributed training: Check `accelerator.state` and device placement
290
+
291
+ **Model Loading:**
292
+
293
+ - Ensure `model_path` points to a local `.safetensors` file
294
+ - Ensure `text_encoder_path` points to a Gemma model directory
295
+ - URLs are NOT supported for model paths
296
+
297
+ **Configuration:**
298
+
299
+ - Validation errors: Check validators in `config.py`
300
+ - Unknown fields: Config uses `extra="forbid"` - all fields must be defined
301
+ - Strategy validation: IC-LoRA requires `reference_videos` in validation config
302
+
303
+ ## Key Constraints
304
+
305
+ ### LTX-2 Frame Requirements
306
+
307
+ Frames must satisfy `frames % 8 == 1`:
308
+ - ✅ Valid: 1, 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97, 121
309
+ - ❌ Invalid: 24, 32, 48, 64, 100
310
+
311
+ ### Resolution Requirements
312
+
313
+ Width and height must be divisible by 32.
314
+
315
+ ### Model Paths
316
+
317
+ - Must be local paths (URLs not supported)
318
+ - `model_path`: Path to `.safetensors` checkpoint
319
+ - `text_encoder_path`: Path to Gemma model directory
320
+
321
+ ### Platform Requirements
322
+
323
+ - Linux required (uses `triton` which is Linux-only)
324
+ - CUDA GPU with 24GB+ VRAM recommended
325
+
326
+ ## Reference: ltx-core Key Components
327
+
328
+ ```
329
+ packages/ltx-core/src/ltx_core/
330
+ ├── model/
331
+ │ ├── transformer/
332
+ │ │ ├── model.py # LTXModel
333
+ │ │ ├── modality.py # Modality dataclass
334
+ │ │ └── transformer.py # BasicAVTransformerBlock
335
+ │ ├── video_vae/
336
+ │ │ └── video_vae.py # Encoder, Decoder
337
+ │ ├── audio_vae/
338
+ │ │ ├── audio_vae.py # Decoder
339
+ │ │ └── vocoder.py # Vocoder
340
+ │ └── clip/gemma/
341
+ │ └── encoders/av_encoder.py # AVGemmaTextEncoderModel
342
+ ├── pipeline/
343
+ │ ├── components/
344
+ │ │ ├── schedulers.py # LTX2Scheduler
345
+ │ │ ├── diffusion_steps.py # EulerDiffusionStep
346
+ │ │ ├── guiders.py # CFGGuider
347
+ │ │ └── patchifiers.py # VideoLatentPatchifier, AudioPatchifier
348
+ │ └── conditioning/ # VideoLatentTools, AudioLatentTools
349
+ └── loader/
350
+ ├── single_gpu_model_builder.py # SingleGPUModelBuilder
351
+ └── sd_ops.py # Key remapping (SDOps)
352
+ ```
packages/ltx-trainer/CLAUDE.md ADDED
@@ -0,0 +1 @@
 
 
1
+ AGENTS.md
packages/ltx-trainer/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LTX-2 Trainer
2
+
3
+ This package provides tools and scripts for training and fine-tuning
4
+ Lightricks' **LTX-2** audio-video generation model. It enables LoRA training, full
5
+ fine-tuning, and training of video-to-video transformations (IC-LoRA) on custom datasets.
6
+
7
+ ---
8
+
9
+ ## 📖 Documentation
10
+
11
+ All detailed guides and technical documentation are in the [docs](./docs/) directory:
12
+
13
+ - [⚡ Quick Start Guide](docs/quick-start.md)
14
+ - [🎬 Dataset Preparation](docs/dataset-preparation.md)
15
+ - [🛠️ Training Modes](docs/training-modes.md)
16
+ - [⚙️ Configuration Reference](docs/configuration-reference.md)
17
+ - [🚀 Training Guide](docs/training-guide.md)
18
+ - [🔧 Utility Scripts](docs/utility-scripts.md)
19
+ - [📚 LTX-Core API Guide](docs/ltx-core-api-guide.md)
20
+ - [🛡️ Troubleshooting Guide](docs/troubleshooting.md)
21
+
22
+ ---
23
+
24
+ ## 🔧 Requirements
25
+
26
+ - **LTX-2 Model Checkpoint** - Local `.safetensors` file
27
+ - **Gemma Text Encoder** - Local Gemma model directory (required for LTX-2)
28
+ - **Linux with CUDA** - CUDA 13+ recommended for optimal performance
29
+ - **Nvidia GPU with 80GB+ VRAM** - Is highly recommended; lower VRAM may work with gradient checkpointing and lower
30
+ resolutions
31
+
32
+ ---
33
+
34
+ ## 🤝 Contributing
35
+
36
+ We welcome contributions from the community! Here's how you can help:
37
+
38
+ - **Share Your Work**: If you've trained interesting LoRAs or achieved cool results, please share them with the
39
+ community.
40
+ - **Report Issues**: Found a bug or have a suggestion? Open an issue on GitHub.
41
+ - **Submit PRs**: Help improve the codebase with bug fixes or general improvements.
42
+ - **Feature Requests**: Have ideas for new features? Let us know through GitHub issues.
43
+
44
+ ---
45
+
46
+ ## 💬 Join the Community
47
+
48
+ Have questions, want to share your results, or need real-time help?
49
+
50
+ Join our [community Discord server](https://discord.gg/2mafsHjJ) to connect with other users and the development team!
51
+
52
+ - Get troubleshooting help
53
+ - Share your training results and workflows
54
+ - Collaborate on new ideas and features
55
+ - Stay up to date with announcements and updates
56
+
57
+ We look forward to seeing you there!
58
+
59
+ ---
60
+
61
+ Happy training! 🎉
packages/ltx-trainer/pyproject.toml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ltx-trainer"
3
+ version = "0.1.0"
4
+ description = "LTX-2 training, democratized."
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Matan Ben-Yosef", email = "mbyosef@lightricks.com" }
8
+ ]
9
+ requires-python = ">=3.12"
10
+ dependencies = [
11
+ "ltx-core",
12
+ "accelerate>=1.2.1",
13
+ "av>=14.2.1",
14
+ "bitsandbytes >=0.45.2; sys_platform == 'linux'",
15
+ "diffusers>=0.32.1",
16
+ "huggingface-hub[hf-xet]>=0.31.4",
17
+ "imageio>=2.37.0",
18
+ "imageio-ffmpeg>=0.6.0",
19
+ "opencv-python>=4.11.0.86",
20
+ "optimum-quanto>=0.2.6",
21
+ "pandas>=2.2.3",
22
+ "peft>=0.14.0",
23
+ "pillow-heif>=0.21.0",
24
+ "pydantic>=2.10.4",
25
+ "rich>=13.9.4",
26
+ "safetensors>=0.5.0",
27
+ "scenedetect>=0.6.5.2",
28
+ "sentencepiece>=0.2.0",
29
+ "torch>=2.6.0",
30
+ "torchaudio>=2.9.0",
31
+ "torchcodec>=0.8.1",
32
+ "torchvision>=0.21.0",
33
+ "typer>=0.15.1",
34
+ "wandb>=0.19.11",
35
+ ]
36
+
37
+ [dependency-groups]
38
+ dev = [
39
+ "pre-commit>=4.0.1",
40
+ "ruff>=0.8.6",
41
+ ]
42
+
43
+
44
+ [build-system]
45
+ requires = ["hatchling"]
46
+ build-backend = "hatchling.build"
47
+
48
+
49
+
50
+ [tool.ruff]
51
+ target-version = "py311"
52
+ line-length = 120
53
+
54
+ [tool.ruff.lint]
55
+ select = [
56
+ "E", # pycodestyle
57
+ "F", # pyflakes
58
+ "W", # pycodestyle (warnings)
59
+ "I", # isort
60
+ "N", # pep8-naming
61
+ "ANN", # flake8-annotations
62
+ "B", # flake8-bugbear
63
+ "A", # flake8-builtins
64
+ "COM", # flake8-commas
65
+ "C4", # flake8-comprehensions
66
+ "DTZ", # flake8-datetimez
67
+ "EXE", # flake8-executable
68
+ "PIE", # flake8-pie
69
+ "T20", # flake8-print
70
+ "PT", # flake8-pytest
71
+ "SIM", # flake8-simplify
72
+ "ARG", # flake8-unused-arguments
73
+ "PTH", # flake8--use-pathlib
74
+ "ERA", # flake8-eradicate
75
+ "RUF", # ruff specific rules
76
+ "PL", # pylint
77
+ ]
78
+ ignore = [
79
+ "ANN002", # Missing type annotation for *args
80
+ "ANN003", # Missing type annotation for **kwargs
81
+ "ANN204", # Missing type annotation for special method
82
+ "COM812", # Missing trailing comma
83
+ "PTH123", # `open()` should be replaced by `Path.open()`
84
+ "PLR2004", # Magic value used in comparison, consider replacing with a constant variable
85
+ ]
86
+ [tool.ruff.lint.pylint]
87
+ max-args = 10
88
+ [tool.ruff.lint.isort]
89
+ known-first-party = ["ltx_trainer", "ltx_core", "ltx_pipelines"]