ehartford commited on
Commit
bb41467
·
verified ·
1 Parent(s): 7d7a63f

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +48 -68
modeling.py CHANGED
@@ -1155,49 +1155,31 @@ class PrismaVLModel(PrismaVLPreTrainedModel):
1155
  if inputs_embeds is None:
1156
  inputs_embeds = self.get_input_embeddings()(input_ids)
1157
 
1158
- # === INJECT 16-BIT UNCERTAINTY SIGNAL ===
1159
- # Add learned uncertainty embedding from previous step
1160
- batch_size, seq_len = inputs_embeds.shape[:2]
1161
-
1162
- # Initialize uncertainty codes if needed
1163
- if self.prev_uncertainty_code is None or self.prev_uncertainty_code.shape[0] != batch_size:
1164
- # First step or batch size changed: use neutral uncertainty (middle of range)
1165
- # 32768 represents "medium uncertainty" (50% of max entropy)
1166
- uncertainty_code = torch.full(
1167
- (batch_size, seq_len),
1168
- self.n_uncertainty_levels // 2, # 32768 for 16-bit
1169
- dtype=torch.long,
1170
- device=inputs_embeds.device
1171
- )
1172
- else:
1173
- # Use uncertainty from previous step
1174
- # Pad or truncate to match current sequence length
1175
- prev_len = self.prev_uncertainty_code.shape[1]
1176
- if prev_len < seq_len:
1177
- # Pad with neutral uncertainty
1178
- padding = torch.full(
1179
- (batch_size, seq_len - prev_len),
1180
- self.n_uncertainty_levels // 2,
1181
- dtype=torch.long,
1182
- device=self.prev_uncertainty_code.device
1183
- )
1184
- uncertainty_code = torch.cat([self.prev_uncertainty_code, padding], dim=1)
1185
- else:
1186
- uncertainty_code = self.prev_uncertainty_code[:, :seq_len]
1187
-
1188
- # Look up uncertainty embeddings (256 learned vectors)
1189
- uncertainty_embeds = self.uncertainty_embeddings(uncertainty_code)
1190
-
1191
- # Shift right: position i gets uncertainty from position i-1
1192
- # First position gets zero (no previous uncertainty)
1193
- uncertainty_shifted = torch.nn.functional.pad(
1194
- uncertainty_embeds[:, :-1, :],
1195
- (0, 0, 1, 0), # Pad one position at the start
1196
- value=0.0
1197
- )
1198
-
1199
- # Inject into input: model sees both content and "how uncertain was I?"
1200
- inputs_embeds = inputs_embeds + uncertainty_shifted
1201
 
1202
  image_mask = None
1203
  video_mask = None
@@ -1465,31 +1447,29 @@ class PrismaVLForConditionalGeneration(PrismaVLPreTrainedModel, GenerationMixin)
1465
  if labels is not None:
1466
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
1467
 
1468
- # === COMPUTE UNCERTAINTY FOR NEXT STEP ===
1469
- # Update uncertainty codes based on current predictions
1470
- # Works during both training and inference for full introspective capability
1471
- if logits is not None:
1472
- with torch.no_grad():
1473
- logits_detached = logits.detach()
1474
-
1475
- # Compute probability distribution
1476
- probs = logits_detached.softmax(dim=-1) # [batch, seq, vocab]
1477
-
1478
- # Compute entropy: H = p log p (uncertainty measure)
1479
- log_probs = torch.log(probs.clamp(min=1e-9))
1480
- entropy = -(probs * log_probs).sum(dim=-1) # [batch, seq]
1481
-
1482
- # Normalize by maximum possible entropy (uniform distribution)
1483
- vocab_size = logits_detached.size(-1)
1484
- max_entropy = math.log(vocab_size)
1485
- entropy_norm = (entropy / max_entropy).clamp(0.0, 1.0)
1486
-
1487
- # Quantize to 16 bits (0-65535)
1488
- # Low entropy (confident) → low code (0-32767)
1489
- # High entropy (uncertain) high code (32768-65535)
1490
- self.model.prev_uncertainty_code = (
1491
- entropy_norm * (self.model.n_uncertainty_levels - 1)
1492
- ).long().clamp(0, self.model.n_uncertainty_levels - 1)
1493
 
1494
  return PrismaVLCausalLMOutputWithPast(
1495
  loss=loss,
 
1155
  if inputs_embeds is None:
1156
  inputs_embeds = self.get_input_embeddings()(input_ids)
1157
 
1158
+ # === PRISMA UNCERTAINTY INJECTION (TRAINING + INFERENCE, ROBUST) ===
1159
+ B, S = inputs_embeds.shape[:2]
1160
+
1161
+ if self.prev_uncertainty_code is not None:
1162
+ if self.training and S > 1:
1163
+ prev_len = self.prev_uncertainty_code.shape[1]
1164
+ curr_len = S - 1
1165
+
1166
+ if prev_len >= curr_len:
1167
+ codes_to_use = self.prev_uncertainty_code[:, :curr_len]
1168
+ else:
1169
+ codes_to_use = F.pad(
1170
+ self.prev_uncertainty_code,
1171
+ (0, curr_len - prev_len),
1172
+ value=self.n_uncertainty_levels // 2, # neutral
1173
+ )
1174
+
1175
+ u = self.uncertainty_embeddings(codes_to_use) # [B, S-1, D]
1176
+ inputs_embeds = inputs_embeds.clone()
1177
+ inputs_embeds[:, 1:, :] += u
1178
+
1179
+ elif not self.training and S == 1:
1180
+ u = self.uncertainty_embeddings(self.prev_uncertainty_code) # [B, 1, D]
1181
+ inputs_embeds = inputs_embeds + u
1182
+ # === END PRISMA INJECTION ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1183
 
1184
  image_mask = None
1185
  video_mask = None
 
1447
  if labels is not None:
1448
  loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
1449
 
1450
+ # === PRISMA UNCERTAINTY UPDATE (TRAINING + INFERENCE) ===
1451
+ with torch.no_grad():
1452
+ B, S, V = logits.shape
1453
+
1454
+ probs = logits.softmax(dim=-1)
1455
+
1456
+ # Compute entropy for all available logits
1457
+ entropy = -(probs * torch.log(probs.clamp_min(1e-9))).sum(dim=-1)
1458
+ entropy = entropy / math.log(V)
1459
+ codes = (
1460
+ entropy * (self.model.n_uncertainty_levels - 1)
1461
+ ).long().clamp(0, self.model.n_uncertainty_levels - 1)
1462
+
1463
+ if self.training:
1464
+ # Teacher forcing: use entropy[t] condition token t+1
1465
+ # So we store entropy[0..S-2]
1466
+ self.model.prev_uncertainty_code = codes[:, :-1] if S > 1 else None
1467
+
1468
+ else:
1469
+ # Inference: only one step at a time (S == 1)
1470
+ self.model.prev_uncertainty_code = codes[:, -1:].contiguous()
1471
+ # === END PRISMA UPDATE ===
1472
+
 
 
1473
 
1474
  return PrismaVLCausalLMOutputWithPast(
1475
  loss=loss,