Upload cnets.py with huggingface_hub
Browse files
cnets.py
CHANGED
|
@@ -869,8 +869,9 @@ class Model(nn.Module):
|
|
| 869 |
loss = -torch.sum(position_mask * plogp, 2).mean()
|
| 870 |
plosses.append(loss)
|
| 871 |
with torch.no_grad():
|
|
|
|
| 872 |
acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (
|
| 873 |
-
|
| 874 |
|
| 875 |
if not last:
|
| 876 |
input_ids = padding(input_ids, left=False)
|
|
|
|
| 869 |
loss = -torch.sum(position_mask * plogp, 2).mean()
|
| 870 |
plosses.append(loss)
|
| 871 |
with torch.no_grad():
|
| 872 |
+
# Fixed: use position_mask.sum() instead of loss_mask.sum() for correct accuracy
|
| 873 |
acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (
|
| 874 |
+
position_mask.sum().item() + 1e-6))
|
| 875 |
|
| 876 |
if not last:
|
| 877 |
input_ids = padding(input_ids, left=False)
|