Spaces:
Sleeping
Sleeping
| import kaggle | |
| from pathlib import Path, PurePosixPath | |
| import json | |
| try: | |
| from __kaggle_login import kaggle_users | |
| except ImportError: | |
| raise ImportError("Please create a __kaggle_login.py file with a kaggle_users" + | |
| "dict containing your Kaggle credentials.") | |
| import argparse | |
| import sys | |
| import subprocess | |
| from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME | |
| from train import get_parser as get_train_parser | |
| from typing import Optional | |
| from configuration import KAGGLE_DATASET_LIST, NB_ID, GIT_USER, GIT_REPO, TRAIN_SCRIPT | |
| def get_git_branch_name(): | |
| try: | |
| branch_name = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode() | |
| return branch_name | |
| except subprocess.CalledProcessError: | |
| return "Error: Could not determine the Git branch name." | |
| def prepare_notebook( | |
| output_nb_path: Path, | |
| exp: int, | |
| branch: str, | |
| git_user: str = None, | |
| git_repo: str = None, | |
| template_nb_path: Path = Path(__file__).parent/"remote_training_template.ipynb", | |
| wandb_flag: bool = False, | |
| output_dir: Path = "scripts/"+OUTPUT_FOLDER_NAME, | |
| dataset_files: Optional[list] = None, | |
| train_script: str = TRAIN_SCRIPT | |
| ): | |
| assert git_user is not None, "Please provide a git username for the repo" | |
| assert git_repo is not None, "Please provide a git repo name for the repo" | |
| expressions = [ | |
| ("exp", f"{exp}"), | |
| ("branch", f"\'{branch}\'"), | |
| ("git_user", f"\'{git_user}\'"), | |
| ("git_repo", f"\'{git_repo}\'"), | |
| ("wandb_flag", "True" if wandb_flag else "False"), | |
| ("output_dir", "None" if output_dir is None else f"\'{output_dir}\'"), | |
| ("dataset_files", "None" if dataset_files is None else f"{dataset_files}"), | |
| ("train_script", "\'"+train_script+"\'") | |
| ] | |
| with open(template_nb_path) as f: | |
| template_nb = f.readlines() | |
| for line_idx, li in enumerate(template_nb): | |
| for expr, expr_replace in expressions: | |
| if f"!!!{expr}!!!" in li: | |
| template_nb[line_idx] = template_nb[line_idx].replace(f"!!!{expr}!!!", expr_replace) | |
| template_nb = "".join(template_nb) | |
| with open(output_nb_path, "w") as w: | |
| w.write(template_nb) | |
| def main(argv): | |
| parser = argparse.ArgumentParser(description="Train a model on Kaggle using a script") | |
| parser.add_argument("-n", "--nb_id", type=str, help="Notebook name in kaggle", default=NB_ID) | |
| parser.add_argument("-u", "--user", type=str, help="Kaggle user", choices=list(kaggle_users.keys())) | |
| parser.add_argument("--branch", type=str, help="Git branch name", default=get_git_branch_name()) | |
| parser.add_argument("-p", "--push", action="store_true", help="Push") | |
| parser.add_argument("-d", "--download", action="store_true", help="Download results") | |
| get_train_parser(parser) | |
| args = parser.parse_args(argv) | |
| nb_id = args.nb_id | |
| exp_str = "_".join(f"{exp:04d}" for exp in args.exp) | |
| kaggle_user = kaggle_users[args.user] | |
| uname_kaggle = kaggle_user["username"] | |
| kaggle.api._load_config(kaggle_user) | |
| if args.download: | |
| tmp_dir = ROOT_DIR/f"__tmp_{exp_str}" | |
| tmp_dir.mkdir(exist_ok=True, parents=True) | |
| kaggle.api.kernels_output_cli(f"{kaggle_user['username']}/{nb_id}", path=str(tmp_dir)) | |
| subprocess.run(["tar", "-xzf", tmp_dir/"output.tgz"]) | |
| # @FIXME: windows probably does not have tar command | |
| import shutil | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| return | |
| kernel_root = ROOT_DIR/f"__nb_{uname_kaggle}" | |
| kernel_root.mkdir(exist_ok=True, parents=True) | |
| kernel_path = kernel_root/exp_str | |
| kernel_path.mkdir(exist_ok=True, parents=True) | |
| branch = args.branch | |
| config = { | |
| "id": str(PurePosixPath(f"{kaggle_user['username']}")/nb_id), | |
| "title": nb_id.lower(), | |
| "code_file": f"{nb_id}.ipynb", | |
| "language": "python", | |
| "kernel_type": "notebook", | |
| "is_private": "true", | |
| "enable_gpu": "true" if not args.cpu else "false", | |
| "enable_tpu": "false", | |
| "enable_internet": "true", | |
| "dataset_sources": KAGGLE_DATASET_LIST, | |
| "competition_sources": [], | |
| "kernel_sources": [], | |
| "model_sources": [] | |
| } | |
| prepare_notebook((kernel_path/nb_id).with_suffix(".ipynb"), args.exp, branch, | |
| git_user=GIT_USER, git_repo=GIT_REPO, wandb_flag=not args.no_wandb) | |
| assert (kernel_path/nb_id).with_suffix(".ipynb").exists() | |
| with open(kernel_path/"kernel-metadata.json", "w") as f: | |
| json.dump(config, f, indent=4) | |
| if args.push: | |
| kaggle.api.kernels_push_cli(str(kernel_path)) | |
| if __name__ == '__main__': | |
| main(sys.argv[1:]) | |