| import json |
| import os |
|
|
| from huggingface_hub import snapshot_download |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
| from sglang.multimodal_gen.runtime.utils.model_overlay import ( |
| _copytree_link_or_copy, |
| _ensure_dir, |
| _link_or_copy_file, |
| ) |
|
|
| AUXILIARY_MODEL_ID = "Lightricks/LTX-2" |
| CONFIG_DONOR_MODEL_ID = "FastVideo/LTX-2.3-Distilled-Diffusers" |
|
|
| AUXILIARY_PATTERNS = [ |
| "audio_vae/**", |
| "scheduler/**", |
| "text_encoder/**", |
| "tokenizer/**", |
| "vae/config.json", |
| "vae/diffusion_pytorch_model.safetensors", |
| ] |
|
|
| CONFIG_DONOR_PATTERNS = [ |
| "transformer/config.json", |
| "text_encoder/config.json", |
| "vae/**", |
| "vocoder/**", |
| ] |
|
|
| MONOLITH_PREFIX = "model.diffusion_model." |
| VIDEO_CONNECTOR_PREFIX = f"{MONOLITH_PREFIX}video_embeddings_connector." |
| AUDIO_CONNECTOR_PREFIX = f"{MONOLITH_PREFIX}audio_embeddings_connector." |
| TEXT_PROJ_IN_PREFIX = f"{MONOLITH_PREFIX}text_proj_in." |
| VIDEO_AGGREGATE_PREFIX = "text_embedding_projection.video_aggregate_embed." |
| AUDIO_AGGREGATE_PREFIX = "text_embedding_projection.audio_aggregate_embed." |
|
|
|
|
| def _load_json(path: str) -> dict: |
| with open(path) as f: |
| return json.load(f) |
|
|
|
|
| def _write_json(path: str, payload: dict) -> None: |
| with open(path, "w") as f: |
| json.dump(payload, f, indent=2) |
| f.write("\n") |
|
|
|
|
| def _rename_connector_key(key: str) -> str | None: |
| if key.startswith(VIDEO_CONNECTOR_PREFIX): |
| suffix = key[len(VIDEO_CONNECTOR_PREFIX) :] |
| suffix = suffix.replace("transformer_1d_blocks", "transformer_blocks") |
| suffix = suffix.replace(".attn1.q_norm.", ".attn1.norm_q.") |
| suffix = suffix.replace(".attn1.k_norm.", ".attn1.norm_k.") |
| return f"video_connector.{suffix}" |
| if key.startswith(AUDIO_CONNECTOR_PREFIX): |
| suffix = key[len(AUDIO_CONNECTOR_PREFIX) :] |
| suffix = suffix.replace("transformer_1d_blocks", "transformer_blocks") |
| suffix = suffix.replace(".attn1.q_norm.", ".attn1.norm_q.") |
| suffix = suffix.replace(".attn1.k_norm.", ".attn1.norm_k.") |
| return f"audio_connector.{suffix}" |
| if key.startswith(TEXT_PROJ_IN_PREFIX): |
| return key[len(MONOLITH_PREFIX) :] |
| if key.startswith(VIDEO_AGGREGATE_PREFIX): |
| return f"video_aggregate_embed.{key[len(VIDEO_AGGREGATE_PREFIX):]}" |
| if key.startswith(AUDIO_AGGREGATE_PREFIX): |
| return f"audio_aggregate_embed.{key[len(AUDIO_AGGREGATE_PREFIX):]}" |
| return None |
|
|
|
|
| def _repack_transformer_weights(source_path: str, output_path: str) -> None: |
| tensors = {} |
| with safe_open(source_path, framework="pt") as f: |
| for key in f.keys(): |
| if not key.startswith(MONOLITH_PREFIX): |
| continue |
| if key.startswith(VIDEO_CONNECTOR_PREFIX): |
| continue |
| if key.startswith(AUDIO_CONNECTOR_PREFIX): |
| continue |
| if key.startswith(TEXT_PROJ_IN_PREFIX): |
| continue |
| tensors[key[len(MONOLITH_PREFIX) :]] = f.get_tensor(key) |
| if not tensors: |
| raise ValueError("No transformer tensors found in LTX-2.3 source checkpoint.") |
| save_file(tensors, output_path) |
|
|
|
|
| def _repack_connectors_weights(source_path: str, output_path: str) -> None: |
| tensors = {} |
| with safe_open(source_path, framework="pt") as f: |
| for key in f.keys(): |
| renamed = _rename_connector_key(key) |
| if renamed is None: |
| continue |
| tensors[renamed] = f.get_tensor(key) |
| if not tensors: |
| raise ValueError("No connector tensors found in LTX-2.3 source checkpoint.") |
| save_file(tensors, output_path) |
|
|
|
|
| def _build_transformer_config(config_donor_dir: str) -> dict: |
| config = _load_json(os.path.join(config_donor_dir, "transformer", "config.json")) |
| config["_class_name"] = "LTX2VideoTransformer3DModel" |
| config["force_sdpa_v2a_cross_attention"] = True |
| config["quantize_video_rope_coords_to_hidden_dtype"] = True |
| return config |
|
|
|
|
| def _build_connectors_config(config_donor_dir: str) -> dict: |
| text_encoder_config = _load_json( |
| os.path.join(config_donor_dir, "text_encoder", "config.json") |
| ) |
| return { |
| "_class_name": "LTX2TextConnectors", |
| "_diffusers_version": "0.37.0.dev0", |
| "audio_connector_attention_head_dim": text_encoder_config[ |
| "audio_connector_attention_head_dim" |
| ], |
| "audio_connector_num_attention_heads": text_encoder_config[ |
| "audio_connector_num_attention_heads" |
| ], |
| "audio_connector_num_layers": text_encoder_config["audio_connector_num_layers"], |
| "audio_connector_num_learnable_registers": text_encoder_config[ |
| "connector_num_learnable_registers" |
| ], |
| "audio_feature_extractor_out_features": text_encoder_config[ |
| "audio_feature_extractor_out_features" |
| ], |
| "caption_channels": text_encoder_config["hidden_size"], |
| "causal_temporal_positioning": False, |
| "connector_apply_gated_attention": text_encoder_config[ |
| "connector_apply_gated_attention" |
| ], |
| "feature_extractor_in_features": text_encoder_config[ |
| "feature_extractor_in_features" |
| ], |
| "connector_rope_base_seq_len": text_encoder_config[ |
| "connector_positional_embedding_max_pos" |
| ][0], |
| "rope_double_precision": text_encoder_config["connector_double_precision_rope"], |
| "rope_theta": text_encoder_config["connector_positional_embedding_theta"], |
| "rope_type": text_encoder_config["connector_rope_type"], |
| "text_proj_in_factor": text_encoder_config["feature_extractor_in_features"] |
| // text_encoder_config["hidden_size"], |
| "video_feature_extractor_out_features": text_encoder_config[ |
| "video_feature_extractor_out_features" |
| ], |
| "video_connector_attention_head_dim": text_encoder_config[ |
| "connector_attention_head_dim" |
| ], |
| "video_connector_num_attention_heads": text_encoder_config[ |
| "connector_num_attention_heads" |
| ], |
| "video_connector_num_layers": text_encoder_config["connector_num_layers"], |
| "video_connector_num_learnable_registers": text_encoder_config[ |
| "connector_num_learnable_registers" |
| ], |
| } |
|
|
|
|
| def _build_vae_config(auxiliary_dir: str, config_donor_dir: str) -> dict: |
| config = _load_json(os.path.join(auxiliary_dir, "vae", "config.json")) |
| config["ltx_variant"] = "ltx_2_3" |
| config["condition_encoder_subdir"] = "ltx23_image_encoder" |
| config["video_decoder_variant"] = "ltx_2_3" |
| config["video_decoder_config"] = _load_json( |
| os.path.join(config_donor_dir, "vae", "config.json") |
| )["vae"] |
| return config |
|
|
|
|
| def _repack_ltx23_image_encoder_weights(source_path: str, output_path: str) -> None: |
| tensors = {} |
| with safe_open(source_path, framework="pt") as f: |
| for key in f.keys(): |
| if key.startswith("encoder."): |
| tensors[key[len("encoder.") :]] = f.get_tensor(key) |
| continue |
| if key.startswith("per_channel_statistics."): |
| tensors[key] = f.get_tensor(key) |
| if not tensors: |
| raise ValueError("No LTX-2.3 image-encoder tensors found in donor checkpoint.") |
| save_file(tensors, output_path) |
|
|
|
|
| def _repack_ltx23_video_decoder_weights( |
| auxiliary_encoder_path: str, |
| donor_decoder_path: str, |
| output_path: str, |
| ) -> None: |
| tensors = {} |
| with safe_open(auxiliary_encoder_path, framework="pt") as f: |
| for key in f.keys(): |
| if key.startswith("encoder."): |
| tensors[key] = f.get_tensor(key) |
| with safe_open(donor_decoder_path, framework="pt") as f: |
| for key in f.keys(): |
| if key.startswith("decoder."): |
| tensors[key] = f.get_tensor(key) |
| continue |
| if key == "per_channel_statistics.mean-of-means": |
| tensor = f.get_tensor(key) |
| tensors["decoder.per_channel_statistics.mean_of_means"] = tensor |
| tensors["latents_mean"] = tensor.clone() |
| continue |
| if key == "per_channel_statistics.std-of-means": |
| tensor = f.get_tensor(key) |
| tensors["decoder.per_channel_statistics.std_of_means"] = tensor |
| tensors["latents_std"] = tensor.clone() |
| continue |
| if not tensors: |
| raise ValueError("No LTX-2.3 decoder tensors found in donor checkpoint.") |
| save_file(tensors, output_path) |
|
|
|
|
| def materialize( |
| *, |
| overlay_dir: str, |
| source_dir: str, |
| output_dir: str, |
| manifest: dict, |
| ) -> None: |
| _ = overlay_dir, manifest |
|
|
| auxiliary_dir = snapshot_download( |
| repo_id=AUXILIARY_MODEL_ID, |
| allow_patterns=AUXILIARY_PATTERNS, |
| max_workers=8, |
| ) |
| config_donor_dir = snapshot_download( |
| repo_id=CONFIG_DONOR_MODEL_ID, |
| allow_patterns=CONFIG_DONOR_PATTERNS, |
| max_workers=8, |
| ) |
|
|
| for component_name in ("audio_vae", "scheduler", "text_encoder", "tokenizer"): |
| _copytree_link_or_copy( |
| os.path.join(auxiliary_dir, component_name), |
| os.path.join(output_dir, component_name), |
| ) |
| _copytree_link_or_copy( |
| os.path.join(config_donor_dir, "vocoder"), |
| os.path.join(output_dir, "vocoder"), |
| ) |
|
|
| source_checkpoint = os.path.join(source_dir, "ltx-2.3-22b-dev.safetensors") |
|
|
| transformer_dir = os.path.join(output_dir, "transformer") |
| _ensure_dir(transformer_dir) |
| _write_json( |
| os.path.join(transformer_dir, "config.json"), |
| _build_transformer_config(config_donor_dir), |
| ) |
| _repack_transformer_weights( |
| source_checkpoint, os.path.join(transformer_dir, "model.safetensors") |
| ) |
|
|
| connectors_dir = os.path.join(output_dir, "connectors") |
| _ensure_dir(connectors_dir) |
| _write_json( |
| os.path.join(connectors_dir, "config.json"), |
| _build_connectors_config(config_donor_dir), |
| ) |
| _repack_connectors_weights( |
| source_checkpoint, os.path.join(connectors_dir, "model.safetensors") |
| ) |
|
|
| vae_dir = os.path.join(output_dir, "vae") |
| _ensure_dir(vae_dir) |
| _write_json( |
| os.path.join(vae_dir, "config.json"), |
| _build_vae_config(auxiliary_dir, config_donor_dir), |
| ) |
| _repack_ltx23_video_decoder_weights( |
| os.path.join(auxiliary_dir, "vae", "diffusion_pytorch_model.safetensors"), |
| os.path.join(config_donor_dir, "vae", "model.safetensors"), |
| os.path.join(vae_dir, "model.safetensors"), |
| ) |
|
|
| image_encoder_dir = os.path.join(vae_dir, "ltx23_image_encoder") |
| _ensure_dir(image_encoder_dir) |
| _link_or_copy_file( |
| os.path.join(config_donor_dir, "vae", "config.json"), |
| os.path.join(image_encoder_dir, "config.json"), |
| ) |
| _repack_ltx23_image_encoder_weights( |
| os.path.join(config_donor_dir, "vae", "model.safetensors"), |
| os.path.join(image_encoder_dir, "model.safetensors"), |
| ) |
|
|
| _link_or_copy_file( |
| os.path.join(source_dir, "ltx-2.3-22b-distilled-lora-384.safetensors"), |
| os.path.join(output_dir, "ltx-2.3-22b-distilled-lora-384.safetensors"), |
| ) |
| _link_or_copy_file( |
| os.path.join(source_dir, "ltx-2.3-spatial-upscaler-x2-1.1.safetensors"), |
| os.path.join(output_dir, "ltx-2.3-spatial-upscaler-x2-1.1.safetensors"), |
| ) |
|
|