Fabrice-TIERCELIN commited on
Commit
114d84a
·
verified ·
1 Parent(s): 0e52bb3

Upload 2 files

Browse files
packages/ltx-core/src/ltx_core/guidance/__init__.py CHANGED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Guidance and perturbation utilities for attention manipulation."""
2
+
3
+ from ltx_core.guidance.perturbations import (
4
+ BatchedPerturbationConfig,
5
+ Perturbation,
6
+ PerturbationConfig,
7
+ PerturbationType,
8
+ )
9
+
10
+ __all__ = [
11
+ "BatchedPerturbationConfig",
12
+ "Perturbation",
13
+ "PerturbationConfig",
14
+ "PerturbationType",
15
+ ]
packages/ltx-core/src/ltx_core/guidance/perturbations.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from dataclasses import dataclass
5
  from enum import Enum
6
 
@@ -9,6 +6,8 @@ from torch._prims_common import DeviceLikeType
9
 
10
 
11
  class PerturbationType(Enum):
 
 
12
  SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
13
  SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
14
  SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
@@ -17,6 +16,8 @@ class PerturbationType(Enum):
17
 
18
  @dataclass(frozen=True)
19
  class Perturbation:
 
 
20
  type: PerturbationType
21
  blocks: list[int] | None # None means all blocks
22
 
@@ -32,6 +33,8 @@ class Perturbation:
32
 
33
  @dataclass(frozen=True)
34
  class PerturbationConfig:
 
 
35
  perturbations: list[Perturbation] | None
36
 
37
  def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
@@ -47,6 +50,8 @@ class PerturbationConfig:
47
 
48
  @dataclass(frozen=True)
49
  class BatchedPerturbationConfig:
 
 
50
  perturbations: list[PerturbationConfig]
51
 
52
  def mask(
 
 
 
 
1
  from dataclasses import dataclass
2
  from enum import Enum
3
 
 
6
 
7
 
8
  class PerturbationType(Enum):
9
+ """Types of attention perturbations for STG (Spatio-Temporal Guidance)."""
10
+
11
  SKIP_A2V_CROSS_ATTN = "skip_a2v_cross_attn"
12
  SKIP_V2A_CROSS_ATTN = "skip_v2a_cross_attn"
13
  SKIP_VIDEO_SELF_ATTN = "skip_video_self_attn"
 
16
 
17
  @dataclass(frozen=True)
18
  class Perturbation:
19
+ """A single perturbation specifying which attention type to skip and in which blocks."""
20
+
21
  type: PerturbationType
22
  blocks: list[int] | None # None means all blocks
23
 
 
33
 
34
  @dataclass(frozen=True)
35
  class PerturbationConfig:
36
+ """Configuration holding a list of perturbations for a single sample."""
37
+
38
  perturbations: list[Perturbation] | None
39
 
40
  def is_perturbed(self, perturbation_type: PerturbationType, block: int) -> bool:
 
50
 
51
  @dataclass(frozen=True)
52
  class BatchedPerturbationConfig:
53
+ """Perturbation configurations for a batch, with utilities for generating attention masks."""
54
+
55
  perturbations: list[PerturbationConfig]
56
 
57
  def mask(