Fabrice-TIERCELIN commited on
Commit
60ef3c5
·
verified ·
1 Parent(s): d975146

Upload 7 files

Browse files
packages/ltx-core/src/ltx_core/model/upsampler/__init__.py CHANGED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Latent upsampler model components."""
2
+
3
+ from ltx_core.model.upsampler.model import LatentUpsampler, upsample_video
4
+ from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator
5
+
6
+ __all__ = [
7
+ "LatentUpsampler",
8
+ "LatentUpsamplerConfigurator",
9
+ "upsample_video",
10
+ ]
packages/ltx-core/src/ltx_core/model/upsampler/blur_downsample.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  import math
5
 
6
  import torch
 
 
 
 
1
  import math
2
 
3
  import torch
packages/ltx-core/src/ltx_core/model/upsampler/model.py CHANGED
@@ -1,19 +1,15 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
-
5
  import torch
6
  from einops import rearrange
7
 
8
  from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND
9
  from ltx_core.model.upsampler.res_block import ResBlock
10
  from ltx_core.model.upsampler.spatial_rational_resampler import SpatialRationalResampler
 
11
 
12
 
13
  class LatentUpsampler(torch.nn.Module):
14
  """
15
- Model to spatially upsample VAE latents.
16
-
17
  Args:
18
  in_channels (`int`): Number of channels in the input latent
19
  mid_channels (`int`): Number of channels in the middle layers
@@ -127,3 +123,20 @@ class LatentUpsampler(torch.nn.Module):
127
  x = self.final_conv(x)
128
 
129
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from einops import rearrange
3
 
4
  from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND
5
  from ltx_core.model.upsampler.res_block import ResBlock
6
  from ltx_core.model.upsampler.spatial_rational_resampler import SpatialRationalResampler
7
+ from ltx_core.model.video_vae import VideoEncoder
8
 
9
 
10
  class LatentUpsampler(torch.nn.Module):
11
  """
12
+ Model to upsample VAE latents spatially and/or temporally.
 
13
  Args:
14
  in_channels (`int`): Number of channels in the input latent
15
  mid_channels (`int`): Number of channels in the middle layers
 
123
  x = self.final_conv(x)
124
 
125
  return x
126
+
127
+
128
+ def upsample_video(latent: torch.Tensor, video_encoder: VideoEncoder, upsampler: "LatentUpsampler") -> torch.Tensor:
129
+ """
130
+ Apply upsampling to the latent representation using the provided upsampler,
131
+ with normalization and un-normalization based on the video encoder's per-channel statistics.
132
+ Args:
133
+ latent: Input latent tensor of shape [B, C, F, H, W].
134
+ video_encoder: VideoEncoder with per_channel_statistics for normalization.
135
+ upsampler: LatentUpsampler module to perform upsampling.
136
+ Returns:
137
+ torch.Tensor: Upsampled and re-normalized latent tensor.
138
+ """
139
+ latent = video_encoder.per_channel_statistics.un_normalize(latent)
140
+ latent = upsampler(latent)
141
+ latent = video_encoder.per_channel_statistics.normalize(latent)
142
+ return latent
packages/ltx-core/src/ltx_core/model/upsampler/model_configurator.py CHANGED
@@ -3,6 +3,11 @@ from ltx_core.model.upsampler.model import LatentUpsampler
3
 
4
 
5
  class LatentUpsamplerConfigurator(ModelConfigurator[LatentUpsampler]):
 
 
 
 
 
6
  @classmethod
7
  def from_config(cls: type[LatentUpsampler], config: dict) -> LatentUpsampler:
8
  in_channels = config.get("in_channels", 128)
 
3
 
4
 
5
  class LatentUpsamplerConfigurator(ModelConfigurator[LatentUpsampler]):
6
+ """
7
+ Configurator for LatentUpsampler model.
8
+ Used to create a LatentUpsampler model from a configuration dictionary.
9
+ """
10
+
11
  @classmethod
12
  def from_config(cls: type[LatentUpsampler], config: dict) -> LatentUpsampler:
13
  in_channels = config.get("in_channels", 128)
packages/ltx-core/src/ltx_core/model/upsampler/pixel_shuffle.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  import torch
5
  from einops import rearrange
6
 
@@ -8,7 +5,6 @@ from einops import rearrange
8
  class PixelShuffleND(torch.nn.Module):
9
  """
10
  N-dimensional pixel shuffle operation for upsampling tensors.
11
-
12
  Args:
13
  dims (int): Number of dimensions to apply pixel shuffle to.
14
  - 1: Temporal (e.g., frames)
@@ -18,11 +14,9 @@ class PixelShuffleND(torch.nn.Module):
18
  For dims=1, only the first value is used.
19
  For dims=2, the first two values are used.
20
  For dims=3, all three values are used.
21
-
22
  The input tensor is rearranged so that the channel dimension is split into
23
  smaller channels and upscaling factors, and the upscaling factors are moved
24
  into the corresponding spatial/temporal dimensions.
25
-
26
  Note:
27
  This operation is equivalent to the patchifier operation in for the models. Consider
28
  using this class instead.
 
 
 
 
1
  import torch
2
  from einops import rearrange
3
 
 
5
  class PixelShuffleND(torch.nn.Module):
6
  """
7
  N-dimensional pixel shuffle operation for upsampling tensors.
 
8
  Args:
9
  dims (int): Number of dimensions to apply pixel shuffle to.
10
  - 1: Temporal (e.g., frames)
 
14
  For dims=1, only the first value is used.
15
  For dims=2, the first two values are used.
16
  For dims=3, all three values are used.
 
17
  The input tensor is rearranged so that the channel dimension is split into
18
  smaller channels and upscaling factors, and the upscaling factors are moved
19
  into the corresponding spatial/temporal dimensions.
 
20
  Note:
21
  This operation is equivalent to the patchifier operation in for the models. Consider
22
  using this class instead.
packages/ltx-core/src/ltx_core/model/upsampler/res_block.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from typing import Optional
5
 
6
  import torch
 
 
 
 
1
  from typing import Optional
2
 
3
  import torch
packages/ltx-core/src/ltx_core/model/upsampler/spatial_rational_resampler.py CHANGED
@@ -1,6 +1,3 @@
1
- # Copyright (c) 2025 Lightricks. All rights reserved.
2
- # Created by Andrew Kvochko
3
-
4
  from typing import Tuple
5
 
6
  import torch
@@ -21,9 +18,7 @@ class SpatialRationalResampler(torch.nn.Module):
21
  """
22
  Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
23
  downsample by 'den' using fixed blur + stride. Operates on H,W only.
24
-
25
  For dims==3, work per-frame for spatial scaling (temporal axis untouched).
26
-
27
  Args:
28
  mid_channels (`int`): Number of intermediate channels for the convolution layer
29
  scale (`float`): Spatial scaling factor. Supported values are:
 
 
 
 
1
  from typing import Tuple
2
 
3
  import torch
 
18
  """
19
  Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
20
  downsample by 'den' using fixed blur + stride. Operates on H,W only.
 
21
  For dims==3, work per-frame for spatial scaling (temporal axis untouched).
 
22
  Args:
23
  mid_channels (`int`): Number of intermediate channels for the convolution layer
24
  scale (`float`): Spatial scaling factor. Supported values are: