Fabrice-TIERCELIN commited on
Commit
2bbc972
·
verified ·
1 Parent(s): b9fe529

Upload 4 files

Browse files
packages/ltx-core/src/ltx_core/tools.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+ from typing import Protocol
3
+
4
+ import torch
5
+ from torch._prims_common import DeviceLikeType
6
+
7
+ from ltx_core.components.patchifiers import (
8
+ AudioLatentShape,
9
+ AudioPatchifier,
10
+ VideoLatentPatchifier,
11
+ VideoLatentShape,
12
+ get_pixel_coords,
13
+ )
14
+ from ltx_core.components.protocols import Patchifier
15
+ from ltx_core.types import LatentState, SpatioTemporalScaleFactors
16
+
17
+ DEFAULT_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
18
+
19
+
20
+ class LatentTools(Protocol):
21
+ """
22
+ Tools for building latent states.
23
+ """
24
+
25
+ patchifier: Patchifier
26
+ target_shape: VideoLatentShape | AudioLatentShape
27
+
28
+ def create_initial_state(
29
+ self,
30
+ device: DeviceLikeType,
31
+ dtype: torch.dtype,
32
+ initial_latent: torch.Tensor | None = None,
33
+ ) -> LatentState:
34
+ """
35
+ Create an initial latent state. If initial_latent is provided, it will be used to create the latent state.
36
+ """
37
+ ...
38
+
39
+ def patchify(self, latent_state: LatentState) -> LatentState:
40
+ """
41
+ Patchify the latent state.
42
+ """
43
+ if latent_state.latent.shape != self.target_shape.to_torch_shape():
44
+ raise ValueError(
45
+ f"Latent state has shape {latent_state.latent.shape}, expected shape is "
46
+ f"{self.target_shape.to_torch_shape()}"
47
+ )
48
+ latent_state = latent_state.clone()
49
+ latent = self.patchifier.patchify(latent_state.latent)
50
+ clean_latent = self.patchifier.patchify(latent_state.clean_latent)
51
+ denoise_mask = self.patchifier.patchify(latent_state.denoise_mask)
52
+ return replace(latent_state, latent=latent, denoise_mask=denoise_mask, clean_latent=clean_latent)
53
+
54
+ def unpatchify(self, latent_state: LatentState) -> LatentState:
55
+ """
56
+ Unpatchify the latent state.
57
+ """
58
+ latent_state = latent_state.clone()
59
+ latent = self.patchifier.unpatchify(latent_state.latent, output_shape=self.target_shape)
60
+ clean_latent = self.patchifier.unpatchify(latent_state.clean_latent, output_shape=self.target_shape)
61
+ denoise_mask = self.patchifier.unpatchify(
62
+ latent_state.denoise_mask, output_shape=self.target_shape.mask_shape()
63
+ )
64
+ return replace(latent_state, latent=latent, denoise_mask=denoise_mask, clean_latent=clean_latent)
65
+
66
+ def clear_conditioning(self, latent_state: LatentState) -> LatentState:
67
+ """
68
+ Clear the conditioning from the latent state. This method removes extra tokens from the end of the latent.
69
+ Therefore, conditioning items should add extra tokens ONLY to the end of the latent.
70
+ """
71
+ latent_state = latent_state.clone()
72
+
73
+ num_tokens = self.patchifier.get_token_count(self.target_shape)
74
+ latent = latent_state.latent[:, :num_tokens]
75
+ clean_latent = latent_state.clean_latent[:, :num_tokens]
76
+ denoise_mask = torch.ones_like(latent_state.denoise_mask)[:, :num_tokens]
77
+ positions = latent_state.positions[:, :, :num_tokens]
78
+
79
+ return LatentState(latent=latent, denoise_mask=denoise_mask, positions=positions, clean_latent=clean_latent)
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class VideoLatentTools(LatentTools):
84
+ """
85
+ Tools for building video latent states.
86
+ """
87
+
88
+ patchifier: VideoLatentPatchifier
89
+ target_shape: VideoLatentShape
90
+ fps: float
91
+ scale_factors: SpatioTemporalScaleFactors = DEFAULT_SCALE_FACTORS
92
+ causal_fix: bool = True
93
+
94
+ def create_initial_state(
95
+ self,
96
+ device: DeviceLikeType,
97
+ dtype: torch.dtype,
98
+ initial_latent: torch.Tensor | None = None,
99
+ ) -> LatentState:
100
+ if initial_latent is not None:
101
+ assert initial_latent.shape == self.target_shape.to_torch_shape(), (
102
+ f"Latent shape {initial_latent.shape} does not match target shape {self.target_shape.to_torch_shape()}"
103
+ )
104
+ else:
105
+ initial_latent = torch.zeros(
106
+ *self.target_shape.to_torch_shape(),
107
+ device=device,
108
+ dtype=dtype,
109
+ )
110
+
111
+ clean_latent = initial_latent.clone()
112
+
113
+ denoise_mask = torch.ones(
114
+ *self.target_shape.mask_shape().to_torch_shape(),
115
+ device=device,
116
+ dtype=torch.float32,
117
+ )
118
+
119
+ latent_coords = self.patchifier.get_patch_grid_bounds(
120
+ output_shape=self.target_shape,
121
+ device=device,
122
+ )
123
+
124
+ positions = get_pixel_coords(
125
+ latent_coords=latent_coords,
126
+ scale_factors=self.scale_factors,
127
+ causal_fix=self.causal_fix,
128
+ ).float()
129
+ positions[:, 0, ...] = positions[:, 0, ...] / self.fps
130
+
131
+ return self.patchify(
132
+ LatentState(
133
+ latent=initial_latent,
134
+ denoise_mask=denoise_mask,
135
+ positions=positions.to(dtype),
136
+ clean_latent=clean_latent,
137
+ )
138
+ )
139
+
140
+
141
+ @dataclass(frozen=True)
142
+ class AudioLatentTools(LatentTools):
143
+ """
144
+ Tools for building audio latent states.
145
+ """
146
+
147
+ patchifier: AudioPatchifier
148
+ target_shape: AudioLatentShape
149
+
150
+ def create_initial_state(
151
+ self,
152
+ device: DeviceLikeType,
153
+ dtype: torch.dtype,
154
+ initial_latent: torch.Tensor | None = None,
155
+ ) -> LatentState:
156
+ if initial_latent is not None:
157
+ assert initial_latent.shape == self.target_shape.to_torch_shape(), (
158
+ f"Latent shape {initial_latent.shape} does not match target shape {self.target_shape.to_torch_shape()}"
159
+ )
160
+ else:
161
+ initial_latent = torch.zeros(
162
+ *self.target_shape.to_torch_shape(),
163
+ device=device,
164
+ dtype=dtype,
165
+ )
166
+
167
+ clean_latent = initial_latent.clone()
168
+
169
+ denoise_mask = torch.ones(
170
+ *self.target_shape.mask_shape().to_torch_shape(),
171
+ device=device,
172
+ dtype=torch.float32,
173
+ )
174
+
175
+ latent_coords = self.patchifier.get_patch_grid_bounds(
176
+ output_shape=self.target_shape,
177
+ device=device,
178
+ )
179
+
180
+ return self.patchify(
181
+ LatentState(
182
+ latent=initial_latent, denoise_mask=denoise_mask, positions=latent_coords, clean_latent=clean_latent
183
+ )
184
+ )
packages/ltx-core/src/ltx_core/types.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import NamedTuple
3
+
4
+ import torch
5
+
6
+
7
+ class VideoPixelShape(NamedTuple):
8
+ """
9
+ Shape of the tensor representing the video pixel array. Assumes BGR channel format.
10
+ """
11
+
12
+ batch: int
13
+ frames: int
14
+ height: int
15
+ width: int
16
+ fps: float
17
+
18
+
19
+ class SpatioTemporalScaleFactors(NamedTuple):
20
+ """
21
+ Describes the spatiotemporal downscaling between decoded video space and
22
+ the corresponding VAE latent grid.
23
+ """
24
+
25
+ time: int
26
+ width: int
27
+ height: int
28
+
29
+ @classmethod
30
+ def default(cls) -> "SpatioTemporalScaleFactors":
31
+ return cls(time=8, width=32, height=32)
32
+
33
+
34
+ VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
35
+
36
+
37
+ class VideoLatentShape(NamedTuple):
38
+ """
39
+ Shape of the tensor representing video in VAE latent space.
40
+ The latent representation is a 5D tensor with dimensions ordered as
41
+ (batch, channels, frames, height, width). Spatial and temporal dimensions
42
+ are downscaled relative to pixel space according to the VAE's scale factors.
43
+ """
44
+
45
+ batch: int
46
+ channels: int
47
+ frames: int
48
+ height: int
49
+ width: int
50
+
51
+ def to_torch_shape(self) -> torch.Size:
52
+ return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
53
+
54
+ @staticmethod
55
+ def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
56
+ return VideoLatentShape(
57
+ batch=shape[0],
58
+ channels=shape[1],
59
+ frames=shape[2],
60
+ height=shape[3],
61
+ width=shape[4],
62
+ )
63
+
64
+ def mask_shape(self) -> "VideoLatentShape":
65
+ return self._replace(channels=1)
66
+
67
+ @staticmethod
68
+ def from_pixel_shape(
69
+ shape: VideoPixelShape,
70
+ latent_channels: int = 128,
71
+ scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
72
+ ) -> "VideoLatentShape":
73
+ frames = (shape.frames - 1) // scale_factors[0] + 1
74
+ height = shape.height // scale_factors[1]
75
+ width = shape.width // scale_factors[2]
76
+
77
+ return VideoLatentShape(
78
+ batch=shape.batch,
79
+ channels=latent_channels,
80
+ frames=frames,
81
+ height=height,
82
+ width=width,
83
+ )
84
+
85
+ def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
86
+ return self._replace(
87
+ channels=3,
88
+ frames=(self.frames - 1) * scale_factors.time + 1,
89
+ height=self.height * scale_factors.height,
90
+ width=self.width * scale_factors.width,
91
+ )
92
+
93
+
94
+ class AudioLatentShape(NamedTuple):
95
+ """
96
+ Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
97
+ mel_bins is the number of frequency bins from the mel-spectrogram encoding.
98
+ """
99
+
100
+ batch: int
101
+ channels: int
102
+ frames: int
103
+ mel_bins: int
104
+
105
+ def to_torch_shape(self) -> torch.Size:
106
+ return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
107
+
108
+ def mask_shape(self) -> "AudioLatentShape":
109
+ return self._replace(channels=1, mel_bins=1)
110
+
111
+ @staticmethod
112
+ def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
113
+ return AudioLatentShape(
114
+ batch=shape[0],
115
+ channels=shape[1],
116
+ frames=shape[2],
117
+ mel_bins=shape[3],
118
+ )
119
+
120
+ @staticmethod
121
+ def from_duration(
122
+ batch: int,
123
+ duration: float,
124
+ channels: int = 8,
125
+ mel_bins: int = 16,
126
+ sample_rate: int = 16000,
127
+ hop_length: int = 160,
128
+ audio_latent_downsample_factor: int = 4,
129
+ ) -> "AudioLatentShape":
130
+ latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
131
+
132
+ return AudioLatentShape(
133
+ batch=batch,
134
+ channels=channels,
135
+ frames=round(duration * latents_per_second),
136
+ mel_bins=mel_bins,
137
+ )
138
+
139
+ @staticmethod
140
+ def from_video_pixel_shape(
141
+ shape: VideoPixelShape,
142
+ channels: int = 8,
143
+ mel_bins: int = 16,
144
+ sample_rate: int = 16000,
145
+ hop_length: int = 160,
146
+ audio_latent_downsample_factor: int = 4,
147
+ ) -> "AudioLatentShape":
148
+ return AudioLatentShape.from_duration(
149
+ batch=shape.batch,
150
+ duration=float(shape.frames) / float(shape.fps),
151
+ channels=channels,
152
+ mel_bins=mel_bins,
153
+ sample_rate=sample_rate,
154
+ hop_length=hop_length,
155
+ audio_latent_downsample_factor=audio_latent_downsample_factor,
156
+ )
157
+
158
+
159
+ @dataclass(frozen=True)
160
+ class LatentState:
161
+ """
162
+ State of latents during the diffusion denoising process.
163
+ Attributes:
164
+ latent: The current noisy latent tensor being denoised.
165
+ denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
166
+ positions: Positional indices for each latent element, used for positional embeddings.
167
+ clean_latent: Initial state of the latent before denoising, may include conditioning latents.
168
+ """
169
+
170
+ latent: torch.Tensor
171
+ denoise_mask: torch.Tensor
172
+ positions: torch.Tensor
173
+ clean_latent: torch.Tensor
174
+
175
+ def clone(self) -> "LatentState":
176
+ return LatentState(
177
+ latent=self.latent.clone(),
178
+ denoise_mask=self.denoise_mask.clone(),
179
+ positions=self.positions.clone(),
180
+ clean_latent=self.clean_latent.clone(),
181
+ )
packages/ltx-core/src/ltx_core/utils.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Amit Pintz.
3
-
4
  from typing import Any
