gary-boon Claude commited on
Commit
4b03268
·
1 Parent(s): 9e42df9

Fix: Refine layer hook output format handling

Browse files

- Simplified logic to match exact output structure
- Ensure compatibility with layer_norm expectations
- Handle all tuple/tensor cases properly

Testing different approach to prevent layer_norm type errors.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. backend/model_service.py +18 -20
backend/model_service.py CHANGED
@@ -298,29 +298,27 @@ class ModelManager:
298
 
299
  def create_layer_hook():
300
  def hook(module, input, output):
301
- # Skip layer by passing through input unchanged
302
- # The input to a transformer layer is (hidden_states, optional_attention_mask, ...)
303
- # The output is (hidden_states, optional_attention_weights, ...)
304
- # We want to pass the input hidden states as if the layer did nothing
305
 
306
- # Get the input hidden states
307
- if isinstance(input, tuple) and len(input) > 0:
308
- input_hidden_states = input[0]
309
- else:
310
- input_hidden_states = input
311
 
312
- # Return in the same format as the output
313
- if isinstance(output, tuple):
314
- # Check if there are additional elements to preserve
315
- if len(output) > 1:
316
- # Keep any additional outputs (like attention weights)
317
- return (input_hidden_states,) + output[1:]
318
- else:
319
- # Output is a single-element tuple, return the same
320
- return (input_hidden_states,)
321
- else:
322
- # Output is a plain tensor, return input as plain tensor
323
  return input_hidden_states
 
 
 
 
 
 
324
  return hook
325
 
326
  # Apply hooks and log what's being disabled
 
298
 
299
  def create_layer_hook():
300
  def hook(module, input, output):
301
+ # Skip layer by making it an identity operation
302
+ # The key insight: we must match the EXACT output structure
303
+ # but replace hidden states with input hidden states
 
304
 
305
+ # For CodeGen blocks, the input/output structure is:
306
+ # input: (hidden_states,) or just hidden_states
307
+ # output: (hidden_states,) or (hidden_states, presents) etc.
 
 
308
 
309
+ # Get input hidden states
310
+ input_hidden_states = input[0] if isinstance(input, tuple) else input
311
+
312
+ # Match output structure exactly
313
+ if not isinstance(output, tuple):
314
+ # If output is a plain tensor, return input as plain tensor
 
 
 
 
 
315
  return input_hidden_states
316
+ elif len(output) == 1:
317
+ # Single element tuple - preserve as single element tuple
318
+ return (input_hidden_states,)
319
+ else:
320
+ # Multiple elements - keep all but replace hidden states
321
+ return (input_hidden_states,) + output[1:]
322
  return hook
323
 
324
  # Apply hooks and log what's being disabled