| | import os |
| | import argparse |
| | import warnings |
| | from train import Trainer |
| | import sklearn.exceptions |
| | warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning) |
| |
|
| | parser = argparse.ArgumentParser() |
| |
|
| | |
| | parser.add_argument('--save_dir', default='experiments_logs', type=str, |
| | help='Directory containing all experiments') |
| | parser.add_argument('--experiment_description', default='Exp1', type=str, help='experiment name') |
| | parser.add_argument('--run_description', default='run1', type=str, help='run name') |
| |
|
| | |
| | parser.add_argument('--dataset', default='mit', type=str, help='mit, ptb') |
| | parser.add_argument('--seed_id', default='0', type=str, |
| | help='to fix a seed while training') |
| |
|
| | |
| | parser.add_argument('--data_path', default=r'/Users/splendor1811/datn/ECGTransForm/datasets', |
| | type=str, help='Path containing dataset') |
| |
|
| | parser.add_argument('--num_runs', default=1, type=int, |
| | help='Number of consecutive run with different seeds') |
| | parser.add_argument('--device', default='mps', type=str, |
| | help='cpu or cuda') |
| |
|
| |
|
| | args = parser.parse_args() |
| |
|
| | if __name__ == "__main__": |
| | trainer = Trainer(args) |
| | trainer.train() |
| |
|