flpelerin
commited on
Commit
·
dee5067
1
Parent(s):
0fe6c32
fix
Browse files
train.py
CHANGED
|
@@ -154,7 +154,7 @@ for epoch in range(num_epochs):
|
|
| 154 |
|
| 155 |
# Reset hidden states if needed
|
| 156 |
h_states = h_states if (step % reset_state_every!= 0) else None
|
| 157 |
-
avg_states =
|
| 158 |
|
| 159 |
optimizer.zero_grad()
|
| 160 |
_, h_states, loss = model.forward(input_ids, h_states)
|
|
|
|
| 154 |
|
| 155 |
# Reset hidden states if needed
|
| 156 |
h_states = h_states if (step % reset_state_every!= 0) else None
|
| 157 |
+
avg_states = torch.mean(h_states).item() if h_states is not None else None
|
| 158 |
|
| 159 |
optimizer.zero_grad()
|
| 160 |
_, h_states, loss = model.forward(input_ids, h_states)
|