flpelerin commited on
Commit
fe132c7
·
1 Parent(s): dee5067
Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -154,7 +154,7 @@ for epoch in range(num_epochs):
154
 
155
  # Reset hidden states if needed
156
  h_states = h_states if (step % reset_state_every!= 0) else None
157
- avg_states = torch.mean(h_states).item() if h_states is not None else None
158
 
159
  optimizer.zero_grad()
160
  _, h_states, loss = model.forward(input_ids, h_states)
 
154
 
155
  # Reset hidden states if needed
156
  h_states = h_states if (step % reset_state_every!= 0) else None
157
+ avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
158
 
159
  optimizer.zero_grad()
160
  _, h_states, loss = model.forward(input_ids, h_states)