5
 
6
  import torch
@@ -8,7 +5,6 @@ import torch
8
 
9
  def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
10
  """Root-mean-square (RMS) normalize `x` over its last dimension.
11
-
12
  Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
13
  shape and forwards `weight` and `eps`.
14
  """
@@ -29,7 +25,6 @@ def to_velocity(
29
  ) -> torch.Tensor:
30
  """
31
  Convert the sample and its denoised version to velocity.
32
-
33
  Returns:
34
  Velocity
35
  """
@@ -48,7 +43,6 @@ def to_denoised(
48
  ) -> torch.Tensor:
49
  """
50
  Convert the sample and its denoising velocity to denoised sample.
51
-
52
  Returns:
53
  Denoised sample
54
  """
 
 
 
 
1
  from typing import Any
2
 
3
  import torch
 
5
 
6
  def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
7
  """Root-mean-square (RMS) normalize `x` over its last dimension.
 
8
  Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
9
  shape and forwards `weight` and `eps`.
10
  """
 
25
  ) -> torch.Tensor:
26
  """
27
  Convert the sample and its denoised version to velocity.
 
28
  Returns:
29
  Velocity
30
  """
 
43
  ) -> torch.Tensor:
44
  """
45
  Convert the sample and its denoising velocity to denoised sample.
 
46
  Returns:
47
  Denoised sample
48
  """