|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import contextlib |
|
|
|
|
|
class _SimpleVAEManager: |
|
|
def __init__(self, pipeline=None, device=None, autocast_dtype=torch.float32): |
|
|
""" |
|
|
pipeline: objeto do LTX que expõe decode_latents(...) ou .vae.decode(...) |
|
|
device: "cuda" ou "cpu" onde a decodificação deve ocorrer |
|
|
autocast_dtype: dtype de autocast quando em CUDA (bf16/fp16/fp32) |
|
|
""" |
|
|
self.pipeline = pipeline |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.autocast_dtype = autocast_dtype |
|
|
|
|
|
def attach_pipeline(self, pipeline, device=None, autocast_dtype=None): |
|
|
self.pipeline = pipeline |
|
|
if device is not None: |
|
|
self.device = device |
|
|
if autocast_dtype is not None: |
|
|
self.autocast_dtype = autocast_dtype |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode(self, latents_5d: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Decodifica todo o bloco 5D de uma vez, replicando o fluxo simples do deformes4D. |
|
|
Retorna tensor de pixels 5D em [0,1] com shape (B,C,T,H',W'). |
|
|
""" |
|
|
if self.pipeline is None: |
|
|
raise RuntimeError("VAE Manager sem pipeline. Chame attach_pipeline primeiro.") |
|
|
|
|
|
|
|
|
latents_5d = latents_5d.to(self.device, non_blocking=True) |
|
|
|
|
|
ctx = torch.autocast(device_type="cuda", dtype=self.autocast_dtype) if self.device == "cuda" else contextlib.nullcontext() |
|
|
with ctx: |
|
|
if hasattr(self.pipeline, "decode_latents"): |
|
|
pixels_5d = self.pipeline.decode_latents(latents_5d) |
|
|
elif hasattr(self.pipeline, "vae") and hasattr(self.pipeline.vae, "decode"): |
|
|
pixels_5d = self.pipeline.vae.decode(latents_5d) |
|
|
else: |
|
|
raise RuntimeError("Pipeline não expõe decode_latents nem vae.decode.") |
|
|
|
|
|
|
|
|
if pixels_5d.min() < 0: |
|
|
pixels_5d = (pixels_5d.clamp(-1, 1) + 1.0) / 2.0 |
|
|
else: |
|
|
pixels_5d = pixels_5d.clamp(0, 1) |
|
|
|
|
|
return pixels_5d |
|
|
|
|
|
|
|
|
vae_manager_singleton = _SimpleVAEManager() |
|
|
|