flpelerin commited on
Commit
a6de2cc
·
1 Parent(s): cc0b609
Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -155,7 +155,7 @@ for epoch in range(num_epochs):
155
  # Reset hidden states if needed
156
  h_states = h_states if (step % reset_state_every!= 0) else None
157
  avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
158
- var_states = torch.var(torch.cat(hidden_states_list, dim=0)).item() if hidden_states_list else None
159
 
160
  optimizer.zero_grad()
161
  _, h_states, loss = model.forward(input_ids, h_states)
 
155
  # Reset hidden states if needed
156
  h_states = h_states if (step % reset_state_every!= 0) else None
157
  avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
158
+ var_states = torch.var(torch.cat(h_states, dim=0)).item() if h_states else None
159
 
160
  optimizer.zero_grad()
161
  _, h_states, loss = model.forward(input_ids, h_states)