LoganResearch commited on
Commit
749c71a
·
verified ·
1 Parent(s): 9728955

Fix device mismatch and tensor.item() bugs

Browse files
Files changed (1) hide show
  1. inference.py +9 -5
inference.py CHANGED
@@ -129,13 +129,15 @@ class MultiHeadPredictor(nn.Module):
129
  Returns:
130
  Aggregated features [batch, seq, d_fiber]
131
  """
 
132
  fibers = []
133
  for i, (proj, hidden) in enumerate(zip(self.fiber_projs, hidden_states)):
134
  if i < len(hidden_states):
 
135
  fibers.append(proj(hidden.float()))
136
 
137
  # Weighted sum across layers
138
- weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
139
  aggregated = sum(w * f for w, f in zip(weights, fibers))
140
  return aggregated
141
 
@@ -153,9 +155,11 @@ class MultiHeadPredictor(nn.Module):
153
  if not self.loaded_heads:
154
  return {}
155
 
 
156
  features = self.get_fiber_features(hidden_states)
157
  risks = {}
158
  for name in self.loaded_heads:
 
159
  logits = self.heads[name](features).squeeze(-1)
160
  risks[name] = torch.sigmoid(logits)
161
  return risks
@@ -370,28 +374,28 @@ class ARCSystem:
370
  interventions = {}
371
 
372
  # Repetition: suppress recently used tokens
373
- if risks.get('repetition', torch.tensor(0)).item() > self.config.repetition_threshold:
374
  for tok in set(recent_tokens[-self.config.repetition_window:]):
375
  logits[0, tok] -= self.config.repetition_penalty
376
  interventions['repetition'] = True
377
  self.total_interventions['repetition'] += 1
378
 
379
  # Hedging: suppress hedge phrase starters
380
- if risks.get('hedging', torch.tensor(0)).item() > self.config.hedging_threshold:
381
  for tok in self._hedge_token_ids:
382
  logits[0, tok] -= self.config.hedging_penalty
383
  interventions['hedging'] = True
384
  self.total_interventions['hedging'] += 1
385
 
386
  # Verbosity: suppress filler phrase starters
387
- if risks.get('verbosity', torch.tensor(0)).item() > self.config.verbosity_threshold:
388
  for tok in self._verbose_token_ids:
389
  logits[0, tok] -= self.config.verbosity_penalty
390
  interventions['verbosity'] = True
391
  self.total_interventions['verbosity'] += 1
392
 
393
  # Sycophancy: suppress sycophantic starters
394
- if risks.get('sycophancy', torch.tensor(0)).item() > self.config.sycophancy_threshold:
395
  for tok in self._sycophancy_token_ids:
396
  logits[0, tok] -= self.config.sycophancy_penalty
397
  interventions['sycophancy'] = True
 
129
  Returns:
130
  Aggregated features [batch, seq, d_fiber]
131
  """
132
+ device = hidden_states[0].device
133
  fibers = []
134
  for i, (proj, hidden) in enumerate(zip(self.fiber_projs, hidden_states)):
135
  if i < len(hidden_states):
136
+ proj = proj.to(device)
137
  fibers.append(proj(hidden.float()))
138
 
139
  # Weighted sum across layers
140
+ weights = F.softmax(self.layer_weights.to(device)[:len(fibers)], dim=0)
141
  aggregated = sum(w * f for w, f in zip(weights, fibers))
142
  return aggregated
143
 
 
155
  if not self.loaded_heads:
156
  return {}
157
 
158
+ device = hidden_states[0].device
159
  features = self.get_fiber_features(hidden_states)
160
  risks = {}
161
  for name in self.loaded_heads:
162
+ self.heads[name] = self.heads[name].to(device)
163
  logits = self.heads[name](features).squeeze(-1)
164
  risks[name] = torch.sigmoid(logits)
165
  return risks
 
374
  interventions = {}
375
 
376
  # Repetition: suppress recently used tokens
377
+ if risks.get('repetition', 0) > self.config.repetition_threshold:
378
  for tok in set(recent_tokens[-self.config.repetition_window:]):
379
  logits[0, tok] -= self.config.repetition_penalty
380
  interventions['repetition'] = True
381
  self.total_interventions['repetition'] += 1
382
 
383
  # Hedging: suppress hedge phrase starters
384
+ if risks.get('hedging', 0) > self.config.hedging_threshold:
385
  for tok in self._hedge_token_ids:
386
  logits[0, tok] -= self.config.hedging_penalty
387
  interventions['hedging'] = True
388
  self.total_interventions['hedging'] += 1
389
 
390
  # Verbosity: suppress filler phrase starters
391
+ if risks.get('verbosity', 0) > self.config.verbosity_threshold:
392
  for tok in self._verbose_token_ids:
393
  logits[0, tok] -= self.config.verbosity_penalty
394
  interventions['verbosity'] = True
395
  self.total_interventions['verbosity'] += 1
396
 
397
  # Sycophancy: suppress sycophantic starters
398
+ if risks.get('sycophancy', 0) > self.config.sycophancy_threshold:
399
  for tok in self._sycophancy_token_ids:
400
  logits[0, tok] -= self.config.sycophancy_penalty
401
  interventions['sycophancy'] = True