gary-boon Claude commited on
Commit
070f9b8
·
1 Parent(s): 3ee2b4b

Fix: Correct layer hook output format for layer_norm compatibility

Browse files

- Fixed "layer_norm() argument 'input' must be Tensor, not tuple" error
- Hook now properly handles different input/output formats
- Maintains compatibility with subsequent layers expecting tensors

The previous fix was returning tuples when tensors were expected,
breaking layer_norm operations after skipped layers.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +16 -10
backend/model_service.py CHANGED
@@ -298,17 +298,23 @@ class ModelManager:
298
 
299
  def create_layer_hook():
300
  def hook(module, input, output):
301
- # Pass through input unchanged (skip layer)
302
- # Handle different input formats
303
- if isinstance(input, tuple) and len(input) > 0:
304
- input_tensor = input[0]
305
- else:
306
- input_tensor = input
307
 
308
- # Return input with same format as output
309
- if isinstance(output, tuple):
310
- return (input_tensor,) + output[1:]
311
- return input_tensor
 
 
 
 
 
 
 
 
312
  return hook
313
 
314
  # Apply hooks and log what's being disabled
 
298
 
299
  def create_layer_hook():
300
  def hook(module, input, output):
301
+ # Skip layer by passing through input unchanged
302
+ # The input to a transformer layer is (hidden_states, optional_attention_mask, ...)
303
+ # The output is (hidden_states, optional_attention_weights, ...)
304
+ # We want to pass the input hidden states as if the layer did nothing
 
 
305
 
306
+ if isinstance(input, tuple):
307
+ # Return input hidden states, but keep any additional outputs from the layer
308
+ # This maintains compatibility with the expected output format
309
+ if isinstance(output, tuple) and len(output) > 1:
310
+ # Keep attention weights etc. from output if present
311
+ return (input[0],) + output[1:]
312
+ else:
313
+ # Just return the input hidden states
314
+ return input[0]
315
+ else:
316
+ # Simple tensor input/output
317
+ return input
318
  return hook
319
 
320
  # Apply hooks and log what's being disabled