|
|
import streamlit as st |
|
|
from utils.load_dataset import load_datasets |
|
|
from utils.load_tasks import load_tasks |
|
|
from utils.load_models import load_models |
|
|
from trainer import train_estimtator |
|
|
from datetime import datetime |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parameter = st.experimental_get_query_params() |
|
|
parameter["model_name_or_path"] = parameter.get("model_name_or_path", ["none"]) |
|
|
parameter["dataset"] = parameter.get("dataset", ["none"]) |
|
|
parameter["task"] = parameter.get("task", ["none"]) |
|
|
|
|
|
parameter["epochs"] = parameter.get("epochs", [3]) |
|
|
parameter["learning_rate"] = parameter.get("learning_rate", [5e-5]) |
|
|
parameter["per_device_train_batch_size"] = parameter.get("per_device_train_batch_size", [8]) |
|
|
parameter["per_device_eval_batch_size"] = parameter.get("per_device_eval_batch_size", [8]) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
|
|
|
dataset_list = load_datasets() |
|
|
task_list = load_tasks() |
|
|
model_list = load_models() |
|
|
|
|
|
st.header("Hugging Face model & dataset") |
|
|
col1, col2 = st.beta_columns(2) |
|
|
parameter["model_name_or_path"] = col1.selectbox("Model ID:", parameter["model_name_or_path"] + model_list) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
|
|
|
parameter["dataset"] = col2.selectbox("Dataset:", parameter["dataset"] + dataset_list) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
|
|
|
parameter["task"] = col1.selectbox("Task:", parameter["task"] + task_list) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
|
|
|
use_auth_token = col2.text_input("HF auth token to upload your model:", help="api_xxxxx") |
|
|
|
|
|
my_expander = st.beta_expander("Hyperparameters") |
|
|
col1, col2 = my_expander.beta_columns(2) |
|
|
parameter["epochs"] = col1.number_input("Epoch", 3) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
|
|
|
parameter["learning_rate"] = col2.text_input("Learning Rate", 5e-5) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
|
|
|
parameter["per_device_train_batch_size"] = col1.number_input("Training Batch Size", 8) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
|
|
|
parameter["per_device_eval_batch_size"] = col2.number_input("Eval Batch Size", 8) |
|
|
st.experimental_set_query_params(**parameter) |
|
|
st.markdown("---") |
|
|
|
|
|
st.header("Amazon Sagemaker configuration") |
|
|
|
|
|
config = {} |
|
|
|
|
|
config["job_name"] = st.text_input( |
|
|
"model name", |
|
|
f"{parameter['model_name_or_path'][0] if isinstance(parameter['model_name_or_path'],list)else parameter['model_name_or_path']}-job-{str(datetime.today()).split()[0]}", |
|
|
) |
|
|
col1, col2 = st.beta_columns(2) |
|
|
|
|
|
config["aws_sagemaker_role"] = col1.text_input("AWS IAM role for sagemaker job") |
|
|
config["instance_type"] = col2.selectbox( |
|
|
"Instance type", |
|
|
[ |
|
|
"single-gpu | ml.p3.2xlarge", |
|
|
"multi-gpu | ml.p3.16xlarge", |
|
|
], |
|
|
) |
|
|
config["region"] = col1.selectbox( |
|
|
"AWS Region", |
|
|
["eu-central-1", "eu-west-1", "us-east-1", "us-east-1", "us-west-1", "us-west-2"], |
|
|
) |
|
|
config["instance_count"] = col2.number_input("Instance count", 1) |
|
|
config["use_spot"] = col1.selectbox("use spot instances", [False, True]) |
|
|
config["distributed"] = col2.selectbox("distributed training", [False, True]) |
|
|
st.markdown("---") |
|
|
|
|
|
st.header("Credentials") |
|
|
|
|
|
col1, col2 = st.beta_columns(2) |
|
|
config["aws_access_key_id"] = col1.text_input("Aws Secret Key ID") |
|
|
config["aws_secret_accesskey"] = col2.text_input("Aws Secret Access Key") |
|
|
|
|
|
if use_auth_token: |
|
|
parameter["use_auth_token"] = use_auth_token |
|
|
|
|
|
if st.button("Start training on SageMaker"): |
|
|
train_estimtator(parameter, config) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|