Yoshitaka16 commited on
Commit
8f33e95
·
verified ·
1 Parent(s): 3f28f19

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +18 -0
train.py CHANGED
@@ -993,3 +993,21 @@ def train_and_evaluate(
993
  pid_data.pop("process_pids", None)
994
  json.dump(pid_data, pid_file, indent=4)
995
  os._exit(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
993
  pid_data.pop("process_pids", None)
994
  json.dump(pid_data, pid_file, indent=4)
995
  os._exit(0)
996
+
997
+
998
+ import argparse
999
+ parser = argparse.ArgumentParser()
1000
+ parser.add_argument("--model_type", type=str, default="rvc", choices=["rvc", "djcm"])
1001
+ parser.add_argument("--train_data", type=str, required=True)
1002
+ parser.add_argument("--output_path", type=str, required=True)
1003
+ args = parser.parse_args()
1004
+
1005
+ if args.model_type == "djcm":
1006
+ from models.djcm.djcm_model import DJCMModel
1007
+ model = DJCMModel()
1008
+ else:
1009
+ from models.rvc.rvc_model import RVCModel
1010
+ model = RVCModel()
1011
+
1012
+ print(f"Training {args.model_type} model with data at {args.train_data}")
1013
+