Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
gary-boon
Claude
commited on
Commit
·
3ee2b4b
1
Parent(s):
97df962
Fix: Prevent hook persistence after ablation errors
Browse files- Fixed tuple index error in layer hook when input format varies
- Added try-finally block to ensure hooks are ALWAYS removed
- Fixed indentation in generation loop
- Prevents 500 errors from persisting after failed ablation
The issue was that when ablation hooks failed, they weren't removed,
causing ALL subsequent generation calls to fail until Space restart.
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
- backend/model_service.py +82 -73
backend/model_service.py
CHANGED
|
@@ -299,9 +299,16 @@ class ModelManager:
|
|
| 299 |
def create_layer_hook():
|
| 300 |
def hook(module, input, output):
|
| 301 |
# Pass through input unchanged (skip layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
if isinstance(output, tuple):
|
| 303 |
-
return (
|
| 304 |
-
return
|
| 305 |
return hook
|
| 306 |
|
| 307 |
# Apply hooks and log what's being disabled
|
|
@@ -333,79 +340,81 @@ class ModelManager:
|
|
| 333 |
if total_attention_disabled > 0:
|
| 334 |
logger.info(f"Total attention heads disabled: {total_attention_disabled} / {self.model.config.n_layer * self.model.config.n_head}")
|
| 335 |
|
| 336 |
-
# Generation loop
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
# Apply top-k filtering
|
| 364 |
-
if top_k is not None and top_k > 0:
|
| 365 |
-
top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0]))
|
| 366 |
-
probs = torch.zeros_like(probs)
|
| 367 |
-
probs[top_k_indices] = top_k_probs
|
| 368 |
-
probs = probs / probs.sum()
|
| 369 |
-
|
| 370 |
-
# Apply top-p (nucleus) filtering
|
| 371 |
-
if top_p is not None and top_p < 1.0:
|
| 372 |
-
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 373 |
-
cumulative_probs = torch.cumsum(sorted_probs, dim=0)
|
| 374 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 375 |
-
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
| 376 |
-
sorted_indices_to_remove[0] = False
|
| 377 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 378 |
-
probs[indices_to_remove] = 0
|
| 379 |
probs = probs / probs.sum()
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
break
|
| 405 |
-
|
| 406 |
-
# Remove hooks
|
| 407 |
-
for handle in handles:
|
| 408 |
-
handle.remove()
|
| 409 |
|
| 410 |
# Decode generated text
|
| 411 |
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
|
|
| 299 |
def create_layer_hook():
|
| 300 |
def hook(module, input, output):
|
| 301 |
# Pass through input unchanged (skip layer)
|
| 302 |
+
# Handle different input formats
|
| 303 |
+
if isinstance(input, tuple) and len(input) > 0:
|
| 304 |
+
input_tensor = input[0]
|
| 305 |
+
else:
|
| 306 |
+
input_tensor = input
|
| 307 |
+
|
| 308 |
+
# Return input with same format as output
|
| 309 |
if isinstance(output, tuple):
|
| 310 |
+
return (input_tensor,) + output[1:]
|
| 311 |
+
return input_tensor
|
| 312 |
return hook
|
| 313 |
|
| 314 |
# Apply hooks and log what's being disabled
|
|
|
|
| 340 |
if total_attention_disabled > 0:
|
| 341 |
logger.info(f"Total attention heads disabled: {total_attention_disabled} / {self.model.config.n_layer * self.model.config.n_head}")
|
| 342 |
|
| 343 |
+
# Generation loop - wrapped in try-finally to ensure hooks are removed
|
| 344 |
+
try:
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
for _ in range(max_tokens):
|
| 347 |
+
outputs = self.model(**inputs)
|
| 348 |
+
logits = outputs.logits
|
| 349 |
+
next_token_logits = logits[0, -1, :]
|
| 350 |
+
|
| 351 |
+
# Handle potential inf/nan values
|
| 352 |
+
if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any():
|
| 353 |
+
# Replace inf/nan with reasonable values
|
| 354 |
+
next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0)
|
| 355 |
+
|
| 356 |
+
# Apply temperature
|
| 357 |
+
if temperature > 0:
|
| 358 |
+
next_token_logits = next_token_logits / temperature
|
| 359 |
+
|
| 360 |
+
# Compute probabilities with numerical stability
|
| 361 |
+
probs = torch.softmax(next_token_logits, dim=0)
|
| 362 |
+
|
| 363 |
+
# Additional safety check
|
| 364 |
+
if torch.isnan(probs).any() or (probs < 0).any() or torch.isinf(probs).any():
|
| 365 |
+
# Fallback to uniform distribution if probabilities are invalid
|
| 366 |
+
probs = torch.ones_like(probs) / probs.shape[0]
|
| 367 |
+
|
| 368 |
+
# Ensure probabilities sum to 1 (numerical stability)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
probs = probs / probs.sum()
|
| 370 |
+
|
| 371 |
+
# Apply top-k filtering
|
| 372 |
+
if top_k is not None and top_k > 0:
|
| 373 |
+
top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0]))
|
| 374 |
+
probs = torch.zeros_like(probs)
|
| 375 |
+
probs[top_k_indices] = top_k_probs
|
| 376 |
+
probs = probs / probs.sum()
|
| 377 |
+
|
| 378 |
+
# Apply top-p (nucleus) filtering
|
| 379 |
+
if top_p is not None and top_p < 1.0:
|
| 380 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
| 381 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=0)
|
| 382 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 383 |
+
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
| 384 |
+
sorted_indices_to_remove[0] = False
|
| 385 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 386 |
+
probs[indices_to_remove] = 0
|
| 387 |
+
probs = probs / probs.sum()
|
| 388 |
+
|
| 389 |
+
# Sample next token
|
| 390 |
+
try:
|
| 391 |
+
if temperature == 0:
|
| 392 |
+
# Deterministic: take argmax
|
| 393 |
+
next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
|
| 394 |
+
else:
|
| 395 |
+
next_token = torch.multinomial(probs, 1)
|
| 396 |
+
except RuntimeError as e:
|
| 397 |
+
# If sampling fails, use argmax as fallback
|
| 398 |
+
logger.warning(f"Sampling failed, using argmax: {e}")
|
| 399 |
next_token = torch.argmax(probs, dim=-1).unsqueeze(0)
|
| 400 |
+
generated_tokens.append(next_token.item())
|
| 401 |
+
token_probs.append(float(probs[next_token.item()]))
|
| 402 |
+
token_strings.append(self.tokenizer.decode([next_token.item()], skip_special_tokens=True))
|
| 403 |
+
|
| 404 |
+
# Update inputs
|
| 405 |
+
inputs = {
|
| 406 |
+
"input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1),
|
| 407 |
+
"attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1)
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
# Check for end of sequence
|
| 411 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 412 |
+
break
|
| 413 |
+
finally:
|
| 414 |
+
# Always remove hooks, even if there's an error
|
| 415 |
+
for handle in handles:
|
| 416 |
+
handle.remove()
|
| 417 |
+
logger.info(f"Removed {len(handles)} hooks")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
# Decode generated text
|
| 420 |
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|