Update modeling.py
Browse files- modeling.py +28 -15
modeling.py
CHANGED
|
@@ -1450,27 +1450,40 @@ class PrismaVLForConditionalGeneration(PrismaVLPreTrainedModel, GenerationMixin)
|
|
| 1450 |
# === PRISMA UNCERTAINTY UPDATE (TRAINING + INFERENCE) ===
|
| 1451 |
with torch.no_grad():
|
| 1452 |
B, S, V = logits.shape
|
|
|
|
| 1453 |
|
| 1454 |
-
|
|
|
|
|
|
|
| 1455 |
|
| 1456 |
-
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
-
|
| 1464 |
-
|
| 1465 |
-
|
| 1466 |
-
|
| 1467 |
|
| 1468 |
else:
|
| 1469 |
-
#
|
| 1470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1471 |
# === END PRISMA UPDATE ===
|
| 1472 |
|
| 1473 |
-
|
| 1474 |
return PrismaVLCausalLMOutputWithPast(
|
| 1475 |
loss=loss,
|
| 1476 |
logits=logits,
|
|
|
|
| 1450 |
# === PRISMA UNCERTAINTY UPDATE (TRAINING + INFERENCE) ===
|
| 1451 |
with torch.no_grad():
|
| 1452 |
B, S, V = logits.shape
|
| 1453 |
+
expected_V = self.config.text_config.vocab_size
|
| 1454 |
|
| 1455 |
+
if V != expected_V:
|
| 1456 |
+
# Logits are sharded (tensor parallel vocab split)
|
| 1457 |
+
# Softmax/entropy would be incorrect → do NOT compute
|
| 1458 |
|
| 1459 |
+
if self.training:
|
| 1460 |
+
# Training: skip uncertainty for next batch
|
| 1461 |
+
self.model.prev_uncertainty_code = None
|
| 1462 |
+
else:
|
| 1463 |
+
# Inference: neutral uncertainty
|
| 1464 |
+
self.model.prev_uncertainty_code = torch.full(
|
| 1465 |
+
(B, 1),
|
| 1466 |
+
self.model.n_uncertainty_levels // 2,
|
| 1467 |
+
dtype=torch.long,
|
| 1468 |
+
device=logits.device,
|
| 1469 |
+
)
|
| 1470 |
|
| 1471 |
else:
|
| 1472 |
+
# Full-vocab logits → safe entropy computation
|
| 1473 |
+
probs = logits.softmax(dim=-1)
|
| 1474 |
+
entropy = -(probs * torch.log(probs.clamp_min(1e-9))).sum(dim=-1)
|
| 1475 |
+
entropy = entropy / math.log(expected_V)
|
| 1476 |
+
|
| 1477 |
+
codes = (
|
| 1478 |
+
entropy * (self.model.n_uncertainty_levels - 1)
|
| 1479 |
+
).long().clamp(0, self.model.n_uncertainty_levels - 1)
|
| 1480 |
+
|
| 1481 |
+
if self.training:
|
| 1482 |
+
self.model.prev_uncertainty_code = codes[:, :-1] if S > 1 else None
|
| 1483 |
+
else:
|
| 1484 |
+
self.model.prev_uncertainty_code = codes[:, -1:].contiguous()
|
| 1485 |
# === END PRISMA UPDATE ===
|
| 1486 |
|
|
|
|
| 1487 |
return PrismaVLCausalLMOutputWithPast(
|
| 1488 |
loss=loss,
|
| 1489 |
logits=logits,
|