flpelerin
commited on
Commit
·
b8f91b6
1
Parent(s):
fe132c7
test
Browse files
train.py
CHANGED
|
@@ -155,6 +155,7 @@ for epoch in range(num_epochs):
|
|
| 155 |
# Reset hidden states if needed
|
| 156 |
h_states = h_states if (step % reset_state_every!= 0) else None
|
| 157 |
avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
|
|
|
|
| 158 |
|
| 159 |
optimizer.zero_grad()
|
| 160 |
_, h_states, loss = model.forward(input_ids, h_states)
|
|
|
|
| 155 |
# Reset hidden states if needed
|
| 156 |
h_states = h_states if (step % reset_state_every!= 0) else None
|
| 157 |
avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
|
| 158 |
+
var_states = torch.var(torch.cat(hidden_states_list, dim=0)).item() if hidden_states_list else None
|
| 159 |
|
| 160 |
optimizer.zero_grad()
|
| 161 |
_, h_states, loss = model.forward(input_ids, h_states)
|