| """Script to run sagemaker training jobs for whisper finetuning jobs.""" |
|
|
| import logging |
| import os |
| from pprint import pprint |
|
|
| import boto3 |
| import sagemaker |
| from sagemaker.huggingface import HuggingFace |
|
|
|
|
| TEST = True |
|
|
|
|
| test_sm_instances = { |
| "ml.g4dn.xlarge": |
| { |
| "num_instances": 1, |
| "num_gpus": 1 |
| } |
| } |
|
|
| full_sm_instances = { |
| "ml.g4dn.xlarge": |
| { |
| "num_instances": 1, |
| "num_gpus": 1 |
| } |
| } |
|
|
| sm_instances = test_sm_instances if TEST else full_sm_instances |
|
|
| ENTRY_POINT = "run_sm.py" |
| RUN_SCRIPT = "test_run.sh" if TEST else "run.sh" |
| IMAGE_URI = "116817510867.dkr.ecr.eu-west-1.amazonaws.com/huggingface-pytorch-training:whisper-finetuning-0223e276db78adf4ea4dc5f874793cb2" |
| if IMAGE_URI is None: |
| raise ValueError("IMAGE_URI variable not set, please update script.") |
|
|
| iam = boto3.client("iam") |
| os.environ["AWS_DEFAULT_REGION"] = "eu-west-1" |
| role = iam.get_role(RoleName="whisper-sagemaker-role")["Role"]["Arn"] |
| _ = sagemaker.Session() |
| sm_client = boto3.client("sagemaker") |
|
|
|
|
| def set_creds(): |
| with open("creds.txt") as f: |
| creds = f.readlines() |
| for line in creds: |
| key, value = line.split("=") |
| os.environ[key] = value.replace("\n", "") |
|
|
|
|
| def parse_run_script(): |
| """Parse the run script to get the hyperparameters.""" |
| hyperparameters = {} |
| with open(RUN_SCRIPT, "r") as f: |
| for line in f.readlines(): |
| if line.startswith("python"): |
| continue |
| line = line \ |
| .replace("\\", "") \ |
| .replace("\t", "") \ |
| .replace("--", "") \ |
| .replace(" \n", "") \ |
| .replace("\n", "") \ |
| .replace('"', "") |
| line = line.split("=") |
| key = str(line[0]) |
| try: |
| value = line[1] |
| except IndexError: |
| value = "True" |
| hyperparameters[key] = value |
| hyperparameters["model_index_name"] = f'"{hyperparameters["model_index_name"]}"' |
| return hyperparameters |
|
|
|
|
| set_creds() |
| |
| |
|
|
| hf_token = os.environ.get("HF_TOKEN") |
| if hf_token is None: |
| raise ValueError("HF_TOKEN environment variable not set") |
|
|
| env_vars = { |
| "HF_TOKEN": hf_token, |
| "EMAIL_ADDRESS": os.environ.get("EMAIL_ADDRESS"), |
| "EMAIL_PASSWORD": os.environ.get("EMAIL_PASSWORD"), |
| "WANDB_TOKEN": os.environ.get("WANDB_TOKEN") |
| } |
| pprint(env_vars) |
| repo = f"https://huggingface.co/marinone94/{os.getcwd().split('/')[-1]}" |
| hyperparameters = { |
| "repo": repo, |
| "entrypoint": RUN_SCRIPT |
| } |
| for sm_instance_name, sm_instance_values in sm_instances.items(): |
| num_instances: int = \ |
| int(sm_instance_values["num_instances"]) |
| num_gpus: int = \ |
| int(sm_instance_values["num_gpus"]) |
| try: |
| |
| hf_estimator = HuggingFace( |
| entry_point=ENTRY_POINT, |
| instance_type=sm_instance_name, |
| instance_count=num_instances, |
| role=role, |
| py_version="py38", |
| image_uri=IMAGE_URI, |
| hyperparameters=hyperparameters, |
| environment=env_vars, |
| git_config={"repo": repo, "branch": "main"}, |
| ) |
| hf_estimator.fit() |
| break |
| except sm_client.exceptions.ResourceLimitExceeded as e_0: |
| logging.warning(f"Instance error {e_0}\nRetrying with new instance") |
|
|