"""Model loading helpers for SATA inference.""" def load_model_by_type(model_type, model_epoch, device): """Load a VAE or RVQ checkpoint by model type.""" if model_type == "vae": from sata.test import prepare_model_test print(f"[Model] Loading VAE model from: {model_epoch}") elif model_type == "rvq": from sata.test_vq import prepare_model_test print(f"[Model] Loading RVQ model from: {model_epoch}") else: raise ValueError(f"Unknown model_type: {model_type}. Must be 'vae' or 'rvq'") model, cfg, ms_dict = prepare_model_test(model_epoch, device) return model, cfg, ms_dict