Spaces:
Sleeping
Sleeping
gary-boon
Claude
commited on
Commit
·
cd300ee
1
Parent(s):
7dd568f
Fix ablation study for Code Llama compatibility
Browse filesFixed AttributeError in ablation generation when using Code Llama:
- CodeGen uses: n_layer, n_head
- Llama/Code Llama uses: num_hidden_layers, num_attention_heads
Changes:
- Added config attribute compatibility variables at start of ablation method
- Replaced hardcoded config.n_layer and config.n_head with compatibility variables
- Now handles both CodeGen and Llama model architectures
This fixes the 500 error in Ablation Study when using Code Llama 7B.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- backend/model_service.py +11 -4
backend/model_service.py
CHANGED
|
@@ -270,18 +270,25 @@ class ModelManager:
|
|
| 270 |
disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()}
|
| 271 |
disabled_ffn = set(disabled_components.get('ffn_layers', [])) if disabled_components else set()
|
| 272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
# Debug logging
|
| 274 |
logger.info(f"Ablation request received with disabled_components: {disabled_components}")
|
| 275 |
if disabled_attention:
|
| 276 |
total_heads = sum(len(heads) for heads in disabled_attention.values())
|
| 277 |
logger.info(f"Total attention heads to disable: {total_heads}")
|
| 278 |
-
|
| 279 |
# Tokenize input
|
| 280 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 281 |
generated_tokens = []
|
| 282 |
token_probs = []
|
| 283 |
token_strings = []
|
| 284 |
-
|
| 285 |
# Create hooks for ablation
|
| 286 |
handles = []
|
| 287 |
|
|
@@ -337,7 +344,7 @@ class ModelManager:
|
|
| 337 |
|
| 338 |
# Apply hooks and log what's being disabled
|
| 339 |
total_attention_disabled = 0
|
| 340 |
-
for layer_idx in range(
|
| 341 |
if layer_idx in disabled_layers:
|
| 342 |
# Disable entire layer
|
| 343 |
handle = self.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook())
|
|
@@ -362,7 +369,7 @@ class ModelManager:
|
|
| 362 |
|
| 363 |
# Log summary
|
| 364 |
if total_attention_disabled > 0:
|
| 365 |
-
logger.info(f"Total attention heads disabled: {total_attention_disabled} / {
|
| 366 |
|
| 367 |
# Generation loop - wrapped in try-finally to ensure hooks are removed
|
| 368 |
try:
|
|
|
|
| 270 |
disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()}
|
| 271 |
disabled_ffn = set(disabled_components.get('ffn_layers', [])) if disabled_components else set()
|
| 272 |
|
| 273 |
+
# Get config attributes with compatibility for different model architectures
|
| 274 |
+
# CodeGen uses: n_layer, n_head
|
| 275 |
+
# Llama/Code Llama uses: num_hidden_layers, num_attention_heads
|
| 276 |
+
config = self.model.config
|
| 277 |
+
num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0))
|
| 278 |
+
num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0))
|
| 279 |
+
|
| 280 |
# Debug logging
|
| 281 |
logger.info(f"Ablation request received with disabled_components: {disabled_components}")
|
| 282 |
if disabled_attention:
|
| 283 |
total_heads = sum(len(heads) for heads in disabled_attention.values())
|
| 284 |
logger.info(f"Total attention heads to disable: {total_heads}")
|
| 285 |
+
|
| 286 |
# Tokenize input
|
| 287 |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 288 |
generated_tokens = []
|
| 289 |
token_probs = []
|
| 290 |
token_strings = []
|
| 291 |
+
|
| 292 |
# Create hooks for ablation
|
| 293 |
handles = []
|
| 294 |
|
|
|
|
| 344 |
|
| 345 |
# Apply hooks and log what's being disabled
|
| 346 |
total_attention_disabled = 0
|
| 347 |
+
for layer_idx in range(num_layers):
|
| 348 |
if layer_idx in disabled_layers:
|
| 349 |
# Disable entire layer
|
| 350 |
handle = self.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook())
|
|
|
|
| 369 |
|
| 370 |
# Log summary
|
| 371 |
if total_attention_disabled > 0:
|
| 372 |
+
logger.info(f"Total attention heads disabled: {total_attention_disabled} / {num_layers * num_heads}")
|
| 373 |
|
| 374 |
# Generation loop - wrapped in try-finally to ensure hooks are removed
|
| 375 |
try:
|