vae / packages /ltx-core /src /ltx_core /utils.py
jiuhai's picture
Add files using upload-large-folder tool
4cd1d55 verified
from pathlib import Path
from typing import Any
import torch
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
"""Root-mean-square (RMS) normalize `x` over its last dimension.
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
shape and forwards `weight` and `eps`.
"""
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
def check_config_value(config: dict, key: str, expected: Any) -> None: # noqa: ANN401
actual = config.get(key)
if actual != expected:
raise ValueError(f"Config value {key} is {actual}, expected {expected}")
def to_velocity(
sample: torch.Tensor,
sigma: float | torch.Tensor,
denoised_sample: torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoised version to velocity.
Returns:
Velocity
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype).item()
if sigma == 0:
raise ValueError("Sigma can't be 0.0")
return ((sample.to(calc_dtype) - denoised_sample.to(calc_dtype)) / sigma).to(sample.dtype)
def to_denoised(
sample: torch.Tensor,
velocity: torch.Tensor,
sigma: float | torch.Tensor,
calc_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Convert the sample and its denoising velocity to denoised sample.
Returns:
Denoised sample
"""
if isinstance(sigma, torch.Tensor):
sigma = sigma.to(calc_dtype)
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
def find_matching_file(root_path: str, pattern: str) -> Path:
"""
Recursively search for files matching a glob pattern and return the first match.
"""
matches = list(Path(root_path).rglob(pattern))
if not matches:
raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}")
return matches[0]