Update model_dtat.py
Browse files- model_dtat.py +1 -1
model_dtat.py
CHANGED
|
@@ -317,7 +317,7 @@ class DTATTransformer(nn.Module):
|
|
| 317 |
logits = logits.view(B*T, C)
|
| 318 |
targets = targets.view(B*T)
|
| 319 |
# Calculate loss directly in BPC instead of nats
|
| 320 |
-
loss = F.cross_entropy(logits, targets)
|
| 321 |
|
| 322 |
return logits, loss, importance_scores
|
| 323 |
|
|
|
|
| 317 |
logits = logits.view(B*T, C)
|
| 318 |
targets = targets.view(B*T)
|
| 319 |
# Calculate loss directly in BPC instead of nats
|
| 320 |
+
loss = F.cross_entropy(logits, targets) / math.log(2)
|
| 321 |
|
| 322 |
return logits, loss, importance_scores
|
| 323 |
|