Update train.py
Browse files
train.py
CHANGED
|
@@ -89,7 +89,7 @@ for epoch in range(num_epochs):
|
|
| 89 |
# h_states = None
|
| 90 |
|
| 91 |
h_states = h_states if (i / batch_size) % reset_state_every != 0 else None
|
| 92 |
-
str_states = ''.join(['{:.3f}, '.format(h_states[0][0][
|
| 93 |
|
| 94 |
optimizer.zero_grad()
|
| 95 |
_, h_states, loss = model.forward(input_ids, h_states)
|
|
|
|
| 89 |
# h_states = None
|
| 90 |
|
| 91 |
h_states = h_states if (i / batch_size) % reset_state_every != 0 else None
|
| 92 |
+
str_states = ''.join(['{:.3f}, '.format(h_states[0][0][i].item()) for i in range(10)]) if h_states is not None else 'None'
|
| 93 |
|
| 94 |
optimizer.zero_grad()
|
| 95 |
_, h_states, loss = model.forward(input_ids, h_states)
|