Upload 2 files
Browse files
packages/ltx-core/src/ltx_core/guidance/__init__.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Guidance and perturbation utilities for attention manipulation."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.guidance.perturbations import (
|
| 4 |
+
BatchedPerturbationConfig,
|
| 5 |
+
Perturbation,
|
| 6 |
+
PerturbationConfig,
|
| 7 |
+
PerturbationType,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"BatchedPerturbationConfig",
|
| 12 |
+
"Perturbation",
|
| 13 |
+
"PerturbationConfig",
|
| 14 |
+
"PerturbationType",
|
| 15 |
+
]
|
packages/ltx-core/src/ltx_core/guidance/perturbations.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from enum import Enum
|
| 6 |
|
|
@@ -9,6 +6,8 @@ from torch._prims_common import DeviceLikeType
|
|
| 9 |
|
| 10 |
|
| 11 |
class PerturbationType(Enum):
|
|
|
|
|
|
|
| 12 |
SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
|
| 13 |
SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
|
| 14 |
SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
|
|
@@ -17,6 +16,8 @@ class PerturbationType(Enum):
|
|
| 17 |
|
| 18 |
@dataclass(frozen=True)
|
| 19 |
class Perturbation:
|
|
|
|
|
|
|
| 20 |
type: PerturbationType
|
| 21 |
blocks: list[int] | None # None means all blocks
|
| 22 |
|
|
@@ -32,6 +33,8 @@ class Perturbation:
|
|
| 32 |
|
| 33 |
@dataclass(frozen=True)
|
| 34 |
class PerturbationConfig:
|
|
|
|
|
|
|
| 35 |
perturbations: list[Perturbation] | None
|
| 36 |
|
| 37 |
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
|
@@ -47,6 +50,8 @@ class PerturbationConfig:
|
|
| 47 |
|
| 48 |
@dataclass(frozen=True)
|
| 49 |
class BatchedPerturbationConfig:
|
|
|
|
|
|
|
| 50 |
perturbations: list[PerturbationConfig]
|
| 51 |
|
| 52 |
def mask(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from enum import Enum
|
| 3 |
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class PerturbationType(Enum):
|
| 9 |
+
"""Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
|
| 10 |
+
|
| 11 |
SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
|
| 12 |
SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
|
| 13 |
SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
|
|
|
|
| 16 |
|
| 17 |
@dataclass(frozen=True)
|
| 18 |
class Perturbation:
|
| 19 |
+
"""A single perturbation specifying which attention type to skip and in which blocks."""
|
| 20 |
+
|
| 21 |
type: PerturbationType
|
| 22 |
blocks: list[int] | None # None means all blocks
|
| 23 |
|
|
|
|
| 33 |
|
| 34 |
@dataclass(frozen=True)
|
| 35 |
class PerturbationConfig:
|
| 36 |
+
"""Configuration holding a list of perturbations for a single sample."""
|
| 37 |
+
|
| 38 |
perturbations: list[Perturbation] | None
|
| 39 |
|
| 40 |
def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
|
|
|
|
| 50 |
|
| 51 |
@dataclass(frozen=True)
|
| 52 |
class BatchedPerturbationConfig:
|
| 53 |
+
"""Perturbation configurations for a batch, with utilities for generating attention masks."""
|
| 54 |
+
|
| 55 |
perturbations: list[PerturbationConfig]
|
| 56 |
|
| 57 |
def mask(
|