Update modeling.py
Browse files- modeling.py +48 -68
modeling.py
CHANGED
|
@@ -1155,49 +1155,31 @@ class PrismaVLModel(PrismaVLPreTrainedModel):
|
|
| 1155 |
if inputs_embeds is None:
|
| 1156 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1157 |
|
| 1158 |
-
# ===
|
| 1159 |
-
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
)
|
| 1184 |
-
uncertainty_code = torch.cat([self.prev_uncertainty_code, padding], dim=1)
|
| 1185 |
-
else:
|
| 1186 |
-
uncertainty_code = self.prev_uncertainty_code[:, :seq_len]
|
| 1187 |
-
|
| 1188 |
-
# Look up uncertainty embeddings (256 learned vectors)
|
| 1189 |
-
uncertainty_embeds = self.uncertainty_embeddings(uncertainty_code)
|
| 1190 |
-
|
| 1191 |
-
# Shift right: position i gets uncertainty from position i-1
|
| 1192 |
-
# First position gets zero (no previous uncertainty)
|
| 1193 |
-
uncertainty_shifted = torch.nn.functional.pad(
|
| 1194 |
-
uncertainty_embeds[:, :-1, :],
|
| 1195 |
-
(0, 0, 1, 0), # Pad one position at the start
|
| 1196 |
-
value=0.0
|
| 1197 |
-
)
|
| 1198 |
-
|
| 1199 |
-
# Inject into input: model sees both content and "how uncertain was I?"
|
| 1200 |
-
inputs_embeds = inputs_embeds + uncertainty_shifted
|
| 1201 |
|
| 1202 |
image_mask = None
|
| 1203 |
video_mask = None
|
|
@@ -1465,31 +1447,29 @@ class PrismaVLForConditionalGeneration(PrismaVLPreTrainedModel, GenerationMixin)
|
|
| 1465 |
if labels is not None:
|
| 1466 |
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
| 1467 |
|
| 1468 |
-
# ===
|
| 1469 |
-
|
| 1470 |
-
|
| 1471 |
-
|
| 1472 |
-
|
| 1473 |
-
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
|
| 1477 |
-
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
-
|
| 1482 |
-
#
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
-
#
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
|
| 1491 |
-
entropy_norm * (self.model.n_uncertainty_levels - 1)
|
| 1492 |
-
).long().clamp(0, self.model.n_uncertainty_levels - 1)
|
| 1493 |
|
| 1494 |
return PrismaVLCausalLMOutputWithPast(
|
| 1495 |
loss=loss,
|
|
|
|
| 1155 |
if inputs_embeds is None:
|
| 1156 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
| 1157 |
|
| 1158 |
+
# === PRISMA UNCERTAINTY INJECTION (TRAINING + INFERENCE, ROBUST) ===
|
| 1159 |
+
B, S = inputs_embeds.shape[:2]
|
| 1160 |
+
|
| 1161 |
+
if self.prev_uncertainty_code is not None:
|
| 1162 |
+
if self.training and S > 1:
|
| 1163 |
+
prev_len = self.prev_uncertainty_code.shape[1]
|
| 1164 |
+
curr_len = S - 1
|
| 1165 |
+
|
| 1166 |
+
if prev_len >= curr_len:
|
| 1167 |
+
codes_to_use = self.prev_uncertainty_code[:, :curr_len]
|
| 1168 |
+
else:
|
| 1169 |
+
codes_to_use = F.pad(
|
| 1170 |
+
self.prev_uncertainty_code,
|
| 1171 |
+
(0, curr_len - prev_len),
|
| 1172 |
+
value=self.n_uncertainty_levels // 2, # neutral
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
u = self.uncertainty_embeddings(codes_to_use) # [B, S-1, D]
|
| 1176 |
+
inputs_embeds = inputs_embeds.clone()
|
| 1177 |
+
inputs_embeds[:, 1:, :] += u
|
| 1178 |
+
|
| 1179 |
+
elif not self.training and S == 1:
|
| 1180 |
+
u = self.uncertainty_embeddings(self.prev_uncertainty_code) # [B, 1, D]
|
| 1181 |
+
inputs_embeds = inputs_embeds + u
|
| 1182 |
+
# === END PRISMA INJECTION ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1183 |
|
| 1184 |
image_mask = None
|
| 1185 |
video_mask = None
|
|
|
|
| 1447 |
if labels is not None:
|
| 1448 |
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
|
| 1449 |
|
| 1450 |
+
# === PRISMA UNCERTAINTY UPDATE (TRAINING + INFERENCE) ===
|
| 1451 |
+
with torch.no_grad():
|
| 1452 |
+
B, S, V = logits.shape
|
| 1453 |
+
|
| 1454 |
+
probs = logits.softmax(dim=-1)
|
| 1455 |
+
|
| 1456 |
+
# Compute entropy for all available logits
|
| 1457 |
+
entropy = -(probs * torch.log(probs.clamp_min(1e-9))).sum(dim=-1)
|
| 1458 |
+
entropy = entropy / math.log(V)
|
| 1459 |
+
codes = (
|
| 1460 |
+
entropy * (self.model.n_uncertainty_levels - 1)
|
| 1461 |
+
).long().clamp(0, self.model.n_uncertainty_levels - 1)
|
| 1462 |
+
|
| 1463 |
+
if self.training:
|
| 1464 |
+
# Teacher forcing: use entropy[t] → condition token t+1
|
| 1465 |
+
# So we store entropy[0..S-2]
|
| 1466 |
+
self.model.prev_uncertainty_code = codes[:, :-1] if S > 1 else None
|
| 1467 |
+
|
| 1468 |
+
else:
|
| 1469 |
+
# Inference: only one step at a time (S == 1)
|
| 1470 |
+
self.model.prev_uncertainty_code = codes[:, -1:].contiguous()
|
| 1471 |
+
# === END PRISMA UPDATE ===
|
| 1472 |
+
|
|
|
|
|
|
|
| 1473 |
|
| 1474 |
return PrismaVLCausalLMOutputWithPast(
|
| 1475 |
loss=loss,
|