| from main import * |
|
|
|
|
| def default_run(): |
|
|
| |
| config_location = "configs/main.yaml" |
| config = workspace.load_config(config_location, None) |
| if os.getenv("LOCAL_RANK", '0') == '0': |
| config = workspace.create_workspace(config) |
|
|
| |
| run_experiment(config) |
|
|
|
|
| def with_mast3r_loss(): |
|
|
| |
| config_location = "configs/with_mast3r_loss.yaml" |
| config = workspace.load_config(config_location, None) |
| if os.getenv("LOCAL_RANK", '0') == '0': |
| config = workspace.create_workspace(config) |
|
|
| |
| run_experiment(config) |
|
|
|
|
| def without_masking(): |
|
|
| |
| config_location = "configs/without_masking.yaml" |
| config = workspace.load_config(config_location, None) |
| if os.getenv("LOCAL_RANK", '0') == '0': |
| config = workspace.create_workspace(config) |
|
|
| |
| run_experiment(config) |
|
|
|
|
| def without_lpips_loss(): |
|
|
| |
| config_location = "configs/without_lpips_loss.yaml" |
| config = workspace.load_config(config_location, None) |
| if os.getenv("LOCAL_RANK", '0') == '0': |
| config = workspace.create_workspace(config) |
|
|
| |
| run_experiment(config) |
|
|
|
|
| def without_offset(): |
|
|
| |
| config_location = "configs/without_offset.yaml" |
| config = workspace.load_config(config_location, None) |
| if os.getenv("LOCAL_RANK", '0') == '0': |
| config = workspace.create_workspace(config) |
|
|
| |
| run_experiment(config) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| |
| ablation_name = sys.argv[1] |
| ablation_function = locals().get(ablation_name) |
|
|
| |
| if ablation_function: |
| ablation_function() |
| else: |
| raise NotImplementedError( |
| f"Ablation name '{sys.argv[1]}' not recognised") |
|
|