Respair commited on
Commit
b2819f3
·
verified ·
1 Parent(s): 093e566

Update ddp_train.py

Browse files
Files changed (1) hide show
  1. ddp_train.py +2 -2
ddp_train.py CHANGED
@@ -190,14 +190,14 @@ def main(config_path):
190
  for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"):
191
  batch = [b.to(device) for b in batch]
192
 
193
- text_input, text_input_length, mel_input, mel_input_length, attn_prior = batch
194
 
195
  # Forward pass
196
  attn_soft, attn_logprob = aligner(spec=mel_input,
197
  spec_len=mel_input_length,
198
  text=text_input,
199
  text_len=text_input_length,
200
- attn_prior=attn_prior)
201
 
202
  # Calculate loss
203
  val_loss = forward_sum_loss(attn_logprob=attn_logprob,
 
190
  for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]"):
191
  batch = [b.to(device) for b in batch]
192
 
193
+ text_input, text_input_length, mel_input, mel_input_length = batch
194
 
195
  # Forward pass
196
  attn_soft, attn_logprob = aligner(spec=mel_input,
197
  spec_len=mel_input_length,
198
  text=text_input,
199
  text_len=text_input_length,
200
+ attn_prior=None)
201
 
202
  # Calculate loss
203
  val_loss = forward_sum_loss(attn_logprob=attn_logprob,