Update train.py
Browse files
train.py
CHANGED
|
@@ -993,3 +993,21 @@ def train_and_evaluate(
|
|
| 993 |
pid_data.pop("process_pids", None)
|
| 994 |
json.dump(pid_data, pid_file, indent=4)
|
| 995 |
os._exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 993 |
pid_data.pop("process_pids", None)
|
| 994 |
json.dump(pid_data, pid_file, indent=4)
|
| 995 |
os._exit(0)
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
import argparse
|
| 999 |
+
parser = argparse.ArgumentParser()
|
| 1000 |
+
parser.add_argument("--model_type", type=str, default="rvc", choices=["rvc", "djcm"])
|
| 1001 |
+
parser.add_argument("--train_data", type=str, required=True)
|
| 1002 |
+
parser.add_argument("--output_path", type=str, required=True)
|
| 1003 |
+
args = parser.parse_args()
|
| 1004 |
+
|
| 1005 |
+
if args.model_type == "djcm":
|
| 1006 |
+
from models.djcm.djcm_model import DJCMModel
|
| 1007 |
+
model = DJCMModel()
|
| 1008 |
+
else:
|
| 1009 |
+
from models.rvc.rvc_model import RVCModel
|
| 1010 |
+
model = RVCModel()
|
| 1011 |
+
|
| 1012 |
+
print(f"Training {args.model_type} model with data at {args.train_data}")
|
| 1013 |
+
|