| from utils.interpolation_models.EDEN import EDEN |
| from utils.interpolation_models.EDEN_VAE import VAE |
| from utils.interpolation_models.EDEN_DiT import DiT |
| from utils.interpolation_models.Discriminator import NLayerDiscriminator |
|
|
|
|
| def load_model(model_name, **model_args): |
| if model_name == "EDEN": |
| return EDEN(**model_args) |
| elif model_name == "EDEN_VAE": |
| return VAE(**model_args) |
| elif model_name == "EDEN_DiT": |
| return DiT(**model_args) |
| elif model_name == "Discriminator": |
| return NLayerDiscriminator(**model_args) |
| else: |
| raise f"No model named {model_name} in models!" |
|
|