| import argparse | |
| import os | |
| import shutil | |
| from huggingface_hub import snapshot_download | |
| available_models = [ | |
| "marigold_appearance/finetuned", | |
| "marigold_appearance/pretrained", | |
| "marigold_lighting/finetuned", | |
| "marigold_lighting/pretrained", | |
| "rgbx/finetuned", | |
| "rgbx/pretrained" | |
| ] | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default="rgbx/finetuned", | |
| choices=available_models, | |
| help="Select model to download (default: rgbx/finetuned)" | |
| ) | |
| parser.add_argument( | |
| "--local_dir", | |
| type=str, | |
| default="checkpoint", | |
| help="Directory to save the model" | |
| ) | |
| args = parser.parse_args() | |
| LOCAL_DIR = args.local_dir | |
| selected_model = args.model | |
| if os.path.exists(LOCAL_DIR): | |
| if os.path.abspath(LOCAL_DIR) in ["/", os.path.expanduser("~")]: | |
| raise ValueError("Refusing to delete critical directory.") | |
| print(f"Removing existing directory: {LOCAL_DIR}") | |
| shutil.rmtree(LOCAL_DIR) | |
| print(f"Downloading model: {selected_model}") | |
| snapshot_download( | |
| repo_id="GDAOSU/olbedo", | |
| allow_patterns=f"{selected_model}/*", | |
| local_dir=LOCAL_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| src = os.path.join(LOCAL_DIR, *selected_model.split("/")) | |
| for name in os.listdir(src): | |
| shutil.move( | |
| os.path.join(src, name), | |
| os.path.join(LOCAL_DIR, name) | |
| ) | |
| top_level_folder = selected_model.split("/")[0] | |
| shutil.rmtree(os.path.join(LOCAL_DIR, top_level_folder), ignore_errors=True) | |
| shutil.rmtree(os.path.join(LOCAL_DIR, ".cache"), ignore_errors=True) | |
| if __name__ == "__main__": | |
| main() | |