flpelerin
commited on
Commit
·
a6de2cc
1
Parent(s):
cc0b609
fix
Browse files
train.py
CHANGED
|
@@ -155,7 +155,7 @@ for epoch in range(num_epochs):
|
|
| 155 |
# Reset hidden states if needed
|
| 156 |
h_states = h_states if (step % reset_state_every!= 0) else None
|
| 157 |
avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
|
| 158 |
-
var_states = torch.var(torch.cat(
|
| 159 |
|
| 160 |
optimizer.zero_grad()
|
| 161 |
_, h_states, loss = model.forward(input_ids, h_states)
|
|
|
|
| 155 |
# Reset hidden states if needed
|
| 156 |
h_states = h_states if (step % reset_state_every!= 0) else None
|
| 157 |
avg_states = sum([torch.mean(h_states[i]).item() for i in range(len(h_states))]) / len(h_states) if h_states is not None else None
|
| 158 |
+
var_states = torch.var(torch.cat(h_states, dim=0)).item() if h_states else None
|
| 159 |
|
| 160 |
optimizer.zero_grad()
|
| 161 |
_, h_states, loss = model.forward(input_ids, h_states)
|