Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
a80fff9
1
Parent(s):
9010215
add training code
Browse files- speech/train.py +16 -12
speech/train.py
CHANGED
|
@@ -31,11 +31,10 @@ from cosyvoice.utils.executor import Executor
|
|
| 31 |
from cosyvoice.utils.losses import DPOLoss
|
| 32 |
from cosyvoice.utils.train_utils import (check_modify_and_save_config,
|
| 33 |
init_dataset_and_dataloader,
|
| 34 |
-
init_distributed,
|
| 35 |
init_optimizer_and_scheduler,
|
| 36 |
-
|
| 37 |
-
|
| 38 |
|
|
|
|
| 39 |
def get_args():
|
| 40 |
parser = argparse.ArgumentParser(description="training your network")
|
| 41 |
parser.add_argument(
|
|
@@ -102,6 +101,20 @@ def get_args():
|
|
| 102 |
type=int,
|
| 103 |
help="timeout (in seconds) of cosyvoice_join.",
|
| 104 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
parser = deepspeed.add_config_arguments(parser)
|
| 106 |
args = parser.parse_args()
|
| 107 |
return args
|
|
@@ -115,17 +128,8 @@ def init_comet_experiment(args, configs):
|
|
| 115 |
if rank == 0 and not args.comet_disabled:
|
| 116 |
# Set up Comet ML experiment
|
| 117 |
experiment = Experiment(
|
| 118 |
-
api_key=args.comet_api_key,
|
| 119 |
project_name=args.comet_project,
|
| 120 |
-
workspace=args.comet_workspace,
|
| 121 |
experiment_name=args.comet_experiment_name,
|
| 122 |
-
disabled=args.comet_disabled,
|
| 123 |
-
offline=args.comet_offline,
|
| 124 |
-
auto_metric_logging=True,
|
| 125 |
-
auto_param_logging=True,
|
| 126 |
-
auto_histogram_weight_logging=True,
|
| 127 |
-
auto_histogram_gradient_logging=True,
|
| 128 |
-
auto_histogram_activation_logging=False,
|
| 129 |
)
|
| 130 |
|
| 131 |
# Log hyperparameters
|
|
|
|
| 31 |
from cosyvoice.utils.losses import DPOLoss
|
| 32 |
from cosyvoice.utils.train_utils import (check_modify_and_save_config,
|
| 33 |
init_dataset_and_dataloader,
|
|
|
|
| 34 |
init_optimizer_and_scheduler,
|
| 35 |
+
save_model)
|
|
|
|
| 36 |
|
| 37 |
+
os.environ["COMET_LOGGING_CONSOLE"] = "ERROR" # Only show errors
|
| 38 |
def get_args():
|
| 39 |
parser = argparse.ArgumentParser(description="training your network")
|
| 40 |
parser.add_argument(
|
|
|
|
| 101 |
type=int,
|
| 102 |
help="timeout (in seconds) of cosyvoice_join.",
|
| 103 |
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--comet_disabled",
|
| 106 |
+
action="store_true",
|
| 107 |
+
default=False,
|
| 108 |
+
help="Disable comet ml experiment",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--comet_project",
|
| 112 |
+
default="speech"
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--comet_experiment_name",
|
| 116 |
+
default="test"
|
| 117 |
+
)
|
| 118 |
parser = deepspeed.add_config_arguments(parser)
|
| 119 |
args = parser.parse_args()
|
| 120 |
return args
|
|
|
|
| 128 |
if rank == 0 and not args.comet_disabled:
|
| 129 |
# Set up Comet ML experiment
|
| 130 |
experiment = Experiment(
|
|
|
|
| 131 |
project_name=args.comet_project,
|
|
|
|
| 132 |
experiment_name=args.comet_experiment_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
# Log hyperparameters
|