| """Model loading helpers for SATA inference.""" | |
| def load_model_by_type(model_type, model_epoch, device): | |
| """Load a VAE or RVQ checkpoint by model type.""" | |
| if model_type == "vae": | |
| from sata.test import prepare_model_test | |
| print(f"[Model] Loading VAE model from: {model_epoch}") | |
| elif model_type == "rvq": | |
| from sata.test_vq import prepare_model_test | |
| print(f"[Model] Loading RVQ model from: {model_epoch}") | |
| else: | |
| raise ValueError(f"Unknown model_type: {model_type}. Must be 'vae' or 'rvq'") | |
| model, cfg, ms_dict = prepare_model_test(model_epoch, device) | |
| return model, cfg, ms_dict | |