| | from main import * |
| |
|
| |
|
| | def default_run(): |
| |
|
| | |
| | config_location = "configs/main.yaml" |
| | config = workspace.load_config(config_location, None) |
| | if os.getenv("LOCAL_RANK", '0') == '0': |
| | config = workspace.create_workspace(config) |
| |
|
| | |
| | run_experiment(config) |
| |
|
| |
|
| | def with_mast3r_loss(): |
| |
|
| | |
| | config_location = "configs/with_mast3r_loss.yaml" |
| | config = workspace.load_config(config_location, None) |
| | if os.getenv("LOCAL_RANK", '0') == '0': |
| | config = workspace.create_workspace(config) |
| |
|
| | |
| | run_experiment(config) |
| |
|
| |
|
| | def without_masking(): |
| |
|
| | |
| | config_location = "configs/without_masking.yaml" |
| | config = workspace.load_config(config_location, None) |
| | if os.getenv("LOCAL_RANK", '0') == '0': |
| | config = workspace.create_workspace(config) |
| |
|
| | |
| | run_experiment(config) |
| |
|
| |
|
| | def without_lpips_loss(): |
| |
|
| | |
| | config_location = "configs/without_lpips_loss.yaml" |
| | config = workspace.load_config(config_location, None) |
| | if os.getenv("LOCAL_RANK", '0') == '0': |
| | config = workspace.create_workspace(config) |
| |
|
| | |
| | run_experiment(config) |
| |
|
| |
|
| | def without_offset(): |
| |
|
| | |
| | config_location = "configs/without_offset.yaml" |
| | config = workspace.load_config(config_location, None) |
| | if os.getenv("LOCAL_RANK", '0') == '0': |
| | config = workspace.create_workspace(config) |
| |
|
| | |
| | run_experiment(config) |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | |
| | ablation_name = sys.argv[1] |
| | ablation_function = locals().get(ablation_name) |
| |
|
| | |
| | if ablation_function: |
| | ablation_function() |
| | else: |
| | raise NotImplementedError( |
| | f"Ablation name '{sys.argv[1]}' not recognised") |
| |
|