Spaces:
Sleeping
Sleeping
gary-boon
Claude
commited on
Commit
·
4b03268
1
Parent(s):
9e42df9
Fix: Refine layer hook output format handling
Browse files- Simplified logic to match exact output structure
- Ensure compatibility with layer_norm expectations
- Handle all tuple/tensor cases properly
Testing different approach to prevent layer_norm type errors.
🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
- backend/model_service.py +18 -20
backend/model_service.py
CHANGED
|
@@ -298,29 +298,27 @@ class ModelManager:
|
|
| 298 |
|
| 299 |
def create_layer_hook():
|
| 300 |
def hook(module, input, output):
|
| 301 |
-
# Skip layer by
|
| 302 |
-
# The
|
| 303 |
-
#
|
| 304 |
-
# We want to pass the input hidden states as if the layer did nothing
|
| 305 |
|
| 306 |
-
#
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
else:
|
| 310 |
-
input_hidden_states = input
|
| 311 |
|
| 312 |
-
#
|
| 313 |
-
if isinstance(
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
else:
|
| 319 |
-
# Output is a single-element tuple, return the same
|
| 320 |
-
return (input_hidden_states,)
|
| 321 |
-
else:
|
| 322 |
-
# Output is a plain tensor, return input as plain tensor
|
| 323 |
return input_hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
return hook
|
| 325 |
|
| 326 |
# Apply hooks and log what's being disabled
|
|
|
|
| 298 |
|
| 299 |
def create_layer_hook():
|
| 300 |
def hook(module, input, output):
|
| 301 |
+
# Skip layer by making it an identity operation
|
| 302 |
+
# The key insight: we must match the EXACT output structure
|
| 303 |
+
# but replace hidden states with input hidden states
|
|
|
|
| 304 |
|
| 305 |
+
# For CodeGen blocks, the input/output structure is:
|
| 306 |
+
# input: (hidden_states,) or just hidden_states
|
| 307 |
+
# output: (hidden_states,) or (hidden_states, presents) etc.
|
|
|
|
|
|
|
| 308 |
|
| 309 |
+
# Get input hidden states
|
| 310 |
+
input_hidden_states = input[0] if isinstance(input, tuple) else input
|
| 311 |
+
|
| 312 |
+
# Match output structure exactly
|
| 313 |
+
if not isinstance(output, tuple):
|
| 314 |
+
# If output is a plain tensor, return input as plain tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
return input_hidden_states
|
| 316 |
+
elif len(output) == 1:
|
| 317 |
+
# Single element tuple - preserve as single element tuple
|
| 318 |
+
return (input_hidden_states,)
|
| 319 |
+
else:
|
| 320 |
+
# Multiple elements - keep all but replace hidden states
|
| 321 |
+
return (input_hidden_states,) + output[1:]
|
| 322 |
return hook
|
| 323 |
|
| 324 |
# Apply hooks and log what's being disabled
|