Spaces:
Paused
Paused
Commit
·
fa9dd69
1
Parent(s):
0e6fd1f
feat: early return when trained 10 epoch
Browse files
infer/modules/train/train.py
CHANGED
|
@@ -248,8 +248,8 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
|
|
| 248 |
scaler = GradScaler(enabled=hps.train.fp16_run)
|
| 249 |
|
| 250 |
cache = []
|
|
|
|
| 251 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
| 252 |
-
print("epoch", epoch)
|
| 253 |
if rank == 0:
|
| 254 |
train_and_evaluate(
|
| 255 |
rank,
|
|
@@ -283,6 +283,10 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
|
|
| 283 |
scheduler_g.step()
|
| 284 |
scheduler_d.step()
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
def train_and_evaluate(
|
| 288 |
rank,
|
|
|
|
| 248 |
scaler = GradScaler(enabled=hps.train.fp16_run)
|
| 249 |
|
| 250 |
cache = []
|
| 251 |
+
trained = 0
|
| 252 |
for epoch in range(epoch_str, hps.train.epochs + 1):
|
|
|
|
| 253 |
if rank == 0:
|
| 254 |
train_and_evaluate(
|
| 255 |
rank,
|
|
|
|
| 283 |
scheduler_g.step()
|
| 284 |
scheduler_d.step()
|
| 285 |
|
| 286 |
+
trained += 1
|
| 287 |
+
if trained >= 10:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
|
| 291 |
def train_and_evaluate(
|
| 292 |
rank,
|