STLDM SEVIR VAE
STLDM์ latent diffusion ์ปดํฌ๋ํธ ์ค VAE ๋ถ๋ถ๋ง ์ถ์ถํ ์ฒดํฌํฌ์ธํธ. JEPACast Phase 1์ฒ๋ผ VAE๋ฅผ frozen tokenizer๋ก ์ฌ์ฌ์ฉํ๋ setup์์ ๊ทธ๋๋ก ๋ก๋ํด์ ์ฌ์ฉ.
์ถ์ฒ
- ์๋ณธ STLDM end-to-end ์ฒดํฌํฌ์ธํธ:
20260406181528_1gpu_reimp/stldm_..._final.pt - ๊ทธ ์ค
backbone.vae.*ํค๋ง ์ถ์ถ (38 keys)
์ํคํ ์ฒ
- VAE encoder + decoder (4-stage ConvSC, hid_S=32, N_S=4)
- ์
๋ ฅ:
(B, 1, 128, 128)SEVIR VIL frame, [0, 1] - Latent: mean/log_var ๊ฐ๊ฐ
(B, 32, 32, 32)(๊ณต๊ฐ 1/4 ๋ค์ด์ํ)
๋ก๋ ์์ (์ด repo๊ฐ ๊ฐ์ ํ๋ STLDM ์ฝ๋๋ฒ ์ด์ค ๊ธฐ์ค)
from stldm.modules import VAE
import torch
from huggingface_hub import hf_hub_download
vae_path = hf_hub_download(
repo_id="KyleBae1017/stldm-sevir-vae",
filename="vae_only.pt",
)
vae = VAE(C_in=1, hid_S=32, N_S=4)
state = torch.load(vae_path, map_location="cpu")
missing, unexpected = vae.load_state_dict(state, strict=True)
assert not missing and not unexpected
vae.eval()
for p in vae.parameters():
p.requires_grad_(False)
train.py์์ ๋ฐ๋ก ์ฐ๊ธฐ
๋จผ์ ์ hf_hub_download๋ก ๋ฐ์์ ๋ก์ปฌ์ ๋๊ณ :
python train.py ... --ae_ckpt <local_path> --ae_eval --freeze_vae ...
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support