ehartford commited on
Commit
1a5225e
·
verified ·
1 Parent(s): bb41467

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +28 -15
modeling.py CHANGED
@@ -1450,27 +1450,40 @@ class PrismaVLForConditionalGeneration(PrismaVLPreTrainedModel, GenerationMixin)
1450
  # === PRISMA UNCERTAINTY UPDATE (TRAINING + INFERENCE) ===
1451
  with torch.no_grad():
1452
  B, S, V = logits.shape
 
1453
 
1454
- probs = logits.softmax(dim=-1)
 
 
1455
 
1456
- # Compute entropy for all available logits
1457
- entropy = -(probs * torch.log(probs.clamp_min(1e-9))).sum(dim=-1)
1458
- entropy = entropy / math.log(V)
1459
- codes = (
1460
- entropy * (self.model.n_uncertainty_levels - 1)
1461
- ).long().clamp(0, self.model.n_uncertainty_levels - 1)
1462
-
1463
- if self.training:
1464
- # Teacher forcing: use entropy[t] → condition token t+1
1465
- # So we store entropy[0..S-2]
1466
- self.model.prev_uncertainty_code = codes[:, :-1] if S > 1 else None
1467
 
1468
  else:
1469
- # Inference: only one step at a time (S == 1)
1470
- self.model.prev_uncertainty_code = codes[:, -1:].contiguous()
 
 
 
 
 
 
 
 
 
 
 
1471
  # === END PRISMA UPDATE ===
1472
 
1473
-
1474
  return PrismaVLCausalLMOutputWithPast(
1475
  loss=loss,
1476
  logits=logits,
 
1450
  # === PRISMA UNCERTAINTY UPDATE (TRAINING + INFERENCE) ===
1451
  with torch.no_grad():
1452
  B, S, V = logits.shape
1453
+ expected_V = self.config.text_config.vocab_size
1454
 
1455
+ if V != expected_V:
1456
+ # Logits are sharded (tensor parallel vocab split)
1457
+ # Softmax/entropy would be incorrect → do NOT compute
1458
 
1459
+ if self.training:
1460
+ # Training: skip uncertainty for next batch
1461
+ self.model.prev_uncertainty_code = None
1462
+ else:
1463
+ # Inference: neutral uncertainty
1464
+ self.model.prev_uncertainty_code = torch.full(
1465
+ (B, 1),
1466
+ self.model.n_uncertainty_levels // 2,
1467
+ dtype=torch.long,
1468
+ device=logits.device,
1469
+ )
1470
 
1471
  else:
1472
+ # Full-vocab logits safe entropy computation
1473
+ probs = logits.softmax(dim=-1)
1474
+ entropy = -(probs * torch.log(probs.clamp_min(1e-9))).sum(dim=-1)
1475
+ entropy = entropy / math.log(expected_V)
1476
+
1477
+ codes = (
1478
+ entropy * (self.model.n_uncertainty_levels - 1)
1479
+ ).long().clamp(0, self.model.n_uncertainty_levels - 1)
1480
+
1481
+ if self.training:
1482
+ self.model.prev_uncertainty_code = codes[:, :-1] if S > 1 else None
1483
+ else:
1484
+ self.model.prev_uncertainty_code = codes[:, -1:].contiguous()
1485
  # === END PRISMA UPDATE ===
1486
 
 
1487
  return PrismaVLCausalLMOutputWithPast(
1488
  loss=loss,
1489
  logits=logits,