gary-boon Claude commited on
Commit
3ee2b4b
·
1 Parent(s): 97df962

Fix: Prevent hook persistence after ablation errors

Browse files

- Fixed tuple index error in layer hook when input format varies
- Added try-finally block to ensure hooks are ALWAYS removed
- Fixed indentation in generation loop
- Prevents 500 errors from persisting after failed ablation

The issue was that when ablation hooks failed, they weren't removed,
causing ALL subsequent generation calls to fail until Space restart.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +82 -73
backend/model_service.py CHANGED
@@ -299,9 +299,16 @@ class ModelManager:
299
  def create_layer_hook():
300
  def hook(module, input, output):
301
  # Pass through input unchanged (skip layer)
 
 
 
 
 
 
 
302
  if isinstance(output, tuple):
303
- return (input[0],) + output[1:]
304
- return input[0]
305
  return hook
306
 
307
  # Apply hooks and log what's being disabled
@@ -333,79 +340,81 @@ class ModelManager:
333
  if total_attention_disabled > 0:
334
  logger.info(f"Total attention heads disabled: {total_attention_disabled} / {self.model.config.n_layer * self.model.config.n_head}")
335
 
336
- # Generation loop
337
- with torch.no_grad():
338
- for _ in range(max_tokens):
339
- outputs = self.model(**inputs)
340
- logits = outputs.logits
341
- next_token_logits = logits[0, -1, :]
342
-
343
- # Handle potential inf/nan values
344
- if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any():
345
- # Replace inf/nan with reasonable values
346
- next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0)
347
-
348
- # Apply temperature
349
- if temperature > 0:
350
- next_token_logits = next_token_logits / temperature
351
-
352
- # Compute probabilities with numerical stability
353
- probs = torch.softmax(next_token_logits, dim=0)
354
-
355
- # Additional safety check
356
- if torch.isnan(probs).any() or (probs < 0).any() or torch.isinf(probs).any():
357
- # Fallback to uniform distribution if probabilities are invalid
358
- probs = torch.ones_like(probs) / probs.shape[0]
359
-
360
- # Ensure probabilities sum to 1 (numerical stability)
361
- probs = probs / probs.sum()
362
-
363
- # Apply top-k filtering
364
- if top_k is not None and top_k > 0:
365
- top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0]))
366
- probs = torch.zeros_like(probs)
367
- probs[top_k_indices] = top_k_probs
368
- probs = probs / probs.sum()
369
-
370
- # Apply top-p (nucleus) filtering
371
- if top_p is not None and top_p < 1.0:
372
- sorted_probs, sorted_indices = torch.sort(probs, descending=True)
373
- cumulative_probs = torch.cumsum(sorted_probs, dim=0)
374
- sorted_indices_to_remove = cumulative_probs > top_p
375
- sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
376
- sorted_indices_to_remove[0] = False
377
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
378
- probs[indices_to_remove] = 0
379
  probs = probs / probs.sum()
380
-
381
- # Sample next token
382
- try:
383
- if temperature == 0:
384
- # Deterministic: take argmax
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
386
- else:
387
- next_token = torch.multinomial(probs, 1)
388
- except RuntimeError as e:
389
- # If sampling fails, use argmax as fallback
390
- logger.warning(f"Sampling failed, using argmax: {e}")
391
- next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
392
- generated_tokens.append(next_token.item())
393
- token_probs.append(float(probs[next_token.item()]))
394
- token_strings.append(self.tokenizer.decode([next_token.item()], skip_special_tokens=True))
395
-
396
- # Update inputs
397
- inputs = {
398
- "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1),
399
- "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1)
400
- }
401
-
402
- # Check for end of sequence
403
- if next_token.item() == self.tokenizer.eos_token_id:
404
- break
405
-
406
- # Remove hooks
407
- for handle in handles:
408
- handle.remove()
409
 
410
  # Decode generated text
411
  generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
299
  def create_layer_hook():
300
  def hook(module, input, output):
301
  # Pass through input unchanged (skip layer)
302
+ # Handle different input formats
303
+ if isinstance(input, tuple) and len(input) > 0:
304
+ input_tensor = input[0]
305
+ else:
306
+ input_tensor = input
307
+
308
+ # Return input with same format as output
309
  if isinstance(output, tuple):
310
+ return (input_tensor,) + output[1:]
311
+ return input_tensor
312
  return hook
313
 
314
  # Apply hooks and log what's being disabled
 
340
  if total_attention_disabled > 0:
341
  logger.info(f"Total attention heads disabled: {total_attention_disabled} / {self.model.config.n_layer * self.model.config.n_head}")
342
 
343
+ # Generation loop - wrapped in try-finally to ensure hooks are removed
344
+ try:
345
+ with torch.no_grad():
346
+ for _ in range(max_tokens):
347
+ outputs = self.model(**inputs)
348
+ logits = outputs.logits
349
+ next_token_logits = logits[0, -1, :]
350
+
351
+ # Handle potential inf/nan values
352
+ if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any():
353
+ # Replace inf/nan with reasonable values
354
+ next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0)
355
+
356
+ # Apply temperature
357
+ if temperature > 0:
358
+ next_token_logits = next_token_logits / temperature
359
+
360
+ # Compute probabilities with numerical stability
361
+ probs = torch.softmax(next_token_logits, dim=0)
362
+
363
+ # Additional safety check
364
+ if torch.isnan(probs).any() or (probs < 0).any() or torch.isinf(probs).any():
365
+ # Fallback to uniform distribution if probabilities are invalid
366
+ probs = torch.ones_like(probs) / probs.shape[0]
367
+
368
+ # Ensure probabilities sum to 1 (numerical stability)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  probs = probs / probs.sum()
370
+
371
+ # Apply top-k filtering
372
+ if top_k is not None and top_k > 0:
373
+ top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0]))
374
+ probs = torch.zeros_like(probs)
375
+ probs[top_k_indices] = top_k_probs
376
+ probs = probs / probs.sum()
377
+
378
+ # Apply top-p (nucleus) filtering
379
+ if top_p is not None and top_p < 1.0:
380
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
381
+ cumulative_probs = torch.cumsum(sorted_probs, dim=0)
382
+ sorted_indices_to_remove = cumulative_probs > top_p
383
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
384
+ sorted_indices_to_remove[0] = False
385
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
386
+ probs[indices_to_remove] = 0
387
+ probs = probs / probs.sum()
388
+
389
+ # Sample next token
390
+ try:
391
+ if temperature == 0:
392
+ # Deterministic: take argmax
393
+ next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
394
+ else:
395
+ next_token = torch.multinomial(probs, 1)
396
+ except RuntimeError as e:
397
+ # If sampling fails, use argmax as fallback
398
+ logger.warning(f"Sampling failed, using argmax: {e}")
399
  next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
400
+ generated_tokens.append(next_token.item())
401
+ token_probs.append(float(probs[next_token.item()]))
402
+ token_strings.append(self.tokenizer.decode([next_token.item()], skip_special_tokens=True))
403
+
404
+ # Update inputs
405
+ inputs = {
406
+ "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1),
407
+ "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1)
408
+ }
409
+
410
+ # Check for end of sequence
411
+ if next_token.item() == self.tokenizer.eos_token_id:
412
+ break
413
+ finally:
414
+ # Always remove hooks, even if there's an error
415
+ for handle in handles:
416
+ handle.remove()
417
+ logger.info(f"Removed {len(handles)} hooks")
 
 
 
 
 
418
 
419
  # Decode generated text
420
  generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)