Spaces:
Runtime error
Runtime error
| from main import * | |
| def default_run(): | |
| # Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
| config_location = "configs/main.yaml" | |
| config = workspace.load_config(config_location, None) | |
| if os.getenv("LOCAL_RANK", '0') == '0': | |
| config = workspace.create_workspace(config) | |
| # Run the experiment | |
| run_experiment(config) | |
| def with_mast3r_loss(): | |
| # Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
| config_location = "configs/with_mast3r_loss.yaml" | |
| config = workspace.load_config(config_location, None) | |
| if os.getenv("LOCAL_RANK", '0') == '0': | |
| config = workspace.create_workspace(config) | |
| # Run the experiment | |
| run_experiment(config) | |
| def without_masking(): | |
| # Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
| config_location = "configs/without_masking.yaml" | |
| config = workspace.load_config(config_location, None) | |
| if os.getenv("LOCAL_RANK", '0') == '0': | |
| config = workspace.create_workspace(config) | |
| # Run the experiment | |
| run_experiment(config) | |
| def without_lpips_loss(): | |
| # Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
| config_location = "configs/without_lpips_loss.yaml" | |
| config = workspace.load_config(config_location, None) | |
| if os.getenv("LOCAL_RANK", '0') == '0': | |
| config = workspace.create_workspace(config) | |
| # Run the experiment | |
| run_experiment(config) | |
| def without_offset(): | |
| # Setup the workspace (eg. load the config, create a directory for results at config.save_dir, etc.) | |
| config_location = "configs/without_offset.yaml" | |
| config = workspace.load_config(config_location, None) | |
| if os.getenv("LOCAL_RANK", '0') == '0': | |
| config = workspace.create_workspace(config) | |
| # Run the experiment | |
| run_experiment(config) | |
| if __name__ == "__main__": | |
| # Somewhat hacky way to fetch the function corresponding to the ablation we want to run | |
| ablation_name = sys.argv[1] | |
| ablation_function = locals().get(ablation_name) | |
| # Run the ablation if it exists | |
| if ablation_function: | |
| ablation_function() | |
| else: | |
| raise NotImplementedError( | |
| f"Ablation name '{sys.argv[1]}' not recognised") | |