| from typing import Any |
|
|
| import torch |
|
|
|
|
| def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6, in_place = False) -> torch.Tensor: |
| |
| if in_place: |
| scale = torch.linalg.vector_norm(x, ord=2, dim=-1, keepdim=True, dtype=torch.float32) |
| scale.square_().div_(x.shape[-1]).add_(eps).rsqrt_() |
| x.mul_(scale) |
| if weight is not None: |
| x.mul_(weight) |
| return x |
| return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps) |
|
|
|
|
| def check_config_value(config: dict, key: str, expected: Any) -> None: |
| actual = config.get(key) |
| if actual != expected: |
| raise ValueError(f"Config value {key} is {actual}, expected {expected}") |
|
|
|
|
| def to_velocity( |
| sample: torch.Tensor, |
| sigma: float | torch.Tensor, |
| denoised_sample: torch.Tensor, |
| calc_dtype: torch.dtype = torch.float32, |
| ) -> torch.Tensor: |
| """ |
| Convert the sample and its denoised version to velocity. |
| Returns: |
| Velocity |
| """ |
| if isinstance(sigma, torch.Tensor): |
| sigma = sigma.to(calc_dtype).item() |
| if sigma == 0: |
| raise ValueError("Sigma can't be 0.0") |
| return ((sample.to(calc_dtype) - denoised_sample.to(calc_dtype)) / sigma).to(sample.dtype) |
|
|
|
|
| def to_denoised( |
| sample: torch.Tensor, |
| velocity: torch.Tensor, |
| sigma: float | torch.Tensor, |
| calc_dtype: torch.dtype = torch.float32, |
| ) -> torch.Tensor: |
| """ |
| Convert the sample and its denoising velocity to denoised sample. |
| Returns: |
| Denoised sample |
| """ |
| if isinstance(sigma, torch.Tensor): |
| sigma = sigma.to(calc_dtype) |
| return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype) |
|
|