Upload 7 files
Browse files- packages/ltx-core/src/ltx_core/model/upsampler/__init__.py +10 -0
- packages/ltx-core/src/ltx_core/model/upsampler/blur_downsample.py +0 -3
- packages/ltx-core/src/ltx_core/model/upsampler/model.py +19 -6
- packages/ltx-core/src/ltx_core/model/upsampler/model_configurator.py +5 -0
- packages/ltx-core/src/ltx_core/model/upsampler/pixel_shuffle.py +0 -6
- packages/ltx-core/src/ltx_core/model/upsampler/res_block.py +0 -3
- packages/ltx-core/src/ltx_core/model/upsampler/spatial_rational_resampler.py +0 -5
packages/ltx-core/src/ltx_core/model/upsampler/__init__.py
CHANGED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Latent upsampler model components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.upsampler.model import LatentUpsampler, upsample_video
|
| 4 |
+
from ltx_core.model.upsampler.model_configurator import LatentUpsamplerConfigurator
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"LatentUpsampler",
|
| 8 |
+
"LatentUpsamplerConfigurator",
|
| 9 |
+
"upsample_video",
|
| 10 |
+
]
|
packages/ltx-core/src/ltx_core/model/upsampler/blur_downsample.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
import math
|
| 5 |
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
|
| 3 |
import torch
|
packages/ltx-core/src/ltx_core/model/upsampler/model.py
CHANGED
|
@@ -1,19 +1,15 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
-
|
| 5 |
import torch
|
| 6 |
from einops import rearrange
|
| 7 |
|
| 8 |
from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND
|
| 9 |
from ltx_core.model.upsampler.res_block import ResBlock
|
| 10 |
from ltx_core.model.upsampler.spatial_rational_resampler import SpatialRationalResampler
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class LatentUpsampler(torch.nn.Module):
|
| 14 |
"""
|
| 15 |
-
Model to
|
| 16 |
-
|
| 17 |
Args:
|
| 18 |
in_channels (`int`): Number of channels in the input latent
|
| 19 |
mid_channels (`int`): Number of channels in the middle layers
|
|
@@ -127,3 +123,20 @@ class LatentUpsampler(torch.nn.Module):
|
|
| 127 |
x = self.final_conv(x)
|
| 128 |
|
| 129 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from einops import rearrange
|
| 3 |
|
| 4 |
from ltx_core.model.upsampler.pixel_shuffle import PixelShuffleND
|
| 5 |
from ltx_core.model.upsampler.res_block import ResBlock
|
| 6 |
from ltx_core.model.upsampler.spatial_rational_resampler import SpatialRationalResampler
|
| 7 |
+
from ltx_core.model.video_vae import VideoEncoder
|
| 8 |
|
| 9 |
|
| 10 |
class LatentUpsampler(torch.nn.Module):
|
| 11 |
"""
|
| 12 |
+
Model to upsample VAE latents spatially and/or temporally.
|
|
|
|
| 13 |
Args:
|
| 14 |
in_channels (`int`): Number of channels in the input latent
|
| 15 |
mid_channels (`int`): Number of channels in the middle layers
|
|
|
|
| 123 |
x = self.final_conv(x)
|
| 124 |
|
| 125 |
return x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def upsample_video(latent: torch.Tensor, video_encoder: VideoEncoder, upsampler: "LatentUpsampler") -> torch.Tensor:
|
| 129 |
+
"""
|
| 130 |
+
Apply upsampling to the latent representation using the provided upsampler,
|
| 131 |
+
with normalization and un-normalization based on the video encoder's per-channel statistics.
|
| 132 |
+
Args:
|
| 133 |
+
latent: Input latent tensor of shape [B, C, F, H, W].
|
| 134 |
+
video_encoder: VideoEncoder with per_channel_statistics for normalization.
|
| 135 |
+
upsampler: LatentUpsampler module to perform upsampling.
|
| 136 |
+
Returns:
|
| 137 |
+
torch.Tensor: Upsampled and re-normalized latent tensor.
|
| 138 |
+
"""
|
| 139 |
+
latent = video_encoder.per_channel_statistics.un_normalize(latent)
|
| 140 |
+
latent = upsampler(latent)
|
| 141 |
+
latent = video_encoder.per_channel_statistics.normalize(latent)
|
| 142 |
+
return latent
|
packages/ltx-core/src/ltx_core/model/upsampler/model_configurator.py
CHANGED
|
@@ -3,6 +3,11 @@ from ltx_core.model.upsampler.model import LatentUpsampler
|
|
| 3 |
|
| 4 |
|
| 5 |
class LatentUpsamplerConfigurator(ModelConfigurator[LatentUpsampler]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
@classmethod
|
| 7 |
def from_config(cls: type[LatentUpsampler], config: dict) -> LatentUpsampler:
|
| 8 |
in_channels = config.get("in_channels", 128)
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class LatentUpsamplerConfigurator(ModelConfigurator[LatentUpsampler]):
|
| 6 |
+
"""
|
| 7 |
+
Configurator for LatentUpsampler model.
|
| 8 |
+
Used to create a LatentUpsampler model from a configuration dictionary.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
@classmethod
|
| 12 |
def from_config(cls: type[LatentUpsampler], config: dict) -> LatentUpsampler:
|
| 13 |
in_channels = config.get("in_channels", 128)
|
packages/ltx-core/src/ltx_core/model/upsampler/pixel_shuffle.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
from einops import rearrange
|
| 6 |
|
|
@@ -8,7 +5,6 @@ from einops import rearrange
|
|
| 8 |
class PixelShuffleND(torch.nn.Module):
|
| 9 |
"""
|
| 10 |
N-dimensional pixel shuffle operation for upsampling tensors.
|
| 11 |
-
|
| 12 |
Args:
|
| 13 |
dims (int): Number of dimensions to apply pixel shuffle to.
|
| 14 |
- 1: Temporal (e.g., frames)
|
|
@@ -18,11 +14,9 @@ class PixelShuffleND(torch.nn.Module):
|
|
| 18 |
For dims=1, only the first value is used.
|
| 19 |
For dims=2, the first two values are used.
|
| 20 |
For dims=3, all three values are used.
|
| 21 |
-
|
| 22 |
The input tensor is rearranged so that the channel dimension is split into
|
| 23 |
smaller channels and upscaling factors, and the upscaling factors are moved
|
| 24 |
into the corresponding spatial/temporal dimensions.
|
| 25 |
-
|
| 26 |
Note:
|
| 27 |
This operation is equivalent to the patchifier operation in for the models. Consider
|
| 28 |
using this class instead.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from einops import rearrange
|
| 3 |
|
|
|
|
| 5 |
class PixelShuffleND(torch.nn.Module):
|
| 6 |
"""
|
| 7 |
N-dimensional pixel shuffle operation for upsampling tensors.
|
|
|
|
| 8 |
Args:
|
| 9 |
dims (int): Number of dimensions to apply pixel shuffle to.
|
| 10 |
- 1: Temporal (e.g., frames)
|
|
|
|
| 14 |
For dims=1, only the first value is used.
|
| 15 |
For dims=2, the first two values are used.
|
| 16 |
For dims=3, all three values are used.
|
|
|
|
| 17 |
The input tensor is rearranged so that the channel dimension is split into
|
| 18 |
smaller channels and upscaling factors, and the upscaling factors are moved
|
| 19 |
into the corresponding spatial/temporal dimensions.
|
|
|
|
| 20 |
Note:
|
| 21 |
This operation is equivalent to the patchifier operation in for the models. Consider
|
| 22 |
using this class instead.
|
packages/ltx-core/src/ltx_core/model/upsampler/res_block.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from typing import Optional
|
| 5 |
|
| 6 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Optional
|
| 2 |
|
| 3 |
import torch
|
packages/ltx-core/src/ltx_core/model/upsampler/spatial_rational_resampler.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# Copyright (c) 2025 Lightricks. All rights reserved.
|
| 2 |
-
# Created by Andrew Kvochko
|
| 3 |
-
|
| 4 |
from typing import Tuple
|
| 5 |
|
| 6 |
import torch
|
|
@@ -21,9 +18,7 @@ class SpatialRationalResampler(torch.nn.Module):
|
|
| 21 |
"""
|
| 22 |
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
| 23 |
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
| 24 |
-
|
| 25 |
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
| 26 |
-
|
| 27 |
Args:
|
| 28 |
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
| 29 |
scale (`float`): Spatial scaling factor. Supported values are:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Tuple
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 18 |
"""
|
| 19 |
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
| 20 |
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
|
|
|
| 21 |
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
|
|
|
| 22 |
Args:
|
| 23 |
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
| 24 |
scale (`float`): Spatial scaling factor. Supported values are:
|