| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import datetime |
|
|
| from google.cloud import storage |
|
|
| from transformers import AutoTokenizer |
| from datasets import load_dataset, load_metric, ReadInstruction |
| from trainer import metadata |
|
|
|
|
| def preprocess_function(examples): |
| tokenizer = AutoTokenizer.from_pretrained( |
| metadata.PRETRAINED_MODEL_NAME, |
| use_fast=True, |
| ) |
|
|
| |
| tokenizer_args = ( |
| (examples['text'],) |
| ) |
| result = tokenizer(*tokenizer_args, |
| padding='max_length', |
| max_length=metadata.MAX_SEQ_LENGTH, |
| truncation=True) |
|
|
| |
| |
| |
| label_to_id = metadata.TARGET_LABELS |
| |
| |
| if label_to_id is not None and "label" in examples: |
| result["label"] = [label_to_id[l] for l in examples["label"]] |
|
|
| return result |
|
|
|
|
| def load_data(args): |
| """Loads the data into two different data loaders. (Train, Test) |
| |
| Args: |
| args: arguments passed to the python script |
| """ |
| |
| |
| dataset = load_dataset(metadata.DATASET_NAME) |
|
|
| dataset = dataset.map(preprocess_function, |
| batched=True, |
| load_from_cache_file=True) |
|
|
| train_dataset, test_dataset = dataset["train"], dataset["test"] |
|
|
| return train_dataset, test_dataset |
|
|
|
|
| def save_model(args): |
| """Saves the model to Google Cloud Storage or local file system |
| |
| Args: |
| args: contains name for saved model. |
| """ |
| scheme = 'gs://' |
| if args.job_dir.startswith(scheme): |
| job_dir = args.job_dir.split("/") |
| bucket_name = job_dir[2] |
| object_prefix = "/".join(job_dir[3:]).rstrip("/") |
|
|
| if object_prefix: |
| model_path = '{}/{}'.format(object_prefix, args.model_name) |
| else: |
| model_path = '{}'.format(args.model_name) |
|
|
| bucket = storage.Client().bucket(bucket_name) |
| local_path = os.path.join("/tmp", args.model_name) |
| files = [f for f in os.listdir(local_path) if os.path.isfile(os.path.join(local_path, f))] |
| for file in files: |
| local_file = os.path.join(local_path, file) |
| blob = bucket.blob("/".join([model_path, file])) |
| blob.upload_from_filename(local_file) |
| print(f"Saved model files in gs://{bucket_name}/{model_path}") |
| else: |
| print(f"Saved model files at {os.path.join('/tmp', args.model_name)}") |
| print(f"To save model files in GCS bucket, please specify job_dir starting with gs://") |
| |
|
|