Text Generation
Transformers
Safetensors
HERMES
English
llama
cognitive-control
decode-time-intervention
repetition-suppression
behavioral-control
contrastive-learning
interpretability
activation-engineering
cf-hot
arc
rlhf-analysis
research
conversational
Eval Results (legacy)
text-generation-inference
Fix device mismatch and tensor.item() bugs
Browse files- inference.py +9 -5
inference.py
CHANGED
|
@@ -129,13 +129,15 @@ class MultiHeadPredictor(nn.Module):
|
|
| 129 |
Returns:
|
| 130 |
Aggregated features [batch, seq, d_fiber]
|
| 131 |
"""
|
|
|
|
| 132 |
fibers = []
|
| 133 |
for i, (proj, hidden) in enumerate(zip(self.fiber_projs, hidden_states)):
|
| 134 |
if i < len(hidden_states):
|
|
|
|
| 135 |
fibers.append(proj(hidden.float()))
|
| 136 |
|
| 137 |
# Weighted sum across layers
|
| 138 |
-
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 139 |
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 140 |
return aggregated
|
| 141 |
|
|
@@ -153,9 +155,11 @@ class MultiHeadPredictor(nn.Module):
|
|
| 153 |
if not self.loaded_heads:
|
| 154 |
return {}
|
| 155 |
|
|
|
|
| 156 |
features = self.get_fiber_features(hidden_states)
|
| 157 |
risks = {}
|
| 158 |
for name in self.loaded_heads:
|
|
|
|
| 159 |
logits = self.heads[name](features).squeeze(-1)
|
| 160 |
risks[name] = torch.sigmoid(logits)
|
| 161 |
return risks
|
|
@@ -370,28 +374,28 @@ class ARCSystem:
|
|
| 370 |
interventions = {}
|
| 371 |
|
| 372 |
# Repetition: suppress recently used tokens
|
| 373 |
-
if risks.get('repetition',
|
| 374 |
for tok in set(recent_tokens[-self.config.repetition_window:]):
|
| 375 |
logits[0, tok] -= self.config.repetition_penalty
|
| 376 |
interventions['repetition'] = True
|
| 377 |
self.total_interventions['repetition'] += 1
|
| 378 |
|
| 379 |
# Hedging: suppress hedge phrase starters
|
| 380 |
-
if risks.get('hedging',
|
| 381 |
for tok in self._hedge_token_ids:
|
| 382 |
logits[0, tok] -= self.config.hedging_penalty
|
| 383 |
interventions['hedging'] = True
|
| 384 |
self.total_interventions['hedging'] += 1
|
| 385 |
|
| 386 |
# Verbosity: suppress filler phrase starters
|
| 387 |
-
if risks.get('verbosity',
|
| 388 |
for tok in self._verbose_token_ids:
|
| 389 |
logits[0, tok] -= self.config.verbosity_penalty
|
| 390 |
interventions['verbosity'] = True
|
| 391 |
self.total_interventions['verbosity'] += 1
|
| 392 |
|
| 393 |
# Sycophancy: suppress sycophantic starters
|
| 394 |
-
if risks.get('sycophancy',
|
| 395 |
for tok in self._sycophancy_token_ids:
|
| 396 |
logits[0, tok] -= self.config.sycophancy_penalty
|
| 397 |
interventions['sycophancy'] = True
|
|
|
|
| 129 |
Returns:
|
| 130 |
Aggregated features [batch, seq, d_fiber]
|
| 131 |
"""
|
| 132 |
+
device = hidden_states[0].device
|
| 133 |
fibers = []
|
| 134 |
for i, (proj, hidden) in enumerate(zip(self.fiber_projs, hidden_states)):
|
| 135 |
if i < len(hidden_states):
|
| 136 |
+
proj = proj.to(device)
|
| 137 |
fibers.append(proj(hidden.float()))
|
| 138 |
|
| 139 |
# Weighted sum across layers
|
| 140 |
+
weights = F.softmax(self.layer_weights.to(device)[:len(fibers)], dim=0)
|
| 141 |
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 142 |
return aggregated
|
| 143 |
|
|
|
|
| 155 |
if not self.loaded_heads:
|
| 156 |
return {}
|
| 157 |
|
| 158 |
+
device = hidden_states[0].device
|
| 159 |
features = self.get_fiber_features(hidden_states)
|
| 160 |
risks = {}
|
| 161 |
for name in self.loaded_heads:
|
| 162 |
+
self.heads[name] = self.heads[name].to(device)
|
| 163 |
logits = self.heads[name](features).squeeze(-1)
|
| 164 |
risks[name] = torch.sigmoid(logits)
|
| 165 |
return risks
|
|
|
|
| 374 |
interventions = {}
|
| 375 |
|
| 376 |
# Repetition: suppress recently used tokens
|
| 377 |
+
if risks.get('repetition', 0) > self.config.repetition_threshold:
|
| 378 |
for tok in set(recent_tokens[-self.config.repetition_window:]):
|
| 379 |
logits[0, tok] -= self.config.repetition_penalty
|
| 380 |
interventions['repetition'] = True
|
| 381 |
self.total_interventions['repetition'] += 1
|
| 382 |
|
| 383 |
# Hedging: suppress hedge phrase starters
|
| 384 |
+
if risks.get('hedging', 0) > self.config.hedging_threshold:
|
| 385 |
for tok in self._hedge_token_ids:
|
| 386 |
logits[0, tok] -= self.config.hedging_penalty
|
| 387 |
interventions['hedging'] = True
|
| 388 |
self.total_interventions['hedging'] += 1
|
| 389 |
|
| 390 |
# Verbosity: suppress filler phrase starters
|
| 391 |
+
if risks.get('verbosity', 0) > self.config.verbosity_threshold:
|
| 392 |
for tok in self._verbose_token_ids:
|
| 393 |
logits[0, tok] -= self.config.verbosity_penalty
|
| 394 |
interventions['verbosity'] = True
|
| 395 |
self.total_interventions['verbosity'] += 1
|
| 396 |
|
| 397 |
# Sycophancy: suppress sycophantic starters
|
| 398 |
+
if risks.get('sycophancy', 0) > self.config.sycophancy_threshold:
|
| 399 |
for tok in self._sycophancy_token_ids:
|
| 400 |
logits[0, tok] -= self.config.sycophancy_penalty
|
| 401 |
interventions['sycophancy'] = True
|