|
|
from sagemaker.huggingface import HuggingFace |
|
|
import logging |
|
|
import sys |
|
|
from contextlib import contextmanager |
|
|
from io import StringIO |
|
|
from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME |
|
|
from threading import current_thread |
|
|
import streamlit as st |
|
|
import sys |
|
|
import sagemaker |
|
|
import boto3 |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def st_redirect(src, dst): |
|
|
placeholder = st.empty() |
|
|
output_func = getattr(placeholder, dst) |
|
|
|
|
|
with StringIO() as buffer: |
|
|
old_write = src.write |
|
|
|
|
|
def new_write(b): |
|
|
if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None): |
|
|
buffer.write(b) |
|
|
output_func(buffer.getvalue()) |
|
|
else: |
|
|
old_write(b) |
|
|
|
|
|
try: |
|
|
src.write = new_write |
|
|
yield |
|
|
finally: |
|
|
src.write = old_write |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def st_stdout(dst): |
|
|
with st_redirect(sys.stdout, dst): |
|
|
yield |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def st_stderr(dst): |
|
|
with st_redirect(sys.stderr, dst): |
|
|
yield |
|
|
|
|
|
|
|
|
task2script = { |
|
|
"text-classification": { |
|
|
"entry_point": "run_glue.py", |
|
|
"source_dir": "examples/text-classification", |
|
|
}, |
|
|
"token-classification": { |
|
|
"entry_point": "run_ner.py", |
|
|
"source_dir": "examples/token-classification", |
|
|
}, |
|
|
"question-answering": { |
|
|
"entry_point": "run_qa.py", |
|
|
"source_dir": "examples/question-answering", |
|
|
}, |
|
|
"summarization": { |
|
|
"entry_point": "run_summarization.py", |
|
|
"source_dir": "examples/seq2seq", |
|
|
}, |
|
|
"translation": { |
|
|
"entry_point": "run_translation.py", |
|
|
"source_dir": "examples/seq2seq", |
|
|
}, |
|
|
"causal-language-modeling": { |
|
|
"entry_point": "run_clm.py", |
|
|
"source_dir": "examples/language-modeling", |
|
|
}, |
|
|
"masked-language-modeling": { |
|
|
"entry_point": "run_mlm.py", |
|
|
"source_dir": "examples/language-modeling", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def train_estimtator(parameter, config): |
|
|
with st_stdout("code"): |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.getLevelName("INFO"), |
|
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
) |
|
|
logger.info = print |
|
|
|
|
|
|
|
|
git_config = {"repo": "https://github.com/huggingface/transformers.git", "branch": "v4.4.2"} |
|
|
|
|
|
|
|
|
entry_point = task2script[parameter["task"]]["entry_point"] |
|
|
source_dir = task2script[parameter["task"]]["source_dir"] |
|
|
|
|
|
|
|
|
session = boto3.session.Session( |
|
|
aws_access_key_id=config["aws_access_key_id"], |
|
|
aws_secret_access_key=config["aws_secret_accesskey"], |
|
|
region_name=config["region"], |
|
|
) |
|
|
sess = sagemaker.Session(boto_session=session) |
|
|
|
|
|
iam = session.client( |
|
|
"iam", aws_access_key_id=config["aws_access_key_id"], aws_secret_access_key=config["aws_secret_accesskey"] |
|
|
) |
|
|
role = iam.get_role(RoleName=config["aws_sagemaker_role"])["Role"]["Arn"] |
|
|
|
|
|
logger.info(f"role: {role}") |
|
|
instance_type = config["instance_type"].split("|")[1].split("|")[0].strip() |
|
|
logger.info(f"instance_type: {instance_type}") |
|
|
|
|
|
hyperparameters = { |
|
|
"output_dir": "/opt/ml/model", |
|
|
"do_train": True, |
|
|
"do_eval": True, |
|
|
"do_predict": True, |
|
|
**parameter, |
|
|
} |
|
|
del hyperparameters["task"] |
|
|
|
|
|
huggingface_estimator = HuggingFace( |
|
|
entry_point=entry_point, |
|
|
source_dir=source_dir, |
|
|
git_config=git_config, |
|
|
base_job_name=config["job_name"], |
|
|
instance_type=instance_type, |
|
|
sagemaker_session=sess, |
|
|
instance_count=config["instance_count"], |
|
|
role=role, |
|
|
transformers_version="4.4", |
|
|
pytorch_version="1.6", |
|
|
py_version="py36", |
|
|
hyperparameters=hyperparameters, |
|
|
) |
|
|
|
|
|
huggingface_estimator.fit() |
|
|
|