SATA / src /sata /utils /model_loading.py
zzysteve
Initial commit
5221c8c
Raw
History Blame Contribute Delete
642 Bytes
"""Model loading helpers for SATA inference."""
def load_model_by_type(model_type, model_epoch, device):
"""Load a VAE or RVQ checkpoint by model type."""
if model_type == "vae":
from sata.test import prepare_model_test
print(f"[Model] Loading VAE model from: {model_epoch}")
elif model_type == "rvq":
from sata.test_vq import prepare_model_test
print(f"[Model] Loading RVQ model from: {model_epoch}")
else:
raise ValueError(f"Unknown model_type: {model_type}. Must be 'vae' or 'rvq'")
model, cfg, ms_dict = prepare_model_test(model_epoch, device)
return model, cfg, ms_dict