| | import argparse |
| | import datetime |
| | import json |
| | from pathlib import Path |
| |
|
| | from helpers import save_useful_info |
| | from train import train |
| |
|
| |
|
| | def train_zeggs(): |
| | |
| | parser = argparse.ArgumentParser(description="Train ZEGGS Network.") |
| |
|
| | |
| | parser.add_argument( |
| | "-o", |
| | "--options", |
| | type=str, |
| | help="Options filename", |
| | ) |
| | parser.add_argument('-n', '--name', type=str, help="Name", required=False) |
| |
|
| | args = parser.parse_args() |
| |
|
| | with open(args.options, "r") as f: |
| | options = json.load(f) |
| | if args.name: |
| | options["name"] = args.name |
| |
|
| | train_options = options["train_opt"] |
| | network_options = options["net_opt"] |
| | paths = options["paths"] |
| |
|
| | base_path = Path(paths["base_path"]) |
| | path_processed_data = base_path / paths["path_processed_data"] / "processed_data.npz" |
| | path_data_definition = base_path / paths["path_processed_data"] / "data_definition.json" |
| |
|
| | |
| | if paths["output_dir"] is None: |
| | output_dir = (base_path / "outputs") / datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") |
| | output_dir.mkdir(exist_ok=True, parents=True) |
| | paths["output_dir"] = str(output_dir) |
| | else: |
| | output_dir = Path(paths["output_dir"]) |
| |
|
| | |
| | if paths["models_dir"] is None and not train_options["resume"]: |
| | models_dir = output_dir / "saved_models" |
| | models_dir.mkdir(exist_ok=True) |
| | paths["models_dir"] = str(models_dir) |
| | else: |
| | models_dir = Path(paths["models_dir"]) |
| |
|
| | |
| | logs_dir = output_dir / "logs" |
| | logs_dir.mkdir(exist_ok=True) |
| |
|
| | options["paths"] = paths |
| | with open(output_dir / 'options.json', 'w') as fp: |
| | json.dump(options, fp, indent=4) |
| |
|
| | save_useful_info(output_dir) |
| |
|
| | train( |
| | models_dir=models_dir, |
| | logs_dir=logs_dir, |
| | path_processed_data=path_processed_data, |
| | path_data_definition=path_data_definition, |
| | train_options=train_options, |
| | network_options=network_options, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train_zeggs() |
| |
|
| | |
| |
|