diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b0d79e2e318e8931fbdad112136372ec28603d2d --- /dev/null +++ b/LICENSE @@ -0,0 +1,381 @@ + LTX-2 Community License Agreement + License date: January 5, 2026 + + +By using or distributing any portion or element of LTX-2, you agree +to be bound by this Agreement. + + 1. Definitions. + + "Agreement" means the terms and conditions for the license, use, + reproduction, and distribution of LTX-2 and the Complementary + Materials, as specified in this document. + + "Control" means the direct or indirect ownership of more than + fifty percent (50%) of the voting securities or other ownership + interests, or the power to direct the management and policies of + such Entity through voting rights, contract, or otherwise. + + "Data" means a collection of information and/or content extracted + from the dataset used with LTX-2, including to train, pretrain, + or otherwise evaluate LTX-2. The Data is not licensed under this + Agreement. + + "Derivatives of LTX-2" means all modifications to LTX-2, works + based on LTX-2, or any other model which is created or initialized + by transfer of patterns of the weights, parameters, activations or + output of LTX-2, to the other model, in order to cause the other + model to perform similarly to LTX-2, including – but not limited + to - distillation methods entailing the use of intermediate data + representations or methods based on the generation of synthetic + data by LTX-2 for training the other model. For clarity, Derivatives + of LTX-2 include: (i) any fine-tuned or adapted weights, parameters, + or checkpoints derived from LTX-2; (ii) derivative model architectures + that incorporate or are based upon LTX-2's architecture; and + (iii) any modified or extended versions of the Complementary + Materials. All intellectual property rights in Derivatives of LTX-2 + shall be subject to the terms of this Agreement, and you may not + claim exclusive ownership rights in any Derivatives of LTX-2 that + would restrict the rights granted herein. + + "Entity" means any individual, corporation, partnership, limited + liability company, or other legal entity. For purposes of this + Agreement, an Entity shall be deemed to include, on an aggregative + basis, all subsidiaries, affiliates, and other companies under + common Control with such Entity. When determining whether an Entity + meets any threshold under this Agreement (including revenue + thresholds), all subsidiaries, affiliates, and companies under + common Control shall be considered collectively. + + "Harm" includes but is not limited to physical, mental, + psychological, financial and reputational damage, pain, or loss. + + "Licensor" or "Lightricks" means the owner that is granting the + license under this Agreement. For the purposes of this Agreement, + the Licensor is Lightricks Ltd. + + "LTX-2" means the large language models, text/image/video/audio/3D + generation models, and multimodal large language models and their + software and algorithms, including trained model weights, parameters + (including optimizer states), machine-learning model code, + inference-enabling code, training-enabling code, fine-tuning + enabling code, accompanying source code, scripts, documentation, + tutorials, examples, and all other elements of the foregoing + distributed and made publicly available by Lightricks (including, + for example, at https://github.com/Lightricks/LTX-2) for the LTX-2 + model released on January 5, 2026. This license is applicable to + all LTX-2 versions released since January 5, 2026, and all future + releases of LTX-2 under this license. + + "Output" means the results of operating LTX-2 as embodied in + informational content resulting therefrom. + + "you" (or "your") means an individual or legal Entity licensing + LTX-2 in accordance with this Agreement and/or making use of LTX-2 + for whichever purpose and in any field of use, including usage of + LTX-2 in an end-use application - e.g. chatbot, translator, image + generator. + + 2. Grant of License. Subject to the terms and conditions of this + Agreement, you are granted a non-exclusive, worldwide, + non-transferable and royalty-free limited license under Licensor's + intellectual property or other rights owned by Licensor embodied + in LTX-2 to use, reproduce, prepare, distribute, publicly display, + publicly perform, sublicense, copy, create derivative works of, + and make modifications to LTX-2, for any purpose, subject to the + restrictions set forth in Attachment A; provided however, that + Entities with annual revenues of at least $10,000,000 (the + "Commercial Entities") are required to obtain a paid commercial + use license in order to use LTX-2 and Derivatives of LTX-2, + subject to the terms and provisions of a different license (the + "Commercial Use Agreement"), as will be provided by the Licensor. + Commercial Entities interested in such a commercial license are + required to [contact Licensor](https://ltx.io/model/licensing). + Any commercial use of LTX-2 or Derivatives of LTX-2 by the + Commercial Entities not in accordance with this Agreement and/or + the Commercial Use Agreement is strictly prohibited and shall be + deemed a material breach of this Agreement. Such material breach + will be subject, in addition to any license fees owed to Licensor + for the period such Commercial Entity used LTX-2 (as will be + determined by Licensor), to liquidated damages, which will be paid + to Licensor immediately upon demand, in an amount equal to double + the amount that would otherwise have been paid by you for the + relevant period of time. Such amount reflects a reasonable estimation + of the losses and administrative costs incurred due to such breach. + You agree and understand that this remedy does not limit the Licensor's + right to pursue other remedies available at law or equity. + + 3. Distribution and Redistribution. You may host for third parties + remote access purposes (e.g. software-as-a-service), reproduce + and distribute copies of LTX-2 or Derivatives of LTX-2 thereof in + any medium, with or without modifications, provided that you meet + the following conditions: + + (a) Use-based restrictions as referenced in paragraph 4 and all + provisions of Attachment A MUST be included as an enforceable + provision by you in any type of legal agreement (e.g. a + license) governing the use and/or distribution of LTX-2 or + Derivatives of LTX-2, and you shall give notice to subsequent + users you distribute to, that LTX-2 or Derivatives of LTX-2 + are subject to paragraph 4 and Attachment A in their entirety, + including all use restrictions and acceptable use policies; + + (b) You must provide any third party recipients of LTX-2 or + Derivatives of LTX-2 a copy of this Agreement, including all + attachments and use policies. Any Derivative of LTX-2 (as + defined in Section 1, including but not limited to fine-tuned + weights, modified training code, models trained on Outputs, or + any other derivative) must be distributed exclusively under + the terms of this Agreement with a complete copy of this + license included; + + (c) You must cause any modified files to carry prominent notices + stating that you changed the files; + + (d) You must retain all copyright, patent, trademark, and + attribution notices excluding those notices that do not + pertain to any part of LTX-2, Derivatives of LTX-2. + + You may add your own copyright statement to your modifications and + may provide additional or different license terms and conditions - + respecting paragraph 3(a) - for use, reproduction, or distribution + of your modifications, or for any such Derivatives of LTX-2 as a + whole, provided your use, reproduction, and distribution of LTX-2 + otherwise complies with the conditions stated in this Agreement, + and you provide a complete copy of this Agreement with any such + use, reproduction and distribution of LTX-2 and any Derivatives + thereof. + + 4. Use-based restrictions. The restrictions set forth in Attachment A + are considered Use-based restrictions. Therefore, you cannot use + LTX-2 and the Derivatives of LTX-2 in violation of the specified + restricted uses. You may use LTX-2 subject to this Agreement, + including only for lawful purposes and in accordance with the + Agreement. "Use" may include creating any content with, fine-tuning, + updating, running, training, evaluating and/or re-parametrizing + LTX-2. You shall require all of your users who use LTX-2 or a + Derivative of LTX-2 to comply with the terms of this paragraph 4. + + 5. The Output You Generate. Except as set forth herein, Licensor + claims no rights in the Output you generate using LTX-2. You are + accountable for input you insert into LTX-2, the Output you + generate and its subsequent uses. No use of the Output can + contravene any provision as stated in the Agreement. + + 6. Updates and Runtime Restrictions. To the maximum extent permitted + by law, Licensor reserves the right to restrict (remotely or + otherwise) usage of LTX-2 in violation of this Agreement, update + LTX-2 through electronic means, or modify the Output of LTX-2 + based on updates. You shall undertake reasonable efforts to use + the latest version of LTX-2. Any use of the non-current version + of LTX-2 is done solely at your risk. + + 7. Export Controls and Sanctions Compliance. You acknowledge that + LTX-2, Derivatives of LTX-2 may be subject to export control laws + and regulations, including but not limited to the U.S. Export + Administration Regulations and sanctions programs administered by + the Office of Foreign Assets Control (OFAC). You represent and + warrant that you and any users of LTX-2 are not (i) located in, + organized under the laws of, or ordinarily resident in any country + or territory subject to comprehensive sanctions; (ii) identified + on any U.S. government restricted party list, including the + Specially Designated Nationals and Blocked Persons List; or + (iii) otherwise prohibited from receiving LTX-2 under applicable + law. You shall not export, re-export, or transfer LTX-2, directly + or indirectly, in violation of any applicable export control or + sanctions laws or regulations. You agree to comply with all + applicable trade control laws and shall indemnify and hold + Licensor harmless from any claims arising from your failure to + comply with such laws. + + 8. Trademarks and related. Nothing in this Agreement permits you to + make use of Licensor's trademarks, trade names, logos or to + otherwise suggest endorsement or misrepresent the relationship + between the parties; and any rights not expressly granted herein + are reserved by the Licensor. + + 9. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides LTX-2 on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or + conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS + FOR A PARTICULAR PURPOSE. You are solely responsible for + determining the appropriateness of using or redistributing LTX-2 + and Derivatives of LTX-2 and assume any risks associated with + your exercise of permissions under this Agreement. + + 10. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall Licensor be liable + to you for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as + a result of this Agreement or out of the use or inability to use + LTX-2 (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if Licensor has been + advised of the possibility of such damages. + + 11. Accepting Warranty or Additional Liability. While redistributing + LTX-2 and Derivatives of LTX-2, you may, provided you do not + violate the terms of this Agreement, choose to offer and charge + a fee for, acceptance of support, warranty, indemnity, or other + liability obligations. However, in accepting such obligations, + you may act only on your own behalf and on your sole + responsibility, not on behalf of Licensor, and only if you agree + to indemnify, defend, and hold Licensor harmless for any liability + incurred by, or claims asserted against Licensor, by reason of + your accepting any such warranty or additional liability. + + 12. Governing Law. This Agreement and all relations, disputes, claims + and other matters arising hereunder (including non-contractual + disputes or claims) will be governed exclusively by, and construed + exclusively in accordance with, the laws of the State of New York. + To the extent permitted by law, choice of laws rules and the + United Nations Convention on Contracts for the International Sale + of Goods will not apply. For the purposes of adjudicating any + action or proceeding to enforce the terms of this Agreement, you + hereby irrevocably consent to the exclusive jurisdiction of, and + venue in, the federal and state courts located in the County of + New York within the State of New York. The prevailing party in + any claim or dispute between the parties under this Agreement + will be entitled to reimbursement of its reasonable attorneys' + fees and costs. You hereby waive the right to a trial by jury, + to participate in a class or representative action (including in + arbitration), or to combine individual proceedings in court or + in arbitration without the consent of all parties. + + 13. Term and Termination. This Agreement is effective upon your + acceptance and continues until terminated. Licensor may terminate + this Agreement immediately upon written notice to you if you + breach any provision of this Agreement, including but not limited + to violations of the use restrictions in Attachment A or + unauthorized commercial use. Upon termination: (a) all rights + granted to you under this Agreement will immediately cease; + (b) you must immediately cease all use of LTX-2 and Derivatives + of LTX-2; (c) you must delete or destroy all copies of LTX-2 + and Derivatives of LTX-2 in your possession or control; and + (d) you must notify any third parties to whom you distributed + LTX-2 or Derivatives of LTX-2 of the termination. Sections 8-13, + and Section 15 shall survive termination of this Agreement. + Termination does not relieve you of any obligations incurred + prior to termination, including payment obligations under + Section 2. In addition, if You commence a lawsuit or other + proceedings (including a cross-claim or counterclaim in a lawsuit) + against Licensor or any person or entity alleging that LTX-2 or + any Output, or any portion of any of the foregoing, infringe any + intellectual property or other right owned or licensable by you, + then all licenses granted to you under this Agreement shall + terminate as of the date such lawsuit or other proceeding is filed. + + 14. Disputes and Arbitration. All disputes arising in connection with + this Agreement shall be finally settled by arbitration under the + Rules of Arbitration of the International Chamber of Commerce + ("ICC Rules"), by one (1) arbitrator appointed in accordance with + the ICC Rules. The seat of arbitration shall be New York, NY, USA, + and the proceedings shall be conducted in English. The arbitrator + shall be empowered to grant any relief that a court could grant. + Judgment on the arbitration award may be entered by any court + having jurisdiction thereof. Each party waives its right to a + trial by jury and to participate in any class or representative + action. + + 15. If any provision of this Agreement is held to be + invalid, illegal + or unenforceable, the remaining provisions shall be unaffected + thereby and remain valid as if such provision had not been set + forth herein. + + END OF TERMS AND CONDITIONS + + ATTACHMENT A: Use Restrictions + + When using the Outputs, LTX-2 and any Derivatives thereof, you + will comply with the Acceptable Use Policy. In addition, you + agree not to use the Outputs, LTX-2 or its Derivatives in any + of the following ways: + + 1. In any way that violates any applicable national, federal, + state, local or international law or regulation; + + 2. For the purpose of exploiting, Harming or attempting to + exploit or Harm minors in any way; + + 3. To generate or disseminate false information and/or content + with the purpose of Harming others; + + 4. To generate or disseminate personal identifiable information + that can be used to Harm an individual; + + 5. To generate or disseminate information and/or content (e.g. + images, code, posts, articles), and place the information + and/or content in any context (e.g. bot generating tweets) + without expressly and intelligibly disclaiming that the + information and/or content is machine generated; + + 6. To defame, disparage or otherwise harass others; + + 7. To impersonate or attempt to impersonate (e.g. deepfakes) + others without their consent; + + 8. For fully automated decision making that adversely impacts an + individual's legal rights or otherwise creates or modifies a + binding, enforceable obligation; + + 9. For any use intended to or which has the effect of + discriminating against or Harming individuals or groups based + on online or offline social behavior or known or predicted + personal or personality characteristics; + + 10. To exploit any of the vulnerabilities of a specific group of + persons based on their age, social, physical or mental + characteristics, in order to materially distort the behavior + of a person pertaining to that group in a manner that causes + or is likely to cause that person or another person physical + or psychological Harm; + + 11. For any use intended to or which has the effect of + discriminating against individuals or groups based on legally + protected characteristics or categories; + + 12. To provide medical advice and medical results interpretation; + + 13. To generate or disseminate information for the purpose to be + used for administration of justice, law enforcement, + immigration or asylum processes, such as predicting an + individual will commit fraud/crime commitment (e.g. by text + profiling, drawing causal relationships between assertions + made in documents, indiscriminate and arbitrarily-targeted use); + + 14. To generate and/or disseminate malware (including – but not + limited to – ransomware) or any other content to be used for + the purpose of harming electronic systems; + + 15. To engage in, promote, incite, or facilitate discrimination + or other unlawful or harmful conduct in the provision of + employment, employment benefits, credit, housing, or other + essential goods and services; + + 16. To engage in, promote, incite, or facilitate the harassment, + abuse, threatening, or bullying of individuals or groups of + individuals; + + 17. For military, warfare, nuclear industries or applications, + weapons development, or any use in connection with activities + that may cause death, personal injury, or severe physical or + environmental damage; + + 18. For commercial use only: To train, improve, or fine-tune any + other machine learning model, artificial intelligence system, + or competing model, except for Derivatives of LTX-2 as + expressly permitted under this Agreement; + + 19. To circumvent, disable, or interfere with any technical + limitations, safety features, content filters, or use + restrictions implemented in LTX-2 by Licensor; + + 20. To use LTX-2 or Derivatives of LTX-2 in any product, service, + or application that directly competes with Licensor's + commercial products or services, or is designed to replace or + substitute Licensor's offerings in the market, without + obtaining a separate commercial license from Licensor. diff --git a/README.md b/README.md index 26e9c70d8e8cc910da142d71e05929fe36474ed8..0d46dc98497c665d58fcb0fea95d52e6adbd95f1 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,162 @@ ---- -title: DramaBox -emoji: 🎭 -colorFrom: red -colorTo: indigo -sdk: gradio -sdk_version: 4.44.1 -app_file: app.py -pinned: true -license: other -license_name: ltx-2-community -license_link: https://huggingface.co/ResembleAI/Dramabox/blob/main/LICENSE -hardware: l40s -short_description: Expressive TTS with voice cloning β€” DramaBox demo ---- +# Dramabox - Expressive TTS with Voice Cloning -# DramaBox β€” Expressive TTS Demo +Prompt-driven TTS with voice cloning built on a 3.3B Diffusion Transformer with flow matching. -Live demo of [`ResembleAI/Dramabox`](https://huggingface.co/ResembleAI/Dramabox). Write a scene prompt, optionally upload a 10-second voice reference, and generate. Audio is automatically watermarked with [Resemble Perth](https://github.com/resemble-ai/Perth). +## Folder Structure -The model checkpoints download automatically on first launch. +``` +DramaBox/ +β”œβ”€β”€ src/ +β”‚ β”œβ”€β”€ inference.py # TTS inference with voice cloning +β”‚ β”œβ”€β”€ inference_server.py # Warm server (~2.5s per generation) +β”‚ β”œβ”€β”€ audio_conditioning.py # Reference audio conditioning +β”‚ └── model_downloader.py # Auto-download models from HuggingFace +β”œβ”€β”€ patches/ +β”‚ β”œβ”€β”€ attention.py # dtype fix for mask allocation +β”‚ └── guiders.py # Per-token CFG clamping +β”œβ”€β”€ assets/ +β”‚ └── silence_latent_frame.pt +β”œβ”€β”€ evals/ +β”‚ β”œβ”€β”€ eval_short.txt # 30 short prompts (~5-15s) +β”‚ β”œβ”€β”€ eval_long.txt # 15 long prompts (~20-37s) +β”‚ └── eval_expressive.txt # 15 expressive prompts (laughs, sighs, stammers) +β”œβ”€β”€ scripts/ +β”‚ β”œβ”€β”€ inference.sh # Inference wrapper +β”‚ └── eval.sh # Evaluation runner +β”œβ”€β”€ app.py # Gradio demo app +β”œβ”€β”€ ltx2/ # LTX-2 dependency packages +└── README.md +``` + +## Models + +Models auto-download from [ResembleAI/Dramabox](https://huggingface.co/ResembleAI/Dramabox) on HuggingFace. + +| Model | Size | Description | +|-------|------|-------------| +| `dramabox-dit-v1.safetensors` | 6.6 GB | DiT transformer | +| `dramabox-audio-components.safetensors` | 2.7 GB | Audio VAE + vocoder + text projection | +| [unsloth/gemma-3-12b-it-bnb-4bit](https://huggingface.co/unsloth/gemma-3-12b-it-bnb-4bit) | ~8 GB | Text encoder (auto-downloaded) | + +**VRAM**: ~24 GB peak | **Speed**: ~2.5s per generation (warm server, H100) + +## Quick Start + +### Warm Server (recommended, ~2.5s per request) + +```python +from src.inference_server import TTSServer + +server = TTSServer(device="cuda") + +server.generate_to_file( + prompt='A woman speaks warmly, "Hello, how are you today?" She laughs, "Hahaha, it is so good to see you!"', + output="output.wav", + voice_ref="reference.wav", # optional, 10+ seconds +) +``` + +### Gradio App + +```bash +GEMINI_API_KEY=your_key CUDA_VISIBLE_DEVICES=4 python app.py +``` + +### CLI Inference + +```bash +python src/inference.py \ + --voice-sample reference.wav \ + --prompt 'A woman speaks warmly, "Hello, how are you today?"' \ + --output output.wav \ + --cfg-scale 2.5 --stg-scale 1.5 +``` + +### Evaluation -## Prompt format +```bash +bash scripts/eval.sh --eval expressive --output eval_results/ +``` + +## Inference Settings + +| Parameter | Default | Notes | +|-----------|---------|-------| +| cfg-scale | 2.5 | Lower = more natural, higher = more text following | +| stg-scale | 1.5 | Skip-token guidance | +| rescale | 0 | No rescaling | +| modality | 1 | No modality guidance | +| duration-multiplier | 1.1 | 10% breathing room | +| steps | 30 | Euler flow matching | + +## Prompt Writing Guide + +**Structure:** `, "" ""` + +### What works inside quotes (model produces actual sounds) +- Laughs: `"Hahaha"` `"Hehehe"` (always one word, never separated) +- Sounds: `"Mmmmm"` `"Ugh"` `"Argh"` `"Ahhh"` `"Hmm"` + +### What goes outside quotes (stage directions) +- `She sighs deeply.` `He gulps nervously.` `A long pause.` +- `Her voice cracks.` `He clears his throat.` `She scoffs.` + +### Never inside quotes (model speaks them literally) +- Ahem, Pfft, Sigh, Gasp, Cough + +### Tips +- Match gender/age in speaker description to voice reference +- Break long dialogue into segments with acting directions between them +- End prompt at the last closing quote mark (no trailing descriptions) + +## Watermarking +Every audio output from `inference.py` and `inference_server.TTSServer.generate_to_file` is automatically watermarked with [Resemble Perth](https://github.com/resemble-ai/Perth) β€” an imperceptible neural watermark that survives MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy. + +```python +import perth, librosa +wav, sr = librosa.load("output.wav", sr=None, mono=True) +detector = perth.PerthImplicitWatermarker() +print(detector.get_watermark(wav, sample_rate=sr)) # confidence β‰ˆ 1.0 for our outputs ``` -, "" "" + +Pass `--no-watermark` to `inference.py` (or `watermark=False` to `generate_to_file`) to disable for debugging. + +## Training + +DramaBox is an IC-LoRA fine-tune of the LTX-2.3 22B audio-only branch. To train your own: + +```bash +# 1. Preprocess raw (audio, transcript) pairs β†’ audio_latents/ + conditions/ +python src/preprocess.py \ + --dataset-type manifest \ + --index your_data.jsonl \ + --output-dir /path/to/preprocessed/ \ + --checkpoint dramabox-audio-components.safetensors \ + --gemma-root /path/to/gemma-3-12b-it-bnb-4bit/ + +# 2. Edit configs/training_args.example.yaml β†’ your data paths + +# 3. Launch (uses HuggingFace accelerate) +bash scripts/train.sh \ + --config configs/training_args.example.yaml \ + --gpus 0,1,2,3,4,5,6 \ + --train-val-gpu 7 ``` -- **Inside double quotes**: dialogue and phonetic sounds (`"Hahaha"`, `"Mmmmm"`, `"Ugh"`) -- **Outside quotes**: stage directions (`She sighs.`, `He clears his throat.`) -- **Avoid inside quotes**: `Ahem`, `Pfft`, `Sigh`, `Gasp`, `Cough` β€” the model will speak them literally. +| Script | Purpose | +|---|---| +| `src/preprocess.py` | Encode audio (Audio VAE) + text (Gemma) into training-ready `.pt` files | +| `src/train.py` | IC-LoRA training loop with peft, accelerate multi-GPU, periodic validation | +| `src/validate.py` | Spawned by `train.py` at each save step; runs the warm validator on a held-out prompt set | +| `scripts/train.sh` | YAML-config wrapper around `accelerate launch src/train.py` | + +LoRA targets the audio branch only: `audio_attn1.{to_q,to_k,to_v,to_out.0}` + `audio_ff.{net.0.proj,net.2}` Γ— 48 transformer blocks (288 LoRA pairs total). Default rank 128 / alpha 128 / dropout 0.1, cosine LR schedule from 1e-4 with 500-step warmup over 10k steps. + +## Language -See the **Load an example prompt** dropdown for ready-made scene templates. +English. -## Files +## License -- `app.py` β€” Gradio UI -- `src/inference_server.py` β€” warm `TTSServer` (single load, ~2.5s/request) -- `src/inference.py` β€” CLI inference -- `src/model_downloader.py` β€” auto-fetches model from HuggingFace -- `ltx2/` β€” vendored LTX-2 pipelines -- `requirements.txt` β€” Python deps (includes `resemble-perth`) +Built on [LTX-2](https://github.com/Lightricks/LTX-2) by Lightricks. Distributed under the LTX-2 Community License Agreement β€” see [`LICENSE`](LICENSE). diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..e710ce8b4d4c30b258fd47710b166d2fe4b11758 --- /dev/null +++ b/app.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +"""DramaBox β€” Gradio demo (warm server). + +Loads the warm TTSServer once, then handles requests at ~2.5 s each. All +generated audio is invisibly watermarked with Resemble Perth before being +returned to the user. +""" +import logging +import os +import sys +import tempfile +import time + +import gradio as gr + +# Local src import. +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")) +from inference_server import TTSServer # noqa: E402 + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logging.info("Loading DramaBox warm server (Gemma + DiT + VAE + Decoder)...") +tts = TTSServer( + device="cuda", + dtype=os.environ.get("LTX_DTYPE", "bf16"), + compile_model=os.environ.get("LTX_COMPILE", "0") == "1", + bnb_4bit=True, # default Gemma is unsloth pre-quantized +) +logging.info("Server ready.") + + +# ── Example prompts (shown as click-to-fill chips in the UI) ───────────────── +EXAMPLES: list[tuple[str, str]] = [ + ( + "Villain monologue", + 'A shadowy villain speaks with cold menace, "You have entered my domain, mortal." ' + 'He chuckles darkly, "Such arrogance will be your undoing." ' + 'His voice rises with fury, "Kneel, or be destroyed where you stand!"' + ), + ( + "Talk-show host wheeze-laugh", + 'A talk show host gasps with shock, "No! You did NOT just say that!" ' + 'He bursts into uncontrollable laughter, "Hahaha! Oh my god, oh my god!" ' + 'He wheezes, "I cannot, I literally cannot breathe right now!"' + ), + ( + "Tender goodnight whisper", + 'A woman speaks tenderly, "It has been a long day, my love." ' + 'She whispers, "Close your eyes. I am right here." ' + 'She hums quietly, "Mmmm-mmm. Sleep now."' + ), + ( + "Old-school radio anchor", + 'A radio host clears his throat, "Excuse me, pardon that." ' + 'He settles into a warm, professional tone, "Good evening everyone, ' + 'and welcome back to the show. We have got a wonderful lineup tonight."' + ), + ( + "Catgirl uncontrollable giggling", + 'A playful girl already mid-giggle, "Hehehe, oh my gosh you should see your face!" ' + 'She gasps for air between giggles, "Oh my, hehe, oh my, I cannot stop!" ' + 'She tries to compose herself, "Ahhhhh okay okay okay, I will stop, I promise."' + ), + ( + "Hero stammering courage", + 'A young warrior speaks with a trembling voice, "I... I do not know if I can do this." ' + 'He takes a shaky breath, "But someone has to try." ' + 'His voice steadies with growing fire, "No more running. I WILL fight!"' + ), + ( + "Exhausted dad, fraying patience", + 'An exhausted father speaks with fraying patience, "Sweetie, daddy is asking very nicely." ' + 'He sighs deeply, "Ohhhh my goodness." ' + 'He puts on an overly cheerful voice, "Hey buddy! Look at the shiny thing!" ' + 'Then he laughs helplessly, "Hahaha, I am losing my mind."' + ), + ( + "Smug-confident announcer", + 'A confident announcer speaks proudly, "And now, the moment you have all been waiting for." ' + 'He chuckles knowingly, "Heheh, trust me, this one is going to blow you away."' + ), +] + + +def on_generate(prompt: str, audio_ref, cfg: float, stg: float, dur_mult: float, seed: int): + if not prompt or not prompt.strip(): + raise gr.Error("Prompt is empty.") + t0 = time.time() + ref_path = audio_ref if audio_ref and os.path.exists(str(audio_ref)) else None + output = tempfile.mktemp(suffix=".wav", prefix="dramabox_") + tts.generate_to_file( + prompt=prompt, + output=output, + voice_ref=ref_path, + cfg_scale=cfg, stg_scale=stg, + duration_multiplier=dur_mult, seed=int(seed), + ) + elapsed = time.time() - t0 + logging.info(f"Generated in {elapsed:.2f}s -> {output}") + return output + + +# ── UI ────────────────────────────────────────────────────────────────────── +with gr.Blocks( + title="DramaBox β€” Expressive TTS", + theme=gr.themes.Default(), + css=".prompt-box textarea { font-size: 14px !important; line-height: 1.5 !important; }", +) as app: + gr.Markdown("# 🎭 DramaBox β€” Expressive TTS with Voice Cloning") + gr.Markdown( + "Write a scene prompt, optionally upload a 10-second voice reference, " + "and generate. Audio is automatically watermarked with " + "[Resemble Perth](https://github.com/resemble-ai/Perth).\n\n" + "**Tips:** put dialogue inside `\"double quotes\"`, scene directions outside. " + "Phonetic sounds (`\"Hahaha\"`, `\"Mmmm\"`, `\"Ugh\"`) go inside quotes; named " + "actions (`She sighs.`, `He clears his throat.`) go outside." + ) + + with gr.Row(): + with gr.Column(scale=3): + prompt_box = gr.Textbox( + label="Scene prompt", + placeholder=EXAMPLES[0][1], + lines=6, elem_classes=["prompt-box"], + ) + example_chooser = gr.Dropdown( + choices=[e[0] for e in EXAMPLES], + label="Load an example prompt", interactive=True, value=None, + ) + audio_ref = gr.Audio( + label="Voice reference (optional, 10+ seconds)", + type="filepath", + ) + gen_btn = gr.Button("Generate", variant="primary", size="lg") + + with gr.Column(scale=2): + with gr.Accordion("Inference settings", open=False): + cfg_slider = gr.Slider(1.0, 10.0, value=2.5, step=0.5, label="CFG scale") + stg_slider = gr.Slider(0.0, 5.0, value=1.5, step=0.5, label="STG scale") + dur_slider = gr.Slider(0.8, 2.0, value=1.1, step=0.05, label="Duration Γ—") + seed_input = gr.Number(value=42, label="Seed", precision=0) + audio_out = gr.Audio(label="Generated audio", type="filepath") + with gr.Accordion("Prompt writing guide", open=False): + gr.Markdown( + "**Structure:** `, \"\" \"\"`\n\n" + "**Inside quotes** (model speaks them):\n" + "- Dialogue: `\"Hello, how are you?\"`\n" + "- Phonetic sounds: `\"Hahaha\"`, `\"Hehehe\"`, `\"Mmmmm\"`, `\"Ugh\"`, `\"Argh\"`\n\n" + "**Outside quotes** (stage directions):\n" + "- `She sighs deeply.`, `He gulps nervously.`, `A long pause.`\n" + "- `Her voice cracks.`, `He clears his throat.`\n\n" + "**Avoid inside quotes:** Ahem, Pfft, Sigh, Gasp, Cough β€” the model speaks them literally." + ) + + def _load_example(choice: str): + if not choice: + return gr.update() + for name, prompt in EXAMPLES: + if name == choice: + return prompt + return gr.update() + + example_chooser.change(_load_example, inputs=[example_chooser], outputs=[prompt_box]) + gen_btn.click( + on_generate, + inputs=[prompt_box, audio_ref, cfg_slider, stg_slider, dur_slider, seed_input], + outputs=[audio_out], + ) + + +if __name__ == "__main__": + port = int(os.environ.get("GRADIO_SERVER_PORT", "7861")) + app.queue(max_size=10).launch( + server_name="0.0.0.0", server_port=port, + share=os.environ.get("GRADIO_SHARE", "0") == "1", + ) diff --git a/assets/silence_latent_frame.pt b/assets/silence_latent_frame.pt new file mode 100644 index 0000000000000000000000000000000000000000..fe6c1e78e5d58899b6fa2e1547b1f739fb4cd30e --- /dev/null +++ b/assets/silence_latent_frame.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f73746d2163f8f1742c5de89005404ccaeeff05154bbb10a3337bf9bd13f161c +size 1501 diff --git a/configs/training_args.example.yaml b/configs/training_args.example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d2cbfcd1621015e58711fc1d71db2b2d3863251d --- /dev/null +++ b/configs/training_args.example.yaml @@ -0,0 +1,53 @@ +# Example DramaBox IC-LoRA training config. Used by scripts/train.sh. + +# Where to load preprocessed `audio_latents/` + `conditions/` shards from. +data_dir: +- /path/to/preprocessed_dataset_a/ +- /path/to/preprocessed_dataset_b/ + +# One index file per data_dir entry. Each line: +# ~~~~~~~ +speaker_index: +- /path/to/preprocessed_dataset_a/index.txt +- /path/to/preprocessed_dataset_b/index.txt + +# Output directory (relative is fine β€” resolved against the repo root). +output_dir: tts_iclora_v1 + +# LTX-2.3 22B base. Same file is used for the transformer + the aux stack +# (PromptEncoder, AudioVAE, AudioDecoder). +checkpoint: ltx-2.3-22b-dev.safetensors +full_checkpoint: ltx-2.3-22b-dev.safetensors +base_model: dev + +# LoRA hyperparams. rank == alpha is the simplest setup (scale = 1.0). +lora_rank: 128 +lora_alpha: 128 +lora_dropout: 0.1 + +# Voice-cloning ref-token settings. +ref_ratio: 0.3 # fraction of training samples that get a ref token +max_ref_tokens: 200 # max ref-token positions appended to target + +text_dropout: 0.4 # CFG training: drop the text prompt with prob 0.4 + +# Schedule. Use lr_scheduler=constant with a small lr (1e-5) for a "fine-tune" +# resume; cosine + larger lr (1e-4) for from-scratch. +steps: 10000 +lr: 1.0e-04 +lr_scheduler: cosine +warmup_steps: 500 + +batch_size: 1 +grad_accum: 4 +max_grad_norm: 1.0 + +save_every: 500 +log_every: 50 +seed: 53 + +# (Optional) per-checkpoint validation eval β€” see configs/val_config.example.yaml +# val_config: val_config.example.yaml + +# (Optional) resume from a previous LoRA adapter file: +# resume_lora: tts_iclora_v0/lora_step_05000.safetensors diff --git a/ltx2/ltx_core/__init__.py b/ltx2/ltx_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx2/ltx_core/batch_split.py b/ltx2/ltx_core/batch_split.py new file mode 100644 index 0000000000000000000000000000000000000000..014ca5a449ba654a4f6b7ebd5ba307d510f8b34d --- /dev/null +++ b/ltx2/ltx_core/batch_split.py @@ -0,0 +1,95 @@ +"""Batch-splitting adapter for the transformer. +Wraps an ``X0Model`` (or ``LayerStreamingWrapper``) and splits batched inputs +into smaller chunks before forwarding, then concatenates the results. This +controls peak activation memory at the cost of more forward passes. +The adapter is transparent β€” it has the same ``forward`` signature as +``X0Model`` and proxies attribute access to the wrapped model. +Example +------- +>>> from ltx_core.batch_split import BatchSplitAdapter +>>> adapter = BatchSplitAdapter(model, max_batch_size=1) +>>> # Receives B=4, runs 4xB=1 internally, returns B=4 +>>> denoised_video, denoised_audio = adapter(video=v_b4, audio=a_b4, perturbations=ptb) +""" + +from __future__ import annotations + +from typing import Any + +import torch +from torch import nn + +from ltx_core.guidance.perturbations import BatchedPerturbationConfig +from ltx_core.model.transformer.modality import Modality + + +def _split_perturbations(config: BatchedPerturbationConfig, sizes: list[int]) -> list[BatchedPerturbationConfig]: + """Split a ``BatchedPerturbationConfig`` along the batch dimension.""" + it = iter(config.perturbations) + return [BatchedPerturbationConfig([next(it) for _ in range(s)]) for s in sizes] + + +def _merge_tensors(tensors: list[torch.Tensor | None]) -> torch.Tensor | None: + """Concatenate tensors along batch dim, or return None if all are None.""" + non_none = [t for t in tensors if t is not None] + if not non_none: + return None + return torch.cat(non_none, dim=0) + + +class BatchSplitAdapter(nn.Module): + """Wraps a model and splits batched forward calls into smaller chunks. + Has the same ``forward`` signature as ``X0Model``: + ``(video, audio, perturbations) -> (denoised_video, denoised_audio)``. + Args: + model: The model to wrap (``X0Model``, ``LayerStreamingWrapper``, etc.). + max_batch_size: Maximum batch size per forward pass. Input batches + larger than this are split into sequential chunks. + """ + + def __init__(self, model: nn.Module, max_batch_size: int) -> None: + if max_batch_size < 1: + raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}") + super().__init__() + self._model = model + self._max_batch_size = max_batch_size + + def _get_chunk_sizes(self, batch_size: int) -> list[int]: + full, remainder = divmod(batch_size, self._max_batch_size) + sizes = [self._max_batch_size] * full + if remainder: + sizes.append(remainder) + return sizes + + def forward( + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + batch_size = (video or audio).latent.shape[0] + + if batch_size <= self._max_batch_size: + return self._model(video=video, audio=audio, perturbations=perturbations) + + sizes = self._get_chunk_sizes(batch_size) + n = len(sizes) + + v_chunks = video.split(sizes) if video is not None else [None] * n + a_chunks = audio.split(sizes) if audio is not None else [None] * n + p_chunks = _split_perturbations(perturbations, sizes) + + chunk_results = [ + self._model(video=vc, audio=ac, perturbations=pc) + for vc, ac, pc in zip(v_chunks, a_chunks, p_chunks, strict=True) + ] + + results_v, results_a = zip(*chunk_results, strict=True) + return _merge_tensors(list(results_v)), _merge_tensors(list(results_a)) + + def __getattr__(self, name: str) -> Any: # noqa: ANN401 + """Proxy attribute access to the wrapped model.""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._model, name) diff --git a/ltx2/ltx_core/components/__init__.py b/ltx2/ltx_core/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1cb638b4d43e6d8a03685bb6f30992a0496bba3 --- /dev/null +++ b/ltx2/ltx_core/components/__init__.py @@ -0,0 +1,10 @@ +""" +Diffusion pipeline components. +Submodules: + diffusion_steps - Diffusion stepping algorithms (EulerDiffusionStep) + guiders - Guidance strategies (CFGGuider, STGGuider, APG variants) + noisers - Noise samplers (GaussianNoiser) + patchifiers - Latent patchification (VideoLatentPatchifier, AudioPatchifier) + protocols - Protocol definitions (Patchifier, etc.) + schedulers - Sigma schedulers (LTX2Scheduler, LinearQuadraticScheduler) +""" diff --git a/ltx2/ltx_core/components/diffusion_steps.py b/ltx2/ltx_core/components/diffusion_steps.py new file mode 100644 index 0000000000000000000000000000000000000000..d4908cb6404312e1f2c7faf4ae49200034dee16a --- /dev/null +++ b/ltx2/ltx_core/components/diffusion_steps.py @@ -0,0 +1,106 @@ +import torch + +from ltx_core.components.protocols import DiffusionStepProtocol +from ltx_core.utils import to_velocity + + +class EulerDiffusionStep(DiffusionStepProtocol): + """ + First-order Euler method for diffusion sampling. + Takes a single step from the current noise level (sigma) to the next by + computing velocity from the denoised prediction and applying: sample + velocity * dt. + """ + + def step( + self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **_kwargs + ) -> torch.Tensor: + sigma = sigmas[step_index] + sigma_next = sigmas[step_index + 1] + dt = sigma_next - sigma + velocity = to_velocity(sample, sigma, denoised_sample) + + return (sample.to(torch.float32) + velocity.to(torch.float32) * dt).to(sample.dtype) + + +class Res2sDiffusionStep(DiffusionStepProtocol): + """ + Second-order diffusion step for res_2s sampling with SDE noise injection. + Used by the res_2s denoising loop. Advances the sample from the current + sigma to the next by mixing a deterministic update (from the denoised + prediction) with injected noise via ``get_sde_coeff``, producing + variance-preserving transitions. + """ + + @staticmethod + def get_sde_coeff( + sigma_next: torch.Tensor, + sigma_up: torch.Tensor | None = None, + sigma_down: torch.Tensor | None = None, + sigma_max: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute SDE coefficients (alpha_ratio, sigma_down, sigma_up) for the step. + Given either ``sigma_down`` or ``sigma_up``, returns the mixing + coefficients used for variance-preserving noise injection. If + ``sigma_up`` is provided, ``sigma_down`` and ``alpha_ratio`` are + derived; if ``sigma_down`` is provided, ``sigma_up`` and + ``alpha_ratio`` are derived. + """ + if sigma_down is not None: + alpha_ratio = (1 - sigma_next) / (1 - sigma_down) + sigma_up = (sigma_next**2 - sigma_down**2 * alpha_ratio**2).clamp(min=0) ** 0.5 + elif sigma_up is not None: + # Fallback to avoid sqrt(neg_num) + sigma_up.clamp_(max=sigma_next * 0.9999) + sigmax = sigma_max if sigma_max is not None else torch.ones_like(sigma_next) + sigma_signal = sigmax - sigma_next + sigma_residual = (sigma_next**2 - sigma_up**2).clamp(min=0) ** 0.5 + alpha_ratio = sigma_signal + sigma_residual + sigma_down = sigma_residual / alpha_ratio + else: + alpha_ratio = torch.ones_like(sigma_next) + sigma_down = sigma_next + sigma_up = torch.zeros_like(sigma_next) + + sigma_up = torch.nan_to_num(sigma_up if sigma_up is not None else torch.zeros_like(sigma_next), 0.0) + # Replace NaNs in sigma_down with corresponding sigma_next elements (float32) + nan_mask = torch.isnan(sigma_down) + sigma_down[nan_mask] = sigma_next[nan_mask].to(sigma_down.dtype) + alpha_ratio = torch.nan_to_num(alpha_ratio, 1.0) + + return alpha_ratio, sigma_down, sigma_up + + def step( + self, + sample: torch.Tensor, + denoised_sample: torch.Tensor, + sigmas: torch.Tensor, + step_index: int, + noise: torch.Tensor, + eta: float = 0.5, + ) -> torch.Tensor: + """Advance one step with SDE noise injection via get_sde_coeff. + Args: + sample: Current noisy sample. + denoised_sample: Denoised prediction from the model. + sigmas: Noise schedule tensor. + step_index: Current step index in the schedule. + noise: Random noise tensor for stochastic injection. + eta: Controls stochastic noise injection strength (0=deterministic, 1=maximum). Default 0.5. + Returns: + Next sample with SDE noise injection applied. + """ + sigma = sigmas[step_index] + sigma_next = sigmas[step_index + 1] + alpha_ratio, sigma_down, sigma_up = self.get_sde_coeff(sigma_next, sigma_up=sigma_next * eta) + output_dtype = denoised_sample.dtype + if torch.any(sigma_up == 0) or torch.any(sigma_next == 0): + return denoised_sample + + # Extract epsilon prediction + eps_next = (sample - denoised_sample) / (sigma - sigma_next) + denoised_next = sample - sigma * eps_next + + # Mix deterministic and stochastic components + x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise + return x_noised.to(output_dtype) diff --git a/ltx2/ltx_core/components/guiders.py b/ltx2/ltx_core/components/guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..88df758ab382c02cb8ff4221b48d4af79289f94c --- /dev/null +++ b/ltx2/ltx_core/components/guiders.py @@ -0,0 +1,383 @@ +import math +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field + +import torch + +from ltx_core.components.protocols import GuiderProtocol + + +@dataclass(frozen=True) +class CFGGuider(GuiderProtocol): + """ + Classifier-free guidance (CFG) guider. + Computes the guidance delta as (scale - 1) * (cond - uncond), steering the + denoising process toward the conditioned prediction. + Attributes: + scale: Guidance strength. 1.0 means no guidance, higher values increase + adherence to the conditioning. + """ + + scale: float + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + return (self.scale - 1) * (cond - uncond) + + def enabled(self) -> bool: + return self.scale != 1.0 + + +@dataclass(frozen=True) +class CFGStarRescalingGuider(GuiderProtocol): + """ + Calculates the CFG delta between conditioned and unconditioned samples. + To minimize offset in the denoising direction and move mostly along the + conditioning axis within the distribution, the unconditioned sample is + rescaled in accordance with the norm of the conditioned sample. + Attributes: + scale (float): + Global guidance strength. A value of 1.0 corresponds to no extra + guidance beyond the base model prediction. Values > 1.0 increase + the influence of the conditioned sample relative to the + unconditioned one. + """ + + scale: float + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + rescaled_neg = projection_coef(cond, uncond) * uncond + return (self.scale - 1) * (cond - rescaled_neg) + + def enabled(self) -> bool: + return self.scale != 1.0 + + +@dataclass(frozen=True) +class STGGuider(GuiderProtocol): + """ + Calculates the STG delta between conditioned and perturbed denoised samples. + Perturbed samples are the result of the denoising process with perturbations, + e.g. attentions acting as passthrough for certain layers and modalities. + Attributes: + scale (float): + Global strength of the STG guidance. A value of 0.0 disables the + guidance. Larger values increase the correction applied in the + direction of (pos_denoised - perturbed_denoised). + """ + + scale: float + + def delta(self, pos_denoised: torch.Tensor, perturbed_denoised: torch.Tensor) -> torch.Tensor: + return self.scale * (pos_denoised - perturbed_denoised) + + def enabled(self) -> bool: + return self.scale != 0.0 + + +@dataclass(frozen=True) +class LtxAPGGuider(GuiderProtocol): + """ + Calculates the APG (adaptive projected guidance) delta between conditioned + and unconditioned samples. + To minimize offset in the denoising direction and move mostly along the + conditioning axis within the distribution, the (cond - uncond) delta is + decomposed into components parallel and orthogonal to the conditioned + sample. The `eta` parameter weights the parallel component, while `scale` + is applied to the orthogonal component. Optionally, a norm threshold can + be used to suppress guidance when the magnitude of the correction is small. + Attributes: + scale (float): + Strength applied to the component of the guidance that is orthogonal + to the conditioned sample. Controls how aggressively we move in + directions that change semantics but stay consistent with the + conditioning manifold. + eta (float): + Weight of the component of the guidance that is parallel to the + conditioned sample. A value of 1.0 keeps the full parallel + component; values in [0, 1] attenuate it, and values > 1.0 amplify + motion along the conditioning direction. + norm_threshold (float): + Minimum L2 norm of the guidance delta below which the guidance + can be reduced or ignored (depending on implementation). + This is useful for avoiding noisy or unstable updates when the + guidance signal is very small. + """ + + scale: float + eta: float = 1.0 + norm_threshold: float = 0.0 + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + guidance = cond - uncond + if self.norm_threshold > 0: + ones = torch.ones_like(guidance) + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm) + guidance = guidance * scale_factor + proj_coeff = projection_coef(guidance, cond) + g_parallel = proj_coeff * cond + g_orth = guidance - g_parallel + g_apg = g_parallel * self.eta + g_orth + + return g_apg * (self.scale - 1) + + def enabled(self) -> bool: + return self.scale != 1.0 + + +@dataclass(frozen=False) +class LegacyStatefulAPGGuider(GuiderProtocol): + """ + Calculates the APG (adaptive projected guidance) delta between conditioned + and unconditioned samples. + To minimize offset in the denoising direction and move mostly along the + conditioning axis within the distribution, the (cond - uncond) delta is + decomposed into components parallel and orthogonal to the conditioned + sample. The `eta` parameter weights the parallel component, while `scale` + is applied to the orthogonal component. Optionally, a norm threshold can + be used to suppress guidance when the magnitude of the correction is small. + Attributes: + scale (float): + Strength applied to the component of the guidance that is orthogonal + to the conditioned sample. Controls how aggressively we move in + directions that change semantics but stay consistent with the + conditioning manifold. + eta (float): + Weight of the component of the guidance that is parallel to the + conditioned sample. A value of 1.0 keeps the full parallel + component; values in [0, 1] attenuate it, and values > 1.0 amplify + motion along the conditioning direction. + norm_threshold (float): + Minimum L2 norm of the guidance delta below which the guidance + can be reduced or ignored (depending on implementation). + This is useful for avoiding noisy or unstable updates when the + guidance signal is very small. + momentum (float): + Exponential moving-average coefficient for accumulating guidance + over time. running_avg = momentum * running_avg + guidance + """ + + scale: float + eta: float + norm_threshold: float = 5.0 + momentum: float = 0.0 + # it is user's responsibility not to use same APGGuider for several denoisings or different modalities + # in order not to share accumulated average across different denoisings or modalities + running_avg: torch.Tensor | None = None + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: + guidance = cond - uncond + if self.momentum != 0: + if self.running_avg is None: + self.running_avg = guidance.clone() + else: + self.running_avg = self.momentum * self.running_avg + guidance + guidance = self.running_avg + + if self.norm_threshold > 0: + ones = torch.ones_like(guidance) + guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True) + scale_factor = torch.minimum(ones, self.norm_threshold / guidance_norm) + guidance = guidance * scale_factor + + proj_coeff = projection_coef(guidance, cond) + g_parallel = proj_coeff * cond + g_orth = guidance - g_parallel + g_apg = g_parallel * self.eta + g_orth + + return g_apg * self.scale + + def enabled(self) -> bool: + return self.scale != 0.0 + + +@dataclass(frozen=True) +class MultiModalGuiderParams: + """ + Parameters for the multi-modal guider. + """ + + cfg_scale: float = 1.0 + "CFG (Classifier-free guidance) scale controlling how strongly the model adheres to the prompt." + stg_scale: float = 0.0 + "STG (Spatio-Temporal Guidance) scale controls how strongly the model reacts to the perturbation of the modality." + stg_blocks: list[int] | None = field(default_factory=list) + "Which transformer blocks to perturb for STG." + rescale_scale: float = 0.0 + "Rescale scale controlling how strongly the model rescales the modality after applying other guidance." + modality_scale: float = 1.0 + "Modality scale controlling how strongly the model reacts to the perturbation of the modality." + cfg_clamp_scale: float = 0.0 + "Clamp guided prediction std to this multiple of conditioned prediction std. 0 = disabled." + skip_step: int = 0 + "Skip step controlling how often the model skips the step." + + +def _params_for_sigma_from_sorted_dict( + sigma: float, params_by_sigma: Sequence[tuple[float, MultiModalGuiderParams]] +) -> MultiModalGuiderParams: + """ + Return params for the given sigma from a sorted (sigma_upper_bound -> params) structure. + Keys are sorted descending (bin upper bounds). Bin i is (key_{i+1}, key_i]. + Get all keys >= sigma; use last in list (smallest such key = upper bound of bin containing sigma), + or last entry in the sequence if list is empty (sigma above max key). + """ + if not params_by_sigma: + raise ValueError("params_by_sigma must be non-empty") + sigma = float(sigma) + keys_desc = [k for k, _ in params_by_sigma] + keys_ge_sigma = [k for k in keys_desc if k >= sigma] + # sigma above all keys: use first bin (max key) + key = keys_ge_sigma[-1] if keys_ge_sigma else keys_desc[0] + return next(p for k, p in params_by_sigma if k == key) + + +@dataclass(frozen=True) +class MultiModalGuider: + """ + Multi-modal guider with constant params per instance. + For sigma-dependent params, use MultiModalGuiderFactory.build_from_sigma(sigma) to + obtain a guider for each step. + """ + + params: MultiModalGuiderParams + negative_context: torch.Tensor | None = None + + def calculate( + self, + cond: torch.Tensor, + uncond_text: torch.Tensor | float, + uncond_perturbed: torch.Tensor | float, + uncond_modality: torch.Tensor | float, + ) -> torch.Tensor: + """ + The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg, + and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned + prediction. + """ + pred = ( + cond + + (self.params.cfg_scale - 1) * (cond - uncond_text) + + self.params.stg_scale * (cond - uncond_perturbed) + + (self.params.modality_scale - 1) * (cond - uncond_modality) + ) + + if self.params.rescale_scale != 0: + factor = cond.std() / pred.std() + factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale) + pred = pred * factor + + # Clamp guided prediction to prevent trajectory overshoot. + # Instead of global std (which averages over all tokens), clamp per-token. + # This catches individual tokens that overshoot even if the global std looks fine. + if self.params.cfg_clamp_scale > 0: + cfg_delta = pred - cond + # Per-token magnitude clamping + delta_norm = cfg_delta.norm(dim=-1, keepdim=True) # [B, T, 1] + cond_norm = cond.norm(dim=-1, keepdim=True) + max_norm = cond_norm * self.params.cfg_clamp_scale + # Clamp tokens where delta exceeds max + scale = torch.where( + delta_norm > max_norm, + max_norm / delta_norm.clamp(min=1e-8), + torch.ones_like(delta_norm), + ) + pred = cond + cfg_delta * scale + + return pred + + def do_unconditional_generation(self) -> bool: + """Returns True if the guider is doing unconditional generation.""" + return not math.isclose(self.params.cfg_scale, 1.0) + + def do_perturbed_generation(self) -> bool: + """Returns True if the guider is doing perturbed generation.""" + return not math.isclose(self.params.stg_scale, 0.0) + + def do_isolated_modality_generation(self) -> bool: + """Returns True if the guider is doing isolated modality generation.""" + return not math.isclose(self.params.modality_scale, 1.0) + + def should_skip_step(self, step: int) -> bool: + """Returns True if the guider should skip the step.""" + if self.params.skip_step == 0: + return False + return step % (self.params.skip_step + 1) != 0 + + +@dataclass(frozen=True) +class MultiModalGuiderFactory: + """ + Factory that creates a MultiModalGuider for a given sigma. + Single source of truth: _params_by_sigma (schedule). Use constant() for + one params for all sigma, from_dict() for sigma-binned params. + """ + + negative_context: torch.Tensor | None = None + _params_by_sigma: tuple[tuple[float, MultiModalGuiderParams], ...] = () + + @classmethod + def constant( + cls, + params: MultiModalGuiderParams, + negative_context: torch.Tensor | None = None, + ) -> "MultiModalGuiderFactory": + """Build a factory with constant params (same guider for all sigma).""" + return cls( + negative_context=negative_context, + _params_by_sigma=((float("inf"), params),), + ) + + @classmethod + def from_dict( + cls, + sigma_to_params: Mapping[float, MultiModalGuiderParams], + negative_context: torch.Tensor | None = None, + ) -> "MultiModalGuiderFactory": + """ + Build a factory from a dict of sigma_value -> MultiModalGuiderParams. + Keys are sorted descending and used for bin lookup in params(sigma). + """ + if not sigma_to_params: + raise ValueError("sigma_to_params must be non-empty") + sorted_items = tuple(sorted(sigma_to_params.items(), key=lambda x: x[0], reverse=True)) + return cls(negative_context=negative_context, _params_by_sigma=sorted_items) + + def params(self, sigma: float | torch.Tensor) -> MultiModalGuiderParams: + """Return params effective for the given sigma (getter; single source of truth).""" + sigma_val = float(sigma.item() if isinstance(sigma, torch.Tensor) else sigma) + return _params_for_sigma_from_sorted_dict(sigma_val, self._params_by_sigma) + + def build_from_sigma(self, sigma: float | torch.Tensor) -> MultiModalGuider: + """Return a MultiModalGuider with params effective for the given sigma.""" + return MultiModalGuider( + params=self.params(sigma), + negative_context=self.negative_context, + ) + + +def create_multimodal_guider_factory( + params: MultiModalGuiderParams | MultiModalGuiderFactory, + negative_context: torch.Tensor | None = None, +) -> MultiModalGuiderFactory: + """ + Create or return a MultiModalGuiderFactory. Pass constant params for a + single-params factory (uses MultiModalGuiderFactory.constant), or an existing + MultiModalGuiderFactory. When given a factory, returns it as-is unless + negative_context is provided. For sigma-dependent params use + MultiModalGuiderFactory.from_dict(...) and pass that as params. + """ + if isinstance(params, MultiModalGuiderFactory): + if negative_context is not None and params.negative_context is not negative_context: + return MultiModalGuiderFactory.from_dict(dict(params._params_by_sigma), negative_context=negative_context) + return params + return MultiModalGuiderFactory.constant(params, negative_context=negative_context) + + +def projection_coef(to_project: torch.Tensor, project_onto: torch.Tensor) -> torch.Tensor: + batch_size = to_project.shape[0] + positive_flat = to_project.reshape(batch_size, -1) + negative_flat = project_onto.reshape(batch_size, -1) + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + return dot_product / squared_norm diff --git a/ltx2/ltx_core/components/noisers.py b/ltx2/ltx_core/components/noisers.py new file mode 100644 index 0000000000000000000000000000000000000000..2db66d89af3f7ad2820e0083aa2b1c2b276fcaaf --- /dev/null +++ b/ltx2/ltx_core/components/noisers.py @@ -0,0 +1,35 @@ +from dataclasses import replace +from typing import Protocol + +import torch + +from ltx_core.types import LatentState + + +class Noiser(Protocol): + """Protocol for adding noise to a latent state during diffusion.""" + + def __call__(self, latent_state: LatentState, noise_scale: float) -> LatentState: ... + + +class GaussianNoiser(Noiser): + """Adds Gaussian noise to a latent state, scaled by the denoise mask.""" + + def __init__(self, generator: torch.Generator): + super().__init__() + + self.generator = generator + + def __call__(self, latent_state: LatentState, noise_scale: float = 1.0) -> LatentState: + noise = torch.randn( + *latent_state.latent.shape, + device=latent_state.latent.device, + dtype=latent_state.latent.dtype, + generator=self.generator, + ) + scaled_mask = latent_state.denoise_mask * noise_scale + latent = noise * scaled_mask + latent_state.latent * (1 - scaled_mask) + return replace( + latent_state, + latent=latent.to(latent_state.latent.dtype), + ) diff --git a/ltx2/ltx_core/components/patchifiers.py b/ltx2/ltx_core/components/patchifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..f9580d52710b4ce17459a37131fb605c874d0437 --- /dev/null +++ b/ltx2/ltx_core/components/patchifiers.py @@ -0,0 +1,348 @@ +import math +from typing import Optional, Tuple + +import einops +import torch + +from ltx_core.components.protocols import Patchifier +from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape + + +class VideoLatentPatchifier(Patchifier): + def __init__(self, patch_size: int): + # Patch sizes for video latents. + self._patch_size = ( + 1, # temporal dimension + patch_size, # height dimension + patch_size, # width dimension + ) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: VideoLatentShape) -> int: + return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + + return latents + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: VideoLatentShape, + ) -> torch.Tensor: + assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" + + patch_grid_frames = output_shape.frames // self._patch_size[0] + patch_grid_height = output_shape.height // self._patch_size[1] + patch_grid_width = output_shape.width // self._patch_size[2] + + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=patch_grid_frames, + h=patch_grid_height, + w=patch_grid_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + + return latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the per-dimension bounds [inclusive start, exclusive end) for every + patch produced by `patchify`. The bounds are expressed in the original + video grid coordinates: frame/time, height, and width. + The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: + - axis 1 (size 3) enumerates (frame/time, height, width) dimensions + - axis 3 (size 2) stores `[start, end)` indices within each dimension + Args: + output_shape: Video grid description containing frames, height, and width. + device: Device of the latent tensor. + """ + if not isinstance(output_shape, VideoLatentShape): + raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") + + frames = output_shape.frames + height = output_shape.height + width = output_shape.width + batch_size = output_shape.batch + + # Validate inputs to ensure positive dimensions + assert frames > 0, f"frames must be positive, got {frames}" + assert height > 0, f"height must be positive, got {height}" + assert width > 0, f"width must be positive, got {width}" + assert batch_size > 0, f"batch_size must be positive, got {batch_size}" + + # Generate grid coordinates for each dimension (frame, height, width) + # We use torch.arange to create the starting coordinates for each patch. + # indexing='ij' ensures the dimensions are in the order (frame, height, width). + grid_coords = torch.meshgrid( + torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), + torch.arange(start=0, end=height, step=self._patch_size[1], device=device), + torch.arange(start=0, end=width, step=self._patch_size[2], device=device), + indexing="ij", + ) + + # Stack the grid coordinates to create the start coordinates tensor. + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_starts = torch.stack(grid_coords, dim=0) + + # Create a tensor containing the size of a single patch: + # (frame_patch_size, height_patch_size, width_patch_size). + # Reshape to (3, 1, 1, 1) to enable broadcasting when adding to the start coordinates. + patch_size_delta = torch.tensor( + self._patch_size, + device=patch_starts.device, + dtype=patch_starts.dtype, + ).view(3, 1, 1, 1) + + # Calculate end coordinates: start + patch_size + # Shape becomes (3, grid_f, grid_h, grid_w) + patch_ends = patch_starts + patch_size_delta + + # Stack start and end coordinates together along the last dimension + # Shape becomes (3, grid_f, grid_h, grid_w, 2), where the last dimension is [start, end] + latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) + + # Broadcast to batch size and flatten all spatial/temporal dimensions into one sequence. + # Final Shape: (batch_size, 3, num_patches, 2) + latent_coords = einops.repeat( + latent_coords, + "c f h w bounds -> b c (f h w) bounds", + b=batch_size, + bounds=2, + ) + + return latent_coords + + +def get_pixel_coords( + latent_coords: torch.Tensor, + scale_factors: SpatioTemporalScaleFactors, + causal_fix: bool = False, +) -> torch.Tensor: + """ + Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling + each axis (frame/time, height, width) with the corresponding VAE downsampling factors. + Optionally compensate for causal encoding that keeps the first frame at unit temporal scale. + Args: + latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`. + scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied + per axis. + causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs + that treat frame zero differently still yield non-negative timestamps. + """ + # Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout. + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width) + scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape) + + # Apply per-axis scaling to convert latent bounds into pixel-space coordinates. + pixel_coords = latent_coords * scale_tensor + + if causal_fix: + # VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`. + # Shift and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) + + return pixel_coords + + +class AudioPatchifier(Patchifier): + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + shift: int = 0, + ): + """ + Patchifier tailored for spectrogram/audio latents. + Args: + patch_size: Number of mel bins combined into a single patch. This + controls the resolution along the frequency axis. + sample_rate: Original waveform sampling rate. Used to map latent + indices back to seconds so downstream consumers can align audio + and video cues. + hop_length: Window hop length used for the spectrogram. Determines + how many real-time samples separate two consecutive latent frames. + audio_latent_downsample_factor: Ratio between spectrogram frames and + latent frames; compensates for additional downsampling inside the + VAE encoder. + is_causal: When True, timing is shifted to account for causal + receptive fields so timestamps do not peek into the future. + shift: Integer offset applied to the latent indices. Enables + constructing overlapping windows from the same latent sequence. + """ + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + self._patch_size = (1, patch_size, patch_size) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + def get_token_count(self, tgt_shape: AudioLatentShape) -> int: + return tgt_shape.frames + + def _get_audio_latent_time_in_sec( + self, + start_latent: int, + end_latent: int, + dtype: torch.dtype, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Converts latent indices into real-time seconds while honoring causal + offsets and the configured hop length. + Args: + start_latent: Inclusive start index inside the latent sequence. This + sets the first timestamp returned. + end_latent: Exclusive end index. Determines how many timestamps get + generated. + dtype: Floating-point dtype used for the returned tensor, allowing + callers to control precision. + device: Target device for the timestamp tensor. When omitted the + computation occurs on CPU to avoid surprising GPU allocations. + """ + if device is None: + device = torch.device("cpu") + + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + + if self.is_causal: + # Frame offset for causal alignment. + # The "+1" ensures the timestamp corresponds to the first sample that is fully available. + causal_offset = 1 + audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) + + return audio_mel_frame * self.hop_length / self.sample_rate + + def _compute_audio_timings( + self, + batch_size: int, + num_steps: int, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. + This helper method underpins `get_patch_grid_bounds` for the audio patchifier. + Args: + batch_size: Number of sequences to broadcast the timings over. + num_steps: Number of latent frames (time steps) to convert into timestamps. + device: Device on which the resulting tensor should reside. + """ + resolved_device = device + if resolved_device is None: + resolved_device = torch.device("cpu") + + start_timings = self._get_audio_latent_time_in_sec( + self.shift, + num_steps + self.shift, + torch.float32, + resolved_device, + ) + start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + end_timings = self._get_audio_latent_time_in_sec( + self.shift + 1, + num_steps + self.shift + 1, + torch.float32, + resolved_device, + ) + end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) + + return torch.stack([start_timings, end_timings], dim=-1) + + def patchify( + self, + audio_latents: torch.Tensor, + ) -> torch.Tensor: + """ + Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` + to derive timestamps for each latent frame based on the configured hop + length and downsampling. + Args: + audio_latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the + corresponding timing metadata when needed. + """ + audio_latents = einops.rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + return audio_latents + + def unpatchify( + self, + audio_latents: torch.Tensor, + output_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. + Use `get_patch_grid_bounds` to recompute the timestamps that describe each + frame's position in real time. + Args: + audio_latents: Latent tensor to unpatchify. + output_shape: Shape of the unpatched output tensor. + Returns: + Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing + metadata associated with the restored latents. + """ + # audio_latents shape: (batch, time, freq * channels) + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=output_shape.channels, + f=output_shape.mel_bins, + ) + + return audio_latents + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + Return the temporal bounds `[inclusive start, exclusive end)` for every + patch emitted by `patchify`. For audio this corresponds to timestamps in + seconds aligned with the original spectrogram grid. + The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores the `[start, end)` timestamps per patch + Args: + output_shape: Audio grid specification describing the number of time steps. + device: Target device for the returned tensor. + """ + if not isinstance(output_shape, AudioLatentShape): + raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") + + return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) diff --git a/ltx2/ltx_core/components/protocols.py b/ltx2/ltx_core/components/protocols.py new file mode 100644 index 0000000000000000000000000000000000000000..18e5836a38edda9f2fe6830edc5362e81e975891 --- /dev/null +++ b/ltx2/ltx_core/components/protocols.py @@ -0,0 +1,101 @@ +from typing import Protocol, Tuple + +import torch + +from ltx_core.types import AudioLatentShape, VideoLatentShape + + +class Patchifier(Protocol): + """ + Protocol for patchifiers that convert latent tensors into patches and assemble them back. + """ + + def patchify( + self, + latents: torch.Tensor, + ) -> torch.Tensor: + ... + """ + Convert latent tensors into flattened patch tokens. + Args: + latents: Latent tensor to patchify. + Returns: + Flattened patch tokens tensor. + """ + + def unpatchify( + self, + latents: torch.Tensor, + output_shape: AudioLatentShape | VideoLatentShape, + ) -> torch.Tensor: + """ + Converts latent tensors between spatio-temporal formats and flattened sequence representations. + Args: + latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`. + output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or + VideoLatentShape. + Returns: + Dense latent tensor restored from the flattened representation. + """ + + @property + def patch_size(self) -> Tuple[int, int, int]: + ... + """ + Returns the patch size as a tuple of (temporal, height, width) dimensions + """ + + def get_patch_grid_bounds( + self, + output_shape: AudioLatentShape | VideoLatentShape, + device: torch.device | None = None, + ) -> torch.Tensor: + ... + """ + Compute metadata describing where each latent patch resides within the + grid specified by `output_shape`. + Args: + output_shape: Target grid layout for the patches. + device: Target device for the returned tensor. + Returns: + Tensor containing patch coordinate metadata such as spatial or temporal intervals. + """ + + +class SchedulerProtocol(Protocol): + """ + Protocol for schedulers that provide a sigmas schedule tensor for a + given number of steps. Device is cpu. + """ + + def execute(self, steps: int, **kwargs) -> torch.FloatTensor: ... + + +class GuiderProtocol(Protocol): + """ + Protocol for guiders that compute a delta tensor given conditioning inputs. + The returned delta should be added to the conditional output (cond), enabling + multiple guiders to be chained together by accumulating their deltas. + """ + + scale: float + + def delta(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: ... + + def enabled(self) -> bool: + """ + Returns whether the corresponding perturbation is enabled. E.g. for CFG, this should return False if the scale + is 1.0. + """ + ... + + +class DiffusionStepProtocol(Protocol): + """ + Protocol for diffusion steps that provide a next sample tensor for a given current sample tensor, + current denoised sample tensor, and sigmas tensor. + """ + + def step( + self, sample: torch.Tensor, denoised_sample: torch.Tensor, sigmas: torch.Tensor, step_index: int, **kwargs + ) -> torch.Tensor: ... diff --git a/ltx2/ltx_core/components/schedulers.py b/ltx2/ltx_core/components/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf7ad93b0595085c40000fcf4ba802e6ae0129e --- /dev/null +++ b/ltx2/ltx_core/components/schedulers.py @@ -0,0 +1,130 @@ +import math +from functools import lru_cache + +import numpy +import scipy +import torch + +from ltx_core.components.protocols import SchedulerProtocol + +BASE_SHIFT_ANCHOR = 1024 +MAX_SHIFT_ANCHOR = 4096 + + +class LTX2Scheduler(SchedulerProtocol): + """ + Default scheduler for LTX-2 diffusion sampling. + Generates a sigma schedule with token-count-dependent shifting and optional + stretching to a terminal value. + """ + + def execute( + self, + steps: int, + latent: torch.Tensor | None = None, + max_shift: float = 2.05, + base_shift: float = 0.95, + stretch: bool = True, + terminal: float = 0.1, + default_number_of_tokens: int = MAX_SHIFT_ANCHOR, + **_kwargs, + ) -> torch.FloatTensor: + tokens = math.prod(latent.shape[2:]) if latent is not None else default_number_of_tokens + sigmas = torch.linspace(1.0, 0.0, steps + 1) + + x1 = BASE_SHIFT_ANCHOR + x2 = MAX_SHIFT_ANCHOR + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + sigma_shift = (tokens) * mm + b + + power = 1 + sigmas = torch.where( + sigmas != 0, + math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), + 0, + ) + + # Stretch sigmas so that its final value matches the given terminal value. + if stretch: + non_zero_mask = sigmas != 0 + non_zero_sigmas = sigmas[non_zero_mask] + one_minus_z = 1.0 - non_zero_sigmas + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched = 1.0 - (one_minus_z / scale_factor) + sigmas[non_zero_mask] = stretched + + return sigmas.to(torch.float32) + + +class LinearQuadraticScheduler(SchedulerProtocol): + """ + Scheduler with linear steps followed by quadratic steps. + Produces a sigma schedule that transitions linearly up to a threshold, + then follows a quadratic curve for the remaining steps. + """ + + def execute( + self, steps: int, threshold_noise: float = 0.025, linear_steps: int | None = None, **_kwargs + ) -> torch.FloatTensor: + if steps == 1: + return torch.FloatTensor([1.0, 0.0]) + + if linear_steps is None: + linear_steps = steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * steps + quadratic_steps = steps - linear_steps + quadratic_sigma_schedule = [] + if quadratic_steps > 0: + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.FloatTensor(sigma_schedule) + + +class BetaScheduler(SchedulerProtocol): + """ + Scheduler using a beta distribution to sample timesteps. + Based on: https://arxiv.org/abs/2407.12173 + """ + + shift = 2.37 + timesteps_length = 10000 + + def execute(self, steps: int, alpha: float = 0.6, beta: float = 0.6) -> torch.FloatTensor: + """ + Execute the beta scheduler. + Args: + steps: The number of steps to execute the scheduler for. + alpha: The alpha parameter for the beta distribution. + beta: The beta parameter for the beta distribution. + Warnings: + The number of steps within `sigmas` theoretically might be less than `steps+1`, + because of the deduplication of the identical timesteps + Returns: + A tensor of sigmas. + """ + model_sampling_sigmas = _precalculate_model_sampling_sigmas(self.shift, self.timesteps_length) + total_timesteps = len(model_sampling_sigmas) - 1 + ts = 1 - numpy.linspace(0, 1, steps, endpoint=False) + ts = numpy.rint(scipy.stats.beta.ppf(ts, alpha, beta) * total_timesteps).tolist() + ts = list(dict.fromkeys(ts)) + + sigmas = [float(model_sampling_sigmas[int(t)]) for t in ts] + [0.0] + return torch.FloatTensor(sigmas) + + +@lru_cache(maxsize=5) +def _precalculate_model_sampling_sigmas(shift: float, timesteps_length: int) -> torch.Tensor: + timesteps = torch.arange(1, timesteps_length + 1, 1) / timesteps_length + return torch.Tensor([flux_time_shift(shift, 1.0, t) for t in timesteps]) + + +def flux_time_shift(mu: float, sigma: float, t: float) -> float: + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) diff --git a/ltx2/ltx_core/conditioning/__init__.py b/ltx2/ltx_core/conditioning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..002e91f7f96059283b98c6e3a2691072d5972ff6 --- /dev/null +++ b/ltx2/ltx_core/conditioning/__init__.py @@ -0,0 +1,19 @@ +"""Conditioning utilities: latent state, tools, and conditioning types.""" + +from ltx_core.conditioning.exceptions import ConditioningError +from ltx_core.conditioning.item import ConditioningItem +from ltx_core.conditioning.types import ( + ConditioningItemAttentionStrengthWrapper, + VideoConditionByKeyframeIndex, + VideoConditionByLatentIndex, + VideoConditionByReferenceLatent, +) + +__all__ = [ + "ConditioningError", + "ConditioningItem", + "ConditioningItemAttentionStrengthWrapper", + "VideoConditionByKeyframeIndex", + "VideoConditionByLatentIndex", + "VideoConditionByReferenceLatent", +] diff --git a/ltx2/ltx_core/conditioning/exceptions.py b/ltx2/ltx_core/conditioning/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..458aa3b47072a11824fda58de7e5c5afaa4e094f --- /dev/null +++ b/ltx2/ltx_core/conditioning/exceptions.py @@ -0,0 +1,4 @@ +class ConditioningError(Exception): + """ + Class for conditioning-related errors. + """ diff --git a/ltx2/ltx_core/conditioning/item.py b/ltx2/ltx_core/conditioning/item.py new file mode 100644 index 0000000000000000000000000000000000000000..e6eef52d1983c1e2e433c27dc1e4735662696092 --- /dev/null +++ b/ltx2/ltx_core/conditioning/item.py @@ -0,0 +1,20 @@ +from typing import Protocol + +from ltx_core.tools import LatentTools +from ltx_core.types import LatentState + + +class ConditioningItem(Protocol): + """Protocol for conditioning items that modify latent state during diffusion.""" + + def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState: + """ + Apply the conditioning to the latent state. + Args: + latent_state: The latent state to apply the conditioning to. This is state always patchified. + Returns: + The latent state after the conditioning has been applied. + IMPORTANT: If the conditioning needs to add extra tokens to the latent, it should add them to the end of the + latent. + """ + ... diff --git a/ltx2/ltx_core/conditioning/mask_utils.py b/ltx2/ltx_core/conditioning/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..57945f48a04d1b7fb2bd4966aca514bd27a9d41b --- /dev/null +++ b/ltx2/ltx_core/conditioning/mask_utils.py @@ -0,0 +1,210 @@ +"""Utilities for building 2D self-attention masks for conditioning items.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from ltx_core.types import LatentState + + +def resolve_cross_mask( + attention_mask: float | int | torch.Tensor, + num_new_tokens: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Convert an attention_mask (scalar or tensor) to a (B, M) cross_mask tensor. + Args: + attention_mask: Scalar value applied uniformly, 1D tensor of shape (M,) + broadcast across batch, or 2D tensor of shape (B, M). + num_new_tokens: Number of new conditioning tokens M. + batch_size: Batch size B. + device: Device for the output tensor. + dtype: Data type for the output tensor. + Returns: + Cross-mask tensor of shape (B, M). + """ + if isinstance(attention_mask, (int, float)): + return torch.full( + (batch_size, num_new_tokens), + fill_value=float(attention_mask), + device=device, + dtype=dtype, + ) + mask = attention_mask.to(device=device, dtype=dtype) + + # Handle scalar (0-D) tensor like a Python scalar. + if mask.dim() == 0: + return torch.full( + (batch_size, num_new_tokens), + fill_value=float(mask.item()), + device=device, + dtype=dtype, + ) + + if mask.dim() == 1: + if mask.shape[0] != num_new_tokens: + raise ValueError( + f"1-D attention_mask length must equal num_new_tokens ({num_new_tokens}), got shape {tuple(mask.shape)}" + ) + mask = mask.unsqueeze(0).expand(batch_size, -1) + elif mask.dim() == 2: + b, m = mask.shape + if m != num_new_tokens: + raise ValueError( + f"2-D attention_mask second dimension must equal num_new_tokens ({num_new_tokens}), " + f"got shape {tuple(mask.shape)}" + ) + if b not in (batch_size, 1): + raise ValueError( + f"2-D attention_mask batch dimension must equal batch_size ({batch_size}) or 1, " + f"got shape {tuple(mask.shape)}" + ) + if b == 1 and batch_size > 1: + mask = mask.expand(batch_size, -1) + else: + raise ValueError( + f"attention_mask tensor must be 0-D, 1-D, or 2-D, got {mask.dim()}-D with shape {tuple(mask.shape)}" + ) + return mask + + +def update_attention_mask( + latent_state: LatentState, + attention_mask: float | torch.Tensor | None, + num_noisy_tokens: int, + num_new_tokens: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor | None: + """Build or update the self-attention mask for newly appended conditioning tokens. + If *attention_mask* is ``None`` and no existing mask is present, returns + ``None``. If *attention_mask* is ``None`` but an existing mask is present, + the mask is expanded with full attention (1s) for the new tokens so that + its dimensions stay consistent with the growing latent sequence. Otherwise, + resolves *attention_mask* to a per-token cross-mask and expands the 2-D + attention mask via :func:`build_attention_mask`. + Args: + latent_state: Current latent state (provides the existing mask and total + existing-token count). + attention_mask: Per-token attention weight. Scalar, 1-D ``(M,)``, 2-D + ``(B, M)`` tensor, or ``None`` (no-op). + num_noisy_tokens: Number of original noisy tokens (from + ``latent_tools.target_shape.token_count()``). + num_new_tokens: Number of new conditioning tokens being appended. + batch_size: Batch size. + device: Device for the output tensor. + dtype: Data type for the output tensor. + Returns: + Updated attention mask of shape ``(B, N+M, N+M)``, or ``None`` if no + masking is needed. + """ + if attention_mask is None: + if latent_state.attention_mask is None: + return None + # Existing mask present but no new mask requested: pad with 1s (full + # attention) so the mask dimensions stay consistent with the growing + # latent sequence. + cross_mask = torch.ones(batch_size, num_new_tokens, device=device, dtype=dtype) + return build_attention_mask( + existing_mask=latent_state.attention_mask, + num_noisy_tokens=num_noisy_tokens, + num_new_tokens=num_new_tokens, + num_existing_tokens=latent_state.latent.shape[1], + cross_mask=cross_mask, + device=device, + dtype=dtype, + ) + + cross_mask = resolve_cross_mask(attention_mask, num_new_tokens, batch_size, device, dtype) + return build_attention_mask( + existing_mask=latent_state.attention_mask, + num_noisy_tokens=num_noisy_tokens, + num_new_tokens=num_new_tokens, + num_existing_tokens=latent_state.latent.shape[1], + cross_mask=cross_mask, + device=device, + dtype=dtype, + ) + + +def build_attention_mask( + existing_mask: torch.Tensor | None, + num_noisy_tokens: int, + num_new_tokens: int, + num_existing_tokens: int, + cross_mask: torch.Tensor, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Expand the attention mask to include newly appended conditioning tokens. + Each conditioning item appends M new reference tokens to the sequence. This function + builds a (B, N+M, N+M) attention mask with the following block structure: + noisy prev_ref new_ref + (N_noisy) (N-N_noisy) (M) + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + noisy β”‚ β”‚ β”‚ β”‚ + (N_noisy) β”‚ existing β”‚ existing β”‚ cross β”‚ + β”‚ β”‚ β”‚ β”‚ + β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ + prev_ref β”‚ β”‚ β”‚ β”‚ + (N-N_noisy)β”‚ existing β”‚ existing β”‚ 0 β”‚ + β”‚ β”‚ β”‚ β”‚ + β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ + new_ref β”‚ β”‚ β”‚ β”‚ + (M) β”‚ cross β”‚ 0 β”‚ 1 β”‚ + β”‚ β”‚ β”‚ β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + Where: + - **existing**: preserved from the previous mask (or 1.0 if first conditioning) + - **cross**: values from *cross_mask* (shape B, M), in [0, 1] + - **0**: no attention between different reference groups + Args: + existing_mask: Current attention mask of shape (B, N, N), or None if no mask exists yet. + When None, the top-left NxN block is filled with 1s (full attention between all + existing tokens including any prior reference tokens that had no mask). + num_noisy_tokens: Number of original noisy tokens (always at positions [0:num_noisy_tokens]). + num_new_tokens: Number of new conditioning tokens M being appended. + num_existing_tokens: Total number of current tokens N (noisy + any prior conditioning tokens). + cross_mask: Per-token attention weight of shape (B, M) controlling attention between + new reference tokens and noisy tokens. Values in [0, 1]. + device: Device for the output tensor. + dtype: Data type for the output tensor. + Returns: + Attention mask of shape (B, N+M, N+M) with values in [0, 1]. + """ + batch_size = cross_mask.shape[0] + total = num_existing_tokens + num_new_tokens + + # Start with zeros + mask = torch.zeros((batch_size, total, total), device=device, dtype=dtype) + + # Top-left: preserve existing mask or fill with 1s for noisy tokens + if existing_mask is not None: + mask[:, :num_existing_tokens, :num_existing_tokens] = existing_mask + else: + mask[:, :num_existing_tokens, :num_existing_tokens] = 1.0 + + # Bottom-right: new reference tokens fully attend to themselves + mask[:, num_existing_tokens:, num_existing_tokens:] = 1.0 + + # Cross-attention between noisy tokens and new reference tokens + # cross_mask shape: (B, M) -> broadcast to (B, N_noisy, M) and (B, M, N_noisy) + + # Noisy tokens attending to new reference tokens: [0:N_noisy, N:N+M] + # Each column j in this block gets cross_mask[:, j] + mask[:, :num_noisy_tokens, num_existing_tokens:] = cross_mask.unsqueeze(1) + + # New reference tokens attending to noisy tokens: [N:N+M, 0:N_noisy] + # Each row i in this block gets cross_mask[:, i] + mask[:, num_existing_tokens:, :num_noisy_tokens] = cross_mask.unsqueeze(2) + + # [N_noisy:N, N:N+M] and [N:N+M, N_noisy:N] remain 0 (no cross-ref attention) + + return mask diff --git a/ltx2/ltx_core/conditioning/types/__init__.py b/ltx2/ltx_core/conditioning/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44bb92048b55480899d4f70b3975ba41391916a2 --- /dev/null +++ b/ltx2/ltx_core/conditioning/types/__init__.py @@ -0,0 +1,13 @@ +"""Conditioning type implementations.""" + +from ltx_core.conditioning.types.attention_strength_wrapper import ConditioningItemAttentionStrengthWrapper +from ltx_core.conditioning.types.keyframe_cond import VideoConditionByKeyframeIndex +from ltx_core.conditioning.types.latent_cond import VideoConditionByLatentIndex +from ltx_core.conditioning.types.reference_video_cond import VideoConditionByReferenceLatent + +__all__ = [ + "ConditioningItemAttentionStrengthWrapper", + "VideoConditionByKeyframeIndex", + "VideoConditionByLatentIndex", + "VideoConditionByReferenceLatent", +] diff --git a/ltx2/ltx_core/conditioning/types/attention_strength_wrapper.py b/ltx2/ltx_core/conditioning/types/attention_strength_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..0327bfa409e1afb3a5e65331a90c5a91b7a5462b --- /dev/null +++ b/ltx2/ltx_core/conditioning/types/attention_strength_wrapper.py @@ -0,0 +1,71 @@ +"""Wrapper conditioning item that adds attention masking to any inner conditioning.""" + +from dataclasses import replace + +import torch + +from ltx_core.conditioning.item import ConditioningItem +from ltx_core.conditioning.mask_utils import update_attention_mask +from ltx_core.tools import LatentTools +from ltx_core.types import LatentState + + +class ConditioningItemAttentionStrengthWrapper(ConditioningItem): + """Wraps a conditioning item to add an attention mask for its tokens. + Separates the *attention-masking* concern from the underlying conditioning + logic (token layout, positional encoding, denoise strength). The inner + conditioning item appends tokens to the latent sequence as usual, and this + wrapper then builds or updates the self-attention mask so that the newly + added tokens interact with the noisy tokens according to *attention_mask*. + Args: + conditioning: Any conditioning item that appends tokens to the latent. + attention_mask: Per-token attention weight controlling how strongly the + new conditioning tokens attend to/from noisy tokens. Can be a + scalar (float) applied uniformly, or a tensor of shape ``(B, M)`` + for spatial control, where ``M = F * H * W`` is the number of + patchified conditioning tokens. Values in ``[0, 1]``. + Example:: + cond = ConditioningItemAttentionStrengthWrapper( + VideoConditionByReferenceLatent(latent=ref, strength=1.0), + attention_mask=0.5, + ) + state = cond.apply_to(latent_state, latent_tools) + """ + + def __init__( + self, + conditioning: ConditioningItem, + attention_mask: float | torch.Tensor, + ): + self.conditioning = conditioning + self.attention_mask = attention_mask + + def apply_to( + self, + latent_state: LatentState, + latent_tools: LatentTools, + ) -> LatentState: + """Apply inner conditioning, then build the attention mask for its tokens.""" + # Snapshot the original state for mask building + original_state = latent_state + + # Inner conditioning appends tokens (positions, denoise mask, etc.) + new_state = self.conditioning.apply_to(latent_state, latent_tools) + + num_new_tokens = new_state.latent.shape[1] - original_state.latent.shape[1] + if num_new_tokens == 0: + return new_state + + # Build the attention mask using the *original* state as the reference + # so that the block structure is computed correctly. + new_attention_mask = update_attention_mask( + latent_state=original_state, + attention_mask=self.attention_mask, + num_noisy_tokens=latent_tools.target_shape.token_count(), + num_new_tokens=num_new_tokens, + batch_size=new_state.latent.shape[0], + device=new_state.latent.device, + dtype=new_state.latent.dtype, + ) + + return replace(new_state, attention_mask=new_attention_mask) diff --git a/ltx2/ltx_core/conditioning/types/keyframe_cond.py b/ltx2/ltx_core/conditioning/types/keyframe_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..d4af3f67ddced88e32d9b8b230689feb94eaeaeb --- /dev/null +++ b/ltx2/ltx_core/conditioning/types/keyframe_cond.py @@ -0,0 +1,70 @@ +import torch + +from ltx_core.components.patchifiers import get_pixel_coords +from ltx_core.conditioning.item import ConditioningItem +from ltx_core.conditioning.mask_utils import update_attention_mask +from ltx_core.tools import VideoLatentTools +from ltx_core.types import LatentState, VideoLatentShape + + +class VideoConditionByKeyframeIndex(ConditioningItem): + """ + Conditions video generation on keyframe latents at a specific frame index. + Appends keyframe tokens to the latent state with positions offset by frame_idx, + and sets denoise strength according to the strength parameter. + To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`. + Args: + keyframes: Keyframe latents [B, C, F, H, W]. + frame_idx: Frame index offset for positional encoding. + strength: Conditioning strength (1.0 = clean, 0.0 = fully denoised). + """ + + def __init__(self, keyframes: torch.Tensor, frame_idx: int, strength: float): + self.keyframes = keyframes + self.frame_idx = frame_idx + self.strength = strength + + def apply_to( + self, + latent_state: LatentState, + latent_tools: VideoLatentTools, + ) -> LatentState: + tokens = latent_tools.patchifier.patchify(self.keyframes) + latent_coords = latent_tools.patchifier.get_patch_grid_bounds( + output_shape=VideoLatentShape.from_torch_shape(self.keyframes.shape), + device=self.keyframes.device, + ) + positions = get_pixel_coords( + latent_coords=latent_coords, + scale_factors=latent_tools.scale_factors, + causal_fix=latent_tools.causal_fix if self.frame_idx == 0 else False, + ) + + positions[:, 0, ...] += self.frame_idx + positions = positions.to(dtype=torch.float32) + positions[:, 0, ...] /= latent_tools.fps + + denoise_mask = torch.full( + size=(*tokens.shape[:2], 1), + fill_value=1.0 - self.strength, + device=self.keyframes.device, + dtype=self.keyframes.dtype, + ) + + new_attention_mask = update_attention_mask( + latent_state=latent_state, + attention_mask=None, + num_noisy_tokens=latent_tools.target_shape.token_count(), + num_new_tokens=tokens.shape[1], + batch_size=tokens.shape[0], + device=self.keyframes.device, + dtype=self.keyframes.dtype, + ) + + return LatentState( + latent=torch.cat([latent_state.latent, tokens], dim=1), + denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1), + positions=torch.cat([latent_state.positions, positions], dim=2), + clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1), + attention_mask=new_attention_mask, + ) diff --git a/ltx2/ltx_core/conditioning/types/latent_cond.py b/ltx2/ltx_core/conditioning/types/latent_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..c0362733d5bbd2514d4a62aea95a8a43e978053c --- /dev/null +++ b/ltx2/ltx_core/conditioning/types/latent_cond.py @@ -0,0 +1,44 @@ +import torch + +from ltx_core.conditioning.exceptions import ConditioningError +from ltx_core.conditioning.item import ConditioningItem +from ltx_core.tools import LatentTools +from ltx_core.types import LatentState + + +class VideoConditionByLatentIndex(ConditioningItem): + """ + Conditions video generation by injecting latents at a specific latent frame index. + Replaces tokens in the latent state at positions corresponding to latent_idx, + and sets denoise strength according to the strength parameter. + """ + + def __init__(self, latent: torch.Tensor, strength: float, latent_idx: int): + self.latent = latent + self.strength = strength + self.latent_idx = latent_idx + + def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState: + cond_batch, cond_channels, _, cond_height, cond_width = self.latent.shape + tgt_batch, tgt_channels, tgt_frames, tgt_height, tgt_width = latent_tools.target_shape.to_torch_shape() + + if (cond_batch, cond_channels, cond_height, cond_width) != (tgt_batch, tgt_channels, tgt_height, tgt_width): + raise ConditioningError( + f"Can't apply image conditioning item to latent with shape {latent_tools.target_shape}, expected " + f"shape is ({tgt_batch}, {tgt_channels}, {tgt_frames}, {tgt_height}, {tgt_width}). Make sure " + "the image and latent have the same spatial shape." + ) + + tokens = latent_tools.patchifier.patchify(self.latent) + start_token = latent_tools.patchifier.get_token_count( + latent_tools.target_shape._replace(frames=self.latent_idx) + ) + stop_token = start_token + tokens.shape[1] + + latent_state = latent_state.clone() + + latent_state.latent[:, start_token:stop_token] = tokens + latent_state.clean_latent[:, start_token:stop_token] = tokens + latent_state.denoise_mask[:, start_token:stop_token] = 1.0 - self.strength + + return latent_state diff --git a/ltx2/ltx_core/conditioning/types/noise_mask_cond.py b/ltx2/ltx_core/conditioning/types/noise_mask_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..186ac4e0fa53ff034821f504575d391b9de268ae --- /dev/null +++ b/ltx2/ltx_core/conditioning/types/noise_mask_cond.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from ltx_core.components.patchifiers import get_pixel_coords +from ltx_core.conditioning.item import ConditioningItem +from ltx_core.tools import LatentTools, SpatioTemporalScaleFactors +from ltx_core.types import AudioLatentShape, LatentState, VideoLatentShape + + +@dataclass(frozen=True) +class TemporalRegionMask(ConditioningItem): + """Conditioning item that sets ``denoise_mask = 0`` outside a time range + and ``1`` inside, so only the specified temporal region is regenerated. + Uses ``start_time`` and ``end_time`` in seconds. Works in *patchified* + (token) space using the patchifier's ``get_patch_grid_bounds``: for video + coords are latent frame indices (converted from seconds via ``fps``), for + audio coords are already in seconds. + """ + + start_time: float # seconds, inclusive + end_time: float # seconds, exclusive + fps: float + + def apply_to(self, latent_state: LatentState, latent_tools: LatentTools) -> LatentState: + coords = latent_tools.patchifier.get_patch_grid_bounds( + latent_tools.target_shape, device=latent_state.denoise_mask.device + ) + if isinstance(latent_tools.target_shape, AudioLatentShape): + # Audio: patchifier get_patch_grid_bounds returns seconds + t_boundaries = coords[:, 0] + elif isinstance(latent_tools.target_shape, VideoLatentShape): + # Video: patchifier get_patch_grid_bounds returns latent bounds, converting to frame numbers & pixel bounds + scale_factors = getattr(latent_tools, "scale_factors", SpatioTemporalScaleFactors.default()) + pixel_bounds = get_pixel_coords(coords, scale_factors, causal_fix=getattr(latent_tools, "causal_fix", True)) + # converting frame numbers to seconds + t_boundaries = pixel_bounds[:, 0] / self.fps + else: + raise ValueError("Unsupported LatentShape type, expected AudioLatentShape or VideoLatentShape") + t_start, t_end = t_boundaries.unbind(dim=-1) # [B, N] + in_region = (t_end > self.start_time) & (t_start < self.end_time) + state = latent_state.clone() + mask_val = in_region.to(state.denoise_mask.dtype) + if state.denoise_mask.dim() == 3: + mask_val = mask_val.unsqueeze(-1) + state.denoise_mask.copy_(mask_val) + return state diff --git a/ltx2/ltx_core/conditioning/types/reference_video_cond.py b/ltx2/ltx_core/conditioning/types/reference_video_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4068e44fe4677167be8c89a97ae9b2dacd8672 --- /dev/null +++ b/ltx2/ltx_core/conditioning/types/reference_video_cond.py @@ -0,0 +1,91 @@ +"""Reference video conditioning for IC-LoRA inference.""" + +import torch + +from ltx_core.components.patchifiers import get_pixel_coords +from ltx_core.conditioning.item import ConditioningItem +from ltx_core.conditioning.mask_utils import update_attention_mask +from ltx_core.tools import VideoLatentTools +from ltx_core.types import LatentState, VideoLatentShape + + +class VideoConditionByReferenceLatent(ConditioningItem): + """ + Conditions video generation on a reference video latent for IC-LoRA inference. + IC-LoRAs are trained by concatenating reference (control signal) and target tokens, + learning to attend across both. This class replicates that setup at inference by + appending reference tokens to the latent sequence. + IC-LoRAs can be trained with lower-resolution references than the target (e.g., 384px + reference for 768px output) for efficiency and better generalization. The + `downscale_factor` scales reference positions to match target coordinates, preserving + the learned positional relationships. This must match the factor used during training + (stored in LoRA metadata). + To add attention masking, wrap with :class:`ConditioningItemAttentionStrengthWrapper`. + Args: + latent: Reference video latents [B, C, F, H, W] + downscale_factor: Target/reference resolution ratio (e.g., 2 = half-resolution + reference). Spatial positions are scaled by this factor. + strength: Conditioning strength. 1.0 = full (reference kept clean), + 0.0 = none (reference denoised). Default 1.0. + """ + + def __init__( + self, + latent: torch.Tensor, + downscale_factor: int = 1, + strength: float = 1.0, + ): + self.latent = latent + self.downscale_factor = downscale_factor + self.strength = strength + + def apply_to( + self, + latent_state: LatentState, + latent_tools: VideoLatentTools, + ) -> LatentState: + """Append reference video tokens with scaled positions.""" + tokens = latent_tools.patchifier.patchify(self.latent) + + # Compute positions for the reference video's actual dimensions + latent_coords = latent_tools.patchifier.get_patch_grid_bounds( + output_shape=VideoLatentShape.from_torch_shape(self.latent.shape), + device=self.latent.device, + ) + positions = get_pixel_coords( + latent_coords=latent_coords, + scale_factors=latent_tools.scale_factors, + causal_fix=latent_tools.causal_fix, + ) + positions = positions.to(dtype=torch.float32) + positions[:, 0, ...] /= latent_tools.fps + + # Scale spatial positions to match target coordinate space + if self.downscale_factor != 1: + positions[:, 1, ...] *= self.downscale_factor # height axis + positions[:, 2, ...] *= self.downscale_factor # width axis + + denoise_mask = torch.full( + size=(*tokens.shape[:2], 1), + fill_value=1.0 - self.strength, + device=self.latent.device, + dtype=self.latent.dtype, + ) + + new_attention_mask = update_attention_mask( + latent_state=latent_state, + attention_mask=None, + num_noisy_tokens=latent_tools.target_shape.token_count(), + num_new_tokens=tokens.shape[1], + batch_size=tokens.shape[0], + device=self.latent.device, + dtype=self.latent.dtype, + ) + + return LatentState( + latent=torch.cat([latent_state.latent, tokens], dim=1), + denoise_mask=torch.cat([latent_state.denoise_mask, denoise_mask], dim=1), + positions=torch.cat([latent_state.positions, positions], dim=2), + clean_latent=torch.cat([latent_state.clean_latent, tokens], dim=1), + attention_mask=new_attention_mask, + ) diff --git a/ltx2/ltx_core/guidance/__init__.py b/ltx2/ltx_core/guidance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8214c5c70987d09f7cbc95dfbd9c652a75300fb5 --- /dev/null +++ b/ltx2/ltx_core/guidance/__init__.py @@ -0,0 +1,15 @@ +"""Guidance and perturbation utilities for attention manipulation.""" + +from ltx_core.guidance.perturbations import ( + BatchedPerturbationConfig, + Perturbation, + PerturbationConfig, + PerturbationType, +) + +__all__ = [ + "BatchedPerturbationConfig", + "Perturbation", + "PerturbationConfig", + "PerturbationType", +] diff --git a/ltx2/ltx_core/guidance/perturbations.py b/ltx2/ltx_core/guidance/perturbations.py new file mode 100644 index 0000000000000000000000000000000000000000..40ba4a772121cc1bf1849c7075b46718514da2e7 --- /dev/null +++ b/ltx2/ltx_core/guidance/perturbations.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass +from enum import Enum + +import torch +from torch._prims_common import DeviceLikeType + + +class PerturbationType(Enum): + """Types of attention perturbations for STG (Spatio-Temporal Guidance).""" + + SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn" + SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn" + SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn" + SKIP_AUDIO_SELF_ATTN = "skip_audio_self_attn" + + +@dataclass(frozen=True) +class Perturbation: + """A single perturbation specifying which attention type to skip and in which blocks.""" + + type: PerturbationType + blocks: list[int] | None # None means all blocks + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.type != perturbation_type: + return False + + if self.blocks is None: + return True + + return block in self.blocks + + +@dataclass(frozen=True) +class PerturbationConfig: + """Configuration holding a list of perturbations for a single sample.""" + + perturbations: list[Perturbation] | None + + def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool: + if self.perturbations is None: + return False + + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty() -> "PerturbationConfig": + return PerturbationConfig([]) + + +@dataclass(frozen=True) +class BatchedPerturbationConfig: + """Perturbation configurations for a batch, with utilities for generating attention masks.""" + + perturbations: list[PerturbationConfig] + + def mask( + self, perturbation_type: PerturbationType, block: int, device: DeviceLikeType, dtype: torch.dtype + ) -> torch.Tensor: + mask = torch.ones((len(self.perturbations),), device=device, dtype=dtype) + for batch_idx, perturbation in enumerate(self.perturbations): + if perturbation.is_perturbed(perturbation_type, block): + mask[batch_idx] = 0 + + return mask + + def mask_like(self, perturbation_type: PerturbationType, block: int, values: torch.Tensor) -> torch.Tensor: + mask = self.mask(perturbation_type, block, values.device, values.dtype) + return mask.view(mask.numel(), *([1] * len(values.shape[1:]))) + + def any_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return any(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + def all_in_batch(self, perturbation_type: PerturbationType, block: int) -> bool: + return all(perturbation.is_perturbed(perturbation_type, block) for perturbation in self.perturbations) + + @staticmethod + def empty(batch_size: int) -> "BatchedPerturbationConfig": + return BatchedPerturbationConfig([PerturbationConfig.empty() for _ in range(batch_size)]) diff --git a/ltx2/ltx_core/layer_streaming.py b/ltx2/ltx_core/layer_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..05145afbc4e51dc80ff08f0fcc2029d652ef2ba7 --- /dev/null +++ b/ltx2/ltx_core/layer_streaming.py @@ -0,0 +1,324 @@ +"""Layer streaming wrapper for memory-efficient inference. +Keeps most transformer/decoder layers on CPU pinned memory and streams them +to GPU on demand, using a secondary CUDA stream to prefetch upcoming layers +so that data transfer overlaps with compute. +General-purpose: works with any ``nn.Module`` whose forward iterates over a +``nn.ModuleList`` attribute (e.g. ``transformer_blocks``, ``layers``). +Each layer is evicted back to CPU immediately after its forward completes, +and prefetch uses modular indexing so the last layer's prefetch wraps around +to prepare early layers for the next forward pass. +Example +------- +>>> model = build_my_model(device=torch.device("cpu")) +>>> model = LayerStreamingWrapper( +... model, +... layers_attr="transformer_blocks", +... target_device=torch.device("cuda:0"), +... prefetch_count=2, +... ) +>>> out = model(inputs) # hooks handle layer streaming +>>> model.teardown() # move everything back to CPU +""" + +from __future__ import annotations + +import functools +import itertools +import logging +from typing import Any + +import torch +from torch import nn + +logger = logging.getLogger(__name__) + + +def _resolve_attr(module: nn.Module, dotted_path: str) -> nn.ModuleList: + """Resolve a dotted attribute path like ``'model.language_model.layers'``.""" + obj: Any = module + for part in dotted_path.split("."): + obj = getattr(obj, part) + if not isinstance(obj, nn.ModuleList): + raise TypeError(f"Expected nn.ModuleList at '{dotted_path}', got {type(obj).__name__}") + return obj + + +class _LayerStore: + """Manages on-demand pinning of layer parameters for GPU streaming. + Stores references to each layer's source data (which may be file-backed + mmap views or in-memory tensors). When a layer needs to be transferred + to GPU, its source data is pinned on demand and copied; on eviction the + pinned copy is freed and the source data is restored. + """ + + def __init__(self, layers: nn.ModuleList, target_device: torch.device) -> None: + self.target_device = target_device + self.num_layers = len(layers) + self._on_gpu: set[int] = set() + + # Keep a reference to the source data for each layer so we can pin it + # on demand and restore it after eviction. + self._source_data: list[dict[str, torch.Tensor]] = [] + for layer in layers: + source: dict[str, torch.Tensor] = {} + for name, tensor in itertools.chain(layer.named_parameters(), layer.named_buffers()): + source[name] = tensor.data + self._source_data.append(source) + + # Hold pinned tensors alive until the H2D transfer completes. + # Without this, the CachingHostAllocator can reclaim a pinned tensor + # as soon as its Python reference is dropped, even if an async H2D + # transfer is still reading from it. + self._pinned_in_flight: dict[int, list[torch.Tensor]] = {} + + def _check_idx(self, idx: int) -> None: + if idx < 0 or idx >= self.num_layers: + raise IndexError(f"Layer index {idx} out of range [0, {self.num_layers})") + + def is_on_gpu(self, idx: int) -> bool: + return idx in self._on_gpu + + def move_to_gpu(self, idx: int, layer: nn.Module, *, non_blocking: bool = False) -> None: + """Pin layer *idx* on demand, then transfer to GPU.""" + self._check_idx(idx) + if idx in self._on_gpu: + return + source = self._source_data[idx] + pinned_refs: list[torch.Tensor] = [] + for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()): + pinned = source[name].pin_memory() + param.data = pinned.to(self.target_device, non_blocking=non_blocking) + pinned_refs.append(pinned) + # Keep pinned tensors alive until eviction β€” the async H2D transfer + # may still be reading from them. + self._pinned_in_flight[idx] = pinned_refs + self._on_gpu.add(idx) + + def evict_to_cpu(self, idx: int, layer: nn.Module) -> None: + """Restore source data, freeing the GPU and pinned copies.""" + self._check_idx(idx) + if idx not in self._on_gpu: + return + source = self._source_data[idx] + for name, param in itertools.chain(layer.named_parameters(), layer.named_buffers()): + param.data = source[name] + # Release pinned tensors β€” the H2D transfer is complete by now + # (the compute stream waited on the prefetch event before using + # the layer, and we only evict after compute finishes). + self._pinned_in_flight.pop(idx, None) + self._on_gpu.discard(idx) + + def cleanup(self) -> None: + """Release all source data and in-flight pinned references. + After this call, the source tensors can be garbage-collected once + the layer parameters (which still reference them via ``.data``) are + also released (e.g. via ``.to("meta")``). + """ + for source_dict in self._source_data: + source_dict.clear() + self._source_data.clear() + self._pinned_in_flight.clear() + + +class _AsyncPrefetcher: + """Issues H2D transfers on a dedicated CUDA stream. + Uses per-layer CUDA events so that the compute stream only waits for the + specific layer it needs, not all pending transfers. + """ + + def __init__(self, store: _LayerStore, layers: nn.ModuleList) -> None: + self._store = store + self._layers = layers + self._stream = torch.cuda.Stream(device=store.target_device) + self._events: dict[int, torch.cuda.Event] = {} + + def prefetch(self, idx: int) -> None: + """Begin async transfer of layer *idx* to GPU (no-op if already there).""" + if self._store.is_on_gpu(idx) or idx in self._events: + return + with torch.cuda.stream(self._stream): + self._store.move_to_gpu(idx, self._layers[idx], non_blocking=True) + event = torch.cuda.Event() + event.record(self._stream) + self._events[idx] = event + + def wait(self, idx: int) -> None: + """Block the compute stream until layer *idx* transfer is complete.""" + event = self._events.pop(idx, None) + if event is not None: + torch.cuda.current_stream(self._store.target_device).wait_event(event) + + def cleanup(self) -> None: + """Drain pending work and release CUDA stream/event resources.""" + self._events.clear() + self._stream = None + self._layers = None + self._store = None + + +class LayerStreamingWrapper(nn.Module): + """Wraps a model to stream its sequential layers between CPU and GPU. + Each layer is evicted immediately after its forward completes, and + prefetch wraps around using modular indexing so the end of one forward + pass prepares early layers for the next. + Parameters + ---------- + model: + The model to wrap, with all parameters on **CPU**. + layers_attr: + Dotted attribute path to the ``nn.ModuleList`` of sequential layers + (e.g. ``"transformer_blocks"`` or ``"model.language_model.layers"``). + target_device: + The GPU device to use for compute. + prefetch_count: + How many layers ahead to prefetch. The maximum number of layers on + GPU at once is ``1 + prefetch_count``. Must be >= 1. + """ + + def __init__( + self, + model: nn.Module, + layers_attr: str, + target_device: torch.device, + prefetch_count: int = 2, + ) -> None: + if prefetch_count < 1: + raise ValueError("prefetch_count must be >= 1") + super().__init__() + # Store the wrapped model as a submodule so parameters are discoverable. + self._model = model + self._layers = _resolve_attr(model, layers_attr) + self._target_device = target_device + # Clamp: no point prefetching more than num_layers - 1 (the rest are evicted). + self._prefetch_count = min(prefetch_count, len(self._layers) - 1) + self._hooks: list[torch.utils.hooks.RemovableHandle] = [] + + self._setup() + + # ------------------------------------------------------------------ + # Setup / teardown + # ------------------------------------------------------------------ + + def _setup(self) -> None: + # 1. Build the pinned CPU store (copies all layer tensors to pinned memory). + self._store = _LayerStore(self._layers, self._target_device) + + # 2. Move all NON-layer params/buffers to GPU. + layer_tensor_ids: set[int] = set() + for layer in self._layers: + for t in itertools.chain(layer.parameters(), layer.buffers()): + layer_tensor_ids.add(id(t)) + + for p in self._model.parameters(): + if id(p) not in layer_tensor_ids: + p.data = p.data.to(self._target_device) + for b in self._model.buffers(): + if id(b) not in layer_tensor_ids: + b.data = b.data.to(self._target_device) + + # 3. Pre-load the first (1 + prefetch_count) layers synchronously. + for idx in range(min(self._prefetch_count + 1, len(self._layers))): + self._store.move_to_gpu(idx, self._layers[idx]) + + # 4. Create the async prefetcher and register hooks. + self._prefetcher = _AsyncPrefetcher(self._store, self._layers) + self._register_hooks() + + def _register_hooks(self) -> None: + idx_map: dict[int, int] = {id(layer): idx for idx, layer in enumerate(self._layers)} + num_layers = len(self._layers) + + compute_stream = torch.cuda.current_stream(self._target_device) + + def _pre_hook( + module: nn.Module, + _args: Any, # noqa: ANN401 + *, + idx: int, + ) -> None: + # Wait only for THIS layer's H2D transfer (not all pending ones). + self._prefetcher.wait(idx) + if not self._store.is_on_gpu(idx): + self._store.move_to_gpu(idx, module) + + # Record that the compute stream will read these weight tensors. + # They were allocated on the prefetch stream, so without this the + # caching allocator would allow the prefetch stream to reuse their + # memory immediately after eviction β€” even if the compute kernel + # that reads them hasn't finished yet. + for param in itertools.chain(module.parameters(), module.buffers()): + param.data.record_stream(compute_stream) + + # Kick off prefetch for upcoming layers (wraps around for next pass). + for offset in range(1, self._prefetch_count + 1): + self._prefetcher.prefetch((idx + offset) % num_layers) + + def _post_hook( + module: nn.Module, + _args: Any, # noqa: ANN401 + _output: Any, # noqa: ANN401 + *, + idx: int, + ) -> None: + # Evict this layer immediately β€” its computation is done. + self._store.evict_to_cpu(idx, module) + + for layer in self._layers: + idx = idx_map[id(layer)] + h1 = layer.register_forward_pre_hook(functools.partial(_pre_hook, idx=idx)) + h2 = layer.register_forward_hook(functools.partial(_post_hook, idx=idx)) + self._hooks.extend([h1, h2]) + + def teardown(self) -> None: + """Remove hooks, release resources, and move parameters back to CPU. + After this call the wrapper is inert: hooks are removed, the prefetch + stream is drained and destroyed, all parameters reside on CPU, and the + ``_LayerStore`` source data references are cleared. Callers should + still follow up with ``.to("meta")`` to release the CPU copies if the + model is no longer needed. + """ + for h in self._hooks: + h.remove() + self._hooks.clear() + + # Drain all in-flight async H2D copies, then release stream resources. + # Without the synchronize, clearing the stream/events can trigger + # use-after-free at the CUDA driver level. + torch.cuda.synchronize(device=self._target_device) + if self._prefetcher is not None: + self._prefetcher.cleanup() + self._prefetcher = None + + # Move everything to CPU. + for idx, layer in enumerate(self._layers): + self._store.evict_to_cpu(idx, layer) + + for p in self._model.parameters(): + p.data = p.data.to("cpu") + for b in self._model.buffers(): + b.data = b.data.to("cpu") + + # Release source data references. After evict_to_cpu() the layer + # params point to the source data. The caller is expected to follow + # up with .to("meta") to drop the param refs; cleanup() drops the + # store's refs. + self._store.cleanup() + + # ------------------------------------------------------------------ + # Forward and attribute delegation + # ------------------------------------------------------------------ + + def forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + return self._model(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: # noqa: ANN401 + """Proxy attribute access to the wrapped model. + This allows calling methods like ``encode()`` on a wrapped + GemmaTextEncoder without the caller needing to know about the wrapper. + ``nn.Module.__getattr__`` is only called when normal attribute lookup + fails, so ``_model``, ``_store``, etc. are found first via ``__dict__``. + """ + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._model, name) diff --git a/ltx2/ltx_core/loader/__init__.py b/ltx2/ltx_core/loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3aaa01e6a70e6cf688ccd4148702ee3ac6109813 --- /dev/null +++ b/ltx2/ltx_core/loader/__init__.py @@ -0,0 +1,48 @@ +"""Loader utilities for model weights, LoRAs, and safetensor operations.""" + +from ltx_core.loader.fuse_loras import apply_loras +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.primitives import ( + LoRAAdaptableProtocol, + LoraPathStrengthAndSDOps, + LoraStateDictWithStrength, + ModelBuilderProtocol, + StateDict, + StateDictLoader, +) +from ltx_core.loader.registry import DummyRegistry, Registry, StateDictRegistry +from ltx_core.loader.sd_ops import ( + LTXV_LORA_COMFY_RENAMING_MAP, + ContentMatching, + ContentReplacement, + KeyValueOperation, + KeyValueOperationResult, + SDKeyValueOperation, + SDOps, +) +from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader, SafetensorsStateDictLoader +from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder + +__all__ = [ + "LTXV_LORA_COMFY_RENAMING_MAP", + "ContentMatching", + "ContentReplacement", + "DummyRegistry", + "KeyValueOperation", + "KeyValueOperationResult", + "LoRAAdaptableProtocol", + "LoraPathStrengthAndSDOps", + "LoraStateDictWithStrength", + "ModelBuilderProtocol", + "ModuleOps", + "Registry", + "SDKeyValueOperation", + "SDOps", + "SafetensorsModelStateDictLoader", + "SafetensorsStateDictLoader", + "SingleGPUModelBuilder", + "StateDict", + "StateDictLoader", + "StateDictRegistry", + "apply_loras", +] diff --git a/ltx2/ltx_core/loader/fuse_loras.py b/ltx2/ltx_core/loader/fuse_loras.py new file mode 100644 index 0000000000000000000000000000000000000000..00eecf64424d79e2cd6a3ee91c961787f33ee348 --- /dev/null +++ b/ltx2/ltx_core/loader/fuse_loras.py @@ -0,0 +1,133 @@ +from collections.abc import Iterator + +import torch + +from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict +from ltx_core.quantization.fp8_cast import _fused_add_round_launch +from ltx_core.quantization.fp8_scaled_mm import quantize_weight_to_fp8_per_tensor + + +def _get_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda", torch.cuda.current_device()) + return torch.device("cpu") + + +def fuse_lora_weights( + model_sd: StateDict, + lora_sd_and_strengths: list[LoraStateDictWithStrength], + dtype: torch.dtype | None = None, +) -> Iterator[tuple[str, torch.Tensor]]: + """Yield ``(key, fused_tensor)`` for each weight modified by at least one LoRA. + For scaled-FP8 weights, this includes both the updated ``.weight`` tensor + and its corresponding ``.weight_scale`` tensor. + """ + for key, original_weight in model_sd.sd.items(): + if original_weight is None or key.endswith(".weight_scale"): + continue + original_device = original_weight.device + weight = original_weight.to(device=_get_device()) + target_dtype = dtype if dtype is not None else weight.dtype + deltas_dtype = target_dtype if target_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2] else torch.bfloat16 + + deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, weight.device) + if deltas is None: + continue + + scale_key = key.replace(".weight", ".weight_scale") if key.endswith(".weight") else None + is_scaled_fp8 = scale_key is not None and scale_key in model_sd.sd + + if weight.dtype == torch.float8_e4m3fn: + if is_scaled_fp8: + fused = _fuse_delta_with_scaled_fp8(deltas, weight, key, scale_key, model_sd) + else: + fused = _fuse_delta_with_cast_fp8(deltas, weight, key, target_dtype) + elif weight.dtype == torch.bfloat16: + fused = _fuse_delta_with_bfloat16(deltas, weight, key, target_dtype) + else: + raise ValueError(f"Unsupported dtype: {weight.dtype}") + + for k, v in fused.items(): + yield k, v.to(device=original_device) + + +def apply_loras( + model_sd: StateDict, + lora_sd_and_strengths: list[LoraStateDictWithStrength], + dtype: torch.dtype | None = None, + destination_sd: StateDict | None = None, +) -> StateDict: + if destination_sd is not None: + sd = destination_sd.sd + for key, tensor in fuse_lora_weights(model_sd, lora_sd_and_strengths, dtype): + sd[key] = tensor + return destination_sd + + fused = dict(fuse_lora_weights(model_sd, lora_sd_and_strengths, dtype)) + sd = {k: (fused[k] if k in fused else v.clone()) for k, v in model_sd.sd.items()} + return StateDict(sd, model_sd.device, model_sd.size, model_sd.dtype) + + +def _prepare_deltas( + lora_sd_and_strengths: list[LoraStateDictWithStrength], key: str, dtype: torch.dtype, device: torch.device +) -> torch.Tensor | None: + deltas = [] + prefix = key[: -len(".weight")] + key_a = f"{prefix}.lora_A.weight" + key_b = f"{prefix}.lora_B.weight" + for lsd, coef in lora_sd_and_strengths: + if key_a not in lsd.sd or key_b not in lsd.sd: + continue + a = lsd.sd[key_a].to(device=device) + b = lsd.sd[key_b].to(device=device) + product = torch.matmul(b * coef, a) + del a, b + deltas.append(product.to(dtype=dtype)) + if len(deltas) == 0: + return None + elif len(deltas) == 1: + return deltas[0] + return torch.sum(torch.stack(deltas, dim=0), dim=0) + + +def _fuse_delta_with_scaled_fp8( + deltas: torch.Tensor, + weight: torch.Tensor, + key: str, + scale_key: str, + model_sd: StateDict, +) -> dict[str, torch.Tensor]: + """Dequantize scaled FP8 weight, add LoRA delta, and re-quantize.""" + weight_scale = model_sd.sd[scale_key] + + original_weight = weight.t().to(torch.float32) * weight_scale + + new_weight = original_weight + deltas.to(torch.float32) + + new_fp8_weight, new_weight_scale = quantize_weight_to_fp8_per_tensor(new_weight) + return {key: new_fp8_weight, scale_key: new_weight_scale} + + +def _fuse_delta_with_cast_fp8( + deltas: torch.Tensor, + weight: torch.Tensor, + key: str, + target_dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + """Fuse LoRA delta with cast-only FP8 weight (no scale factor).""" + if str(weight.device).startswith("cuda"): + _fused_add_round_launch(deltas, weight, seed=0) + else: + deltas.add_(weight.to(dtype=deltas.dtype)) + return {key: deltas.to(dtype=target_dtype)} + + +def _fuse_delta_with_bfloat16( + deltas: torch.Tensor, + weight: torch.Tensor, + key: str, + target_dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + """Fuse LoRA delta with bfloat16 weight.""" + deltas.add_(weight) + return {key: deltas.to(dtype=target_dtype)} diff --git a/ltx2/ltx_core/loader/kernels.py b/ltx2/ltx_core/loader/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4cefbe6cab9bb7633fa69604b8b9ea1e7891f4 --- /dev/null +++ b/ltx2/ltx_core/loader/kernels.py @@ -0,0 +1,72 @@ +# ruff: noqa: ANN001, ANN201, ERA001, N803, N806 +import triton +import triton.language as tl + + +@triton.jit +def fused_add_round_kernel( + x_ptr, + output_ptr, # contents will be added to the output + seed, + n_elements, + EXPONENT_BIAS, + MANTISSA_BITS, + BLOCK_SIZE: tl.constexpr, +): + """ + A kernel to upcast 8bit quantized weights to bfloat16 with stochastic rounding + and add them to bfloat16 output weights. Might be used to upcast original model weights + and to further add them to precalculated deltas coming from LoRAs. + """ + # Get program ID and compute offsets + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load data + x = tl.load(x_ptr + offsets, mask=mask) + rand_vals = tl.rand(seed, offsets) - 0.5 + + x = tl.cast(x, tl.float16) + delta = tl.load(output_ptr + offsets, mask=mask) + delta = tl.cast(delta, tl.float16) + x = x + delta + + x_bits = tl.cast(x, tl.int16, bitcast=True) + + # Calculate the exponent. Unbiased fp16 exponent is ((x_bits & 0x7C00) >> 10) - 15 for + # normal numbers and -14 for subnormals. + fp16_exponent_bits = (x_bits & 0x7C00) >> 10 + fp16_normals = fp16_exponent_bits > 0 + fp16_exponent = tl.where(fp16_normals, fp16_exponent_bits - 15, -14) + + # Add the target dtype's exponent bias and clamp to the target dtype's exponent range. + exponent = fp16_exponent + EXPONENT_BIAS + MAX_EXPONENT = 2 * EXPONENT_BIAS + 1 + exponent = tl.where(exponent > MAX_EXPONENT, MAX_EXPONENT, exponent) + exponent = tl.where(exponent < 0, 0, exponent) + + # Normal ULP exponent, expressed as an fp16 exponent field: + # (exponent - EXPONENT_BIAS - MANTISSA_BITS) + 15 + # Simplifies to: fp16_exponent - MANTISSA_BITS + 15 + # See https://en.wikipedia.org/wiki/Unit_in_the_last_place + eps_exp = tl.maximum(0, tl.minimum(31, exponent - EXPONENT_BIAS - MANTISSA_BITS + 15)) + + # Calculate epsilon in the target dtype + eps_normal = tl.cast(tl.cast(eps_exp << 10, tl.int16), tl.float16, bitcast=True) + + # Subnormal ULP: 2^(1 - EXPONENT_BIAS - MANTISSA_BITS) -> + # fp16 exponent bits: (1 - EXPONENT_BIAS - MANTISSA_BITS) + 15 = + # 16 - EXPONENT_BIAS - MANTISSA_BITS + eps_subnormal = tl.cast(tl.cast((16 - EXPONENT_BIAS - MANTISSA_BITS) << 10, tl.int16), tl.float16, bitcast=True) + eps = tl.where(exponent > 0, eps_normal, eps_subnormal) + + # Apply zero mask to epsilon + eps = tl.where(x == 0, 0.0, eps) + + # Apply stochastic rounding + output = tl.cast(x + rand_vals * eps, tl.bfloat16) + + # Store the result + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/ltx2/ltx_core/loader/module_ops.py b/ltx2/ltx_core/loader/module_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c3ed2f2bad4de5102db03071042e33fc3aaa554 --- /dev/null +++ b/ltx2/ltx_core/loader/module_ops.py @@ -0,0 +1,14 @@ +from typing import Callable, NamedTuple + +import torch + + +class ModuleOps(NamedTuple): + """ + Defines a named operation for matching and mutating PyTorch modules. + Used to selectively transform modules in a model (e.g., replacing layers with quantized versions). + """ + + name: str + matcher: Callable[[torch.nn.Module], bool] + mutator: Callable[[torch.nn.Module], torch.nn.Module] diff --git a/ltx2/ltx_core/loader/primitives.py b/ltx2/ltx_core/loader/primitives.py new file mode 100644 index 0000000000000000000000000000000000000000..8a918a1bef5cb678851eacade3747c915e19f240 --- /dev/null +++ b/ltx2/ltx_core/loader/primitives.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple, Protocol + +import torch + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.model_protocol import ModelType + +if TYPE_CHECKING: + from ltx_core.loader.registry import Registry + + +@dataclass(frozen=True) +class StateDict: + """ + Immutable container for a PyTorch state dictionary. + Contains: + - sd: Dictionary of tensors (weights, buffers, etc.) + - device: Device where tensors are stored + - size: Total memory footprint in bytes + - dtype: Set of tensor dtypes present + """ + + sd: dict + device: torch.device + size: int + dtype: set[torch.dtype] + + def footprint(self) -> tuple[int, torch.device]: + return self.size, self.device + + +class StateDictLoader(Protocol): + """ + Protocol for loading state dictionaries from various sources. + Implementations must provide: + - metadata: Extract model metadata from a single path + - load: Load state dict from path(s) and apply SDOps transformations + """ + + def metadata(self, path: str) -> dict: + """ + Load metadata from path + """ + + def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: + """ + Load state dict from path or paths (for sharded model storage) and apply sd_ops + """ + + +class ModelBuilderProtocol(Protocol[ModelType]): + """ + Protocol for building PyTorch models from configuration dictionaries. + Implementations must provide: + - meta_model: Create a model from configuration dictionary and apply module operations + - build: Create and initialize a model from state dictionary and apply dtype transformations + """ + + model_sd_ops: SDOps | None + module_ops: tuple[ModuleOps, ...] + loras: tuple["LoraPathStrengthAndSDOps", ...] + registry: "Registry" + + def meta_model(self, config: dict, module_ops: list[ModuleOps] | None = None) -> ModelType: + """ + Create a model on the meta device from a configuration dictionary. + This decouples model creation from weight loading, allowing the model + architecture to be instantiated without allocating memory for parameters. + Args: + config: Model configuration dictionary. + module_ops: Optional list of module operations to apply (e.g., quantization). + Returns: + Model instance on meta device (no actual memory allocated for parameters). + """ + ... + + def with_sd_ops(self, sd_ops: SDOps | None) -> "ModelBuilderProtocol[ModelType]": + """Return a copy of this builder with the given state-dict key remapping ops.""" + ... + + def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "ModelBuilderProtocol[ModelType]": + """Return a copy of this builder with the given module operations (e.g. quantization).""" + ... + + def with_loras(self, loras: tuple["LoraPathStrengthAndSDOps", ...]) -> "ModelBuilderProtocol[ModelType]": + """Return a copy of this builder with the given LoRAs to fuse at build time.""" + ... + + def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType]": + """Return a copy of this builder using the given weight registry for allocation.""" + ... + + def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]": + """Return a copy of this builder that loads LoRA weights onto the given device.""" + ... + + def build( + self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object + ) -> ModelType: + """ + Build the model + Args: + device: Target device for the model + dtype: Target dtype for the model, if None, uses the dtype of the model_path model + Returns: + Model instance + """ + ... + + def model_config(self) -> dict: + """Return the model configuration dictionary extracted from the checkpoint metadata.""" + ... + + +class LoRAAdaptableProtocol(Protocol): + """ + Protocol for models that can be adapted with LoRAs. + Implementations must provide: + - lora: Add a LoRA to the model + """ + + def lora(self, lora_path: str, strength: float) -> "LoRAAdaptableProtocol": + pass + + +class LoraPathStrengthAndSDOps(NamedTuple): + """ + Tuple containing a LoRA path, strength, and SDOps for applying to the LoRA state dict. + """ + + path: str + strength: float + sd_ops: SDOps + + +class LoraStateDictWithStrength(NamedTuple): + """ + Tuple containing a LoRA state dict and strength for applying to the model. + """ + + state_dict: StateDict + strength: float diff --git a/ltx2/ltx_core/loader/registry.py b/ltx2/ltx_core/loader/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd5f9dc159d079dcb7e1f189f618d7c24a07f8b --- /dev/null +++ b/ltx2/ltx_core/loader/registry.py @@ -0,0 +1,84 @@ +import hashlib +import threading +from dataclasses import dataclass, field +from pathlib import Path +from typing import Protocol + +from ltx_core.loader.primitives import StateDict +from ltx_core.loader.sd_ops import SDOps + + +class Registry(Protocol): + """ + Protocol for managing state dictionaries in a registry. + It is used to store state dictionaries and reuse them later without loading them again. + Implementations must provide: + - add: Add a state dictionary to the registry + - pop: Remove a state dictionary from the registry + - get: Retrieve a state dictionary from the registry + - clear: Clear all state dictionaries from the registry + """ + + def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: ... + + def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ... + + def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: ... + + def clear(self) -> None: ... + + +class DummyRegistry(Registry): + """ + Dummy registry that does not store state dictionaries. + """ + + def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> None: + pass + + def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + pass + + def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + pass + + def clear(self) -> None: + pass + + +@dataclass +class StateDictRegistry(Registry): + """ + Registry that stores state dictionaries in a dictionary. + """ + + _state_dicts: dict[str, StateDict] = field(default_factory=dict) + _lock: threading.Lock = field(default_factory=threading.Lock) + + def _generate_id(self, paths: list[str], sd_ops: SDOps) -> str: + m = hashlib.sha256() + parts = [str(Path(p).resolve()) for p in paths] + if sd_ops is not None: + parts.append(sd_ops.name) + m.update("\0".join(parts).encode("utf-8")) + return m.hexdigest() + + def add(self, paths: list[str], sd_ops: SDOps | None, state_dict: StateDict) -> str: + sd_id = self._generate_id(paths, sd_ops) + with self._lock: + if sd_id in self._state_dicts: + raise ValueError(f"State dict retrieved from {paths} with {sd_ops} already added, check with get first") + self._state_dicts[sd_id] = state_dict + return sd_id + + def pop(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + with self._lock: + return self._state_dicts.pop(self._generate_id(paths, sd_ops), None) + + def get(self, paths: list[str], sd_ops: SDOps | None) -> StateDict | None: + with self._lock: + return self._state_dicts.get(self._generate_id(paths, sd_ops), None) + + def clear(self) -> None: + with self._lock: + self._state_dicts.clear() diff --git a/ltx2/ltx_core/loader/sd_ops.py b/ltx2/ltx_core/loader/sd_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c61a794cea15c09d8474d0a6eccb27b0003fc8f8 --- /dev/null +++ b/ltx2/ltx_core/loader/sd_ops.py @@ -0,0 +1,139 @@ +from dataclasses import dataclass, replace +from typing import NamedTuple, Protocol + +import torch + + +@dataclass(frozen=True, slots=True) +class ContentReplacement: + """ + Represents a content replacement operation. + Used to replace a specific content with a replacement in a state dict key. + """ + + content: str + replacement: str + + +@dataclass(frozen=True, slots=True) +class ContentMatching: + """ + Represents a content matching operation. + Used to match a specific prefix and suffix in a state dict key. + """ + + prefix: str = "" + suffix: str = "" + + +class KeyValueOperationResult(NamedTuple): + """ + Represents the result of a key-value operation. + Contains the new key and value after the operation has been applied. + """ + + new_key: str + new_value: torch.Tensor + + +class KeyValueOperation(Protocol): + """ + Protocol for key-value operations. + Used to apply operations to a specific key and value in a state dict. + """ + + def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ... + + +@dataclass(frozen=True, slots=True) +class SDKeyValueOperation: + """ + Represents a key-value operation. + Used to apply operations to a specific key and value in a state dict. + """ + + key_matcher: ContentMatching + kv_operation: KeyValueOperation + + +@dataclass(frozen=True, slots=True) +class SDOps: + """Immutable class representing state dict key operations.""" + + name: str + mapping: tuple[ + ContentReplacement | ContentMatching | SDKeyValueOperation, ... + ] = () # Immutable tuple of (key, value) pairs + allowed_keys: frozenset[str] | None = None + + def with_replacement(self, content: str, replacement: str) -> "SDOps": + """Create a new SDOps instance with the specified replacement added to the mapping.""" + + new_mapping = (*self.mapping, ContentReplacement(content, replacement)) + return replace(self, mapping=new_mapping) + + def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps": + """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping.""" + + new_mapping = (*self.mapping, ContentMatching(prefix, suffix)) + return replace(self, mapping=new_mapping) + + def with_additional_allowed_keys(self, keys: frozenset[str]) -> "SDOps": + """Create a new SDOps instance that only passes keys present in *keys* (post-replacement). + If allowed_keys already exists, the sets are merged via union. + """ + merged = frozenset(keys) | self.allowed_keys if self.allowed_keys is not None else frozenset(keys) + return replace(self, allowed_keys=merged) + + def with_kv_operation( + self, + operation: KeyValueOperation, + key_prefix: str = "", + key_suffix: str = "", + ) -> "SDOps": + """Create a new SDOps instance with the specified value operation added to the mapping.""" + key_matcher = ContentMatching(key_prefix, key_suffix) + sd_kv_operation = SDKeyValueOperation(key_matcher, operation) + new_mapping = (*self.mapping, sd_kv_operation) + return replace(self, mapping=new_mapping) + + def apply_to_key(self, key: str) -> str | None: + """Apply the mapping to the given name.""" + matchers = [content for content in self.mapping if isinstance(content, ContentMatching)] + valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers) + if not valid: + return None + + for replacement in self.mapping: + if not isinstance(replacement, ContentReplacement): + continue + if replacement.content in key: + key = key.replace(replacement.content, replacement.replacement) + + if self.allowed_keys is not None and key not in self.allowed_keys: + return None + + return key + + def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: + """Apply the value operation to the given name and associated value.""" + for operation in self.mapping: + if not isinstance(operation, SDKeyValueOperation): + continue + if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix): + return operation.kv_operation(key, value) + return [KeyValueOperationResult(key, value)] + + +# Predefined SDOps instances +LTXV_LORA_COMFY_RENAMING_MAP = ( + SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "") +) + +LTXV_LORA_COMFY_TARGET_MAP = ( + SDOps("LTXV_LORA_COMFY_TARGET_MAP") + .with_matching() + .with_replacement("diffusion_model.", "") + .with_replacement(".lora_A.weight", ".weight") + .with_replacement(".lora_B.weight", ".weight") +) diff --git a/ltx2/ltx_core/loader/sft_loader.py b/ltx2/ltx_core/loader/sft_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..859da8098ff018e6bfa8ab1d3299a4431e604568 --- /dev/null +++ b/ltx2/ltx_core/loader/sft_loader.py @@ -0,0 +1,66 @@ +import json + +import safetensors +import torch + +from ltx_core.loader.primitives import StateDict, StateDictLoader +from ltx_core.loader.sd_ops import SDOps + + +class SafetensorsStateDictLoader(StateDictLoader): + """ + Loads weights from safetensors files without metadata support. + Use this for loading raw weight files. For model files that include + configuration metadata, use SafetensorsModelStateDictLoader instead. + """ + + def metadata(self, path: str) -> dict: + raise NotImplementedError("Not implemented") + + def load(self, path: str | list[str], sd_ops: SDOps, device: torch.device | None = None) -> StateDict: + """ + Load state dict from path or paths (for sharded model storage) and apply sd_ops + """ + sd = {} + size = 0 + dtype = set() + device = device or torch.device("cpu") + model_paths = path if isinstance(path, list) else [path] + for shard_path in model_paths: + with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f: + safetensor_keys = f.keys() + for name in safetensor_keys: + expected_name = name if sd_ops is None else sd_ops.apply_to_key(name) + if expected_name is None: + continue + value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False) + key_value_pairs = ((expected_name, value),) + if sd_ops is not None: + key_value_pairs = sd_ops.apply_to_key_value(expected_name, value) + for key, value in key_value_pairs: + size += value.nbytes + dtype.add(value.dtype) + sd[key] = value + + return StateDict(sd=sd, device=device, size=size, dtype=dtype) + + +class SafetensorsModelStateDictLoader(StateDictLoader): + """ + Loads weights and configuration metadata from safetensors model files. + Unlike SafetensorsStateDictLoader, this loader can read model configuration + from the safetensors file metadata via the metadata() method. + """ + + def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None): + self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader() + + def metadata(self, path: str) -> dict: + with safetensors.safe_open(path, framework="pt") as f: + meta = f.metadata() + if meta is None or "config" not in meta: + return {} + return json.loads(meta["config"]) + + def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: + return self.weight_loader.load(path, sd_ops, device) diff --git a/ltx2/ltx_core/loader/single_gpu_model_builder.py b/ltx2/ltx_core/loader/single_gpu_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..d439d3d7ee5cfcd9b1789e99d8a71f8131ffe690 --- /dev/null +++ b/ltx2/ltx_core/loader/single_gpu_model_builder.py @@ -0,0 +1,136 @@ +import logging +from dataclasses import dataclass, field, replace +from typing import Generic + +import torch + +from ltx_core.loader.fuse_loras import apply_loras +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.primitives import ( + LoRAAdaptableProtocol, + LoraPathStrengthAndSDOps, + LoraStateDictWithStrength, + ModelBuilderProtocol, + StateDict, + StateDictLoader, +) +from ltx_core.loader.registry import DummyRegistry, Registry +from ltx_core.loader.sd_ops import SDOps +from ltx_core.loader.sft_loader import SafetensorsModelStateDictLoader +from ltx_core.model.model_protocol import ModelConfigurator, ModelType + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol): + """ + Builder for PyTorch models residing on a single GPU. + Attributes: + model_class_configurator: Class responsible for constructing the model from a config dict. + model_path: Path (or tuple of shard paths) to the model's `.safetensors` checkpoint(s). + model_sd_ops: Optional state-dict operations applied when loading the model weights. + module_ops: Sequence of module-level mutations applied to the meta model before weight loading. + loras: Sequence of LoRA adapters (path, strength, optional sd_ops) to fuse into the model. + model_loader: Strategy for loading state dicts from disk. Defaults to + :class:`SafetensorsModelStateDictLoader`. + registry: Cache for already-loaded state dicts. Defaults to :class:`DummyRegistry` (no caching). + lora_load_device: Device used when loading LoRA weight tensors from disk. Defaults to + ``torch.device("cpu")``, which keeps LoRA weights in CPU memory and transfers them to + the target GPU sequentially during fusion, reducing peak GPU memory usage compared to + loading all LoRA weights directly onto the GPU at once. + """ + + model_class_configurator: type[ModelConfigurator[ModelType]] + model_path: str | tuple[str, ...] + model_sd_ops: SDOps | None = None + module_ops: tuple[ModuleOps, ...] = field(default_factory=tuple) + loras: tuple[LoraPathStrengthAndSDOps, ...] = field(default_factory=tuple) + model_loader: StateDictLoader = field(default_factory=SafetensorsModelStateDictLoader) + registry: Registry = field(default_factory=DummyRegistry) + lora_load_device: torch.device = field(default_factory=lambda: torch.device("cpu")) + + def lora(self, lora_path: str, strength: float = 1.0, sd_ops: SDOps | None = None) -> "SingleGPUModelBuilder": + return replace(self, loras=(*self.loras, LoraPathStrengthAndSDOps(lora_path, strength, sd_ops))) + + def with_sd_ops(self, sd_ops: SDOps | None) -> "SingleGPUModelBuilder": + return replace(self, model_sd_ops=sd_ops) + + def with_module_ops(self, module_ops: tuple[ModuleOps, ...]) -> "SingleGPUModelBuilder": + return replace(self, module_ops=module_ops) + + def with_loras(self, loras: tuple[LoraPathStrengthAndSDOps, ...]) -> "SingleGPUModelBuilder": + return replace(self, loras=loras) + + def with_registry(self, registry: Registry) -> "SingleGPUModelBuilder": + return replace(self, registry=registry) + + def with_lora_load_device(self, device: torch.device) -> "SingleGPUModelBuilder": + return replace(self, lora_load_device=device) + + def model_config(self) -> dict: + first_shard_path = self.model_path[0] if isinstance(self.model_path, tuple) else self.model_path + return self.model_loader.metadata(first_shard_path) + + def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelType: + with torch.device("meta"): + model = self.model_class_configurator.from_config(config) + for module_op in module_ops: + if module_op.matcher(model): + model = module_op.mutator(model) + return model + + def load_sd( + self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None + ) -> StateDict: + state_dict = registry.get(paths, sd_ops) + if state_dict is None: + state_dict = self.model_loader.load(paths, sd_ops=sd_ops, device=device) + registry.add(paths, sd_ops=sd_ops, state_dict=state_dict) + return state_dict + + def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType: + uninitialized_params = [name for name, param in meta_model.named_parameters() if str(param.device) == "meta"] + uninitialized_buffers = [name for name, buffer in meta_model.named_buffers() if str(buffer.device) == "meta"] + if uninitialized_params or uninitialized_buffers: + logger.warning(f"Uninitialized parameters or buffers: {uninitialized_params + uninitialized_buffers}") + return meta_model + retval = meta_model.to(device) + return retval + + def build( + self, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + **kwargs: object, # noqa: ARG002 + ) -> ModelType: + device = torch.device("cuda") if device is None else device + config = self.model_config() + meta_model = self.meta_model(config, self.module_ops) + model_paths = list(self.model_path) if isinstance(self.model_path, tuple) else [self.model_path] + model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device) + + lora_strengths = [lora.strength for lora in self.loras] + if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0): + sd = model_state_dict.sd + if dtype is not None: + sd = {key: value.to(dtype=dtype) for key, value in model_state_dict.sd.items()} + meta_model.load_state_dict(sd, strict=False, assign=True) + return self._return_model(meta_model, device) + + lora_state_dicts = [ + self.load_sd([lora.path], sd_ops=lora.sd_ops, registry=self.registry, device=self.lora_load_device) + for lora in self.loras + ] + lora_sd_and_strengths = [ + LoraStateDictWithStrength(sd, strength) + for sd, strength in zip(lora_state_dicts, lora_strengths, strict=True) + ] + final_sd = apply_loras( + model_sd=model_state_dict, + lora_sd_and_strengths=lora_sd_and_strengths, + dtype=dtype, + destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None, + ) + meta_model.load_state_dict(final_sd.sd, strict=False, assign=True) + return self._return_model(meta_model, device) diff --git a/ltx2/ltx_core/modality_tiling.py b/ltx2/ltx_core/modality_tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d45c392376a95b562095300155a421d486e1f1 --- /dev/null +++ b/ltx2/ltx_core/modality_tiling.py @@ -0,0 +1,222 @@ +"""Video modality tiling helpers. +Provides :class:`VideoModalityTilingHelper` β€” a stateless helper that +tiles and blends video :class:`Modality` token sequences by +spatial/temporal region. Tile geometry is represented by the existing +:class:`Tile` NamedTuple from :mod:`ltx_core.tiling`; no distributed +primitives are required. +""" + +from __future__ import annotations + +from dataclasses import dataclass, replace + +import torch + +from ltx_core.model.transformer.modality import Modality +from ltx_core.tiling import Tile, TileCountConfig, create_tiles, identity_mapping_operation, split_by_count +from ltx_core.tools import VideoLatentTools +from ltx_core.types import VideoLatentShape + + +@dataclass(frozen=True) +class TilingContext: + """Opaque context produced by :meth:`VideoModalityTilingHelper.tile_modality`. + Carries the token-level keep mask and per-conditioning-token blend + weights needed by :meth:`~VideoModalityTilingHelper.blend`. + """ + + keep_mask: torch.Tensor + cond_blend_weights: torch.Tensor | None + """``(num_kept_cond,)`` β€” weight for each kept conditioning token, + equal to ``1 / num_tiles_that_keep_this_token``. ``None`` when + there are no conditioning tokens.""" + + +class VideoModalityTilingHelper: + """Stateless helper that tiles and blends video :class:`Modality` sequences. + Constructed once with a :class:`TileCountConfig` and + :class:`VideoLatentTools`. Tiles are computed at construction and + available via the :attr:`tiles` property. Use :meth:`tile_modality` + and :meth:`blend` with any tile from that list. + Usage:: + helper = VideoModalityTilingHelper(tiling, video_tools) + for tile in helper.tiles: + tiled_mod, ctx = helper.tile_modality(modality, tile) + result = run_model(tiled_mod) + helper.blend(result, tile, ctx, output=output) + """ + + def __init__(self, tiling: TileCountConfig, video_tools: VideoLatentTools) -> None: + self._patchifier = video_tools.patchifier + self._latent_shape = video_tools.target_shape + self._num_generated_tokens = self._patchifier.get_token_count(self._latent_shape) + self._tiles = create_tiles( + torch.Size([self._latent_shape.frames, self._latent_shape.height, self._latent_shape.width]), + splitters=[ + split_by_count(tiling.frames.num_tiles, tiling.frames.overlap), + split_by_count(tiling.height.num_tiles, tiling.height.overlap), + split_by_count(tiling.width.num_tiles, tiling.width.overlap), + ], + mappers=[identity_mapping_operation] * 3, + ) + + @property + def tiles(self) -> list[Tile]: + """All tiles for the configured tiling layout.""" + return self._tiles + + # -- tile modality ----------------------------------------------------- + + def tile_modality(self, modality: Modality, tile: Tile) -> tuple[Modality, TilingContext]: + """Slice *modality* to the tokens covered by *tile*. + Selects generated tokens belonging to the tile's spatial region + and conditioning tokens that overlap with the tile (or have + negative time coordinates). + Returns: + A ``(tiled_modality, context)`` tuple. Pass *context* to + :meth:`blend` together with the model output. + """ + keep_mask = self._keep_mask(modality, tile) + + tile_attention_mask = None + if modality.attention_mask is not None: + keep_indices = keep_mask.nonzero(as_tuple=False).squeeze(1) + tile_attention_mask = modality.attention_mask[:, keep_indices, :][:, :, keep_indices] + + tiled = replace( + modality, + latent=modality.latent[:, keep_mask, :], + timesteps=modality.timesteps[:, keep_mask], + positions=modality.positions[:, :, keep_mask, :], + attention_mask=tile_attention_mask, + ) + + cond_blend_weights = None + num_total = modality.latent.shape[1] + if num_total > self._num_generated_tokens: + cond_keep = keep_mask[self._num_generated_tokens :] + # Count how many tiles keep each conditioning token. + cond_counts = torch.zeros(cond_keep.sum(), dtype=torch.float32) + for t in self._tiles: + other_mask = self._keep_mask(modality, t) + other_cond = other_mask[self._num_generated_tokens :] + # Map other tile's kept cond tokens into this tile's kept subset. + cond_counts += other_cond[cond_keep].float() + cond_blend_weights = 1.0 / cond_counts + + return tiled, TilingContext(keep_mask=keep_mask, cond_blend_weights=cond_blend_weights) + + # -- blend ------------------------------------------------------------- + + def blend( + self, + tile_to_blend: torch.Tensor, + tile: Tile, + context: TilingContext, + output: torch.Tensor | None = None, + ) -> torch.Tensor: + """Blend-weight tile results and accumulate into the full token space. + Premultiplied (blend-weighted) data is **added** to *output*, + allowing multiple tiles to be accumulated into the same buffer. + Args: + tile_to_blend: Denoised tile tensor ``(B, num_tile_tokens, D)``, + where the first ``_tile_generated_token_count(tile)`` + entries are generated tokens and the remainder are + conditioning tokens. + tile: The :class:`Tile` that was used in :meth:`tile_modality`. + context: The :class:`TilingContext` returned by :meth:`tile_modality`. + output: Optional pre-allocated output tensor. When provided + its shape must be ``(B, num_total_tokens, D)`` and the + blended tile is **added** into it. When ``None`` a new + zero-filled tensor is created. + Returns: + The output tensor with the blended tile added at the correct + positions. + """ + batch, _, dim = tile_to_blend.shape + num_tile_gen = self._tile_generated_token_count(tile) + gen_indices = self._generated_token_indices(tile) + + num_total_tokens = context.keep_mask.shape[0] + expected_shape = (batch, num_total_tokens, dim) + + if output is not None: + if output.shape != expected_shape: + raise ValueError(f"Expected output shape {expected_shape}, got {output.shape}") + result = output + else: + result = torch.zeros(*expected_shape, device=tile_to_blend.device, dtype=tile_to_blend.dtype) + + # Blend mask is (tile_F, tile_H, tile_W) β€” one weight per token in row-major order. + blend_weights = tile.blend_mask.reshape(-1).to(device=tile_to_blend.device, dtype=tile_to_blend.dtype) + tile_gen = tile_to_blend[:, :num_tile_gen, :] * blend_weights[None, :, None] + + result[:, gen_indices, :] += tile_gen + + # Scatter kept conditioning tokens, weighted by 1/N where N is + # the number of tiles that keep each token (so they sum to 1). + if num_total_tokens > self._num_generated_tokens and context.cond_blend_weights is not None: + cond_keep = context.keep_mask[self._num_generated_tokens :] + cond_indices = self._num_generated_tokens + cond_keep.nonzero(as_tuple=False).squeeze(1) + weights = context.cond_blend_weights.to(device=tile_to_blend.device, dtype=tile_to_blend.dtype) + result[:, cond_indices, :] += tile_to_blend[:, num_tile_gen:, :] * weights[None, :, None] + + return result + + # -- private ----------------------------------------------------------- + + def _tile_generated_token_count(self, tile: Tile) -> int: + """Number of generated tokens in *tile*.""" + frame_slice, height_slice, width_slice = tile.in_coords + tile_shape = VideoLatentShape( + batch=self._latent_shape.batch, + channels=self._latent_shape.channels, + frames=frame_slice.stop - frame_slice.start, + height=height_slice.stop - height_slice.start, + width=width_slice.stop - width_slice.start, + ) + return self._patchifier.get_token_count(tile_shape) + + def _generated_token_indices(self, tile: Tile) -> torch.Tensor: + """Flat token indices of *tile*'s generated tokens in the full sequence.""" + frame_slice, height_slice, width_slice = tile.in_coords + f = torch.arange(frame_slice.start, frame_slice.stop) + h = torch.arange(height_slice.start, height_slice.stop) + w = torch.arange(width_slice.start, width_slice.stop) + return ( + f[:, None, None] * self._latent_shape.height * self._latent_shape.width + + h[None, :, None] * self._latent_shape.width + + w[None, None, :] + ).reshape(-1) + + def _keep_mask(self, modality: Modality, tile: Tile) -> torch.Tensor: + """Boolean mask ``(num_total_tokens,)`` β€” True for tokens the tile processes. + Generated tokens are selected by grid position. Conditioning + tokens are kept when their ``[start, end)`` intervals overlap + the tile in all three dimensions, or when they have a negative + time coordinate (reference tokens). + """ + num_total = modality.latent.shape[1] + mask = torch.zeros(num_total, dtype=torch.bool) + + gen_indices = self._generated_token_indices(tile) + mask[gen_indices] = True + + if num_total > self._num_generated_tokens: + gen_positions = modality.positions[:, :, gen_indices, :] # (B, 3, num_tile_gen, 2) + tile_start = gen_positions[..., 0].amin(dim=2) # (B, 3) + tile_end = gen_positions[..., 1].amax(dim=2) # (B, 3) + + cond_positions = modality.positions[:, :, self._num_generated_tokens :, :] # (B, 3, num_cond, 2) + + overlaps = (cond_positions[..., 0] < tile_end.unsqueeze(2)) & ( + cond_positions[..., 1] > tile_start.unsqueeze(2) + ) # (B, 3, num_cond) + overlaps_all_dims = overlaps.all(dim=1) # (B, num_cond) + + has_negative_time = cond_positions[:, 0, :, 0] < 0 # (B, num_cond) + + keep_cond = (overlaps_all_dims | has_negative_time).any(dim=0) # (num_cond,) + mask[self._num_generated_tokens :] = keep_cond + + return mask diff --git a/ltx2/ltx_core/model/__init__.py b/ltx2/ltx_core/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a5b27abc167b8245b5b68c876fe2127bd40a4f --- /dev/null +++ b/ltx2/ltx_core/model/__init__.py @@ -0,0 +1,8 @@ +"""Model definitions for LTX-2.""" + +from ltx_core.model.model_protocol import ModelConfigurator, ModelType + +__all__ = [ + "ModelConfigurator", + "ModelType", +] diff --git a/ltx2/ltx_core/model/audio_vae/__init__.py b/ltx2/ltx_core/model/audio_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1cd3ac31e9f71e8dca2f1e786894dbab1ceae9c --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/__init__.py @@ -0,0 +1,29 @@ +"""Audio VAE model components.""" + +from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio, encode_audio +from ltx_core.model.audio_vae.model_configurator import ( + AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, + AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, + VOCODER_COMFY_KEYS_FILTER, + AudioDecoderConfigurator, + AudioEncoderConfigurator, + VocoderConfigurator, +) +from ltx_core.model.audio_vae.ops import AudioProcessor +from ltx_core.model.audio_vae.vocoder import Vocoder, VocoderWithBWE + +__all__ = [ + "AUDIO_VAE_DECODER_COMFY_KEYS_FILTER", + "AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER", + "VOCODER_COMFY_KEYS_FILTER", + "AudioDecoder", + "AudioDecoderConfigurator", + "AudioEncoder", + "AudioEncoderConfigurator", + "AudioProcessor", + "Vocoder", + "VocoderConfigurator", + "VocoderWithBWE", + "decode_audio", + "encode_audio", +] diff --git a/ltx2/ltx_core/model/audio_vae/attention.py b/ltx2/ltx_core/model/audio_vae/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..46d5ebb29d340d75cd1907ce4e4bd7e14e80f394 --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/attention.py @@ -0,0 +1,71 @@ +from enum import Enum + +import torch + +from ltx_core.model.common.normalization import NormType, build_normalization_layer + + +class AttentionType(Enum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class AttnBlock(torch.nn.Module): + def __init__( + self, + in_channels: int, + norm_type: NormType = NormType.GROUP, + ) -> None: + super().__init__() + self.in_channels = in_channels + + self.norm = build_normalization_layer(in_channels, normtype=norm_type) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn( + in_channels: int, + attn_type: AttentionType = AttentionType.VANILLA, + norm_type: NormType = NormType.GROUP, +) -> torch.nn.Module: + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return torch.nn.Identity() + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") diff --git a/ltx2/ltx_core/model/audio_vae/audio_vae.py b/ltx2/ltx_core/model/audio_vae/audio_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..7b2cfe20a1739578cc740513c25f901cf009070f --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/audio_vae.py @@ -0,0 +1,508 @@ +from typing import Set, Tuple + +import torch +import torch.nn.functional as F + +from ltx_core.components.patchifiers import AudioPatchifier +from ltx_core.model.audio_vae.attention import AttentionType, make_attn +from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.downsample import build_downsampling_path +from ltx_core.model.audio_vae.ops import AudioProcessor, PerChannelStatistics +from ltx_core.model.audio_vae.resnet import ResnetBlock +from ltx_core.model.audio_vae.upsample import build_upsampling_path +from ltx_core.model.audio_vae.vocoder import Vocoder +from ltx_core.model.common.normalization import NormType, build_normalization_layer +from ltx_core.types import Audio, AudioLatentShape + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +def build_mid_block( + channels: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + add_attention: bool, +) -> torch.nn.Module: + """Build the middle block with two ResNet blocks and optional attention.""" + mid = torch.nn.Module() + mid.block_1 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity() + mid.block_2 = ResnetBlock( + in_channels=channels, + out_channels=channels, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + return mid + + +def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor: + """Run features through the middle block.""" + features = mid.block_1(features, temb=None) + features = mid.attn_1(features) + return mid.block_2(features, temb=None) + + +class AudioEncoder(torch.nn.Module): + """ + Encoder that compresses audio spectrograms into latent representations. + The encoder uses a series of downsampling blocks with residual connections, + attention mechanisms, and configurable causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions: Set[int], + dropout: float = 0.0, + resamp_with_conv: bool = True, + in_channels: int, + resolution: int, + z_channels: int, + double_z: bool = True, + attn_type: AttentionType = AttentionType.VANILLA, + mid_block_add_attention: bool = True, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + sample_rate: int = 16000, + mel_hop_length: int = 160, + n_fft: int = 1024, + is_causal: bool = True, + mel_bins: int = 64, + **_ignore_kwargs, + ) -> None: + """ + Initialize the Encoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + ch: Base number of feature channels used in the first convolution layer. + ch_mult: Multiplicative factors for the number of channels at each resolution level. + num_res_blocks: Number of residual blocks to use at each resolution level. + attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention. + resolution: Input spatial resolution of the spectrogram (height, width). + z_channels: Number of channels in the latent representation. + norm_type: Normalization layer type to use within the network (e.g., group, batch). + causality_axis: Axis along which convolutions should be causal (e.g., time axis). + sample_rate: Audio sample rate in Hz for the input signals. + mel_hop_length: Hop length used when computing the mel spectrogram. + n_fft: FFT size used to compute the spectrogram. + mel_bins: Number of mel-frequency bins in the input spectrogram. + in_channels: Number of channels in the input spectrogram tensor. + double_z: If True, predict both mean and log-variance (doubling latent channels). + is_causal: If True, use causal convolutions suitable for streaming setups. + dropout: Dropout probability used in residual and mid blocks. + attn_type: Type of attention mechanism to use in attention blocks. + resamp_with_conv: If True, perform resolution changes using strided convolutions. + mid_block_add_attention: If True, add an attention block in the mid-level of the encoder. + """ + super().__init__() + + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + self.causality_axis = causality_axis + self.attn_type = attn_type + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + self.non_linearity = torch.nn.SiLU() + + self.down, block_in = build_downsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + ) + + self.mid = build_mid_block( + channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + + self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=self.causality_axis, + ) + + def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: + """ + Encode audio spectrogram into latent representations. + Args: + spectrogram: Input spectrogram of shape (batch, channels, time, frequency) + Returns: + Encoded latent representation of shape (batch, channels, frames, mel_bins) + """ + h = self.conv_in(spectrogram) + h = self._run_downsampling_path(h) + h = run_mid_block(self.mid, h) + h = self._finalize_output(h) + + return self._normalize_latents(h) + + def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx in range(self.num_res_blocks): + h = stage.block[block_idx](h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != self.num_resolutions - 1: + h = stage.downsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + h = self.norm_out(h) + h = self.non_linearity(h) + return self.conv_out(h) + + def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor: + """ + Normalize encoder latents using per-channel statistics. + When the encoder is configured with ``double_z=True``, the final + convolution produces twice the number of latent channels, typically + interpreted as two concatenated tensors along the channel dimension + (e.g., mean and variance or other auxiliary parameters). + This method intentionally uses only the first half of the channels + (the "mean" component) as input to the patchifier and normalization + logic. The remaining channels are left unchanged by this method and + are expected to be consumed elsewhere in the VAE pipeline. + If ``double_z=False``, the encoder output already contains only the + mean latents and the chunking operation simply returns that tensor. + """ + means = torch.chunk(latent_output, 2, dim=1)[0] + latent_shape = AudioLatentShape( + batch=means.shape[0], + channels=means.shape[1], + frames=means.shape[2], + mel_bins=means.shape[3], + ) + latent_patched = self.patchifier.patchify(means) + latent_normalized = self.per_channel_statistics.normalize(latent_patched) + return self.patchifier.unpatchify(latent_normalized, latent_shape) + + +def encode_audio( + audio: Audio, + audio_encoder: AudioEncoder, + audio_processor: AudioProcessor | None = None, +) -> torch.Tensor: + """Encode audio waveform into latent representation. + Args: + audio: Audio container with waveform tensor of shape (batch, channels, samples) and sampling rate. + audio_encoder: Audio encoder model + audio_processor: Audio processor model (optional, if not provided, it will be created from the audio encoder) + """ + dtype = next(audio_encoder.parameters()).dtype + device = next(audio_encoder.parameters()).device + + if audio_processor is None: + audio_processor = AudioProcessor( + target_sample_rate=audio_encoder.sample_rate, + mel_bins=audio_encoder.mel_bins, + mel_hop_length=audio_encoder.mel_hop_length, + n_fft=audio_encoder.n_fft, + ).to(device=device) + + mel_spectrogram = audio_processor.waveform_to_mel(audio.to(device=device)) + + latent = audio_encoder(mel_spectrogram.to(dtype=dtype)) + return latent + + +class AudioDecoder(torch.nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + The decoder mirrors the encoder structure with configurable channel multipliers, + attention resolutions, and causal convolutions. + """ + + def __init__( # noqa: PLR0913 + self, + *, + ch: int, + out_ch: int, + ch_mult: Tuple[int, ...] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions: Set[int], + resolution: int, + z_channels: int, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + dropout: float = 0.0, + mid_block_add_attention: bool = True, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = None, + ) -> None: + """ + Initialize the Decoder. + Args: + Arguments are configuration parameters, loaded from the audio VAE checkpoint config + (audio_vae.model.params.ddconfig): + - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions + - resolution, z_channels + - norm_type, causality_axis + """ + super().__init__() + + # Internal behavioural defaults that are not driven by the checkpoint. + resamp_with_conv = True + attn_type = AttentionType.VANILLA + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.out_ch = out_ch + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.z_channels = z_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + self.attn_type = attn_type + + base_block_channels = ch * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, z_channels, base_resolution, base_resolution) + + self.conv_in = make_conv2d( + z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + self.non_linearity = torch.nn.SiLU() + self.mid = build_mid_block( + channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + add_attention=mid_block_add_attention, + ) + self.up, final_block_channels = build_upsampling_path( + ch=ch, + ch_mult=ch_mult, + num_resolutions=self.num_resolutions, + num_res_blocks=num_res_blocks, + resolution=resolution, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + attn_type=self.attn_type, + attn_resolutions=attn_resolutions, + resamp_with_conv=resamp_with_conv, + initial_block_channels=base_block_channels, + ) + + self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) + self.conv_out = make_conv2d( + final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + """ + Decode latent features back to audio spectrograms. + Args: + sample: Encoded latent representation of shape (batch, channels, frames, mel_bins) + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + sample, target_shape = self._denormalize_latents(sample) + + h = self.conv_in(sample) + h = run_mid_block(self.mid, h) + h = self._run_upsampling_path(h) + h = self._finalize_output(h) + + return self._adjust_output_shape(h, target_shape) + + def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]: + latent_shape = AudioLatentShape( + batch=sample.shape[0], + channels=sample.shape[1], + frames=sample.shape[2], + mel_bins=sample.shape[3], + ) + + sample_patched = self.patchifier.patchify(sample) + sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) + sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) + + target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR + if self.causality_axis != CausalityAxis.NONE: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_shape = AudioLatentShape( + batch=latent_shape.batch, + channels=self.out_ch, + frames=target_frames, + mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, + ) + + return sample, target_shape + + def _adjust_output_shape( + self, + decoded_output: torch.Tensor, + target_shape: AudioLatentShape, + ) -> torch.Tensor: + """ + Adjust output shape to match target dimensions for variable-length audio. + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: AudioLatentShape describing (batch, channels, time, mel bins) + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + target_channels = target_shape.channels + target_time = target_shape.frames + target_freq = target_shape.mel_bins + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor: + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + h = block(h, temb=None) + if stage.attn: + h = stage.attn[block_idx](h) + + if level != 0 and hasattr(stage, "upsample"): + h = stage.upsample(h) + + return h + + def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = self.non_linearity(h) + h = self.conv_out(h) + return torch.tanh(h) if self.tanh_out else h + + +def decode_audio(latent: torch.Tensor, audio_decoder: "AudioDecoder", vocoder: "Vocoder") -> Audio: + """ + Decode an audio latent representation using the provided audio decoder and vocoder. + Args: + latent: Input audio latent tensor. + audio_decoder: Model to decode the latent to waveform features. + vocoder: Model to convert decoded features to audio waveform. + Returns: + Decoded audio with waveform and sampling rate. + """ + decoded_audio = audio_decoder(latent) + waveform = vocoder(decoded_audio).squeeze(0).float() + return Audio(waveform=waveform, sampling_rate=vocoder.output_sampling_rate) diff --git a/ltx2/ltx_core/model/audio_vae/causal_conv_2d.py b/ltx2/ltx_core/model/audio_vae/causal_conv_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..efb63861cd26ab6130ab88a0aeb7836d8e63c420 --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/causal_conv_2d.py @@ -0,0 +1,110 @@ +import torch +import torch.nn.functional as F + +from ltx_core.model.audio_vae.causality_axis import CausalityAxis + + +class CausalConv2d(torch.nn.Module): + """ + A causal 2D convolution. + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = torch.nn.modules.utils._pair(kernel_size) + dilation = torch.nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + padding: tuple[int, int, int, int] | None = None, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causality_axis: CausalityAxis | None = None, +) -> torch.nn.Module: + """ + Create a 2D convolution layer that can be either causal or non-causal. + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) + + return torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) diff --git a/ltx2/ltx_core/model/audio_vae/causality_axis.py b/ltx2/ltx_core/model/audio_vae/causality_axis.py new file mode 100644 index 0000000000000000000000000000000000000000..b99f83550f3e73658b05b4c467d78ecb330b1822 --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/causality_axis.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class CausalityAxis(Enum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" diff --git a/ltx2/ltx_core/model/audio_vae/downsample.py b/ltx2/ltx_core/model/audio_vae/downsample.py new file mode 100644 index 0000000000000000000000000000000000000000..336735bcb4392bbce38015c2ed51a935b03c6260 --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/downsample.py @@ -0,0 +1,110 @@ +from typing import Set, Tuple + +import torch + +from ltx_core.model.audio_vae.attention import AttentionType, make_attn +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.resnet import ResnetBlock +from ltx_core.model.common.normalization import NormType + + +class Downsample(torch.nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.WIDTH, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +def build_downsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, +) -> tuple[torch.nn.ModuleList, int]: + """Build the downsampling path with residual blocks, attention, and downsampling layers.""" + down_modules = torch.nn.ModuleList() + curr_res = resolution + in_ch_mult = (1, *tuple(ch_mult)) + block_in = ch + + for i_level in range(num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + down_modules.append(down) + + return down_modules, block_in diff --git a/ltx2/ltx_core/model/audio_vae/model_configurator.py b/ltx2/ltx_core/model/audio_vae/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..90f9b06dfa4ef0c30edff6d4ce798c57270905d2 --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/model_configurator.py @@ -0,0 +1,200 @@ +import torch + +from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps +from ltx_core.model.audio_vae.attention import AttentionType +from ltx_core.model.audio_vae.audio_vae import AudioDecoder, AudioEncoder +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.vocoder import MelSTFT, Vocoder, VocoderWithBWE +from ltx_core.model.common.normalization import NormType +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.utils import check_config_value + + +def _vocoder_from_config( + cfg: dict, + apply_final_activation: bool = True, + output_sampling_rate: int | None = None, +) -> Vocoder: + """Instantiate a Vocoder from a flat config dict. + Args: + cfg: Vocoder config dict (keys match Vocoder constructor args). + apply_final_activation: Whether to apply tanh/clamp at the output. + output_sampling_rate: Explicit override for the output sample rate. + When None, reads from cfg["output_sampling_rate"] (default 24000). + """ + return Vocoder( + resblock_kernel_sizes=cfg.get("resblock_kernel_sizes", [3, 7, 11]), + upsample_rates=cfg.get("upsample_rates", [6, 5, 2, 2, 2]), + upsample_kernel_sizes=cfg.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]), + resblock_dilation_sizes=cfg.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]), + upsample_initial_channel=cfg.get("upsample_initial_channel", 1024), + resblock=cfg.get("resblock", "1"), + output_sampling_rate=( + output_sampling_rate if output_sampling_rate is not None else cfg.get("output_sampling_rate", 24000) + ), + activation=cfg.get("activation", "snake"), + use_tanh_at_final=cfg.get("use_tanh_at_final", True), + apply_final_activation=apply_final_activation, + use_bias_at_final=cfg.get("use_bias_at_final", True), + ) + + +class VocoderConfigurator(ModelConfigurator[Vocoder]): + """Configurator that auto-detects the checkpoint format. + Returns a plain Vocoder for pre-ltx-2.3 checkpoints (flat config) or a + VocoderWithBWE for ltx-2.3+ checkpoints (nested "vocoder" + "bwe" config). + """ + + @classmethod + def from_config(cls: type[Vocoder], config: dict) -> Vocoder | VocoderWithBWE: + cfg = config.get("vocoder", {}) + + if "bwe" not in cfg: + check_config_value(cfg, "resblock", "1") + check_config_value(cfg, "stereo", True) + return _vocoder_from_config(cfg) + + vocoder_cfg = cfg.get("vocoder", {}) + bwe_cfg = cfg["bwe"] + + check_config_value(vocoder_cfg, "resblock", "AMP1") + check_config_value(vocoder_cfg, "stereo", True) + check_config_value(vocoder_cfg, "activation", "snakebeta") + check_config_value(bwe_cfg, "resblock", "AMP1") + check_config_value(bwe_cfg, "stereo", True) + check_config_value(bwe_cfg, "activation", "snakebeta") + + vocoder = _vocoder_from_config( + vocoder_cfg, + output_sampling_rate=bwe_cfg["input_sampling_rate"], + ) + bwe_generator = _vocoder_from_config( + bwe_cfg, + apply_final_activation=False, + output_sampling_rate=bwe_cfg["output_sampling_rate"], + ) + mel_stft = MelSTFT( + filter_length=bwe_cfg["n_fft"], + hop_length=bwe_cfg["hop_length"], + win_length=bwe_cfg["n_fft"], + n_mel_channels=bwe_cfg["num_mels"], + ) + return VocoderWithBWE( + vocoder=vocoder, + bwe_generator=bwe_generator, + mel_stft=mel_stft, + input_sampling_rate=bwe_cfg["input_sampling_rate"], + output_sampling_rate=bwe_cfg["output_sampling_rate"], + hop_length=bwe_cfg["hop_length"], + ) + + +def _strip_vocoder_prefix(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: + """Strip the leading 'vocoder.' prefix exactly once. + Uses removeprefix instead of str.replace so that BWE keys like + 'vocoder.vocoder.conv_pre' become 'vocoder.conv_pre' (not 'conv_pre'). + Works identically for legacy keys like 'vocoder.conv_pre' β†’ 'conv_pre'. + """ + return [KeyValueOperationResult(key.removeprefix("vocoder."), value)] + + +VOCODER_COMFY_KEYS_FILTER = ( + SDOps("VOCODER_COMFY_KEYS_FILTER") + .with_matching(prefix="vocoder.") + .with_kv_operation(operation=_strip_vocoder_prefix, key_prefix="vocoder.") +) + + +class AudioDecoderConfigurator(ModelConfigurator[AudioDecoder]): + @classmethod + def from_config(cls: type[AudioDecoder], config: dict) -> AudioDecoder: + audio_vae_cfg = config.get("audio_vae", {}) + model_cfg = audio_vae_cfg.get("model", {}) + model_params = model_cfg.get("params", {}) + ddconfig = model_params.get("ddconfig", {}) + preprocessing_cfg = audio_vae_cfg.get("preprocessing", {}) + stft_cfg = preprocessing_cfg.get("stft", {}) + mel_cfg = preprocessing_cfg.get("mel", {}) + variables_cfg = audio_vae_cfg.get("variables", {}) + + sample_rate = model_params.get("sampling_rate", 16000) + mel_hop_length = stft_cfg.get("hop_length", 160) + is_causal = stft_cfg.get("causal", True) + mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins") + + return AudioDecoder( + ch=ddconfig.get("ch", 128), + out_ch=ddconfig.get("out_ch", 2), + ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))), + num_res_blocks=ddconfig.get("num_res_blocks", 2), + attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}), + resolution=ddconfig.get("resolution", 256), + z_channels=ddconfig.get("z_channels", 8), + norm_type=NormType(ddconfig.get("norm_type", "pixel")), + causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")), + dropout=ddconfig.get("dropout", 0.0), + mid_block_add_attention=ddconfig.get("mid_block_add_attention", True), + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + +class AudioEncoderConfigurator(ModelConfigurator[AudioEncoder]): + @classmethod + def from_config(cls: type[AudioEncoder], config: dict) -> AudioEncoder: + audio_vae_cfg = config.get("audio_vae", {}) + model_cfg = audio_vae_cfg.get("model", {}) + model_params = model_cfg.get("params", {}) + ddconfig = model_params.get("ddconfig", {}) + preprocessing_cfg = audio_vae_cfg.get("preprocessing", {}) + stft_cfg = preprocessing_cfg.get("stft", {}) + mel_cfg = preprocessing_cfg.get("mel", {}) + variables_cfg = audio_vae_cfg.get("variables", {}) + + sample_rate = model_params.get("sampling_rate", 16000) + mel_hop_length = stft_cfg.get("hop_length", 160) + n_fft = stft_cfg.get("filter_length", 1024) + is_causal = stft_cfg.get("causal", True) + mel_bins = ddconfig.get("mel_bins") or mel_cfg.get("n_mel_channels") or variables_cfg.get("mel_bins") + + return AudioEncoder( + ch=ddconfig.get("ch", 128), + ch_mult=tuple(ddconfig.get("ch_mult", (1, 2, 4))), + num_res_blocks=ddconfig.get("num_res_blocks", 2), + attn_resolutions=ddconfig.get("attn_resolutions", {8, 16, 32}), + resolution=ddconfig.get("resolution", 256), + z_channels=ddconfig.get("z_channels", 8), + double_z=ddconfig.get("double_z", True), + dropout=ddconfig.get("dropout", 0.0), + resamp_with_conv=ddconfig.get("resamp_with_conv", True), + in_channels=ddconfig.get("in_channels", 2), + attn_type=AttentionType(ddconfig.get("attn_type", "vanilla")), + mid_block_add_attention=ddconfig.get("mid_block_add_attention", True), + norm_type=NormType(ddconfig.get("norm_type", "pixel")), + causality_axis=CausalityAxis(ddconfig.get("causality_axis", "height")), + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + n_fft=n_fft, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + +AUDIO_VAE_DECODER_COMFY_KEYS_FILTER = ( + SDOps("AUDIO_VAE_DECODER_COMFY_KEYS_FILTER") + .with_matching(prefix="audio_vae.decoder.") + .with_matching(prefix="audio_vae.per_channel_statistics.") + .with_replacement("audio_vae.decoder.", "") + .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.") +) + + +AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER = ( + SDOps("AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER") + .with_matching(prefix="audio_vae.encoder.") + .with_matching(prefix="audio_vae.per_channel_statistics.") + .with_replacement("audio_vae.encoder.", "") + .with_replacement("audio_vae.per_channel_statistics.", "per_channel_statistics.") +) diff --git a/ltx2/ltx_core/model/audio_vae/ops.py b/ltx2/ltx_core/model/audio_vae/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..58fcbecbc9f7413744bdb76ca170b0b5afce4efc --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/ops.py @@ -0,0 +1,73 @@ +import torch +import torchaudio +from torch import nn + +from ltx_core.types import Audio + + +class AudioProcessor(nn.Module): + """Converts audio waveforms to log-mel spectrograms with optional resampling.""" + + def __init__( + self, + target_sample_rate: int, + mel_bins: int, + mel_hop_length: int, + n_fft: int, + ) -> None: + super().__init__() + self.target_sample_rate = target_sample_rate + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=n_fft, + hop_length=mel_hop_length, + f_min=0.0, + f_max=target_sample_rate / 2.0, + n_mels=mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ) + + def resample_audio(self, audio: Audio) -> Audio: + """Resample audio to the processor's target sample rate if needed.""" + if audio.sampling_rate == self.target_sample_rate: + return audio + resampled = torchaudio.functional.resample(audio.waveform, audio.sampling_rate, self.target_sample_rate) + resampled = resampled.to(device=audio.waveform.device, dtype=audio.waveform.dtype) + return Audio(waveform=resampled, sampling_rate=self.target_sample_rate) + + def waveform_to_mel( + self, + audio: Audio, + ) -> torch.Tensor: + """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" + waveform = self.resample_audio(audio).waveform + + mel = self.mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + + mel = mel.to(device=waveform.device, dtype=waveform.dtype) + return mel.permute(0, 1, 3, 2).contiguous() + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict. + """ + + def __init__(self, latent_channels: int = 128) -> None: + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) diff --git a/ltx2/ltx_core/model/audio_vae/resnet.py b/ltx2/ltx_core/model/audio_vae/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a529d6da8853daa1d0e2bb85e8fa2432dff5723e --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/resnet.py @@ -0,0 +1,176 @@ +from typing import Tuple + +import torch + +from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.common.normalization import NormType, build_normalization_layer + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding="same", + ), + ] + ) + + self.convs2 = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2, strict=True): + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv1(xt) + xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) + xt = conv2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)): + super(ResBlock2, self).__init__() + self.convs = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding="same", + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding="same", + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv in self.convs: + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = conv(xt) + x = xt + x + return x + + +class ResnetBlock(torch.nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: NormType = NormType.GROUP, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) + self.non_linearity = torch.nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward( + self, + x: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h diff --git a/ltx2/ltx_core/model/audio_vae/upsample.py b/ltx2/ltx_core/model/audio_vae/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..3046e210ec62fef1118f88126b1e64ed7149a15f --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/upsample.py @@ -0,0 +1,106 @@ +from typing import Set, Tuple + +import torch + +from ltx_core.model.audio_vae.attention import AttentionType, make_attn +from ltx_core.model.audio_vae.causal_conv_2d import make_conv2d +from ltx_core.model.audio_vae.causality_axis import CausalityAxis +from ltx_core.model.audio_vae.resnet import ResnetBlock +from ltx_core.model.common.normalization import NormType + + +class Upsample(torch.nn.Module): + def __init__( + self, + in_channels: int, + with_conv: bool, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ) -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +def build_upsampling_path( # noqa: PLR0913 + *, + ch: int, + ch_mult: Tuple[int, ...], + num_resolutions: int, + num_res_blocks: int, + resolution: int, + temb_channels: int, + dropout: float, + norm_type: NormType, + causality_axis: CausalityAxis, + attn_type: AttentionType, + attn_resolutions: Set[int], + resamp_with_conv: bool, + initial_block_channels: int, +) -> tuple[torch.nn.ModuleList, int]: + """Build the upsampling path with residual blocks, attention, and upsampling layers.""" + up_modules = torch.nn.ModuleList() + block_in = initial_block_channels + curr_res = resolution // (2 ** (num_resolutions - 1)) + + for level in reversed(range(num_resolutions)): + stage = torch.nn.Module() + stage.block = torch.nn.ModuleList() + stage.attn = torch.nn.ModuleList() + block_out = ch * ch_mult[level] + + for _ in range(num_res_blocks + 1): + stage.block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_channels, + dropout=dropout, + norm_type=norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) + + if level != 0: + stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res *= 2 + + up_modules.insert(0, stage) + + return up_modules, block_in diff --git a/ltx2/ltx_core/model/audio_vae/vocoder.py b/ltx2/ltx_core/model/audio_vae/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f78e3e86ad9fac1620e9b8e05a8b8f8ae52609c2 --- /dev/null +++ b/ltx2/ltx_core/model/audio_vae/vocoder.py @@ -0,0 +1,594 @@ +import math +from typing import List + +import einops +import torch +import torch.nn.functional as F +from torch import nn + +from ltx_core.model.audio_vae.resnet import LRELU_SLOPE, ResBlock1 + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor) -> torch.Tensor: + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if amplitude > 50.0: + beta = 0.1102 * (amplitude - 8.7) + elif amplitude >= 21.0: + beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + return filter_.view(1, 1, kernel_size) + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff: float = 0.5, + half_width: float = 0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ) -> None: + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, n_channels, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels) + + +class UpSample1d(nn.Module): + def __init__( + self, + ratio: int = 2, + kernel_size: int | None = None, + persistent: bool = True, + window_type: str = "kaiser", + ) -> None: + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter equivalent to torchaudio.functional.resample + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff + time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 + sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + sinc_filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size, + ) + + self.register_buffer("filter", sinc_filter, persistent=persistent) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, n_channels, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1) + x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels) + return x[..., self.pad_left : -self.pad_right] + + +class DownSample1d(nn.Module): + def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None: + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation: nn.Module, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ) -> None: + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.upsample(x) + x = self.act(x) + return self.downsample(x) + + +class Snake(nn.Module): + def __init__( + self, + in_features: int, + alpha: float = 1.0, + alpha_trainable: bool = True, + alpha_logscale: bool = True, + ) -> None: + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, + in_features: int, + alpha: float = 1.0, + alpha_trainable: bool = True, + alpha_logscale: bool = True, + ) -> None: + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2) + + +class AMPBlock1(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + dilation: tuple[int, int, int] = (1, 3, 5), + activation: str = "snake", + ) -> None: + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), + ] + ) + + self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]) + self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from Mel spectrograms. + Args: + resblock_kernel_sizes: List of kernel sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`. + upsample_rates: List of upsampling rates. + This value is read from the checkpoint at `config.vocoder.upsample_rates`. + upsample_kernel_sizes: List of kernel sizes for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`. + resblock_dilation_sizes: List of dilation sizes for the residual blocks. + This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`. + upsample_initial_channel: Initial number of channels for the upsampling layers. + This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`. + resblock: Type of residual block to use ("1", "2", or "AMP1"). + This value is read from the checkpoint at `config.vocoder.resblock`. + output_sampling_rate: Waveform sample rate. + This value is read from the checkpoint at `config.vocoder.output_sampling_rate`. + activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1". + use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True). + apply_final_activation: Whether to apply the final tanh/clamp activation. + use_bias_at_final: Whether to use bias in the final conv layer. + """ + + def __init__( # noqa: PLR0913 + self, + resblock_kernel_sizes: List[int] | None = None, + upsample_rates: List[int] | None = None, + upsample_kernel_sizes: List[int] | None = None, + resblock_dilation_sizes: List[List[int]] | None = None, + upsample_initial_channel: int = 1024, + resblock: str = "1", + output_sampling_rate: int = 24000, + activation: str = "snake", + use_tanh_at_final: bool = True, + apply_final_activation: bool = True, + use_bias_at_final: bool = True, + ) -> None: + super().__init__() + + # Mutable default values are not supported as default arguments. + if resblock_kernel_sizes is None: + resblock_kernel_sizes = [3, 7, 11] + if upsample_rates is None: + upsample_rates = [6, 5, 2, 2, 2] + if upsample_kernel_sizes is None: + upsample_kernel_sizes = [16, 15, 8, 4, 4] + if resblock_dilation_sizes is None: + resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + + self.output_sampling_rate = output_sampling_rate + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.use_tanh_at_final = use_tanh_at_final + self.apply_final_activation = apply_final_activation + self.is_amp = resblock == "AMP1" + + # All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel + # bins each), 2 output channels. + self.conv_pre = nn.Conv1d( + in_channels=128, + out_channels=upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1 + + self.ups = nn.ModuleList( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)) + ) + + final_channels = upsample_initial_channel // (2 ** len(upsample_rates)) + self.resblocks = nn.ModuleList() + + for i in range(len(upsample_rates)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True): + if self.is_amp: + self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, kernel_size, dilations)) + + if self.is_amp: + self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels)) + else: + self.act_post = nn.LeakyReLU() + + # All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo). + self.conv_post = nn.Conv1d( + in_channels=final_channels, + out_channels=2, + kernel_size=7, + stride=1, + padding=3, + bias=use_bias_at_final, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the vocoder. + Args: + x: Input Mel spectrogram tensor. Can be either: + - 3D: (batch_size, time, mel_bins) for mono + - 4D: (batch_size, 2, time, mel_bins) for stereo + Returns: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time) + + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = einops.rearrange(x, "b s c t -> b (s c) t") + + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + if not self.is_amp: + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + start = i * self.num_kernels + end = start + self.num_kernels + + # Evaluate all resblocks with the same input tensor so they can run + # independently (and thus in parallel on accelerator hardware) before + # aggregating their outputs via mean. + block_outputs = torch.stack( + [self.resblocks[idx](x) for idx in range(start, end)], + dim=0, + ) + x = block_outputs.mean(dim=0) + + x = self.act_post(x) + x = self.conv_post(x) + + if self.apply_final_activation: + x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1) + + return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT x Hann-window bases. + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None: + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input β€” no lookahead. + Args: + y: Waveform tensor of shape (B, T). + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real**2 + imag**2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + ) -> None: + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + # Initialized to zeros; load_state_dict overwrites with the checkpoint's + # exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]). + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + Args: + y: Waveform tensor of shape (B, T). + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(nn.Module): + """Vocoder with bandwidth extension (BWE) upsampling. + Chains a mel-to-wav vocoder with a BWE module that upsamples the output + to a higher sample rate. The BWE computes a mel spectrogram from the + vocoder output, runs it through a second generator to predict a residual, + and adds it to a sinc-resampled skip connection. + The forward pass runs in fp32 via autocast to avoid bfloat16 accumulation + errors that degrade spectral metrics by 40-90%. + """ + + def __init__( + self, + vocoder: Vocoder, + bwe_generator: Vocoder, + mel_stft: MelSTFT, + input_sampling_rate: int, + output_sampling_rate: int, + hop_length: int, + ) -> None: + super().__init__() + self.vocoder = vocoder + self.bwe_generator = bwe_generator + self.mel_stft = mel_stft + self.input_sampling_rate = input_sampling_rate + self.output_sampling_rate = output_sampling_rate + self.hop_length = hop_length + # Compute the resampler on CPU so the sinc filter is materialized even when + # the model is constructed on meta device (SingleGPUModelBuilder pattern). + # The filter is not stored in the checkpoint (persistent=False). + with torch.device("cpu"): + self.resampler = UpSample1d( + ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann" + ) + + @property + def conv_pre(self) -> nn.Conv1d: + return self.vocoder.conv_pre + + @property + def conv_post(self) -> nn.Conv1d: + return self.vocoder.conv_post + + def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor: + """Compute log-mel spectrogram from waveform using causal STFT bases. + Args: + audio: Waveform tensor of shape (B, C, T). + Returns: + mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames). + """ + batch, n_channels, _ = audio.shape + flat = audio.reshape(batch * n_channels, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: + """Run the full vocoder + BWE forward pass. + Runs in float32 regardless of weight or input dtype. bfloat16 arithmetic + causes 40-90% spectral metric degradation due to accumulation errors + compounding through 108 sequential convolutions in the BigVGAN v2 architecture. + Args: + mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo + or (B, T, mel_bins) for mono. Same format as Vocoder.forward. + Returns: + Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1]. + """ + input_dtype = mel_spec.dtype + # Run the entire forward pass in fp32. bfloat16 accumulation errors + # compound through 108 sequential convolutions and degrade spectral + # metrics (mel_l1, MRSTFT) by 40-90% while perceptual quality (CDPAM) + # is unaffected. fp32 eliminates this degradation. + # We use autocast(dtype=float32) rather than self.float() because it + # upcasts bf16 weights per-op at kernel level, avoiding the temporary + # memory spike of self.float() / self.to(original_dtype). + # Benchmarked on H100 (128.5M-param model): + # autocast fp32: +70 MB peak VRAM, 123 ms (vs 482 MB / 95 ms for bf16) + # model.float(): +324 MB peak VRAM, 149 ms + # Tested: both approaches produce bit-identical output. + + with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32): + x = self.vocoder(mel_spec.float()) + _, _, length_low_rate = x.shape + output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate + + # Pad to multiple of hop_length for exact mel frame count + remainder = length_low_rate % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + # Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames) + mel = self._compute_mel(x) + + # Vocoder.forward expects (B, C, T, mel_bins) β€” transpose before calling bwe_generator + mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins) + residual = self.bwe_generator(mel_for_bwe) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :output_length].to(input_dtype) diff --git a/ltx2/ltx_core/model/common/__init__.py b/ltx2/ltx_core/model/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..42713f7d0f4f42c3d9f7f01db55b24c5e3e2a165 --- /dev/null +++ b/ltx2/ltx_core/model/common/__init__.py @@ -0,0 +1,9 @@ +"""Common model utilities.""" + +from ltx_core.model.common.normalization import NormType, PixelNorm, build_normalization_layer + +__all__ = [ + "NormType", + "PixelNorm", + "build_normalization_layer", +] diff --git a/ltx2/ltx_core/model/common/normalization.py b/ltx2/ltx_core/model/common/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..9877db1178286dca93df9c51f8929d9a915fe853 --- /dev/null +++ b/ltx2/ltx_core/model/common/normalization.py @@ -0,0 +1,59 @@ +from enum import Enum + +import torch +from torch import nn + + +class NormType(Enum): + """Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm).""" + + GROUP = "group" + PIXEL = "pixel" + + +class PixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + For each element along the chosen dimension, this layer normalizes the tensor + by the root-mean-square of its values across that dimension: + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +def build_normalization_layer( + in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP +) -> nn.Module: + """ + Create a normalization layer based on the normalization type. + Args: + in_channels: Number of input channels + num_groups: Number of groups for group normalization + normtype: Type of normalization: "group" or "pixel" + Returns: + A normalization layer + """ + if normtype == NormType.GROUP: + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if normtype == NormType.PIXEL: + return PixelNorm(dim=1, eps=1e-6) + raise ValueError(f"Invalid normalization type: {normtype}") diff --git a/ltx2/ltx_core/model/model_protocol.py b/ltx2/ltx_core/model/model_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..37a07f781523ec1132381415fec9f454f548b570 --- /dev/null +++ b/ltx2/ltx_core/model/model_protocol.py @@ -0,0 +1,10 @@ +from typing import Protocol, TypeVar + +ModelType = TypeVar("ModelType") + + +class ModelConfigurator(Protocol[ModelType]): + """Protocol for model loader classes that instantiates models from a configuration dictionary.""" + + @classmethod + def from_config(cls, config: dict) -> ModelType: ... diff --git a/ltx2/ltx_core/model/transformer/__init__.py b/ltx2/ltx_core/model/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43874b8bf20f1dedc8c427703d592dcc5f9ce7f6 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/__init__.py @@ -0,0 +1,18 @@ +"""Transformer model components.""" + +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.model import LTXModel, X0Model +from ltx_core.model.transformer.model_configurator import ( + LTXV_MODEL_COMFY_RENAMING_MAP, + LTXModelConfigurator, + LTXVideoOnlyModelConfigurator, +) + +__all__ = [ + "LTXV_MODEL_COMFY_RENAMING_MAP", + "LTXModel", + "LTXModelConfigurator", + "LTXVideoOnlyModelConfigurator", + "Modality", + "X0Model", +] diff --git a/ltx2/ltx_core/model/transformer/adaln.py b/ltx2/ltx_core/model/transformer/adaln.py new file mode 100644 index 0000000000000000000000000000000000000000..63200f899299220e822fefdfa5b2088f79da791e --- /dev/null +++ b/ltx2/ltx_core/model/transformer/adaln.py @@ -0,0 +1,45 @@ +from typing import Optional, Tuple + +import torch + +from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings + +# Number of AdaLN modulation parameters per transformer block. +# Base: 2 params (shift + scale) x 3 norms (self-attn, feed-forward, output). +ADALN_NUM_BASE_PARAMS = 6 +# Cross-attention AdaLN adds 3 more (scale, shift, gate) for the CA norm. +ADALN_NUM_CROSS_ATTN_PARAMS = 3 + + +def adaln_embedding_coefficient(cross_attention_adaln: bool) -> int: + """Total number of AdaLN parameters per block.""" + return ADALN_NUM_BASE_PARAMS + (ADALN_NUM_CROSS_ATTN_PARAMS if cross_attention_adaln else 0) + + +class AdaLayerNormSingle(torch.nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, embedding_coefficient: int = 6): + super().__init__() + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, + size_emb_dim=embedding_dim // 3, + ) + + self.silu = torch.nn.SiLU() + self.linear = torch.nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/ltx2/ltx_core/model/transformer/attention.py b/ltx2/ltx_core/model/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7b2d73f5e1ed1e8063cd8cc4c9a97758b0faa7 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/attention.py @@ -0,0 +1,252 @@ +from enum import Enum +from typing import Protocol + +import torch + +from ltx_core.model.transformer.rope import LTXRopeType, apply_rotary_emb + +memory_efficient_attention = None +flash_attn_interface = None +try: + from xformers.ops import memory_efficient_attention +except ImportError: + memory_efficient_attention = None +try: + # FlashAttention3 and XFormersAttention cannot be used together + if memory_efficient_attention is None: + import flash_attn_interface +except ImportError: + flash_attn_interface = None + + +class AttentionCallable(Protocol): + def __call__( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None + ) -> torch.Tensor: ... + + +class PytorchAttention(AttentionCallable): + def __call__( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None + ) -> torch.Tensor: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = (t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) + return out + + +class XFormersAttention(AttentionCallable): + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if memory_efficient_attention is None: + raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.") + + b, _, dim_head = q.shape + dim_head //= heads + + # xformers expects [B, M, H, K] + q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v)) + + # Use v.dtype as the target since q/k get cast to v.dtype for xformers + target_dtype = v.dtype + + if mask is not None: + # add a singleton batch dimension + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a singleton heads dimension + if mask.ndim == 3: + mask = mask.unsqueeze(1) + # pad to a multiple of 8 + pad = 8 - mask.shape[-1] % 8 + # the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk) + # but when using separated heads, the shape has to be (B, H, Nq, Nk) + # in flux, this matrix ends up being over 1GB + # here, we create a mask with the same batch/head size as the input mask (potentially singleton or full) + mask_out = torch.empty( + [mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=target_dtype, device=q.device + ) + + mask_out[..., : mask.shape[-1]] = mask + # doesn't this remove the padding again?? + mask = mask_out[..., : mask.shape[-1]] + mask = mask.expand(b, heads, -1, -1) + + out = memory_efficient_attention(q.to(target_dtype), k.to(target_dtype), v, attn_bias=mask, p=0.0) + out = out.reshape(b, -1, heads * dim_head) + return out + + +class FlashAttention3(AttentionCallable): + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + heads: int, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if flash_attn_interface is None: + raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.") + + b, _, dim_head = q.shape + dim_head //= heads + + q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v)) + + if mask is not None: + raise NotImplementedError("Mask is not supported for FlashAttention3") + + out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v) + out = out.reshape(b, -1, heads * dim_head) + return out + + +class AttentionFunction(Enum): + PYTORCH = "pytorch" + XFORMERS = "xformers" + FLASH_ATTENTION_3 = "flash_attention_3" + DEFAULT = "default" + + def to_callable(self) -> AttentionCallable: + """Resolve to a concrete callable. Use this at module init time so that + torch.compile can trace through the attention call without graph breaks.""" + if self is AttentionFunction.PYTORCH: + return PytorchAttention() + elif self is AttentionFunction.XFORMERS: + return XFormersAttention() + elif self is AttentionFunction.FLASH_ATTENTION_3: + return FlashAttention3() + else: + # Default behavior: XFormers if installed else - PyTorch + return XFormersAttention() if memory_efficient_attention is not None else PytorchAttention() + + +class Attention(torch.nn.Module): + def __init__( + self, + query_dim: int, + context_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + norm_eps: float = 1e-6, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + attention_function: AttentionCallable | AttentionFunction = AttentionFunction.DEFAULT, + apply_gated_attention: bool = False, + ) -> None: + super().__init__() + self.rope_type = rope_type + self.attention_function = ( + attention_function.to_callable() + if isinstance(attention_function, AttentionFunction) + else attention_function + ) + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + + self.q_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + self.k_norm = torch.nn.RMSNorm(inner_dim, eps=norm_eps) + + self.to_q = torch.nn.Linear(query_dim, inner_dim, bias=True) + self.to_k = torch.nn.Linear(context_dim, inner_dim, bias=True) + self.to_v = torch.nn.Linear(context_dim, inner_dim, bias=True) + + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True) + else: + self.to_gate_logits = None + + self.to_out = torch.nn.Sequential(torch.nn.Linear(inner_dim, query_dim, bias=True), torch.nn.Identity()) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + k_pe: torch.Tensor | None = None, + perturbation_mask: torch.Tensor | None = None, + all_perturbed: bool = False, + ) -> torch.Tensor: + """Multi-head attention with optional RoPE, perturbation masking, and per-head gating. + When ``perturbation_mask`` is all zeros, the expensive query/key path + (linear projections, RMSNorm, RoPE) is skipped entirely and only the + value projection is used as a pass-through. + Args: + x: Query input tensor of shape ``(B, T, query_dim)``. + context: Key/value context tensor of shape ``(B, S, context_dim)``. + Falls back to ``x`` (self-attention) when *None*. + mask: Optional attention mask. Interpretation depends on the attention + backend (additive bias for xformers/PyTorch SDPA). + pe: Rotary positional embeddings applied to both ``q`` and ``k``. + k_pe: Separate rotary positional embeddings for ``k`` only. When + *None*, ``pe`` is reused for keys. + perturbation_mask: Optional mask in ``[0, 1]`` that + blends the attention output with the raw value projection: + ``out = attn_out * mask + v * (1 - mask)``. + **1** keeps the full attention output, **0** bypasses attention + and passes the value projection through unchanged. + *None* or all-ones means standard attention; all-zeros skips + the query/key path entirely for efficiency. + all_perturbed: Whether all perturbations are active for this block. + Returns: + Output tensor of shape ``(B, T, query_dim)``. + """ + context = x if context is None else context + use_attention = not all_perturbed + + v = self.to_v(context) + + if not use_attention: + out = v + else: + q = self.to_q(x) + k = self.to_k(context) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pe is not None: + q = apply_rotary_emb(q, pe, self.rope_type) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type) + + out = self.attention_function(q, k, v, self.heads, mask) # (B, T, H*D) + + if perturbation_mask is not None: + out = out * perturbation_mask + v * (1 - perturbation_mask) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + # Reshape to (B, T, H, D) for per-head gating + out = out.view(b, t, self.heads, self.dim_head) + # Apply gating: 2 * sigmoid(x) so that zero-init gives identity (2 * 0.5 = 1.0) + gates = 2.0 * torch.sigmoid(gate_logits) # (B, T, H) + out = out * gates.unsqueeze(-1) # (B, T, H, D) * (B, T, H, 1) + # Reshape back to (B, T, H*D) + out = out.view(b, t, self.heads * self.dim_head) + + return self.to_out(out) diff --git a/ltx2/ltx_core/model/transformer/compiling.py b/ltx2/ltx_core/model/transformer/compiling.py new file mode 100644 index 0000000000000000000000000000000000000000..aa75838cb4cb785013a048fd3272189b644d5633 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/compiling.py @@ -0,0 +1,37 @@ +import torch + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.transformer.model import LTXModel + + +def compile_transformer(model: LTXModel) -> LTXModel: + model.transformer_blocks = torch.nn.ModuleList(torch.compile(m) for m in model.transformer_blocks) + + def patched_dynamo_forward(*args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + with ( + torch._inductor.config.patch(unsafe_skip_cache_dynamic_shape_guards=True), + torch._dynamo.config.patch( # type: ignore[attr-defined] + inline_inbuilt_nn_modules=True, cache_size_limit=256, allow_unspec_int_on_nn_module=True + ), + ): + return model.forward_without_compilation(*args, **kwargs) + + model.forward_without_compilation = model.forward + model.forward = patched_dynamo_forward + return model + + +COMPILE_TRANSFORMER = ModuleOps( + name="compile_transformer", + matcher=lambda model: isinstance(model, LTXModel), + mutator=lambda model: compile_transformer(model), +) + + +def modify_sd_ops_for_compilation(original_sd_ops: SDOps, number_of_blocks: int = 48) -> SDOps: + for i in range(number_of_blocks): + original_sd_ops = original_sd_ops.with_replacement( + f"transformer_blocks.{i}.", f"transformer_blocks.{i}._orig_mod." + ) + return original_sd_ops diff --git a/ltx2/ltx_core/model/transformer/feed_forward.py b/ltx2/ltx_core/model/transformer/feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..a55b5a73384e4fb56051388bbe2b86d69650bff0 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/feed_forward.py @@ -0,0 +1,15 @@ +import torch + +from ltx_core.model.transformer.gelu_approx import GELUApprox + + +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, dim_out: int, mult: int = 4) -> None: + super().__init__() + inner_dim = int(dim * mult) + project_in = GELUApprox(dim, inner_dim) + + self.net = torch.nn.Sequential(project_in, torch.nn.Identity(), torch.nn.Linear(inner_dim, dim_out)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) diff --git a/ltx2/ltx_core/model/transformer/gelu_approx.py b/ltx2/ltx_core/model/transformer/gelu_approx.py new file mode 100644 index 0000000000000000000000000000000000000000..1923e6e22cf03bac8c5d9449ff761af3109215ee --- /dev/null +++ b/ltx2/ltx_core/model/transformer/gelu_approx.py @@ -0,0 +1,10 @@ +import torch + + +class GELUApprox(torch.nn.Module): + def __init__(self, dim_in: int, dim_out: int) -> None: + super().__init__() + self.proj = torch.nn.Linear(dim_in, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(self.proj(x), approximate="tanh") diff --git a/ltx2/ltx_core/model/transformer/modality.py b/ltx2/ltx_core/model/transformer/modality.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2316aaf24ba70d329c55790ed448e7b6054ffd --- /dev/null +++ b/ltx2/ltx_core/model/transformer/modality.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import dataclasses +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class Modality: + """ + Input data for a single modality (video or audio) in the transformer. + Bundles the latent tokens, timestep embeddings, positional information, + and text conditioning context for processing by the diffusion transformer. + Attributes: + latent: Patchified latent tokens, shape ``(B, T, D)`` where *B* is + the batch size, *T* is the total number of tokens (noisy + + conditioning), and *D* is the input dimension. + timesteps: Per-token timestep embeddings, shape ``(B, T)``. + positions: Positional coordinates, shape ``(B, 3, T)`` for video + (time, height, width) or ``(B, 1, T)`` for audio. + context: Text conditioning embeddings from the prompt encoder. + enabled: Whether this modality is active in the current forward pass. + context_mask: Optional mask for the text context tokens. + attention_mask: Optional 2-D self-attention mask, shape ``(B, T, T)``. + Values in ``[0, 1]`` where ``1`` = full attention and ``0`` = no + attention. ``None`` means unrestricted (full) attention between + all tokens. Built incrementally by conditioning items; see + :class:`~ltx_core.conditioning.types.attention_strength_wrapper.ConditioningItemAttentionStrengthWrapper`. + """ + + latent: ( + torch.Tensor + ) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension + sigma: torch.Tensor # Shape: (B,). Current sigma value, used for cross-attention timestep calculation. + timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps + positions: ( + torch.Tensor + ) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens + context: torch.Tensor + enabled: bool = True + context_mask: torch.Tensor | None = None + attention_mask: torch.Tensor | None = None + + def split(self, sizes: list[int]) -> list[Modality]: + """Split along the batch dimension into chunks of the given sizes.""" + n = len(sizes) + split_fields: dict[str, list[torch.Tensor | None] | list[bool]] = {} + for f in dataclasses.fields(self): + value = getattr(self, f.name) + if isinstance(value, torch.Tensor): + split_fields[f.name] = list(value.split(sizes, dim=0)) + elif value is None or isinstance(value, bool): + split_fields[f.name] = [value] * n + else: + raise TypeError(f"Cannot split field {f.name!r}: unsupported type {type(value)}") + return [Modality(**{name: parts[i] for name, parts in split_fields.items()}) for i in range(n)] diff --git a/ltx2/ltx_core/model/transformer/model.py b/ltx2/ltx_core/model/transformer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4233bce8e8460e3e2d7ce2ab7da1565745ff10a4 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/model.py @@ -0,0 +1,486 @@ +from enum import Enum + +import torch + +from ltx_core.guidance.perturbations import BatchedPerturbationConfig +from ltx_core.model.transformer.adaln import AdaLayerNormSingle, adaln_embedding_coefficient +from ltx_core.model.transformer.attention import AttentionCallable, AttentionFunction +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.transformer import BasicAVTransformerBlock, TransformerConfig +from ltx_core.model.transformer.transformer_args import ( + MultiModalTransformerArgsPreprocessor, + TransformerArgs, + TransformerArgsPreprocessor, +) +from ltx_core.utils import to_denoised + + +class LTXModelType(Enum): + AudioVideo = "ltx av model" + VideoOnly = "ltx video only model" + AudioOnly = "ltx audio only model" + + def is_video_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) + + def is_audio_enabled(self) -> bool: + return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) + + +class LTXModel(torch.nn.Module): + """ + LTX model transformer implementation. + This class implements the transformer blocks for the LTX model. + """ + + def __init__( # noqa: PLR0913 + self, + *, + model_type: LTXModelType = LTXModelType.AudioVideo, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + in_channels: int = 128, + out_channels: int = 128, + num_layers: int = 48, + cross_attention_dim: int = 4096, + norm_eps: float = 1e-06, + attention_type: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = None, + timestep_scale_multiplier: int = 1000, + use_middle_indices_grid: bool = True, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_in_channels: int = 128, + audio_out_channels: int = 128, + audio_cross_attention_dim: int = 2048, + audio_positional_embedding_max_pos: list[int] | None = None, + av_ca_timestep_scale_multiplier: int = 1, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + double_precision_rope: bool = False, + apply_gated_attention: bool = False, + caption_projection: torch.nn.Module | None = None, + audio_caption_projection: torch.nn.Module | None = None, + cross_attention_adaln: bool = False, + ): + super().__init__() + self._enable_gradient_checkpointing = False + self.cross_attention_adaln = cross_attention_adaln + self.use_middle_indices_grid = use_middle_indices_grid + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.timestep_scale_multiplier = timestep_scale_multiplier + self.positional_embedding_theta = positional_embedding_theta + self.model_type = model_type + cross_pe_max_pos = None + if model_type.is_video_enabled(): + if positional_embedding_max_pos is None: + positional_embedding_max_pos = [20, 2048, 2048] + self.positional_embedding_max_pos = positional_embedding_max_pos + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self._init_video( + in_channels=in_channels, + out_channels=out_channels, + norm_eps=norm_eps, + caption_projection=caption_projection, + ) + + if model_type.is_audio_enabled(): + if audio_positional_embedding_max_pos is None: + audio_positional_embedding_max_pos = [20] + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_inner_dim = self.audio_num_attention_heads * audio_attention_head_dim + self._init_audio( + in_channels=audio_in_channels, + out_channels=audio_out_channels, + norm_eps=norm_eps, + caption_projection=audio_caption_projection, + ) + + if model_type.is_video_enabled() and model_type.is_audio_enabled(): + cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]) + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + self.audio_cross_attention_dim = audio_cross_attention_dim + self._init_audio_video(num_scale_shift_values=4) + + self._init_preprocessors(cross_pe_max_pos) + # Initialize transformer blocks + self._init_transformer_blocks( + num_layers=num_layers, + attention_head_dim=attention_head_dim if model_type.is_video_enabled() else 0, + cross_attention_dim=cross_attention_dim, + audio_attention_head_dim=audio_attention_head_dim if model_type.is_audio_enabled() else 0, + audio_cross_attention_dim=audio_cross_attention_dim, + norm_eps=norm_eps, + attention_type=attention_type, + apply_gated_attention=apply_gated_attention, + ) + + @property + def _adaln_embedding_coefficient(self) -> int: + return adaln_embedding_coefficient(self.cross_attention_adaln) + + def _init_video( + self, + in_channels: int, + out_channels: int, + norm_eps: float, + caption_projection: torch.nn.Module | None = None, + ) -> None: + """Initialize video-specific components.""" + # Video input components + self.patchify_proj = torch.nn.Linear(in_channels, self.inner_dim, bias=True) + if caption_projection is not None: + self.caption_projection = caption_projection + + self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=self._adaln_embedding_coefficient) + + self.prompt_adaln_single = ( + AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None + ) + + # Video output components + self.scale_shift_table = torch.nn.Parameter(torch.empty(2, self.inner_dim)) + self.norm_out = torch.nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=norm_eps) + self.proj_out = torch.nn.Linear(self.inner_dim, out_channels) + + def _init_audio( + self, + in_channels: int, + out_channels: int, + norm_eps: float, + caption_projection: torch.nn.Module | None = None, + ) -> None: + """Initialize audio-specific components.""" + + # Audio input components + self.audio_patchify_proj = torch.nn.Linear(in_channels, self.audio_inner_dim, bias=True) + if caption_projection is not None: + self.audio_caption_projection = caption_projection + + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=self._adaln_embedding_coefficient, + ) + + self.audio_prompt_adaln_single = ( + AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2) if self.cross_attention_adaln else None + ) + + # Audio output components + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(2, self.audio_inner_dim)) + self.audio_norm_out = torch.nn.LayerNorm(self.audio_inner_dim, elementwise_affine=False, eps=norm_eps) + self.audio_proj_out = torch.nn.Linear(self.audio_inner_dim, out_channels) + + def _init_audio_video( + self, + num_scale_shift_values: int, + ) -> None: + """Initialize audio-video cross-attention components.""" + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=num_scale_shift_values, + ) + + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + embedding_coefficient=1, + ) + + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=1, + ) + + def _init_preprocessors( + self, + cross_pe_max_pos: int | None = None, + ) -> None: + """Initialize preprocessors for LTX.""" + + if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled(): + self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_a2v_gate_adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + caption_projection=getattr(self, "caption_projection", None), + prompt_adaln=getattr(self, "prompt_adaln_single", None), + ) + self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single, + cross_gate_adaln=self.av_ca_v2a_gate_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + cross_pe_max_pos=cross_pe_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + audio_cross_attention_dim=self.audio_cross_attention_dim, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + av_ca_timestep_scale_multiplier=self.av_ca_timestep_scale_multiplier, + caption_projection=getattr(self, "audio_caption_projection", None), + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), + ) + elif self.model_type.is_video_enabled(): + self.video_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.patchify_proj, + adaln=self.adaln_single, + inner_dim=self.inner_dim, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + caption_projection=getattr(self, "caption_projection", None), + prompt_adaln=getattr(self, "prompt_adaln_single", None), + ) + elif self.model_type.is_audio_enabled(): + self.audio_args_preprocessor = TransformerArgsPreprocessor( + patchify_proj=self.audio_patchify_proj, + adaln=self.audio_adaln_single, + inner_dim=self.audio_inner_dim, + max_pos=self.audio_positional_embedding_max_pos, + num_attention_heads=self.audio_num_attention_heads, + use_middle_indices_grid=self.use_middle_indices_grid, + timestep_scale_multiplier=self.timestep_scale_multiplier, + double_precision_rope=self.double_precision_rope, + positional_embedding_theta=self.positional_embedding_theta, + rope_type=self.rope_type, + caption_projection=getattr(self, "audio_caption_projection", None), + prompt_adaln=getattr(self, "audio_prompt_adaln_single", None), + ) + + def _init_transformer_blocks( + self, + num_layers: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_attention_head_dim: int, + audio_cross_attention_dim: int, + norm_eps: float, + attention_type: AttentionFunction | AttentionCallable, + apply_gated_attention: bool, + ) -> None: + """Initialize transformer blocks for LTX.""" + video_config = ( + TransformerConfig( + dim=self.inner_dim, + heads=self.num_attention_heads, + d_head=attention_head_dim, + context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, + ) + if self.model_type.is_video_enabled() + else None + ) + audio_config = ( + TransformerConfig( + dim=self.audio_inner_dim, + heads=self.audio_num_attention_heads, + d_head=audio_attention_head_dim, + context_dim=audio_cross_attention_dim, + apply_gated_attention=apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, + ) + if self.model_type.is_audio_enabled() + else None + ) + self.transformer_blocks = torch.nn.ModuleList( + [ + BasicAVTransformerBlock( + idx=idx, + video=video_config, + audio=audio_config, + rope_type=self.rope_type, + norm_eps=norm_eps, + attention_function=attention_type, + ) + for idx in range(num_layers) + ] + ) + + def set_gradient_checkpointing(self, enable: bool) -> None: + """Enable or disable gradient checkpointing for transformer blocks. + Gradient checkpointing trades compute for memory by recomputing activations + during the backward pass instead of storing them. This can significantly + reduce memory usage at the cost of ~20-30% slower training. + Args: + enable: Whether to enable gradient checkpointing + """ + self._enable_gradient_checkpointing = enable + + def _process_transformer_blocks( + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[TransformerArgs, TransformerArgs]: + """Process transformer blocks for LTXAV.""" + + # Process transformer blocks + for block in self.transformer_blocks: + if self._enable_gradient_checkpointing and self.training: + # Use gradient checkpointing to save memory during training. + # With use_reentrant=False, we can pass dataclasses directly - + # PyTorch will track all tensor leaves in the computation graph. + video, audio = torch.utils.checkpoint.checkpoint( + block, + video, + audio, + perturbations, + use_reentrant=False, + ) + else: + video, audio = block( + video=video, + audio=audio, + perturbations=perturbations, + ) + + return video, audio + + def _process_output( + self, + scale_shift_table: torch.Tensor, + norm_out: torch.nn.LayerNorm, + proj_out: torch.nn.Linear, + x: torch.Tensor, + embedded_timestep: torch.Tensor, + ) -> torch.Tensor: + """Process output for LTXV.""" + # Apply scale-shift modulation + scale_shift_values = ( + scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + x = norm_out(x) + x = x * (1 + scale) + shift + x = proj_out(x) + return x + + def forward( + self, video: Modality | None, audio: Modality | None, perturbations: BatchedPerturbationConfig + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for LTX models. + Returns: + Processed output tensors + """ + if not self.model_type.is_video_enabled() and video is not None: + raise ValueError("Video is not enabled for this model") + if not self.model_type.is_audio_enabled() and audio is not None: + raise ValueError("Audio is not enabled for this model") + + video_args = self.video_args_preprocessor.prepare(video, audio) if video is not None else None + audio_args = self.audio_args_preprocessor.prepare(audio, video) if audio is not None else None + # Process transformer blocks + video_out, audio_out = self._process_transformer_blocks( + video=video_args, + audio=audio_args, + perturbations=perturbations, + ) + + # Process output + vx = ( + self._process_output( + self.scale_shift_table, self.norm_out, self.proj_out, video_out.x, video_out.embedded_timestep + ) + if video_out is not None + else None + ) + ax = ( + self._process_output( + self.audio_scale_shift_table, + self.audio_norm_out, + self.audio_proj_out, + audio_out.x, + audio_out.embedded_timestep, + ) + if audio_out is not None + else None + ) + return vx, ax + + +class LegacyX0Model(torch.nn.Module): + """ + Legacy X0 model implementation. + Returns fully denoised output based on the velocities produced by the base model. + """ + + def __init__(self, velocity_model: LTXModel): + super().__init__() + self.velocity_model = velocity_model + + def forward( + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + sigma: float, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Denoise the video and audio according to the sigma. + Returns: + Denoised video and audio + """ + vx, ax = self.velocity_model(video, audio, perturbations) + denoised_video = to_denoised(video.latent, vx, sigma) if vx is not None else None + denoised_audio = to_denoised(audio.latent, ax, sigma) if ax is not None else None + return denoised_video, denoised_audio + + +class X0Model(torch.nn.Module): + """ + X0 model implementation. + Returns fully denoised outputs based on the velocities produced by the base model. + Applies scaled denoising to the video and audio according to the timesteps = sigma * denoising_mask. + """ + + def __init__(self, velocity_model: LTXModel): + super().__init__() + self.velocity_model = velocity_model + + def forward( + self, + video: Modality | None, + audio: Modality | None, + perturbations: BatchedPerturbationConfig, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Denoise the video and audio according to the sigma. + Returns: + Denoised video and audio + """ + vx, ax = self.velocity_model(video, audio, perturbations) + denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None + denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None + return denoised_video, denoised_audio diff --git a/ltx2/ltx_core/model/transformer/model_configurator.py b/ltx2/ltx_core/model/transformer/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..adbbce50841c0c830056ff14428164a6bea71211 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/model_configurator.py @@ -0,0 +1,152 @@ +import torch + +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.transformer.attention import AttentionFunction +from ltx_core.model.transformer.model import LTXModel, LTXModelType +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.text_projection import create_caption_projection +from ltx_core.utils import check_config_value + + +class LTXModelConfigurator(ModelConfigurator[LTXModel]): + """ + Configurator for LTX model. + Used to create an LTX model from a configuration dictionary. + """ + + @classmethod + def from_config(cls: type[LTXModel], config: dict) -> LTXModel: + # Build caption projections for 19B models (projection handled in transformer). + caption_projection, audio_caption_projection = _build_caption_projections(config, is_av=True) + + config = config.get("transformer", {}) + + check_config_value(config, "dropout", 0.0) + check_config_value(config, "attention_bias", True) + check_config_value(config, "num_vector_embeds", None) + check_config_value(config, "activation_fn", "gelu-approximate") + check_config_value(config, "num_embeds_ada_norm", 1000) + check_config_value(config, "use_linear_projection", False) + check_config_value(config, "only_cross_attention", False) + check_config_value(config, "cross_attention_norm", True) + check_config_value(config, "double_self_attention", False) + check_config_value(config, "upcast_attention", False) + check_config_value(config, "standardization_norm", "rms_norm") + check_config_value(config, "norm_elementwise_affine", False) + check_config_value(config, "qk_norm", "rms_norm") + check_config_value(config, "positional_embedding_type", "rope") + check_config_value(config, "use_audio_video_cross_attention", True) + check_config_value(config, "share_ff", False) + check_config_value(config, "av_cross_ada_norm", True) + check_config_value(config, "use_middle_indices_grid", True) + + return LTXModel( + model_type=LTXModelType.AudioVideo, + num_attention_heads=config.get("num_attention_heads", 32), + attention_head_dim=config.get("attention_head_dim", 128), + in_channels=config.get("in_channels", 128), + out_channels=config.get("out_channels", 128), + num_layers=config.get("num_layers", 48), + cross_attention_dim=config.get("cross_attention_dim", 4096), + norm_eps=config.get("norm_eps", 1e-06), + attention_type=AttentionFunction(config.get("attention_type", "default")), + positional_embedding_theta=config.get("positional_embedding_theta", 10000.0), + positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]), + timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000), + use_middle_indices_grid=config.get("use_middle_indices_grid", True), + audio_num_attention_heads=config.get("audio_num_attention_heads", 32), + audio_attention_head_dim=config.get("audio_attention_head_dim", 64), + audio_in_channels=config.get("audio_in_channels", 128), + audio_out_channels=config.get("audio_out_channels", 128), + audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048), + audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]), + av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1), + rope_type=LTXRopeType(config.get("rope_type", "interleaved")), + double_precision_rope=config.get("frequencies_precision", False) == "float64", + apply_gated_attention=config.get("apply_gated_attention", False), + caption_projection=caption_projection, + audio_caption_projection=audio_caption_projection, + cross_attention_adaln=config.get("cross_attention_adaln", False), + ) + + +class LTXVideoOnlyModelConfigurator(ModelConfigurator[LTXModel]): + """ + Configurator for LTX video only model. + Used to create an LTX video only model from a configuration dictionary. + """ + + @classmethod + def from_config(cls: type[LTXModel], config: dict) -> LTXModel: + # Build caption projection for 19B model (projection handled in transformer). + caption_projection, _ = _build_caption_projections(config, is_av=False) + + config = config.get("transformer", {}) + + check_config_value(config, "dropout", 0.0) + check_config_value(config, "attention_bias", True) + check_config_value(config, "num_vector_embeds", None) + check_config_value(config, "activation_fn", "gelu-approximate") + check_config_value(config, "num_embeds_ada_norm", 1000) + check_config_value(config, "use_linear_projection", False) + check_config_value(config, "only_cross_attention", False) + check_config_value(config, "cross_attention_norm", True) + check_config_value(config, "double_self_attention", False) + check_config_value(config, "upcast_attention", False) + check_config_value(config, "standardization_norm", "rms_norm") + check_config_value(config, "norm_elementwise_affine", False) + check_config_value(config, "qk_norm", "rms_norm") + check_config_value(config, "positional_embedding_type", "rope") + check_config_value(config, "use_middle_indices_grid", True) + + return LTXModel( + model_type=LTXModelType.VideoOnly, + num_attention_heads=config.get("num_attention_heads", 32), + attention_head_dim=config.get("attention_head_dim", 128), + in_channels=config.get("in_channels", 128), + out_channels=config.get("out_channels", 128), + num_layers=config.get("num_layers", 48), + cross_attention_dim=config.get("cross_attention_dim", 4096), + norm_eps=config.get("norm_eps", 1e-06), + attention_type=AttentionFunction(config.get("attention_type", "default")), + positional_embedding_theta=config.get("positional_embedding_theta", 10000.0), + positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]), + timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000), + use_middle_indices_grid=config.get("use_middle_indices_grid", True), + rope_type=LTXRopeType(config.get("rope_type", "interleaved")), + double_precision_rope=config.get("frequencies_precision", False) == "float64", + apply_gated_attention=config.get("apply_gated_attention", False), + caption_projection=caption_projection, + cross_attention_adaln=config.get("cross_attention_adaln", False), + ) + + +def _build_caption_projections( + config: dict, + is_av: bool, +) -> tuple[torch.nn.Module | None, torch.nn.Module | None]: + """Build caption projections for the transformer when projection is NOT in the text encoder. + 19B models: projection is in the transformer (caption_proj_before_connector=False). + 22B models: projection is in the text encoder, so no projections are created here. + Args: + config: Full model config dict (must contain "transformer" key). + is_av: Whether this is an audio-video model. When False, audio projection is skipped. + Returns: + Tuple of (video_caption_projection, audio_caption_projection), both None for 22B models. + """ + transformer_config = config.get("transformer", {}) + if transformer_config.get("caption_proj_before_connector", False): + return None, None + + with torch.device("meta"): + caption_projection = create_caption_projection(transformer_config) + audio_caption_projection = create_caption_projection(transformer_config, audio=True) if is_av else None + return caption_projection, audio_caption_projection + + +LTXV_MODEL_COMFY_RENAMING_MAP = ( + SDOps("LTXV_MODEL_COMFY_PREFIX_MAP") + .with_matching(prefix="model.diffusion_model.") + .with_replacement("model.diffusion_model.", "") +) diff --git a/ltx2/ltx_core/model/transformer/rope.py b/ltx2/ltx_core/model/transformer/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce58d90184f4cd4da136a3045b05d04d5a89e44 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/rope.py @@ -0,0 +1,204 @@ +import functools +import math +from enum import Enum +from typing import Callable, Tuple + +import numpy as np +import torch +from einops import rearrange + + +class LTXRopeType(Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + +def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, +) -> torch.Tensor: + if rope_type == LTXRopeType.INTERLEAVED: + return apply_interleaved_rotary_emb(input_tensor, *freqs_cis) + elif rope_type == LTXRopeType.SPLIT: + return apply_split_rotary_emb(input_tensor, *freqs_cis) + else: + raise ValueError(f"Invalid rope type: {rope_type}") + + +def apply_interleaved_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +def apply_split_rotary_emb( + input_tensor: torch.Tensor, cos_freqs: torch.Tensor, sin_freqs: torch.Tensor +) -> torch.Tensor: + needs_reshape = False + if input_tensor.ndim != 4 and cos_freqs.ndim == 4: + b, h, t, _ = cos_freqs.shape + input_tensor = input_tensor.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + + output = split_input * cos_freqs.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + + first_half_output.addcmul_(-sin_freqs.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin_freqs.unsqueeze(-2), first_half_input) + + output = rearrange(output, "... d r -> ... (d r)") + if needs_reshape: + output = output.swapaxes(1, 2).reshape(b, t, -1) + + return output + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + np.log(start) / np.log(theta), + np.log(end) / np.log(theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + + +@functools.lru_cache(maxsize=5) +def generate_freq_grid_pytorch( + positional_embedding_theta: float, positional_embedding_max_pos_count: int, inner_dim: int +) -> torch.Tensor: + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + + +def get_fractional_positions(indices_grid: torch.Tensor, max_pos: list[int]) -> torch.Tensor: + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), ( + f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})" + ) + fractional_positions = torch.stack( + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + dim=-1, + ) + return fractional_positions + + +def generate_freqs( + indices: torch.Tensor, indices_grid: torch.Tensor, max_pos: list[int], use_middle_indices_grid: bool +) -> torch.Tensor: + if use_middle_indices_grid: + assert len(indices_grid.shape) == 4 + assert indices_grid.shape[-1] == 2 + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] + + fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = indices.to(device=fractional_positions.device) + + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + return freqs + + +def split_freqs_cis(freqs: torch.Tensor, pad_size: int, num_attention_heads: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1) + + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq + + +def interleaved_freqs_cis(freqs: torch.Tensor, pad_size: int) -> tuple[torch.Tensor, torch.Tensor]: + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, :pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq + + +def precompute_freqs_cis( + indices_grid: torch.Tensor, + dim: int, + out_dtype: torch.dtype, + theta: float = 10000.0, + max_pos: list[int] | None = None, + use_middle_indices_grid: bool = False, + num_attention_heads: int = 32, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + freq_grid_generator: Callable[[float, int, int, torch.device], torch.Tensor] = generate_freq_grid_pytorch, +) -> tuple[torch.Tensor, torch.Tensor]: + if max_pos is None: + max_pos = [20, 2048, 2048] + + indices = freq_grid_generator(theta, indices_grid.shape[1], dim) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if rope_type == LTXRopeType.SPLIT: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype) diff --git a/ltx2/ltx_core/model/transformer/text_projection.py b/ltx2/ltx_core/model/transformer/text_projection.py new file mode 100644 index 0000000000000000000000000000000000000000..5046755d248cdfa53985659ecb5e45d5e9da9ec4 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/text_projection.py @@ -0,0 +1,38 @@ +import torch + + +class PixArtAlphaTextProjection(torch.nn.Module): + """ + Projects caption embeddings using dual linear layers. + Flow: linear_1 β†’ activation β†’ linear_2 + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_features: int, hidden_size: int, out_features: int | None = None, act_fn: str = "gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = torch.nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + if act_fn == "gelu_tanh": + self.act_1 = torch.nn.GELU(approximate="tanh") + elif act_fn == "silu": + self.act_1 = torch.nn.SiLU() + else: + raise ValueError(f"Unknown activation function: {act_fn}") + self.linear_2 = torch.nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def create_caption_projection(transformer_config: dict, audio: bool = False) -> PixArtAlphaTextProjection: + """Create a caption projection for the transformer (V1/19B only).""" + caption_channels = transformer_config["caption_channels"] + if audio: + inner_dim = transformer_config["audio_num_attention_heads"] * transformer_config["audio_attention_head_dim"] + else: + inner_dim = transformer_config["num_attention_heads"] * transformer_config["attention_head_dim"] + return PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) diff --git a/ltx2/ltx_core/model/transformer/timestep_embedding.py b/ltx2/ltx_core/model/transformer/timestep_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..820a955894aed2b5509b876e7d06c6d3d9ab1d26 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/timestep_embedding.py @@ -0,0 +1,143 @@ +import math + +import torch + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TimestepEmbedding(torch.nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + ): + super().__init__() + + self.linear_1 = torch.nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = torch.nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = torch.nn.SiLU() + time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim + + self.linear_2 = torch.nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + + def forward(self, sample: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Timesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class PixArtAlphaCombinedTimestepSizeEmbeddings(torch.nn.Module): + """ + For PixArt-Alpha. + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__( + self, + embedding_dim: int, + size_emb_dim: int, + ): + super().__init__() + + self.outdim = size_emb_dim + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + return timesteps_emb diff --git a/ltx2/ltx_core/model/transformer/transformer.py b/ltx2/ltx_core/model/transformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..af8b606d477d045556741310d4d2acffc44a9d3f --- /dev/null +++ b/ltx2/ltx_core/model/transformer/transformer.py @@ -0,0 +1,398 @@ +from dataclasses import dataclass, replace + +import torch + +from ltx_core.guidance.perturbations import BatchedPerturbationConfig, PerturbationType +from ltx_core.model.transformer.adaln import adaln_embedding_coefficient +from ltx_core.model.transformer.attention import Attention, AttentionCallable, AttentionFunction +from ltx_core.model.transformer.feed_forward import FeedForward +from ltx_core.model.transformer.rope import LTXRopeType +from ltx_core.model.transformer.transformer_args import TransformerArgs +from ltx_core.utils import rms_norm + + +@dataclass +class TransformerConfig: + dim: int + heads: int + d_head: int + context_dim: int + apply_gated_attention: bool = False + cross_attention_adaln: bool = False + + +class BasicAVTransformerBlock(torch.nn.Module): + def __init__( + self, + idx: int, + video: TransformerConfig | None = None, + audio: TransformerConfig | None = None, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + norm_eps: float = 1e-6, + attention_function: AttentionFunction | AttentionCallable = AttentionFunction.DEFAULT, + ): + super().__init__() + + self.idx = idx + if video is not None: + self.attn1 = Attention( + query_dim=video.dim, + heads=video.heads, + dim_head=video.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + apply_gated_attention=video.apply_gated_attention, + ) + self.attn2 = Attention( + query_dim=video.dim, + context_dim=video.context_dim, + heads=video.heads, + dim_head=video.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + apply_gated_attention=video.apply_gated_attention, + ) + self.ff = FeedForward(video.dim, dim_out=video.dim) + video_sst_size = adaln_embedding_coefficient(video.cross_attention_adaln) + self.scale_shift_table = torch.nn.Parameter(torch.empty(video_sst_size, video.dim)) + + if audio is not None: + self.audio_attn1 = Attention( + query_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + context_dim=None, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + apply_gated_attention=audio.apply_gated_attention, + ) + self.audio_attn2 = Attention( + query_dim=audio.dim, + context_dim=audio.context_dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + apply_gated_attention=audio.apply_gated_attention, + ) + self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim) + audio_sst_size = adaln_embedding_coefficient(audio.cross_attention_adaln) + self.audio_scale_shift_table = torch.nn.Parameter(torch.empty(audio_sst_size, audio.dim)) + + if audio is not None and video is not None: + # Q: Video, K,V: Audio + self.audio_to_video_attn = Attention( + query_dim=video.dim, + context_dim=audio.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + apply_gated_attention=video.apply_gated_attention, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = Attention( + query_dim=audio.dim, + context_dim=video.dim, + heads=audio.heads, + dim_head=audio.d_head, + rope_type=rope_type, + norm_eps=norm_eps, + attention_function=attention_function, + apply_gated_attention=audio.apply_gated_attention, + ) + + self.scale_shift_table_a2v_ca_audio = torch.nn.Parameter(torch.empty(5, audio.dim)) + self.scale_shift_table_a2v_ca_video = torch.nn.Parameter(torch.empty(5, video.dim)) + + self.cross_attention_adaln = (video is not None and video.cross_attention_adaln) or ( + audio is not None and audio.cross_attention_adaln + ) + + if self.cross_attention_adaln and video is not None: + self.prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, video.dim)) + if self.cross_attention_adaln and audio is not None: + self.audio_prompt_scale_shift_table = torch.nn.Parameter(torch.empty(2, audio.dim)) + + self.norm_eps = norm_eps + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice + ) -> tuple[torch.Tensor, ...]: + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + scale_shift_indices: slice, + num_scale_shift_values: int = 4, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], batch_size, scale_shift_timestep, scale_shift_indices + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], batch_size, gate_timestep, slice(None, None) + ) + + scale, shift = (t.squeeze(2) for t in scale_shift_ada_values) + (gate,) = (t.squeeze(2) for t in gate_ada_values) + + return scale, shift, gate + + def _apply_text_cross_attention( + self, + x: torch.Tensor, + context: torch.Tensor, + attn: AttentionCallable, + scale_shift_table: torch.Tensor, + prompt_scale_shift_table: torch.Tensor | None, + timestep: torch.Tensor, + prompt_timestep: torch.Tensor | None, + context_mask: torch.Tensor | None, + cross_attention_adaln: bool = False, + ) -> torch.Tensor: + """Apply text cross-attention, with optional AdaLN modulation.""" + if cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values(scale_shift_table, x.shape[0], timestep, slice(6, 9)) + return apply_cross_attention_adaln( + x, + context, + attn, + shift_q, + scale_q, + gate, + prompt_scale_shift_table, + prompt_timestep, + context_mask, + self.norm_eps, + ) + return attn(rms_norm(x, eps=self.norm_eps), context=context, mask=context_mask) + + def forward( # noqa: PLR0915 + self, + video: TransformerArgs | None, + audio: TransformerArgs | None, + perturbations: BatchedPerturbationConfig | None = None, + ) -> tuple[TransformerArgs | None, TransformerArgs | None]: + if video is None and audio is None: + raise ValueError("At least one of video or audio must be provided") + + batch_size = (video or audio).x.shape[0] + + if perturbations is None: + perturbations = BatchedPerturbationConfig.empty(batch_size) + + vx = video.x if video is not None else None + ax = audio.x if audio is not None else None + + run_vx = video is not None and video.enabled and vx.numel() > 0 + run_ax = audio is not None and audio.enabled and ax.numel() > 0 + + run_a2v = run_vx and (audio is not None and ax.numel() > 0) + run_v2a = run_ax and (video is not None and vx.numel() > 0) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3) + ) + norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa + del vshift_msa, vscale_msa + + all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) + none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx) + v_mask = ( + perturbations.mask_like(PerturbationType.SKIP_VIDEO_SELF_ATTN, self.idx, vx) + if not all_perturbed and not none_perturbed + else None + ) + vx = ( + vx + + self.attn1( + norm_vx, + pe=video.positional_embeddings, + mask=video.self_attention_mask, + perturbation_mask=v_mask, + all_perturbed=all_perturbed, + ) + * vgate_msa + ) + del vgate_msa, norm_vx, v_mask + vx = vx + self._apply_text_cross_attention( + vx, + video.context, + self.attn2, + self.scale_shift_table, + getattr(self, "prompt_scale_shift_table", None), + video.timesteps, + video.prompt_timestep, + video.context_mask, + cross_attention_adaln=self.cross_attention_adaln, + ) + + if run_ax: + ashift_msa, ascale_msa, agate_msa = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3) + ) + + norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa + del ashift_msa, ascale_msa + all_perturbed = perturbations.all_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) + none_perturbed = not perturbations.any_in_batch(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx) + a_mask = ( + perturbations.mask_like(PerturbationType.SKIP_AUDIO_SELF_ATTN, self.idx, ax) + if not all_perturbed and not none_perturbed + else None + ) + ax = ( + ax + + self.audio_attn1( + norm_ax, + pe=audio.positional_embeddings, + mask=audio.self_attention_mask, + perturbation_mask=a_mask, + all_perturbed=all_perturbed, + ) + * agate_msa + ) + del agate_msa, norm_ax, a_mask + ax = ax + self._apply_text_cross_attention( + ax, + audio.context, + self.audio_attn2, + self.audio_scale_shift_table, + getattr(self, "audio_prompt_scale_shift_table", None), + audio.timesteps, + audio.prompt_timestep, + audio.context_mask, + cross_attention_adaln=self.cross_attention_adaln, + ) + + # Audio - Video cross attention. + if run_a2v or run_v2a: + vx_norm3 = rms_norm(vx, eps=self.norm_eps) + ax_norm3 = rms_norm(ax, eps=self.norm_eps) + + if run_a2v and not perturbations.all_in_batch(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx): + scale_ca_video_a2v, shift_ca_video_a2v, gate_out_a2v = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + slice(0, 2), + ) + vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v + del scale_ca_video_a2v, shift_ca_video_a2v + + scale_ca_audio_a2v, shift_ca_audio_a2v, _ = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + slice(0, 2), + ) + ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v + del scale_ca_audio_a2v, shift_ca_audio_a2v + a2v_mask = perturbations.mask_like(PerturbationType.SKIP_A2V_CROSS_ATTN, self.idx, vx) + vx = vx + ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=video.cross_positional_embeddings, + k_pe=audio.cross_positional_embeddings, + ) + * gate_out_a2v + * a2v_mask + ) + del gate_out_a2v, a2v_mask, vx_scaled, ax_scaled + + if run_v2a and not perturbations.all_in_batch(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx): + scale_ca_audio_v2a, shift_ca_audio_v2a, gate_out_v2a = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + audio.cross_scale_shift_timestep, + audio.cross_gate_timestep, + slice(2, 4), + ) + ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a + del scale_ca_audio_v2a, shift_ca_audio_v2a + scale_ca_video_v2a, shift_ca_video_v2a, _ = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + video.cross_scale_shift_timestep, + video.cross_gate_timestep, + slice(2, 4), + ) + vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a + del scale_ca_video_v2a, shift_ca_video_v2a + v2a_mask = perturbations.mask_like(PerturbationType.SKIP_V2A_CROSS_ATTN, self.idx, ax) + ax = ax + ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=audio.cross_positional_embeddings, + k_pe=video.cross_positional_embeddings, + ) + * gate_out_v2a + * v2a_mask + ) + del gate_out_v2a, v2a_mask, ax_scaled, vx_scaled + + del vx_norm3, ax_norm3 + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values( + self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6) + ) + vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp + vx = vx + self.ff(vx_scaled) * vgate_mlp + + del vshift_mlp, vscale_mlp, vgate_mlp, vx_scaled + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values( + self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6) + ) + ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp + ax = ax + self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp, ax_scaled + + return replace(video, x=vx) if video is not None else None, replace(audio, x=ax) if audio is not None else None + + +def apply_cross_attention_adaln( + x: torch.Tensor, + context: torch.Tensor, + attn: AttentionCallable, + q_shift: torch.Tensor, + q_scale: torch.Tensor, + q_gate: torch.Tensor, + prompt_scale_shift_table: torch.Tensor, + prompt_timestep: torch.Tensor, + context_mask: torch.Tensor | None = None, + norm_eps: float = 1e-6, +) -> torch.Tensor: + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = rms_norm(x, eps=norm_eps) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=context_mask) * q_gate diff --git a/ltx2/ltx_core/model/transformer/transformer_args.py b/ltx2/ltx_core/model/transformer/transformer_args.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd05664cf5eed7733beaf8aade5e13b76d4a236 --- /dev/null +++ b/ltx2/ltx_core/model/transformer/transformer_args.py @@ -0,0 +1,297 @@ +from dataclasses import dataclass, replace + +import torch + +from ltx_core.model.transformer.adaln import AdaLayerNormSingle +from ltx_core.model.transformer.modality import Modality +from ltx_core.model.transformer.rope import ( + LTXRopeType, + generate_freq_grid_np, + generate_freq_grid_pytorch, + precompute_freqs_cis, +) + + +@dataclass(frozen=True) +class TransformerArgs: + x: torch.Tensor + context: torch.Tensor + context_mask: torch.Tensor + timesteps: torch.Tensor + embedded_timestep: torch.Tensor + positional_embeddings: torch.Tensor + cross_positional_embeddings: torch.Tensor | None + cross_scale_shift_timestep: torch.Tensor | None + cross_gate_timestep: torch.Tensor | None + enabled: bool + prompt_timestep: torch.Tensor | None = None + self_attention_mask: torch.Tensor | None = ( + None # Additive log-space self-attention bias (B, 1, T, T), None = full attention + ) + + +class TransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + use_middle_indices_grid: bool, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + caption_projection: torch.nn.Module | None = None, + prompt_adaln: AdaLayerNormSingle | None = None, + ) -> None: + self.patchify_proj = patchify_proj + self.adaln = adaln + self.inner_dim = inner_dim + self.max_pos = max_pos + self.num_attention_heads = num_attention_heads + self.use_middle_indices_grid = use_middle_indices_grid + self.timestep_scale_multiplier = timestep_scale_multiplier + self.double_precision_rope = double_precision_rope + self.positional_embedding_theta = positional_embedding_theta + self.rope_type = rope_type + self.caption_projection = caption_projection + self.prompt_adaln = prompt_adaln + + def _prepare_timestep( + self, timestep: torch.Tensor, adaln: AdaLayerNormSingle, batch_size: int, hidden_dtype: torch.dtype + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare timestep embeddings.""" + timestep_scaled = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = adaln( + timestep_scaled.flatten(), + hidden_dtype=hidden_dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + return timestep, embedded_timestep + + def _prepare_context( + self, + context: torch.Tensor, + x: torch.Tensor, + ) -> torch.Tensor: + """Prepare context for transformer blocks.""" + if self.caption_projection is not None: + context = self.caption_projection(context) + batch_size = x.shape[0] + return context.view(batch_size, -1, x.shape[-1]) + + def _prepare_attention_mask(self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype) -> torch.Tensor | None: + """Prepare attention mask.""" + if attention_mask is None or torch.is_floating_point(attention_mask): + return attention_mask + + return (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + + def _prepare_self_attention_mask( + self, attention_mask: torch.Tensor | None, x_dtype: torch.dtype + ) -> torch.Tensor | None: + """Prepare self-attention mask by converting [0,1] values to additive log-space bias. + Input shape: (B, T, T) with values in [0, 1]. + Output shape: (B, 1, T, T) with 0.0 for full attention and a large negative value + for masked positions. + Positions with attention_mask <= 0 are fully masked (mapped to the dtype's minimum + representable value). Strictly positive entries are converted via log-space for + smooth attenuation, with small values clamped for numerical stability. + Returns None if input is None (no masking). + """ + if attention_mask is None: + return None + + # Convert [0, 1] attention mask to additive log-space bias: + # 1.0 -> log(1.0) = 0.0 (no bias, full attention) + # 0.0 -> finfo.min (fully masked) + finfo = torch.finfo(x_dtype) + eps = finfo.tiny + + bias = torch.full_like(attention_mask, finfo.min, dtype=x_dtype) + positive = attention_mask > 0 + if positive.any(): + bias[positive] = torch.log(attention_mask[positive].clamp(min=eps)).to(x_dtype) + + return bias.unsqueeze(1) # (B, 1, T, T) for head broadcast + + def _prepare_positional_embeddings( + self, + positions: torch.Tensor, + inner_dim: int, + max_pos: list[int], + use_middle_indices_grid: bool, + num_attention_heads: int, + x_dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare positional embeddings.""" + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + pe = precompute_freqs_cis( + positions, + dim=inner_dim, + out_dtype=x_dtype, + theta=self.positional_embedding_theta, + max_pos=max_pos, + use_middle_indices_grid=use_middle_indices_grid, + num_attention_heads=num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + return pe + + def prepare( + self, + modality: Modality, + cross_modality: Modality | None = None, # noqa: ARG002 + ) -> TransformerArgs: + x = self.patchify_proj(modality.latent) + batch_size = x.shape[0] + timestep, embedded_timestep = self._prepare_timestep( + modality.timesteps, self.adaln, batch_size, modality.latent.dtype + ) + prompt_timestep = None + if self.prompt_adaln is not None: + prompt_timestep, _ = self._prepare_timestep( + modality.sigma, self.prompt_adaln, batch_size, modality.latent.dtype + ) + context = self._prepare_context(modality.context, x) + attention_mask = self._prepare_attention_mask(modality.context_mask, modality.latent.dtype) + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + self_attention_mask = self._prepare_self_attention_mask(modality.attention_mask, modality.latent.dtype) + return TransformerArgs( + x=x, + context=context, + context_mask=attention_mask, + timesteps=timestep, + embedded_timestep=embedded_timestep, + positional_embeddings=pe, + cross_positional_embeddings=None, + cross_scale_shift_timestep=None, + cross_gate_timestep=None, + enabled=modality.enabled, + prompt_timestep=prompt_timestep, + self_attention_mask=self_attention_mask, + ) + + +class MultiModalTransformerArgsPreprocessor: + def __init__( # noqa: PLR0913 + self, + patchify_proj: torch.nn.Linear, + adaln: AdaLayerNormSingle, + cross_scale_shift_adaln: AdaLayerNormSingle, + cross_gate_adaln: AdaLayerNormSingle, + inner_dim: int, + max_pos: list[int], + num_attention_heads: int, + cross_pe_max_pos: int, + use_middle_indices_grid: bool, + audio_cross_attention_dim: int, + timestep_scale_multiplier: int, + double_precision_rope: bool, + positional_embedding_theta: float, + rope_type: LTXRopeType, + av_ca_timestep_scale_multiplier: int, + caption_projection: torch.nn.Module | None = None, + prompt_adaln: AdaLayerNormSingle | None = None, + ) -> None: + self.simple_preprocessor = TransformerArgsPreprocessor( + patchify_proj=patchify_proj, + adaln=adaln, + inner_dim=inner_dim, + max_pos=max_pos, + num_attention_heads=num_attention_heads, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + double_precision_rope=double_precision_rope, + positional_embedding_theta=positional_embedding_theta, + rope_type=rope_type, + caption_projection=caption_projection, + prompt_adaln=prompt_adaln, + ) + self.cross_scale_shift_adaln = cross_scale_shift_adaln + self.cross_gate_adaln = cross_gate_adaln + self.cross_pe_max_pos = cross_pe_max_pos + self.audio_cross_attention_dim = audio_cross_attention_dim + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + def prepare( + self, + modality: Modality, + cross_modality: Modality | None = None, + ) -> TransformerArgs: + transformer_args = self.simple_preprocessor.prepare(modality) + if cross_modality is None: + return transformer_args + + if cross_modality.sigma.numel() > 1: + if cross_modality.sigma.shape[0] != modality.timesteps.shape[0]: + raise ValueError("Cross modality sigma must have the same batch size as the modality") + if cross_modality.sigma.ndim != 1: + raise ValueError("Cross modality sigma must be a 1D tensor") + + cross_timestep = cross_modality.sigma.view( + modality.timesteps.shape[0], 1, *[1] * len(modality.timesteps.shape[2:]) + ) + + cross_pe = self.simple_preprocessor._prepare_positional_embeddings( + positions=modality.positions[:, 0:1, :], + inner_dim=self.audio_cross_attention_dim, + max_pos=[self.cross_pe_max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.simple_preprocessor.num_attention_heads, + x_dtype=modality.latent.dtype, + ) + + cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep( + timestep=cross_timestep, + timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, + batch_size=transformer_args.x.shape[0], + hidden_dtype=modality.latent.dtype, + ) + + return replace( + transformer_args, + cross_positional_embeddings=cross_pe, + cross_scale_shift_timestep=cross_scale_shift_timestep, + cross_gate_timestep=cross_gate_timestep, + ) + + def _prepare_cross_attention_timestep( + self, + timestep: torch.Tensor | None, + timestep_scale_multiplier: int, + batch_size: int, + hidden_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare cross attention timestep embeddings.""" + timestep = timestep * timestep_scale_multiplier + + av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier + + scale_shift_timestep, _ = self.cross_scale_shift_adaln( + timestep.flatten(), + hidden_dtype=hidden_dtype, + ) + scale_shift_timestep = scale_shift_timestep.view(batch_size, -1, scale_shift_timestep.shape[-1]) + gate_noise_timestep, _ = self.cross_gate_adaln( + timestep.flatten() * av_ca_factor, + hidden_dtype=hidden_dtype, + ) + gate_noise_timestep = gate_noise_timestep.view(batch_size, -1, gate_noise_timestep.shape[-1]) + + return scale_shift_timestep, gate_noise_timestep diff --git a/ltx2/ltx_core/model/upsampler/__init__.py b/ltx2/ltx_core/model/upsampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4836bfc9f61368d82ff781a9c1ced835920914fb --- /dev/null +++ b/ltx2/ltx_core/model/upsampler/__init__.py @@ -0,0 +1,10 @@ +"""Latent upsampler model components.""" + +from ltx_core.model.upsampler.model import LatentUpsampler, upsample_video +from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator + +__all__ = [ + "LatentUpsampler", + "LatentUpsamplerConfigurator", + "upsample_video", +] diff --git a/ltx2/ltx_core/model/upsampler/blur_downsample.py b/ltx2/ltx_core/model/upsampler/blur_downsample.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc0149730432954eef015a8ac0073177e3cea35 --- /dev/null +++ b/ltx2/ltx_core/model/upsampler/blur_downsample.py @@ -0,0 +1,53 @@ +import math + +import torch +import torch.nn.functional as F +from einops import rearrange + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + assert dims in (2, 3) + assert isinstance(stride, int) + assert stride >= 1 + assert kernel_size >= 3 + assert kernel_size % 2 == 1 + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + return self._apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self._apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor: + c = x2d.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + return x2d diff --git a/ltx2/ltx_core/model/upsampler/model.py b/ltx2/ltx_core/model/upsampler/model.py new file mode 100644 index 0000000000000000000000000000000000000000..10d65edab03b4fada16fa5e873d03e3d51bddba5 --- /dev/null +++ b/ltx2/ltx_core/model/upsampler/model.py @@ -0,0 +1,142 @@ +import torch +from einops import rearrange + +from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND +from ltx_core.model.upsampler.res_block import ResBlock +from ltx_core.model.upsampler.spatial_rational_resampler import SpatialRationalResampler +from ltx_core.model.video_vae import VideoEncoder + + +class LatentUpsampler(torch.nn.Module): + """ + Model to upsample VAE latents spatially and/or temporally. + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + spatial_scale (`float`): Scale factor for spatial upsampling + rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + # remove the first frame after upsampling. + # This is done because the first frame encodes one pixel frame. + x = x[:, :, 1:, :, :] + elif isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + +def upsample_video(latent: torch.Tensor, video_encoder: VideoEncoder, upsampler: "LatentUpsampler") -> torch.Tensor: + """ + Apply upsampling to the latent representation using the provided upsampler, + with normalization and un-normalization based on the video encoder's per-channel statistics. + Args: + latent: Input latent tensor of shape [B, C, F, H, W]. + video_encoder: VideoEncoder with per_channel_statistics for normalization. + upsampler: LatentUpsampler module to perform upsampling. + Returns: + torch.Tensor: Upsampled and re-normalized latent tensor. + """ + latent = video_encoder.per_channel_statistics.un_normalize(latent) + latent = upsampler(latent) + latent = video_encoder.per_channel_statistics.normalize(latent) + return latent diff --git a/ltx2/ltx_core/model/upsampler/model_configurator.py b/ltx2/ltx_core/model/upsampler/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..2a714a41706652eb8d6160a9165310fce5f0f7ba --- /dev/null +++ b/ltx2/ltx_core/model/upsampler/model_configurator.py @@ -0,0 +1,30 @@ +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.upsampler.model import LatentUpsampler + + +class LatentUpsamplerConfigurator(ModelConfigurator[LatentUpsampler]): + """ + Configurator for LatentUpsampler model. + Used to create a LatentUpsampler model from a configuration dictionary. + """ + + @classmethod + def from_config(cls: type[LatentUpsampler], config: dict) -> LatentUpsampler: + in_channels = config.get("in_channels", 128) + mid_channels = config.get("mid_channels", 512) + num_blocks_per_stage = config.get("num_blocks_per_stage", 4) + dims = config.get("dims", 3) + spatial_upsample = config.get("spatial_upsample", True) + temporal_upsample = config.get("temporal_upsample", False) + spatial_scale = config.get("spatial_scale", 2.0) + rational_resampler = config.get("rational_resampler", False) + return LatentUpsampler( + in_channels=in_channels, + mid_channels=mid_channels, + num_blocks_per_stage=num_blocks_per_stage, + dims=dims, + spatial_upsample=spatial_upsample, + temporal_upsample=temporal_upsample, + spatial_scale=spatial_scale, + rational_resampler=rational_resampler, + ) diff --git a/ltx2/ltx_core/model/upsampler/pixel_shuffle.py b/ltx2/ltx_core/model/upsampler/pixel_shuffle.py new file mode 100644 index 0000000000000000000000000000000000000000..3c78f3bb4cb6569391c83c0c25af198efceafd30 --- /dev/null +++ b/ltx2/ltx_core/model/upsampler/pixel_shuffle.py @@ -0,0 +1,54 @@ +import torch +from einops import rearrange + + +class PixelShuffleND(torch.nn.Module): + """ + N-dimensional pixel shuffle operation for upsampling tensors. + Args: + dims (int): Number of dimensions to apply pixel shuffle to. + - 1: Temporal (e.g., frames) + - 2: Spatial (e.g., height and width) + - 3: Spatiotemporal (e.g., depth, height, width) + upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension. + For dims=1, only the first value is used. + For dims=2, the first two values are used. + For dims=3, all three values are used. + The input tensor is rearranged so that the channel dimension is split into + smaller channels and upscaling factors, and the upscaling factors are moved + into the corresponding spatial/temporal dimensions. + Note: + This operation is equivalent to the patchifier operation in for the models. Consider + using this class instead. + """ + + def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + else: + raise ValueError(f"Unsupported dims: {self.dims}") diff --git a/ltx2/ltx_core/model/upsampler/res_block.py b/ltx2/ltx_core/model/upsampler/res_block.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7603527191e28eabd8f012ef361d72d63a6d4e --- /dev/null +++ b/ltx2/ltx_core/model/upsampler/res_block.py @@ -0,0 +1,37 @@ +from typing import Optional + +import torch + + +class ResBlock(torch.nn.Module): + """ + Residual block with two convolutional layers, group normalization, and SiLU activation. + Args: + channels (int): Number of input and output channels. + mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels` + if not specified. + dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3. + """ + + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x diff --git a/ltx2/ltx_core/model/upsampler/spatial_rational_resampler.py b/ltx2/ltx_core/model/upsampler/spatial_rational_resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e055615d6635b4b0630b55d5e91968e2145d12 --- /dev/null +++ b/ltx2/ltx_core/model/upsampler/spatial_rational_resampler.py @@ -0,0 +1,47 @@ +from typing import Tuple + +import torch +from einops import rearrange + +from ltx_core.model.upsampler.blur_downsample import BlurDownsample +from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}") + return mapping[float(scale)] + + +class SpatialRationalResampler(torch.nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + Args: + mid_channels (`int`): Number of intermediate channels for the convolution layer + scale (`float`): Spatial scaling factor. Supported values are: + - 0.75: Downsample by 3/4 (reduce spatial size) + - 1.5: Upsample by 3/2 (increase spatial size) + - 2.0: Upsample by 2x (double spatial size) + - 4.0: Upsample by 4x (quadruple spatial size) + Any other value will raise a ValueError. + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, _, f, _, _ = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x diff --git a/ltx2/ltx_core/model/video_vae/__init__.py b/ltx2/ltx_core/model/video_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..122bb8fd1af3273e967865137380d308fd311e49 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/__init__.py @@ -0,0 +1,23 @@ +"""Video VAE package.""" + +from ltx_core.model.video_vae.model_configurator import ( + VAE_DECODER_COMFY_KEYS_FILTER, + VAE_ENCODER_COMFY_KEYS_FILTER, + VideoDecoderConfigurator, + VideoEncoderConfigurator, +) +from ltx_core.model.video_vae.tiling import SpatialTilingConfig, TemporalTilingConfig, TilingConfig +from ltx_core.model.video_vae.video_vae import VideoDecoder, VideoEncoder, get_video_chunks_number + +__all__ = [ + "VAE_DECODER_COMFY_KEYS_FILTER", + "VAE_ENCODER_COMFY_KEYS_FILTER", + "SpatialTilingConfig", + "TemporalTilingConfig", + "TilingConfig", + "VideoDecoder", + "VideoDecoderConfigurator", + "VideoEncoder", + "VideoEncoderConfigurator", + "get_video_chunks_number", +] diff --git a/ltx2/ltx_core/model/video_vae/convolution.py b/ltx2/ltx_core/model/video_vae/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3957add65c70ee7dfb0e1d8c482f5e3172aa87 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/convolution.py @@ -0,0 +1,317 @@ +from typing import Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + +from ltx_core.model.video_vae.enums import PaddingModeType + + +def make_conv_nd( # noqa: PLR0913 + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS, +) -> nn.Module: + if not (spatial_padding_mode == temporal_padding_mode or causal): + raise NotImplementedError("spatial and temporal padding modes must be equal") + if dims == 2: + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + spatial_padding_mode=spatial_padding_mode, + ) + return nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + padding_mode=spatial_padding_mode.value, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias: bool = True, +) -> nn.Module: + if dims == 2: + return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + elif dims in (3, (2, 1)): + return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + ) -> None: + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.padding_mode = padding_mode + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = out_channels if in_channels < out_channels else in_channels + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / torch.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / torch.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward( + self, + x: torch.Tensor, + use_conv3d: bool = False, + skip_time_conv: bool = False, + ) -> torch.Tensor: + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + + return x + + def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: + b, _, _, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d( + x, + weight1, + self.bias1, + stride1, + padding1, + dilation1, + self.groups, + padding_mode=self.padding_mode, + ) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d( + x, + weight2, + self.bias2, + stride2, + padding2, + dilation2, + self.groups, + padding_mode=self.padding_mode, + ) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self) -> torch.Tensor: + return self.weight2 + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode=spatial_padding_mode.value, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self) -> torch.Tensor: + return self.conv.weight diff --git a/ltx2/ltx_core/model/video_vae/enums.py b/ltx2/ltx_core/model/video_vae/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..edb2eace10d724d9b7c91fcb9ca97b475b9ae578 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/enums.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class NormLayerType(Enum): + GROUP_NORM = "group_norm" + PIXEL_NORM = "pixel_norm" + + +class LogVarianceType(Enum): + PER_CHANNEL = "per_channel" + UNIFORM = "uniform" + CONSTANT = "constant" + NONE = "none" + + +class PaddingModeType(Enum): + ZEROS = "zeros" + REFLECT = "reflect" + REPLICATE = "replicate" + CIRCULAR = "circular" diff --git a/ltx2/ltx_core/model/video_vae/model_configurator.py b/ltx2/ltx_core/model/video_vae/model_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..8e64ce5c8333bb02544335afd561999272ea79d1 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/model_configurator.py @@ -0,0 +1,79 @@ +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType +from ltx_core.model.video_vae.video_vae import VideoDecoder, VideoEncoder + + +class VideoEncoderConfigurator(ModelConfigurator[VideoEncoder]): + """Configurator for creating a video VAE Encoder from a configuration dictionary.""" + + @classmethod + def from_config(cls: type[VideoEncoder], config: dict) -> VideoEncoder: + config = config.get("vae", {}) + convolution_dimensions = config.get("dims", 3) + in_channels = config.get("in_channels", 3) + latent_channels = config.get("latent_channels", 128) + spatial_padding_mode = PaddingModeType(config.get("spatial_padding_mode", "zeros")) + encoder_blocks = config.get("encoder_blocks", []) + patch_size = config.get("patch_size", 4) + norm_layer_str = config.get("norm_layer", "pixel_norm") + latent_log_var_str = config.get("latent_log_var", "uniform") + + return VideoEncoder( + convolution_dimensions=convolution_dimensions, + in_channels=in_channels, + out_channels=latent_channels, + encoder_blocks=encoder_blocks, + patch_size=patch_size, + norm_layer=NormLayerType(norm_layer_str), + latent_log_var=LogVarianceType(latent_log_var_str), + encoder_spatial_padding_mode=spatial_padding_mode, + ) + + +class VideoDecoderConfigurator(ModelConfigurator[VideoDecoder]): + """Configurator for creating a video VAE Decoder from a configuration dictionary.""" + + @classmethod + def from_config(cls: type[VideoDecoder], config: dict) -> VideoDecoder: + config = config.get("vae", {}) + convolution_dimensions = config.get("dims", 3) + latent_channels = config.get("latent_channels", 128) + spatial_padding_mode = PaddingModeType(config.get("spatial_padding_mode", "reflect")) + out_channels = config.get("out_channels", 3) + decoder_blocks = config.get("decoder_blocks", []) + patch_size = config.get("patch_size", 4) + norm_layer_str = config.get("norm_layer", "pixel_norm") + causal = config.get("causal_decoder", False) + timestep_conditioning = config.get("timestep_conditioning", True) + base_channels = config.get("decoder_base_channels", 128) + + return VideoDecoder( + convolution_dimensions=convolution_dimensions, + in_channels=latent_channels, + out_channels=out_channels, + decoder_blocks=decoder_blocks, + patch_size=patch_size, + norm_layer=NormLayerType(norm_layer_str), + causal=causal, + timestep_conditioning=timestep_conditioning, + decoder_spatial_padding_mode=spatial_padding_mode, + base_channels=base_channels, + ) + + +VAE_DECODER_COMFY_KEYS_FILTER = ( + SDOps("VAE_DECODER_COMFY_KEYS_FILTER") + .with_matching(prefix="vae.decoder.") + .with_matching(prefix="vae.per_channel_statistics.") + .with_replacement("vae.decoder.", "") + .with_replacement("vae.per_channel_statistics.", "per_channel_statistics.") +) + +VAE_ENCODER_COMFY_KEYS_FILTER = ( + SDOps("VAE_ENCODER_COMFY_KEYS_FILTER") + .with_matching(prefix="vae.encoder.") + .with_matching(prefix="vae.per_channel_statistics.") + .with_replacement("vae.encoder.", "") + .with_replacement("vae.per_channel_statistics.", "per_channel_statistics.") +) diff --git a/ltx2/ltx_core/model/video_vae/normalization.py b/ltx2/ltx_core/model/video_vae/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..45c003533a329d571e54cc483ccdcd2821883056 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/normalization.py @@ -0,0 +1,3 @@ +from ltx_core.model.common.normalization import PixelNorm, build_normalization_layer + +__all__ = ["PixelNorm", "build_normalization_layer"] diff --git a/ltx2/ltx_core/model/video_vae/ops.py b/ltx2/ltx_core/model/video_vae/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8912491cd9b8ce2deddc6837818c2354d68f0166 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/ops.py @@ -0,0 +1,82 @@ +import torch +from einops import rearrange +from torch import nn + + +def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks + and moves pixels from each block into separate channels (space-to-depth). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching). + For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw) + Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: + """ + Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from + channels back into patch_size x patch_size blocks (depth-to-space). + Args: + x: Input tensor (4D or 5D) + patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x. + patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion). + For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw) + Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512) + """ + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +class PerChannelStatistics(nn.Module): + """ + Per-channel statistics for normalizing and denormalizing the latent representation. + This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict. + """ + + def __init__(self, latent_channels: int = 128): + super().__init__() + self.register_buffer("std-of-means", torch.empty(latent_channels)) + self.register_buffer("mean-of-means", torch.empty(latent_channels)) + + def un_normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view( + 1, -1, 1, 1, 1 + ).to(x) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view( + 1, -1, 1, 1, 1 + ).to(x) diff --git a/ltx2/ltx_core/model/video_vae/resnet.py b/ltx2/ltx_core/model/video_vae/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1423f2f2b62c602dfdbd31357520adcb87f604b1 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/resnet.py @@ -0,0 +1,277 @@ +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from ltx_core.model.common.normalization import PixelNorm +from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings +from ltx_core.model.video_vae.convolution import make_conv_nd, make_linear_nd +from ltx_core.model.video_vae.enums import NormLayerType, PaddingModeType + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == NormLayerType.GROUP_NORM: + self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) + if in_channels != out_channels + else nn.Identity() + ) + + # Using GroupNorm with 1 group is equivalent to LayerNorm but works with (B, C, ...) layout + # avoiding the need for dimension rearrangement used in standard nn.LayerNorm + self.norm3 = ( + nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels)) + + def _feed_spatial_noise( + self, + hidden_states: torch.Tensor, + per_channel_scale: torch.Tensor, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + ada_values = self.scale_shift_table[None, ..., None, None, None].to( + device=hidden_states.device, dtype=hidden_states.dtype + ) + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, + self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype), + generator=generator, + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: NormLayerType = NormLayerType.GROUP_NORM, + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim=in_channels * 4, size_emb_dim=0 + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: + timestep_embed = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) + + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, + causal=causal, + timestep=timestep_embed, + generator=generator, + ) + + return hidden_states diff --git a/ltx2/ltx_core/model/video_vae/sampling.py b/ltx2/ltx_core/model/video_vae/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8f7427b9c7eedf38ed11b4b2637c39cc06c1a6 --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/sampling.py @@ -0,0 +1,123 @@ +import math +from typing import Tuple, Union + +import torch +from einops import rearrange +from torch import nn + +from .convolution import make_conv_nd +from .enums import PaddingModeType + + +class SpaceToDepthDownsample(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + stride: Tuple[int, int, int], + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.group_size = in_channels * math.prod(stride) // out_channels + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels // math.prod(stride), + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.stride[0] == 2: + x = torch.cat([x[:, :, :1, :, :], x], dim=2) # duplicate first frames for padding + + # skip connection + x_in = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) + x_in = x_in.mean(dim=2) + + # conv + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + x = x + x_in + + return x + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, + dims: int | Tuple[int, int], + in_channels: int, + stride: Tuple[int, int, int], + residual: bool = False, + out_channels_reduction_factor: int = 1, + spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + self.stride = stride + self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward( + self, + x: torch.Tensor, + causal: bool = True, + ) -> torch.Tensor: + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x diff --git a/ltx2/ltx_core/model/video_vae/tiling.py b/ltx2/ltx_core/model/video_vae/tiling.py new file mode 100644 index 0000000000000000000000000000000000000000..b32fc599bb75cebfa846caaa94863afbdf7ea19f --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/tiling.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap. + Args: + tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32. + tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0. + """ + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" + ) + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap. + Args: + tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8. + tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles. + Must be divisible by 8. Defaults to 0. + """ + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + if self.tile_overlap_in_frames >= self.tile_size_in_frames: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" + ) + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap. + Attributes: + spatial_config: Configuration for splitting spatial dimensions into tiles. + temporal_config: Configuration for splitting temporal dimension into tiles. + """ + + spatial_config: SpatialTilingConfig | None = None + temporal_config: TemporalTilingConfig | None = None + + @classmethod + def default(cls) -> "TilingConfig": + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) diff --git a/ltx2/ltx_core/model/video_vae/video_vae.py b/ltx2/ltx_core/model/video_vae/video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1040961625038aa8bb55d159c77342e50fd77e --- /dev/null +++ b/ltx2/ltx_core/model/video_vae/video_vae.py @@ -0,0 +1,1109 @@ +import logging +from typing import Any, Callable, Iterator, List, Tuple + +import torch +from einops import rearrange +from torch import nn + +from ltx_core.model.common.normalization import PixelNorm +from ltx_core.model.transformer.timestep_embedding import PixArtAlphaCombinedTimestepSizeEmbeddings +from ltx_core.model.video_vae.convolution import make_conv_nd +from ltx_core.model.video_vae.enums import LogVarianceType, NormLayerType, PaddingModeType +from ltx_core.model.video_vae.ops import PerChannelStatistics, patchify, unpatchify +from ltx_core.model.video_vae.resnet import ResnetBlock3D, UNetMidBlock3D +from ltx_core.model.video_vae.sampling import DepthToSpaceUpsample, SpaceToDepthDownsample +from ltx_core.model.video_vae.tiling import TilingConfig +from ltx_core.tiling import ( + DEFAULT_MAPPING_OPERATION, + DEFAULT_SPLIT_OPERATION, + DimensionIntervals, + MappingOperation, + Tile, + compute_rectangular_mask_1d, + compute_trapezoidal_mask_1d, + create_tiles, + split_temporal, +) +from ltx_core.tiling import ( + split_by_size as split_in_spatial, +) +from ltx_core.tiling import ( + split_temporal_causal as split_in_temporal, +) +from ltx_core.types import VIDEO_SCALE_FACTORS, SpatioTemporalScaleFactors, VideoLatentShape + +logger: logging.Logger = logging.getLogger(__name__) + + +def _make_encoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_x_y": + out_channels = in_channels * block_config.get("multiplier", 2) + block = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time_res": + out_channels = in_channels * block_config.get("multiplier", 2) + block = SpaceToDepthDownsample( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + return block, out_channels + + +class VideoEncoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Encoder. Encodes video frames into a latent representation. + The encoder compresses the input video through a series of downsampling operations controlled by + patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W'). + Compression Behavior: + The total compression is determined by: + 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4) + 2. Sequential compression through encoder_blocks based on their stride patterns + Compression blocks apply 2x compression in specified dimensions: + - "compress_time" / "compress_time_res": temporal only + - "compress_space" / "compress_space_res": spatial only (H and W) + - "compress_all" / "compress_all_res": all dimensions (F, H, W) + - "res_x" / "res_x_y": no compression + Standard LTX Video configuration: + - patch_size=4 + - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res + - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32 + - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16) + - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...) + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels. For RGB images, this is 3. + out_channels: The number of output channels (latent channels). For latent channels, this is 128. + encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: The patch size for initial spatial compression. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 3, + out_channels: int = 128, + encoder_blocks: List[Tuple[str, int]] | List[Tuple[str, dict[str, Any]]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, + encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, + ): + super().__init__() + + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for normalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) + + in_channels = in_channels * patch_size**2 + feature_channels = out_channels + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in encoder_blocks: + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_encoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + self.down_blocks.append(block) + + # out + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == LogVarianceType.PER_CHANNEL: + conv_out_channels *= 2 + elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: + conv_out_channels += 1 + elif latent_log_var != LogVarianceType.NONE: + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=conv_out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r""" + Encode video frames into normalized latent representation. + Args: + sample: Input video (B, C, F, H, W). F should be 1 + 8*k (e.g., 1, 9, 17, 25, 33...). + If not, the encoder crops the last frames to the nearest valid length. + Returns: + Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32. + Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16). + """ + # Validate frame count (crop to nearest valid length if needed) + frames_count = sample.shape[2] + if ((frames_count - 1) % 8) != 0: + frames_to_crop = (frames_count - 1) % 8 + logger.warning( + "Invalid number of frames %s for encode; cropping last %s frames to satisfy 1 + 8*k.", + frames_count, + frames_to_crop, + ) + sample = sample[:, :, :-frames_to_crop, ...] + + # Initial spatial compression: trade spatial resolution for channel depth + # This reduces H,W by patch_size and increases channels, making convolutions more efficient + # Example: (B, 3, F, 512, 512) -> (B, 48, F, 128, 128) with patch_size=4 + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + for down_block in self.down_blocks: + sample = down_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == LogVarianceType.UNIFORM: + # Uniform Variance: model outputs N means and 1 shared log-variance channel. + # We need to expand the single logvar to match the number of means channels + # to create a format compatible with PER_CHANNEL (means + logvar, each with N channels). + # Sample shape: (B, N+1, ...) where N = latent_channels (e.g., 128 means + 1 logvar = 129) + # Target shape: (B, 2*N, ...) where first N are means, last N are logvar + + if sample.shape[1] < 2: + raise ValueError( + f"Invalid channel count for UNIFORM mode: expected at least 2 channels " + f"(N means + 1 logvar), got {sample.shape[1]}" + ) + + # Extract means (first N channels) and logvar (last 1 channel) + means = sample[:, :-1, ...] # (B, N, ...) + logvar = sample[:, -1:, ...] # (B, 1, ...) + + # Repeat logvar N times to match means channels + # Use expand/repeat pattern that works for both 4D and 5D tensors + num_channels = means.shape[1] + repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2) + repeated_logvar = logvar.repeat(*repeat_shape) # (B, N, ...) + + # Concatenate to create (B, 2*N, ...) format: [means, repeated_logvar] + sample = torch.cat([means, repeated_logvar], dim=1) + elif self.latent_log_var == LogVarianceType.CONSTANT: + sample = sample[:, :-1, ...] + approx_ln_0 = -30 # this is the minimal clamp value in DiagonalGaussianDistribution objects + sample = torch.cat( + [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], + dim=1, + ) + + # Split into means and logvar, then normalize means + means, _ = torch.chunk(sample, 2, dim=1) + return self.per_channel_statistics.normalize(means) + + def tiled_encode( + self, + video: torch.Tensor, + tiling_config: TilingConfig | None = None, + ) -> torch.Tensor: + """Encode video to latent using tiled processing of the given video tensor. + Device Handling: + - Input video can be on CPU or GPU + - Accumulation buffers are created on model's device + - Each tile is automatically moved to model's device before encoding + - Output latent is returned on model's device + Args: + video: Input video tensor (B, 3, F, H, W) in range [-1, 1] + tiling_config: Tiling configuration for the video tensor + Returns: + Latent tensor (B, 128, F', H', W') on model's device + where F' = 1 + (F-1)/8, H' = H/32, W' = W/32 + """ + # Detect model device and dtype + model_device = next(self.parameters()).device + model_dtype = next(self.parameters()).dtype + + # Extract shape components + batch, _, frames, height, width = video.shape + + # Check frame count and crop if needed + if (frames - 1) % VIDEO_SCALE_FACTORS.time != 0: + frames_to_crop = (frames - 1) % VIDEO_SCALE_FACTORS.time + logger.warning( + f"Number of frames {frames} of input video is not ({VIDEO_SCALE_FACTORS.time} * k + 1), " + f"last {frames_to_crop} frames will be cropped" + ) + video = video[:, :, :-frames_to_crop, ...] + # Update frames after cropping + frames = video.shape[2] + + # Calculate output latent shape (inverse of upscale) + latent_shape = VideoLatentShape( + batch=batch, + channels=self.latent_channels, # 128 for standard VAE + frames=(frames - 1) // VIDEO_SCALE_FACTORS.time + 1, + height=height // VIDEO_SCALE_FACTORS.height, + width=width // VIDEO_SCALE_FACTORS.width, + ) + + # Prepare tiles (operates on VIDEO dimensions) + tiles = prepare_tiles_for_encoding(video, tiling_config) + + # Initialize accumulation buffers on model device + latent_buffer = torch.zeros( + latent_shape.to_torch_shape(), + device=model_device, + dtype=model_dtype, + ) + weights_buffer = torch.zeros_like(latent_buffer) + + # Process each tile + for tile in tiles: + # Extract video tile from input (may be on CPU) + video_tile = video[tile.in_coords] + + # Move tile to model device if needed + if video_tile.device != model_device or video_tile.dtype != model_dtype: + video_tile = video_tile.to(device=model_device, dtype=model_dtype) + + # Encode tile to latent (output on model device) + latent_tile = self.forward(video_tile) + + # Move blend mask to model device + mask = tile.blend_mask.to( + device=model_device, + dtype=model_dtype, + ) + + # Weighted accumulation in latent space + latent_buffer[tile.out_coords] += latent_tile * mask + weights_buffer[tile.out_coords] += mask + + del latent_tile, mask, video_tile + + # Normalize by accumulated weights + weights_buffer = weights_buffer.clamp(min=1e-8) + return latent_buffer / weights_buffer + + +def prepare_tiles_for_encoding( + video: torch.Tensor, + tiling_config: TilingConfig | None = None, +) -> List[Tile]: + """Prepare tiles for VAE encoding. + Args: + video: Input video tensor (B, 3, F, H, W) in range [-1, 1] + tiling_config: Tiling configuration for the video tensor + Returns: + List of tiles for the video tensor + """ + + splitters = [DEFAULT_SPLIT_OPERATION] * len(video.shape) + mappers = [DEFAULT_MAPPING_OPERATION] * len(video.shape) + minimum_spatial_overlap_px = 64 + minimum_temporal_overlap_frames = 16 + + if tiling_config is not None and tiling_config.spatial_config is not None: + cfg = tiling_config.spatial_config + + tile_size_px = cfg.tile_size_in_pixels + overlap_px = cfg.tile_overlap_in_pixels + + # Set minimum spatial overlap to 64 pixels in order to allow cutting padding from + # the front and back of the tiles and concatenate tiles without artifacts. + # The encoder uses symmetric padding (pad=1) in H and W at each conv layer. At tile + # boundaries, convs see padding (zeros/reflect) instead of real neighbor pixels, causing + # incorrect context near edges. + # For each overlap we discard 1 latent per edge (32px at scale 32) and concatenate tiles at a + # shared region with the next tile. + if overlap_px < minimum_spatial_overlap_px: + logger.warning( + f"Overlap pixels {overlap_px} in spatial tiling is less than \ + {minimum_spatial_overlap_px}, setting to minimum required {minimum_spatial_overlap_px}" + ) + overlap_px = minimum_spatial_overlap_px + + # Define split and map operations for the spatial dimensions + + # Height axis (H) + splitters[3] = split_in_spatial(tile_size_px, overlap_px) + mappers[3] = to_mapping_operation(map_spatial_interval_to_latent, scale=VIDEO_SCALE_FACTORS.height) + + # Width axis (W) + splitters[4] = split_in_spatial(tile_size_px, overlap_px) + mappers[4] = to_mapping_operation(map_spatial_interval_to_latent, scale=VIDEO_SCALE_FACTORS.width) + + if tiling_config is not None and tiling_config.temporal_config is not None: + cfg = tiling_config.temporal_config + tile_size_frames = cfg.tile_size_in_frames + overlap_frames = cfg.tile_overlap_in_frames + + if overlap_frames < minimum_temporal_overlap_frames: + logger.warning(f"Overlap frames {overlap_frames} is less than 16, setting to minimum required 16") + overlap_frames = minimum_temporal_overlap_frames + + splitters[2] = split_temporal(tile_size_frames, overlap_frames) + mappers[2] = to_mapping_operation(map_temporal_interval_to_latent, scale=VIDEO_SCALE_FACTORS.time) + + return create_tiles(video.shape, splitters, mappers) + + +def _make_decoder_block( + block_name: str, + block_config: dict[str, Any], + in_channels: int, + convolution_dimensions: int, + norm_layer: NormLayerType, + timestep_conditioning: bool, + norm_num_groups: int, + spatial_padding_mode: PaddingModeType, +) -> Tuple[nn.Module, int]: + out_channels = in_channels + if block_name == "res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + num_layers=block_config["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_config["attention_head_dim"], + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "res_x_y": + out_channels = in_channels // block_config.get("multiplier", 2) + block = ResnetBlock3D( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=out_channels, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_config.get("inject_noise", False), + timestep_conditioning=False, + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_time": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 1, 1), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_space": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(1, 2, 2), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + elif block_name == "compress_all": + out_channels = in_channels // block_config.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=convolution_dimensions, + in_channels=in_channels, + stride=(2, 2, 2), + residual=block_config.get("residual", False), + out_channels_reduction_factor=block_config.get("multiplier", 1), + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + return block, out_channels + + +class VideoDecoder(nn.Module): + _DEFAULT_NORM_NUM_GROUPS = 32 + """ + Variational Autoencoder Decoder. Decodes latent representation into video frames. + The decoder upsamples latents through a series of upsampling operations (inverse of encoder). + Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration. + Upsampling blocks expand dimensions by 2x in specified dimensions: + - "compress_time": temporal only + - "compress_space": spatial only (H and W) + - "compress_all": all dimensions (F, H, W) + - "res_x" / "res_x_y" / "attn_res_x": no upsampling + Causal Mode: + causal=False (standard): Symmetric padding, allows future frame dependencies. + causal=True: Causal padding, each frame depends only on past/current frames. + First frame removed after temporal upsampling in both modes. Output shape unchanged. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes. + Args: + convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). + in_channels: The number of input channels (latent channels). Default is 128. + out_channels: The number of output channels. For RGB images, this is 3. + decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params) + where params is either an int (num_layers) or a dict with configuration. + patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion: + H -> Hx4, W -> Wx4. Should be a power of 2. + norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding. + When True, uses causal padding (past/current frames only). + timestep_conditioning: Whether to condition the decoder on timestep for denoising. + """ + + def __init__( + self, + convolution_dimensions: int = 3, + in_channels: int = 128, + out_channels: int = 3, + decoder_blocks: List[Tuple[str, int | dict]] = [], # noqa: B006 + patch_size: int = 4, + norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, + causal: bool = False, + timestep_conditioning: bool = False, + decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, + base_channels: int = 128, + ): + super().__init__() + + # Spatiotemporal downscaling between decoded video space and VAE latents. + # According to the LTXV paper, the standard configuration downsamples + # video inputs by a factor of 8 in the temporal dimension and 32 in + # each spatial dimension (height and width). This parameter determines how + # many video frames and pixels correspond to a single latent cell. + self.video_downscale_factors = SpatioTemporalScaleFactors( + time=8, + width=32, + height=32, + ) + + self.patch_size = patch_size + out_channels = out_channels * patch_size**2 + self.causal = causal + self.timestep_conditioning = timestep_conditioning + self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS + + # Per-channel statistics for denormalizing latents + self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) + + # Noise and timestep parameters for decoder conditioning + self.decode_noise_scale = 0.025 + self.decode_timestep = 0.05 + + # LTX VAE decoder architecture uses 3 upsampler blocks with multiplier equals to 2. + # Hence the total feature_channels is multiplied by 8 (2^3). + feature_channels = base_channels * 8 + + self.conv_in = make_conv_nd( + dims=convolution_dimensions, + in_channels=in_channels, + out_channels=feature_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(decoder_blocks)): + # Convert int to dict format for uniform handling + block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params + + block, feature_channels = _make_decoder_block( + block_name=block_name, + block_config=block_config, + in_channels=feature_channels, + convolution_dimensions=convolution_dimensions, + norm_layer=norm_layer, + timestep_conditioning=timestep_conditioning, + norm_num_groups=self._norm_num_groups, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + self.up_blocks.append(block) + + if norm_layer == NormLayerType.GROUP_NORM: + self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) + elif norm_layer == NormLayerType.PIXEL_NORM: + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims=convolution_dimensions, + in_channels=feature_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + causal=True, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0)) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim=feature_channels * 2, size_emb_dim=0 + ) + self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels)) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r""" + Decode latent representation into video frames. + Args: + sample: Latent tensor (B, 128, F', H', W'). + timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None. + generator: Random generator for deterministic noise injection (if inject_noise=True in blocks). + Returns: + Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'. + Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512). + Note: First frame is removed after temporal upsampling regardless of causal mode. + When causal=False, allows future frame dependencies in convolutions but maintains same output shape. + """ + batch_size = sample.shape[0] + + # Add noise if timestep conditioning is enabled + if self.timestep_conditioning: + noise = ( + torch.randn( + sample.size(), + generator=generator, + dtype=sample.dtype, + device=sample.device, + ) + * self.decode_noise_scale + ) + + sample = noise + (1.0 - self.decode_noise_scale) * sample + + # Denormalize latents + sample = self.per_channel_statistics.un_normalize(sample) + + # Use default decode_timestep if timestep not provided + if timestep is None and self.timestep_conditioning: + timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype) + + sample = self.conv_in(sample, causal=self.causal) + + scaled_timestep = None + if self.timestep_conditioning: + if timestep is None: + raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") + scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample) + + for up_block in self.up_blocks: + if isinstance(up_block, UNetMidBlock3D): + block_kwargs = { + "causal": self.causal, + "timestep": scaled_timestep if self.timestep_conditioning else None, + "generator": generator, + } + sample = up_block(sample, **block_kwargs) + elif isinstance(up_block, ResnetBlock3D): + sample = up_block(sample, causal=self.causal, generator=generator) + else: + sample = up_block(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) + ada_values = self.last_scale_shift_table[None, ..., None, None, None].to( + device=sample.device, dtype=sample.dtype + ) + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + # Final spatial expansion: reverse the initial patchify from encoder + # Moves pixels from channels back to spatial dimensions + # Example: (B, 48, F, 128, 128) -> (B, 3, F, 512, 512) with patch_size=4 + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + def _prepare_tiles( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + ) -> List[Tile]: + splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape) + mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape) + if tiling_config is not None and tiling_config.spatial_config is not None: + cfg = tiling_config.spatial_config + long_side = max(latent.shape[3], latent.shape[4]) + + def enable_on_axis(axis_idx: int, factor: int) -> None: + size = cfg.tile_size_in_pixels // factor + overlap = cfg.tile_overlap_in_pixels // factor + axis_length = latent.shape[axis_idx] + lower_threshold = max(2, overlap + 1) + tile_size = max(lower_threshold, round(size * axis_length / long_side)) + splitters[axis_idx] = split_in_spatial(tile_size, overlap) + mappers[axis_idx] = to_mapping_operation(map_spatial_slice, scale=factor) + + enable_on_axis(3, self.video_downscale_factors.height) + enable_on_axis(4, self.video_downscale_factors.width) + + if tiling_config is not None and tiling_config.temporal_config is not None: + cfg = tiling_config.temporal_config + tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time + overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time + splitters[2] = split_in_temporal(tile_size, overlap) + mappers[2] = to_mapping_operation(map_temporal_slice, scale=self.video_downscale_factors.time) + + return create_tiles(latent.shape, splitters, mappers) + + def tiled_decode( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + timestep: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> Iterator[torch.Tensor]: + """ + Decode a latent tensor into video frames using tiled processing. + Splits the latent tensor into tiles, decodes each tile individually, + and yields video chunks as they become available. + Args: + latent: Input latent tensor (B, C, F', H', W'). + tiling_config: Tiling configuration for the latent tensor. + timestep: Optional timestep for decoder conditioning. + generator: Optional random generator for deterministic decoding. + Yields: + Video chunks (B, C, T, H, W) by temporal slices; + """ + + # Calculate full video shape from latent shape to get spatial dimensions + full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors) + tiles = self._prepare_tiles(latent, tiling_config) + + temporal_groups = self._group_tiles_by_temporal_slice(tiles) + + # State for temporal overlap handling + previous_chunk = None + previous_weights = None + previous_temporal_slice = None + + for temporal_group_tiles in temporal_groups: + curr_temporal_slice = temporal_group_tiles[0].out_coords[2] + + # Calculate the shape of the temporal buffer for this group of tiles. + # The temporal length depends on whether this is the first tile (starts at 0) or not. + # - First tile: (frames - 1) * scale + 1 + # - Subsequent tiles: frames * scale + # This logic is handled by TemporalAxisMapping and reflected in out_coords. + temporal_tile_buffer_shape = full_video_shape._replace( + frames=curr_temporal_slice.stop - curr_temporal_slice.start, + ) + + buffer = torch.zeros( + temporal_tile_buffer_shape.to_torch_shape(), + device=latent.device, + dtype=latent.dtype, + ) + + curr_weights = self._accumulate_temporal_group_into_buffer( + group_tiles=temporal_group_tiles, + buffer=buffer, + latent=latent, + timestep=timestep, + generator=generator, + ) + + # Blend with previous temporal chunk if it exists + if previous_chunk is not None: + # Check if current temporal slice overlaps with previous temporal slice + if previous_temporal_slice.stop > curr_temporal_slice.start: + overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start + temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None) + + # The overlap is already masked before it reaches this step. Each tile is accumulated into buffer + # with its trapezoidal mask, and curr_weights accumulates the same mask. In the overlap blend we add + # the masked values (buffer[...]) and the corresponding weights (curr_weights[...]) into the + # previous buffers, then later normalize by weights. + previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :] + previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[ + :, :, slice(0, overlap_len), :, : + ] + + buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :] + curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[ + :, :, temporal_overlap_slice, :, : + ] + + # Yield the non-overlapping part of the previous chunk + previous_weights = previous_weights.clamp(min=1e-8) + yield_len = curr_temporal_slice.start - previous_temporal_slice.start + yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :] + + # Update state for next iteration + previous_chunk = buffer + previous_weights = curr_weights + previous_temporal_slice = curr_temporal_slice + + # Yield any remaining chunk + if previous_chunk is not None: + previous_weights = previous_weights.clamp(min=1e-8) + yield previous_chunk / previous_weights + + def decode_video( + self, + latent: torch.Tensor, + tiling_config: TilingConfig | None = None, + generator: torch.Generator | None = None, + ) -> Iterator[torch.Tensor]: + """Decode a video latent tensor, yielding uint8 chunks ``[f, h, w, c]``. + Subclasses (e.g. ``DistributedVideoDecoder``) may override this to + control eagerness or distribution across ranks. + """ + + def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: + frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8) + frames = rearrange(frames[0], "c f h w -> f h w c") + return frames + + if tiling_config is not None: + for frames in self.tiled_decode(latent, tiling_config, generator=generator): + yield convert_to_uint8(frames) + else: + decoded = self(latent, generator=generator) + yield convert_to_uint8(decoded) + + def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]: + """Group tiles by their temporal output slice.""" + if not tiles: + return [] + + groups = [] + current_slice = tiles[0].out_coords[2] + current_group = [] + + for tile in tiles: + tile_slice = tile.out_coords[2] + if tile_slice == current_slice: + current_group.append(tile) + else: + groups.append(current_group) + current_slice = tile_slice + current_group = [tile] + + # Add the final group + if current_group: + groups.append(current_group) + + return groups + + def _accumulate_temporal_group_into_buffer( + self, + group_tiles: List[Tile], + buffer: torch.Tensor, + latent: torch.Tensor, + timestep: torch.Tensor | None, + generator: torch.Generator | None, + ) -> torch.Tensor: + """ + Decode and accumulate all tiles of a temporal group into a local buffer. + The buffer is local to the group and always starts at time 0; temporal coordinates + are rebased by subtracting temporal_slice.start. + """ + temporal_slice = group_tiles[0].out_coords[2] + + weights = torch.zeros_like(buffer) + + for tile in group_tiles: + decoded_tile = self.forward(latent[tile.in_coords], timestep, generator) + mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype) + temporal_offset = tile.out_coords[2].start - temporal_slice.start + # Use the tile's output coordinate length, not the decoded tile's length, + # as the decoder may produce a different number of frames than expected + expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start + decoded_temporal_len = decoded_tile.shape[2] + + # Ensure we don't exceed the buffer or decoded tile bounds + actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset) + + chunk_coords = ( + slice(None), # batch + slice(None), # channels + slice(temporal_offset, temporal_offset + actual_temporal_len), + tile.out_coords[3], # height + tile.out_coords[4], # width + ) + + # Slice decoded_tile and mask to match the actual length we're writing + decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :] + mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask + + buffer[chunk_coords] += decoded_slice * mask_slice + weights[chunk_coords] += mask_slice + + return weights + + +def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: + """ + Get the number of video chunks for a given number of frames and tiling configuration. + Args: + num_frames: Number of frames in the video. + tiling_config: Tiling configuration. + Returns: + Number of video chunks. + """ + if not tiling_config or not tiling_config.temporal_config: + return 1 + cfg = tiling_config.temporal_config + frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames + return (num_frames - 1 + frame_stride - 1) // frame_stride + + +def to_mapping_operation( + map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor | None]], + scale: int, +) -> MappingOperation: + """Create a mapping operation over a set of tiling intervals. + The given mapping function is applied to each interval in the input dimension. The result function is used for + creating tiles in the output dimension. + Args: + map_func: Mapping function to create the mapping operation from + scale: Scale factor for the transformation, used as an argument for the mapping function + Returns: + Mapping operation that takes a set of tiling intervals and returns a set of slices and masks in the output + dimension. + """ + + def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]: + output_slices: list[slice] = [] + masks_1d: list[torch.Tensor | None] = [] + for interval in intervals.intervals: + output_slice, mask_1d = map_func( + interval.start, interval.end, interval.left_ramp, interval.right_ramp, scale + ) + output_slices.append(output_slice) + masks_1d.append(mask_1d) + return output_slices, masks_1d + + return map_op + + +def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp = 0 if left_ramp == 0 else 1 + (left_ramp - 1) * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True) + + +def map_temporal_interval_to_latent( + begin: int, end: int, left_ramp: int, right_ramp: int | None = None, scale: int = 1 +) -> Tuple[slice, torch.Tensor]: + """ + Map temporal interval in video frame space to latent space. + Args: + begin: Start position in video frame space + end: End position in video frame space + left_ramp: Left ramp size in video frame space + right_ramp: Right ramp size in video frame space + scale: Scale factor for transformation + Returns: + Tuple of (output_slice, blend_mask) + """ + start = begin // scale + stop = (end - 1) // scale + 1 + + left_ramp_latents = 0 if left_ramp == 0 else 1 + (left_ramp - 1) // scale + right_ramp_latents = right_ramp // scale + + if right_ramp_latents != 0: + raise ValueError("For tiled encoding, temporal tiles are expected to have a right ramp equal to 0") + + mask_1d = compute_rectangular_mask_1d(stop - start, left_ramp_latents, right_ramp_latents) + + return slice(start, stop), mask_1d + + +def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: + start = begin * scale + stop = end * scale + left_ramp = left_ramp * scale + right_ramp = right_ramp * scale + + return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False) + + +def map_spatial_interval_to_latent( + begin: int, + end: int, + left_ramp: int, + right_ramp: int, + scale: int, +) -> Tuple[slice, torch.Tensor]: + """Map spatial interval in pixel space to latent space. + Args: + begin: Start position in pixel space + end: End position in pixel space + left_ramp: Left ramp size in pixel space + right_ramp: Right ramp size in pixel space + scale: Scale factor for transformation + Returns: + Tuple of (output_slice, blend_mask) + """ + start = begin // scale + stop = end // scale + left_ramp = max(0, left_ramp // scale - 1) + + right_ramp = 0 if right_ramp == 0 else 1 + + mask_1d = compute_rectangular_mask_1d(stop - start, left_ramp, right_ramp) + return slice(start, stop), mask_1d diff --git a/ltx2/ltx_core/quantization/__init__.py b/ltx2/ltx_core/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89103278c6ebdaac8b5864b761106ab649962a8f --- /dev/null +++ b/ltx2/ltx_core/quantization/__init__.py @@ -0,0 +1,16 @@ +from ltx_core.quantization.fp8_cast import ( + TRANSFORMER_LINEAR_DOWNCAST_MAP, + UPCAST_DURING_INFERENCE, + UpcastWithStochasticRounding, +) +from ltx_core.quantization.fp8_scaled_mm import FP8_PREPARE_MODULE_OPS, FP8_TRANSPOSE_SD_OPS +from ltx_core.quantization.policy import QuantizationPolicy + +__all__ = [ + "FP8_PREPARE_MODULE_OPS", + "FP8_TRANSPOSE_SD_OPS", + "TRANSFORMER_LINEAR_DOWNCAST_MAP", + "UPCAST_DURING_INFERENCE", + "QuantizationPolicy", + "UpcastWithStochasticRounding", +] diff --git a/ltx2/ltx_core/quantization/fp8_cast.py b/ltx2/ltx_core/quantization/fp8_cast.py new file mode 100644 index 0000000000000000000000000000000000000000..c48a8a8d64df599bcc84de99aa06e7cc089168e7 --- /dev/null +++ b/ltx2/ltx_core/quantization/fp8_cast.py @@ -0,0 +1,167 @@ +import torch + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps +from ltx_core.model.transformer.model import LTXModel + +BLOCK_SIZE = 1024 + + +def _fused_add_round_launch(target_weight: torch.Tensor, original_weight: torch.Tensor, seed: int) -> torch.Tensor: + # Lazy import triton - only available on CUDA platforms + import triton # noqa: PLC0415 + + from ltx_core.loader.kernels import fused_add_round_kernel # noqa: PLC0415 + + if original_weight.dtype == torch.float8_e4m3fn: + exponent_bits, mantissa_bits, exponent_bias = 4, 3, 7 + elif original_weight.dtype == torch.float8_e5m2: + exponent_bits, mantissa_bits, exponent_bias = 5, 2, 15 # noqa: F841 + else: + raise ValueError("Unsupported dtype") + + if target_weight.dtype != torch.bfloat16: + raise ValueError("target_weight dtype must be bfloat16") + + # Calculate grid and block sizes + n_elements = original_weight.numel() + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + # Launch kernel + fused_add_round_kernel[grid]( + original_weight, + target_weight, + seed, + n_elements, + exponent_bias, + mantissa_bits, + BLOCK_SIZE, + ) + return target_weight + + +def _naive_weight_or_bias_downcast(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: + """ + Downcast the weight or bias to the float8_e4m3fn dtype. + """ + return [KeyValueOperationResult(key, value.to(dtype=torch.float8_e4m3fn))] + + +def _upcast_and_round( + weight: torch.Tensor, dtype: torch.dtype, with_stochastic_rounding: bool = False, seed: int = 0 +) -> torch.Tensor: + """ + Upcast the weight to the given dtype and optionally apply stochastic rounding. + Input weight needs to have float8_e4m3fn or float8_e5m2 dtype. + """ + if not with_stochastic_rounding: + return weight.to(dtype) + return _fused_add_round_launch(torch.zeros_like(weight, dtype=dtype), weight, seed) + + +class Fp8CastLinear(torch.nn.Linear): + """nn.Linear storing weights in fp8, upcasting to input dtype during forward. + Used via __class__ reassignment (not subclassing) so existing weight tensors + are preserved in-place. Class-level forward is required for torch.compile + compatibility β€” instance-level closure monkey-patches cause graph breaks. + """ + + _with_stochastic_rounding: bool + _seed: int + + def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002, type: ignore[override] + w_up = _upcast_and_round(self.weight, input.dtype, self._with_stochastic_rounding, self._seed) + b_up = ( + _upcast_and_round(self.bias, input.dtype, self._with_stochastic_rounding, self._seed) + if self.bias is not None + else None + ) + return torch.nn.functional.linear(input, w_up, b_up) + + +def _replace_fwd_with_upcast(layer: torch.nn.Linear, with_stochastic_rounding: bool = False, seed: int = 0) -> None: + """ + Intended to be applied via __class__ reassignment to existing nn.Linear + instances so that their parameter and buffer tensors are preserved in-place, + avoiding re-instantiation. Forward remains defined at the class level, which + is required for torch.compile compatibility β€” instance-level closure + monkey-patches cause graph breaks. + """ + layer.__class__ = Fp8CastLinear + layer._with_stochastic_rounding = with_stochastic_rounding + layer._seed = seed + + +def _amend_forward_with_upcast( + model: torch.nn.Module, with_stochastic_rounding: bool = False, seed: int = 0 +) -> torch.nn.Module: + """ + Replace the forward method of the model's Linear layers to forward + with upcast and optional stochastic rounding. + """ + for m in model.modules(): + if isinstance(m, (torch.nn.Linear)): + _replace_fwd_with_upcast(m, with_stochastic_rounding, seed) + return model + + +TRANSFORMER_LINEAR_DOWNCAST_MAP = ( + SDOps("TRANSFORMER_LINEAR_DOWNCAST_MAP") + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_q.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_q.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_k.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_k.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_v.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_v.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_out.0.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix=".to_out.0.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix="ff.net.0.proj.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix="ff.net.0.proj.bias", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix="ff.net.2.weight", operation=_naive_weight_or_bias_downcast + ) + .with_kv_operation( + key_prefix="transformer_blocks.", key_suffix="ff.net.2.bias", operation=_naive_weight_or_bias_downcast + ) +) + +UPCAST_DURING_INFERENCE = ModuleOps( + name="upcast_fp8_during_linear_forward", + matcher=lambda model: isinstance(model, LTXModel), + mutator=lambda model: _amend_forward_with_upcast(model, False), +) + + +class UpcastWithStochasticRounding(ModuleOps): + """ + ModuleOps for upcasting the model's float8_e4m3fn weights and biases to the bfloat16 dtype + and applying stochastic rounding during linear forward. + """ + + def __new__(cls, seed: int = 0): + return super().__new__( + cls, + name="upcast_fp8_during_linear_forward_with_stochastic_rounding", + matcher=lambda model: isinstance(model, LTXModel), + mutator=lambda model: _amend_forward_with_upcast(model, True, seed), + ) diff --git a/ltx2/ltx_core/quantization/fp8_scaled_mm.py b/ltx2/ltx_core/quantization/fp8_scaled_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f53ac46297c15e29ffb423bbcef7ac2868ead4a --- /dev/null +++ b/ltx2/ltx_core/quantization/fp8_scaled_mm.py @@ -0,0 +1,207 @@ +from typing import Callable + +import torch +from torch import nn + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import KeyValueOperationResult, SDOps +from ltx_core.model.transformer import LTXModel + + +class FP8Linear(nn.Module): + """Linear layer with FP8 weight storage for scaled matrix multiplication.""" + + in_features: int + out_features: int + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: torch.device | str | None = None, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + fp8_shape = (in_features, out_features) + self.weight = nn.Parameter(torch.empty(fp8_shape, dtype=torch.float8_e4m3fn, device=device)) + # Weight scale for FP8 dequantization (shape matches checkpoint format) + self.weight_scale = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) + # Input scale for static quantization (pre-quantized checkpoints) + self.input_scale = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) + + if bias: + self.bias = nn.Parameter(torch.empty(out_features, device=device)) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + origin_shape = x.shape + + # Static quantization: use pre-computed scale + qinput, cur_input_scale = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(x, self.input_scale) + + # Flatten to 2D for matmul + if qinput.dim() == 3: + qinput = qinput.reshape(-1, qinput.shape[-1]) + + # FP8 scaled matmul + output = torch.ops.trtllm.cublas_scaled_mm( + qinput, + self.weight, + scale_a=cur_input_scale, + scale_b=self.weight_scale, + bias=None, + out_dtype=x.dtype, + ) + + # Add bias + if self.bias is not None: + bias = self.bias + if bias.dtype != output.dtype: + bias = bias.to(output.dtype) + output = output + bias + + # Restore original shape + if output.dim() != len(origin_shape): + output_shape = list(origin_shape) + output_shape[-1] = output.shape[-1] + output = output.reshape(output_shape) + + return output + + +def quantize_weight_to_fp8_per_tensor(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize a weight tensor to FP8 (float8_e4m3fn) using per-tensor scaling. + Args: + weight: The weight tensor to quantize (any dtype, will be cast to float32) + Returns: + Tuple of (quantized_weight, weight_scale): + - quantized_weight: FP8 tensor, transposed for cublas_scaled_mm + - weight_scale: Per-tensor scale factor (reciprocal of quantization scale) + """ + weight_fp32 = weight.to(torch.float32) + + fp8_min = torch.finfo(torch.float8_e4m3fn).min + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + max_abs = torch.amax(torch.abs(weight_fp32)) + scale = fp8_max / max_abs + + @torch.compiler.disable + def _quantize( + weight_fp32: torch.Tensor, scale: torch.Tensor, fp8_min: torch.Tensor, fp8_max: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + quantized_weight = torch.clamp(weight_fp32 * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + quantized_weight = quantized_weight.t() + weight_scale = scale.reciprocal() + return quantized_weight, weight_scale + + quantized_weight, weight_scale = _quantize(weight_fp32, scale, fp8_min, fp8_max) + return quantized_weight, weight_scale + + +def _should_skip_layer(layer_name: str, excluded_layer_substrings: tuple[str, ...]) -> bool: + return any(substring in layer_name for substring in excluded_layer_substrings) + + +EXCLUDED_LAYER_SUBSTRINGS = ( + "patchify_proj", + "adaln_single", + "av_ca_video_scale_shift_adaln_single", + "av_ca_a2v_gate_adaln_single", + "caption_projection", + "proj_out", + "audio_patchify_proj", + "audio_adaln_single", + "av_ca_audio_scale_shift_adaln_single", + "av_ca_v2a_gate_adaln_single", + "audio_caption_projection", + "audio_proj_out", + "transformer_blocks.0.", + *[f"transformer_blocks.{i}." for i in range(43, 48)], +) + + +def _linear_to_fp8linear(layer: nn.Linear) -> FP8Linear: + """ + Create an FP8Linear layer from an nn.Linear layer. + Args: + layer: The nn.Linear layer to convert (typically on meta device) + Returns: + A new FP8Linear with the same configuration + """ + return FP8Linear( + in_features=layer.in_features, + out_features=layer.out_features, + bias=layer.bias is not None, + device=layer.weight.device, + ) + + +def _apply_fp8_prepare_to_model(model: nn.Module, excluded_layer_substrings: tuple[str, ...]) -> nn.Module: + """Replace nn.Linear layers with FP8Linear in the module tree.""" + replacements: list[tuple[nn.Module, str, nn.Linear]] = [] + + for name, module in model.named_modules(): + if not isinstance(module, nn.Linear) or isinstance(module, FP8Linear): + continue + + if _should_skip_layer(name, excluded_layer_substrings): + continue + + if "." in name: + parent_name, attr_name = name.rsplit(".", 1) + parent = model.get_submodule(parent_name) + else: + parent = model + attr_name = name + + replacements.append((parent, attr_name, module)) + + for parent, attr_name, linear in replacements: + setattr(parent, attr_name, _linear_to_fp8linear(linear)) + + return model + + +def _create_transpose_kv_operation( + excluded_layer_substrings: tuple[str, ...], +) -> Callable[[str, torch.Tensor], list[KeyValueOperationResult]]: + def transpose_if_matches(key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: + # Only process .weight keys + if not key.endswith(".weight"): + return [KeyValueOperationResult(key, value)] + + # Only transpose 2D FP8 tensors (Linear weights) + if value.dim() != 2 or value.dtype != torch.float8_e4m3fn: + return [KeyValueOperationResult(key, value)] + + # Check if the layer is excluded + layer_name = key.rsplit(".weight", 1)[0] + if _should_skip_layer(layer_name, excluded_layer_substrings): + return [KeyValueOperationResult(key, value)] + + # Transpose to cuBLAS layout (in, out) + transposed_weight = value.t() + + return [KeyValueOperationResult(key, transposed_weight)] + + return transpose_if_matches + + +FP8_TRANSPOSE_SD_OPS = SDOps("fp8_transpose_weights").with_kv_operation( + _create_transpose_kv_operation(EXCLUDED_LAYER_SUBSTRINGS), + key_prefix="transformer_blocks.", + key_suffix=".weight", +) + + +FP8_PREPARE_MODULE_OPS = ModuleOps( + name="fp8_prepare_for_loading", + matcher=lambda model: isinstance(model, LTXModel), + mutator=lambda model: _apply_fp8_prepare_to_model(model, EXCLUDED_LAYER_SUBSTRINGS), +) diff --git a/ltx2/ltx_core/quantization/policy.py b/ltx2/ltx_core/quantization/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..56dad51bcbd2740afbca7c655e8e542e5ad420c3 --- /dev/null +++ b/ltx2/ltx_core/quantization/policy.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import SDOps +from ltx_core.quantization.fp8_cast import TRANSFORMER_LINEAR_DOWNCAST_MAP, UPCAST_DURING_INFERENCE +from ltx_core.quantization.fp8_scaled_mm import FP8_PREPARE_MODULE_OPS, FP8_TRANSPOSE_SD_OPS + + +@dataclass(frozen=True) +class QuantizationPolicy: + """Configuration for model quantization during loading. + Attributes: + sd_ops: State dict operations for weight transformation. + module_ops: Post-load module transformations. + """ + + sd_ops: SDOps | None = None + module_ops: tuple[ModuleOps, ...] = () + + @classmethod + def fp8_cast(cls) -> "QuantizationPolicy": + """Create policy using FP8 casting with upcasting during inference.""" + return cls( + sd_ops=TRANSFORMER_LINEAR_DOWNCAST_MAP, + module_ops=(UPCAST_DURING_INFERENCE,), + ) + + @classmethod + def fp8_scaled_mm(cls) -> "QuantizationPolicy": + """Create policy using FP8 scaled matrix multiplication.""" + try: + import tensorrt_llm # noqa: F401, PLC0415 + except ImportError as e: + raise ImportError("tensorrt_llm is not installed, skipping FP8 scaled MM quantization") from e + + return cls( + sd_ops=FP8_TRANSPOSE_SD_OPS, + module_ops=(FP8_PREPARE_MODULE_OPS,), + ) diff --git a/ltx2/ltx_core/text_encoders/__init__.py b/ltx2/ltx_core/text_encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b65ae143a9e9442e15ee801cee818dcf9d683d4a --- /dev/null +++ b/ltx2/ltx_core/text_encoders/__init__.py @@ -0,0 +1 @@ +"""CLIP/text encoder model components.""" diff --git a/ltx2/ltx_core/text_encoders/gemma/__init__.py b/ltx2/ltx_core/text_encoders/gemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..009a38d4b6691f769117a19320039df1ca8c410e --- /dev/null +++ b/ltx2/ltx_core/text_encoders/gemma/__init__.py @@ -0,0 +1,33 @@ +"""Gemma text encoder components.""" + +from ltx_core.text_encoders.gemma.embeddings_processor import ( + EmbeddingsProcessor, + EmbeddingsProcessorOutput, + convert_to_additive_mask, +) +from ltx_core.text_encoders.gemma.encoders.base_encoder import ( + GemmaTextEncoder, + module_ops_from_gemma_root, +) +from ltx_core.text_encoders.gemma.encoders.encoder_configurator import ( + EMBEDDINGS_PROCESSOR_KEY_OPS, + GEMMA_LLM_KEY_OPS, + GEMMA_MODEL_OPS, + VIDEO_ONLY_EMBEDDINGS_PROCESSOR_KEY_OPS, + EmbeddingsProcessorConfigurator, + GemmaTextEncoderConfigurator, +) + +__all__ = [ + "EMBEDDINGS_PROCESSOR_KEY_OPS", + "GEMMA_LLM_KEY_OPS", + "GEMMA_MODEL_OPS", + "VIDEO_ONLY_EMBEDDINGS_PROCESSOR_KEY_OPS", + "EmbeddingsProcessor", + "EmbeddingsProcessorConfigurator", + "EmbeddingsProcessorOutput", + "GemmaTextEncoder", + "GemmaTextEncoderConfigurator", + "convert_to_additive_mask", + "module_ops_from_gemma_root", +] diff --git a/ltx2/ltx_core/text_encoders/gemma/config.py b/ltx2/ltx_core/text_encoders/gemma/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8b23e91b0fdd491404ffd33a756b3f11cb2821b3 --- /dev/null +++ b/ltx2/ltx_core/text_encoders/gemma/config.py @@ -0,0 +1,75 @@ +from dataclasses import asdict, dataclass, field + + +@dataclass +class Gemma3RopeScaling: + factor: float = 8.0 + rope_type: str = "linear" + + +@dataclass +class Gemma3TextConfig: + attention_bias: bool = False + attention_dropout: float = 0.0 + attn_logit_softcapping: float | None = None + cache_implementation: str = "hybrid" + final_logit_softcapping: float | None = None + head_dim: int = 256 + hidden_activation: str = "gelu_pytorch_tanh" + hidden_size: int = 3840 + initializer_range: float = 0.02 + intermediate_size: int = 15360 + max_position_embeddings: int = 131072 + model_type: str = "gemma3_text" + num_attention_heads: int = 16 + num_hidden_layers: int = 48 + num_key_value_heads: int = 8 + query_pre_attn_scalar: int = 256 + rms_norm_eps: float = 1e-06 + rope_local_base_freq: int = 10000 + rope_scaling: Gemma3RopeScaling = field(default_factory=Gemma3RopeScaling) + rope_theta: int = 1000000 + sliding_window: int = 1024 + sliding_window_pattern: int = 6 + torch_dtype: str = "float32" + use_cache: bool = True + vocab_size: int = 262208 + + +@dataclass +class Gemma3VisionConfig: + attention_dropout: float = 0.0 + hidden_act: str = "gelu_pytorch_tanh" + hidden_size: int = 1152 + image_size: int = 896 + intermediate_size: int = 4304 + layer_norm_eps: float = 1e-06 + model_type: str = "siglip_vision_model" + num_attention_heads: int = 16 + num_channels: int = 3 + num_hidden_layers: int = 27 + patch_size: int = 14 + torch_dtype: str = "float32" + vision_use_head: bool = False + + +@dataclass +class Gemma3ConfigData: + architectures: list[str] = field(default_factory=lambda: ["Gemma3ForConditionalGeneration"]) + boi_token_index: int = 255999 + eoi_token_index: int = 256000 + eos_token_id: list[int] = field(default_factory=lambda: [1, 106]) + image_token_index: int = 262144 + initializer_range: float = 0.02 + mm_tokens_per_image: int = 256 + model_type: str = "gemma3" + text_config: Gemma3TextConfig = field(default_factory=Gemma3TextConfig) + torch_dtype: str = "bfloat16" + transformers_version: str = "4.51.0" + vision_config: Gemma3VisionConfig = field(default_factory=Gemma3VisionConfig) + + def to_dict(self) -> dict: + return asdict(self) + + +GEMMA3_CONFIG_FOR_LTX = Gemma3ConfigData() diff --git a/ltx2/ltx_core/text_encoders/gemma/embeddings_connector.py b/ltx2/ltx_core/text_encoders/gemma/embeddings_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..2e614d56554f367305e168d276a900c752b8b3b6 --- /dev/null +++ b/ltx2/ltx_core/text_encoders/gemma/embeddings_connector.py @@ -0,0 +1,261 @@ +import torch + +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.model.transformer.attention import Attention +from ltx_core.model.transformer.feed_forward import FeedForward +from ltx_core.model.transformer.rope import ( + LTXRopeType, + generate_freq_grid_np, + generate_freq_grid_pytorch, + precompute_freqs_cis, +) +from ltx_core.utils import rms_norm + + +class _BasicTransformerBlock1D(torch.nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + apply_gated_attention: bool = False, + ): + super().__init__() + + self.attn1 = Attention( + query_dim=dim, + heads=heads, + dim_head=dim_head, + rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + ) + + self.ff = FeedForward( + dim, + dim_out=dim, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + pe: torch.Tensor | None = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(torch.nn.Module): + """ + Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or + other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can + substitute padded positions with learnable registers. The module is highly configurable for head size, number of + layers, and register usage. + Args: + attention_head_dim (int): Dimension of each attention head (default=128). + num_attention_heads (int): Number of attention heads (default=30). + num_layers (int): Number of transformer layers (default=2). + positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0). + positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]). + causal_temporal_positioning (bool): If True, uses causal attention (default=False). + num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables + register replacement. (default=128) + rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE). + double_precision_rope (bool): Use double precision rope calculation (default=False). + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + attention_head_dim: int = 128, + num_attention_heads: int = 30, + num_layers: int = 2, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list[int] | None = None, + causal_temporal_positioning: bool = False, + num_learnable_registers: int | None = 128, + rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, + double_precision_rope: bool = False, + apply_gated_attention: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = ( + positional_embedding_max_pos if positional_embedding_max_pos is not None else [1] + ) + self.rope_type = rope_type + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = torch.nn.ModuleList( + [ + _BasicTransformerBlock1D( + dim=self.inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + rope_type=rope_type, + apply_gated_attention=apply_gated_attention, + ) + for _ in range(num_layers) + ] + ) + + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = torch.nn.Parameter( + torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0 + ) + + def _replace_padded_with_learnable_registers( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[1] % self.num_learnable_registers == 0, ( + f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers " + f"{self.num_learnable_registers}." + ) + + num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers + learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1)) + attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int() + + non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :] + non_zero_nums = non_zero_hidden_states.shape[1] + pad_length = hidden_states.shape[1] - non_zero_nums + adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0) + flipped_mask = torch.flip(attention_mask_binary, dims=[1]) + hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers + + attention_mask = torch.full_like( + attention_mask, + 0.0, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + return hidden_states, attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of Embeddings1DConnector. + Args: + hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]). + attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states). + Returns: + tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask. + """ + if self.num_learnable_registers: + hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) + + indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device) + indices_grid = indices_grid[None, None, :] + freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch + freqs_cis = precompute_freqs_cis( + indices_grid=indices_grid, + dim=self.inner_dim, + out_dtype=hidden_states.dtype, + theta=self.positional_embedding_theta, + max_pos=self.positional_embedding_max_pos, + num_attention_heads=self.num_attention_heads, + rope_type=self.rope_type, + freq_grid_generator=freq_grid_generator, + ) + + for block in self.transformer_1d_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis) + + hidden_states = rms_norm(hidden_states) + + return hidden_states, attention_mask + + +class Embeddings1DConnectorConfigurator(ModelConfigurator[Embeddings1DConnector]): + """Configurator for video embeddings connector.""" + + @classmethod + def from_config(cls: type[Embeddings1DConnector], config: dict) -> Embeddings1DConnector: + transformer_config = config.get("transformer", {}) + rope_type = LTXRopeType(transformer_config.get("rope_type", "interleaved")) + double_precision_rope = transformer_config.get("frequencies_precision", False) == "float64" + pe_max_pos = transformer_config.get("connector_positional_embedding_max_pos", [1]) + + # Video connector dimensions + num_attention_heads = transformer_config.get("connector_num_attention_heads", 30) + attention_head_dim = transformer_config.get("connector_attention_head_dim", 128) + num_layers = transformer_config.get("connector_num_layers", 2) + + connector = Embeddings1DConnector( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + positional_embedding_max_pos=pe_max_pos, + rope_type=rope_type, + double_precision_rope=double_precision_rope, + apply_gated_attention=transformer_config.get("connector_apply_gated_attention", False), + ) + return connector + + +class AudioEmbeddings1DConnectorConfigurator(ModelConfigurator[Embeddings1DConnector]): + """Configurator for audio embeddings connector with separate dimension settings.""" + + @classmethod + def from_config(cls: type[Embeddings1DConnector], config: dict) -> Embeddings1DConnector: + transformer_config = config.get("transformer", {}) + rope_type = LTXRopeType(transformer_config.get("rope_type", "interleaved")) + double_precision_rope = transformer_config.get("frequencies_precision", False) == "float64" + pe_max_pos = transformer_config.get("connector_positional_embedding_max_pos", [1]) + + # Audio connector dimensions - fall back to video connector config for backwards compatibility + num_attention_heads = transformer_config.get( + "audio_connector_num_attention_heads", + transformer_config.get("connector_num_attention_heads", 30), + ) + attention_head_dim = transformer_config.get( + "audio_connector_attention_head_dim", + transformer_config.get("connector_attention_head_dim", 128), + ) + num_layers = transformer_config.get( + "audio_connector_num_layers", + transformer_config.get("connector_num_layers", 2), + ) + + connector = Embeddings1DConnector( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + positional_embedding_max_pos=pe_max_pos, + rope_type=rope_type, + double_precision_rope=double_precision_rope, + apply_gated_attention=transformer_config.get("connector_apply_gated_attention", False), + ) + return connector diff --git a/ltx2/ltx_core/text_encoders/gemma/embeddings_processor.py b/ltx2/ltx_core/text_encoders/gemma/embeddings_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..d04f2d3c59cdcf1ee54ef5e2aa25c2dac730aad0 --- /dev/null +++ b/ltx2/ltx_core/text_encoders/gemma/embeddings_processor.py @@ -0,0 +1,89 @@ +from typing import NamedTuple + +import torch +from torch import nn + +from ltx_core.text_encoders.gemma.embeddings_connector import Embeddings1DConnector + + +class EmbeddingsProcessorOutput(NamedTuple): + video_encoding: torch.Tensor + audio_encoding: torch.Tensor | None + attention_mask: torch.Tensor + + +def convert_to_additive_mask(attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """Convert binary attention mask to additive form for transformer masking.""" + return (attention_mask.to(torch.int64) - 1).to(dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(dtype).max + + +def _to_binary_mask(encoded: torch.Tensor, encoded_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert connector output mask to binary mask and apply to encoded tensor.""" + binary_mask = (encoded_mask < 0.000001).to(torch.int64) + binary_mask = binary_mask.reshape([encoded.shape[0], encoded.shape[1], 1]) + encoded = encoded * binary_mask + return encoded, binary_mask + + +class EmbeddingsProcessor(nn.Module): + """Wraps feature extractor + video connector + optional audio connector. + Can operate in two modes: + 1. create_embeddings(): Takes pre-computed features + additive mask (backward compat, used by trainer) + 2. process_hidden_states(): Takes raw Gemma hidden states, runs feature extraction + connectors + """ + + def __init__( + self, + *, + feature_extractor: nn.Module | None = None, + video_connector: Embeddings1DConnector, + audio_connector: Embeddings1DConnector | None = None, + ): + super().__init__() + self.feature_extractor = feature_extractor + self.video_connector = video_connector + self.audio_connector = audio_connector + + def create_embeddings( + self, + video_features: torch.Tensor, + audio_features: torch.Tensor | None, + additive_attention_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + if self.audio_connector is not None and audio_features is None: + raise ValueError("Audio connector is configured but no audio features were provided.") + if self.audio_connector is None and audio_features is not None: + raise ValueError("Audio features were provided but no audio connector is configured.") + + video_encoded, video_mask = self.video_connector(video_features, additive_attention_mask) + video_encoded, binary_mask = _to_binary_mask(video_encoded, video_mask) + + audio_encoded = None + if self.audio_connector is not None: + audio_encoded, _ = self.audio_connector(audio_features, additive_attention_mask) + + return video_encoded, audio_encoded, binary_mask.squeeze(-1) + + def process_hidden_states( + self, + hidden_states: tuple[torch.Tensor, ...], + attention_mask: torch.Tensor, + padding_side: str = "left", + ) -> EmbeddingsProcessorOutput: + """Full pipeline: feature extraction -> connectors -> final embeddings. + Args: + hidden_states: Raw Gemma hidden states (tuple of tensors per layer). + attention_mask: Binary attention mask [B, seq_len]. + padding_side: Padding side used during tokenization. + Returns: + EmbeddingsProcessorOutput with video_encoding, audio_encoding, and attention_mask. + """ + if self.feature_extractor is None: + raise ValueError("feature_extractor is required for process_hidden_states()") + + video_feats, audio_feats = self.feature_extractor(hidden_states, attention_mask, padding_side) + additive_mask = convert_to_additive_mask(attention_mask, video_feats.dtype) + video_enc, audio_enc, binary_mask = self.create_embeddings(video_feats, audio_feats, additive_mask) + return EmbeddingsProcessorOutput(video_enc, audio_enc, binary_mask) diff --git a/ltx2/ltx_core/text_encoders/gemma/encoders/base_encoder.py b/ltx2/ltx_core/text_encoders/gemma/encoders/base_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d9d24163c7a485a7e3d769d65d10f6270015a9 --- /dev/null +++ b/ltx2/ltx_core/text_encoders/gemma/encoders/base_encoder.py @@ -0,0 +1,202 @@ +import functools +from pathlib import Path + +import torch +from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor + +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer +from ltx_core.utils import find_matching_file + + +class GemmaTextEncoder(torch.nn.Module): + """Pure Gemma text encoder β€” runs the LLM and returns raw hidden states. + Prompt enhancement (generate) is also supported since the full + Gemma3ForConditionalGeneration model (including lm_head) is loaded. + """ + + def __init__( + self, + model: Gemma3ForConditionalGeneration | None = None, + tokenizer: LTXVGemmaTokenizer | None = None, + processor: Gemma3Processor | None = None, + dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.processor = processor + self._dtype = dtype + + def encode( + self, + text: str, + padding_side: str = "left", # noqa: ARG002 + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: + """Run Gemma LLM and return raw hidden states + attention mask. + Calls the inner model (self.model.model) to skip lm_head logits computation (~500 MiB saving). + Returns: + (hidden_states, attention_mask) where hidden_states is a tuple of per-layer tensors. + """ + token_pairs = self.tokenizer.tokenize_with_weights(text)["gemma"] + input_ids = torch.tensor([[t[0] for t in token_pairs]], device=self.model.device) + attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=self.model.device) + outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + hidden_states = outputs.hidden_states + del outputs + return hidden_states, attention_mask + + # --- Prompt enhancement methods --- + + def _enhance( + self, + messages: list[dict[str, str]], + image: torch.Tensor | None = None, + max_new_tokens: int = 512, + seed: int = 10, + ) -> str: + text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + model_inputs = self.processor( + text=text, + images=image, + return_tensors="pt", + ).to(self.model.device) + pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0 + model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id) + + with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]): + torch.manual_seed(seed) + outputs = self.model.generate( + **model_inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=0.7, + ) + generated_ids = outputs[0][len(model_inputs.input_ids[0]) :] + enhanced_prompt = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True) + + return enhanced_prompt + + def enhance_t2v( + self, + prompt: str, + max_new_tokens: int = 512, + system_prompt: str | None = None, + seed: int = 10, + ) -> str: + """Enhance a text prompt for T2V generation.""" + system_prompt = system_prompt or self.default_gemma_t2v_system_prompt + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"user prompt: {prompt}"}, + ] + + return self._enhance(messages, max_new_tokens=max_new_tokens, seed=seed) + + def enhance_i2v( + self, + prompt: str, + image: torch.Tensor, + max_new_tokens: int = 512, + system_prompt: str | None = None, + seed: int = 10, + ) -> str: + """Enhance a text prompt for I2V generation using a reference image.""" + system_prompt = system_prompt or self.default_gemma_i2v_system_prompt + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": f"User Raw Input Prompt: {prompt}."}, + ], + }, + ] + return self._enhance(messages, image=image, max_new_tokens=max_new_tokens, seed=seed) + + @functools.cached_property + def default_gemma_i2v_system_prompt(self) -> str: + return _load_system_prompt("gemma_i2v_system_prompt.txt") + + @functools.cached_property + def default_gemma_t2v_system_prompt(self) -> str: + return _load_system_prompt("gemma_t2v_system_prompt.txt") + + +# --- Standalone utility functions --- + + +@functools.lru_cache(maxsize=2) +def _load_system_prompt(prompt_name: str) -> str: + with open(Path(__file__).parent / "prompts" / f"{prompt_name}", "r") as f: + return f.read() + + +def _cat_with_padding( + tensor: torch.Tensor, + padding_length: int, + value: int | float, +) -> torch.Tensor: + """Concatenate a tensor with a padding tensor of the given value.""" + return torch.cat( + [ + tensor, + torch.full( + (1, padding_length), + value, + dtype=tensor.dtype, + device=tensor.device, + ), + ], + dim=1, + ) + + +def _pad_inputs_for_attention_alignment( + model_inputs: dict[str, torch.Tensor], + pad_token_id: int = 0, + alignment: int = 8, +) -> dict[str, torch.Tensor]: + """Pad sequence length to multiple of alignment for Flash Attention compatibility.""" + seq_len = model_inputs.input_ids.shape[1] + padded_len = ((seq_len + alignment - 1) // alignment) * alignment + padding_length = padded_len - seq_len + + if padding_length > 0: + model_inputs["input_ids"] = _cat_with_padding(model_inputs.input_ids, padding_length, pad_token_id) + model_inputs["attention_mask"] = _cat_with_padding(model_inputs.attention_mask, padding_length, 0) + if "token_type_ids" in model_inputs and model_inputs["token_type_ids"] is not None: + model_inputs["token_type_ids"] = _cat_with_padding(model_inputs["token_type_ids"], padding_length, 0) + + return model_inputs + + +def module_ops_from_gemma_root(gemma_root: str) -> tuple[ModuleOps, ...]: + tokenizer_root = str(find_matching_file(gemma_root, "tokenizer.model").parent) + processor_root = str(find_matching_file(gemma_root, "preprocessor_config.json").parent) + + def load_tokenizer(module: GemmaTextEncoder) -> GemmaTextEncoder: + module.tokenizer = LTXVGemmaTokenizer(tokenizer_root, 1024) + return module + + def load_processor(module: GemmaTextEncoder) -> GemmaTextEncoder: + image_processor = AutoImageProcessor.from_pretrained(processor_root, local_files_only=True) + if not module.tokenizer: + raise ValueError("Tokenizer model operation must be performed before processor model operation") + module.processor = Gemma3Processor(image_processor=image_processor, tokenizer=module.tokenizer.tokenizer) + return module + + tokenizer_load_ops = ModuleOps( + "TokenizerLoad", + matcher=lambda module: isinstance(module, GemmaTextEncoder) and module.tokenizer is None, + mutator=load_tokenizer, + ) + processor_load_ops = ModuleOps( + "ProcessorLoad", + matcher=lambda module: isinstance(module, GemmaTextEncoder) and module.processor is None, + mutator=load_processor, + ) + return (tokenizer_load_ops, processor_load_ops) diff --git a/ltx2/ltx_core/text_encoders/gemma/encoders/encoder_configurator.py b/ltx2/ltx_core/text_encoders/gemma/encoders/encoder_configurator.py new file mode 100644 index 0000000000000000000000000000000000000000..f7654dfb5fd203670f345f8ed71c19dbbc100b5f --- /dev/null +++ b/ltx2/ltx_core/text_encoders/gemma/encoders/encoder_configurator.py @@ -0,0 +1,181 @@ +import torch +from transformers import Gemma3Config +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.models.gemma3 import Gemma3ForConditionalGeneration + +from ltx_core.loader import KeyValueOperationResult +from ltx_core.loader.module_ops import ModuleOps +from ltx_core.loader.sd_ops import SDOps +from ltx_core.model.model_protocol import ModelConfigurator +from ltx_core.text_encoders.gemma.config import GEMMA3_CONFIG_FOR_LTX +from ltx_core.text_encoders.gemma.embeddings_connector import ( + AudioEmbeddings1DConnectorConfigurator, + Embeddings1DConnectorConfigurator, +) +from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessor +from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoder +from ltx_core.text_encoders.gemma.feature_extractor import ( + FeatureExtractorV1, + FeatureExtractorV2, +) + + +class GemmaTextEncoderConfigurator(ModelConfigurator[GemmaTextEncoder]): + @classmethod + def from_config(cls, config: dict) -> GemmaTextEncoder: # noqa: ARG003 + gemma_config = Gemma3Config.from_dict(GEMMA3_CONFIG_FOR_LTX.to_dict()) + with torch.device("meta"): + model = Gemma3ForConditionalGeneration(gemma_config) + + return GemmaTextEncoder(model=model) + + +class EmbeddingsProcessorConfigurator(ModelConfigurator[EmbeddingsProcessor]): + @classmethod + def from_config(cls, config: dict) -> EmbeddingsProcessor: + transformer_config = config.get("transformer", {}) + + # Create video embeddings connector (always needed) + video_connector = Embeddings1DConnectorConfigurator.from_config(config) + + # Create audio embeddings connector + audio_connector = AudioEmbeddings1DConnectorConfigurator.from_config(config) + + # Create feature extractor + feature_extractor = _create_feature_extractor(transformer_config) + + return EmbeddingsProcessor( + video_connector=video_connector, + audio_connector=audio_connector, + feature_extractor=feature_extractor, + ) + + +_V2_EXPECTED_CONFIG = { + "caption_proj_before_connector": True, + "caption_projection_first_linear": False, + "caption_proj_input_norm": False, + "caption_projection_second_linear": False, +} + + +def _create_feature_extractor(transformer_config: dict) -> torch.nn.Module: + """Select and create the appropriate feature extractor based on config. + Detection logic: + - V1: V2 config keys absent β†’ projection lives in transformer + - V2: V2 config keys present with exact expected values β†’ per-token RMS norm with dual aggregate embeds + - Anything else: NotImplementedError (config drift) + """ + gemma_text_config = GEMMA3_CONFIG_FOR_LTX.text_config + embedding_dim = gemma_text_config.hidden_size + num_layers = gemma_text_config.num_hidden_layers + 1 # +1 for the embedding layer + flat_dim = embedding_dim * num_layers + + overlapping_keys = transformer_config.keys() & _V2_EXPECTED_CONFIG.keys() + if not overlapping_keys: + aggregate_embed = torch.nn.Linear(flat_dim, embedding_dim, bias=False) + return FeatureExtractorV1(aggregate_embed=aggregate_embed, is_av=True) + + missing_keys = _V2_EXPECTED_CONFIG.keys() - overlapping_keys + if missing_keys: + raise NotImplementedError("Partial V2 config β€” missing keys: " + ", ".join(sorted(missing_keys))) + + unexpected_value_keys = {k for k in overlapping_keys if transformer_config[k] != _V2_EXPECTED_CONFIG[k]} + if unexpected_value_keys: + raise NotImplementedError( + "Unknown config: " + + ", ".join( + f"{k}={transformer_config[k]!r} (expected {_V2_EXPECTED_CONFIG[k]!r})" for k in unexpected_value_keys + ) + ) + + video_inner_dim = transformer_config["num_attention_heads"] * transformer_config["attention_head_dim"] + audio_inner_dim = transformer_config["audio_num_attention_heads"] * transformer_config["audio_attention_head_dim"] + return FeatureExtractorV2( + video_aggregate_embed=torch.nn.Linear(flat_dim, video_inner_dim, bias=True), + embedding_dim=embedding_dim, + audio_aggregate_embed=torch.nn.Linear(flat_dim, audio_inner_dim, bias=True), + ) + + +# --- Split SDOps: Gemma LLM keys vs Embeddings Processor keys --- + +GEMMA_LLM_KEY_OPS = ( + SDOps("GEMMA_LLM_KEY_OPS") + # 1. Map language model layers (note the double .model prefix) + .with_matching(prefix="language_model.model.") + .with_replacement("language_model.model.", "model.model.language_model.") + # 2. Map the Vision Tower + .with_matching(prefix="vision_tower.") + .with_replacement("vision_tower.", "model.model.vision_tower.") + # 3. Map the Multi-Modal Projector + .with_matching(prefix="multi_modal_projector.") + .with_replacement("multi_modal_projector.", "model.model.multi_modal_projector.") + # 4. Duplicate embed_tokens to lm_head (needed for prompt enhancement via generate()) + .with_kv_operation( + operation=lambda key, value: [ + KeyValueOperationResult(key, value), + KeyValueOperationResult("model.lm_head.weight", value), + ], + key_prefix="model.model.language_model.embed_tokens.weight", + ) +) + +EMBEDDINGS_PROCESSOR_KEY_OPS = ( + SDOps("EMBEDDINGS_PROCESSOR_KEY_OPS") + # 1. Map the feature extractor (V1: aggregate_embed inside feature_extractor) + .with_matching(prefix="text_embedding_projection.aggregate_embed.") + .with_replacement("text_embedding_projection.aggregate_embed.", "feature_extractor.aggregate_embed.") + # V2 dual aggregate embeds + .with_matching(prefix="text_embedding_projection.video_aggregate_embed.") + .with_replacement("text_embedding_projection.video_aggregate_embed.", "feature_extractor.video_aggregate_embed.") + .with_matching(prefix="text_embedding_projection.audio_aggregate_embed.") + .with_replacement("text_embedding_projection.audio_aggregate_embed.", "feature_extractor.audio_aggregate_embed.") + # 2. Map the connectors + .with_matching(prefix="model.diffusion_model.video_embeddings_connector.") + .with_replacement("model.diffusion_model.video_embeddings_connector.", "video_connector.") + .with_matching(prefix="model.diffusion_model.audio_embeddings_connector.") + .with_replacement("model.diffusion_model.audio_embeddings_connector.", "audio_connector.") +) + +VIDEO_ONLY_EMBEDDINGS_PROCESSOR_KEY_OPS = ( + SDOps("VIDEO_ONLY_EMBEDDINGS_PROCESSOR_KEY_OPS") + # 1. Map the feature extractor (V1: aggregate_embed inside feature_extractor) + .with_matching(prefix="text_embedding_projection.aggregate_embed.") + .with_replacement("text_embedding_projection.aggregate_embed.", "feature_extractor.aggregate_embed.") + # V2 video aggregate embed + .with_matching(prefix="text_embedding_projection.video_aggregate_embed.") + .with_replacement("text_embedding_projection.video_aggregate_embed.", "feature_extractor.video_aggregate_embed.") + # 2. Map the connectors + .with_matching(prefix="model.diffusion_model.embeddings_connector.") + .with_replacement("model.diffusion_model.embeddings_connector.", "embeddings_processor.video_connector.") +) + + +def create_and_populate(module: GemmaTextEncoder) -> GemmaTextEncoder: + model = module.model + v_model = model.model.vision_tower.vision_model + l_model = model.model.language_model + + config = model.config.text_config + dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + base = config.rope_local_base_freq + local_rope_freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(dtype=torch.float) / dim)) + inv_freqs, _ = ROPE_INIT_FUNCTIONS[config.rope_scaling["rope_type"]](config) + + positions_length = len(v_model.embeddings.position_ids[0]) + position_ids = torch.arange(positions_length, dtype=torch.long, device="cpu").unsqueeze(0) + v_model.embeddings.register_buffer("position_ids", position_ids) + embed_scale = torch.tensor(model.config.text_config.hidden_size**0.5, device="cpu") + l_model.embed_tokens.register_buffer("embed_scale", embed_scale) + l_model.rotary_emb_local.register_buffer("inv_freq", local_rope_freqs) + l_model.rotary_emb.register_buffer("inv_freq", inv_freqs) + + return module + + +GEMMA_MODEL_OPS = ModuleOps( + name="GemmaModel", + matcher=lambda module: hasattr(module, "model") and isinstance(module.model, Gemma3ForConditionalGeneration), + mutator=create_and_populate, +) diff --git a/ltx2/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt b/ltx2/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..0d677244f76bd0666c4b3a8c301779207e39e625 --- /dev/null +++ b/ltx2/ltx_core/text_encoders/gemma/encoders/prompts/gemma_i2v_system_prompt.txt @@ -0,0 +1,30 @@ +You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image. + +#### Guidelines: +- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood. +- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene). +- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts. +- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements. +- Chronological flow: Use temporal connectors ("as," "then," "while"). +- Audio layer: Describe complete soundscape throughout the prompt alongside actionsβ€”NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound"). +- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.") +- Style: Include visual style at beginning: "Style: