STLDM SEVIR VAE

STLDM์˜ latent diffusion ์ปดํฌ๋„ŒํŠธ ์ค‘ VAE ๋ถ€๋ถ„๋งŒ ์ถ”์ถœํ•œ ์ฒดํฌํฌ์ธํŠธ. JEPACast Phase 1์ฒ˜๋Ÿผ VAE๋ฅผ frozen tokenizer๋กœ ์žฌ์‚ฌ์šฉํ•˜๋Š” setup์—์„œ ๊ทธ๋Œ€๋กœ ๋กœ๋“œํ•ด์„œ ์‚ฌ์šฉ.

์ถœ์ฒ˜

  • ์›๋ณธ STLDM end-to-end ์ฒดํฌํฌ์ธํŠธ: 20260406181528_1gpu_reimp/stldm_..._final.pt
  • ๊ทธ ์ค‘ backbone.vae.* ํ‚ค๋งŒ ์ถ”์ถœ (38 keys)

์•„ํ‚คํ…์ฒ˜

  • VAE encoder + decoder (4-stage ConvSC, hid_S=32, N_S=4)
  • ์ž…๋ ฅ: (B, 1, 128, 128) SEVIR VIL frame, [0, 1]
  • Latent: mean/log_var ๊ฐ๊ฐ (B, 32, 32, 32) (๊ณต๊ฐ„ 1/4 ๋‹ค์šด์ƒ˜ํ”Œ)

๋กœ๋“œ ์˜ˆ์‹œ (์ด repo๊ฐ€ ๊ฐ€์ •ํ•˜๋Š” STLDM ์ฝ”๋“œ๋ฒ ์ด์Šค ๊ธฐ์ค€)

from stldm.modules import VAE
import torch
from huggingface_hub import hf_hub_download

vae_path = hf_hub_download(
    repo_id="KyleBae1017/stldm-sevir-vae",
    filename="vae_only.pt",
)
vae = VAE(C_in=1, hid_S=32, N_S=4)
state = torch.load(vae_path, map_location="cpu")
missing, unexpected = vae.load_state_dict(state, strict=True)
assert not missing and not unexpected
vae.eval()
for p in vae.parameters():
    p.requires_grad_(False)

train.py์—์„œ ๋ฐ”๋กœ ์“ฐ๊ธฐ

๋จผ์ € ์œ„ hf_hub_download๋กœ ๋ฐ›์•„์„œ ๋กœ์ปฌ์— ๋‘๊ณ :

python train.py ... --ae_ckpt <local_path> --ae_eval --freeze_vae ...
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support