File size: 20,251 Bytes
c535f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
#!/usr/bin/env python3
"""
Foreground vs Background Token Analysis

Analyzes the distribution of tokens based on whether they are in GT box (foreground) or not (background).
- Foreground: HR tokens (have parent_lr_idx) - inside GT box - Light Blue
- Background: LR tokens (no parent_lr_idx replacement) - outside GT box - Light Red

Visualization:
- X-axis: Feature Norm (0 to max)
- Left Y-axis: Token Count (blue line for foreground, red line for background)
- Right Y-axis: Average Attention Weight
"""

import os
import sys
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import glob
from typing import List, Dict, Tuple
import argparse

# Add custom models path
sys.path.insert(0, '/mnt/rdata4_6/huixin/LLaMA-Factory-main/custom_models/qwen2_5_vl')


def load_mts_sample(npz_path: str) -> dict:
    """Load a single MTS attention sample from .npz file."""
    data = np.load(npz_path, allow_pickle=True)
    return {key: data[key] for key in data.files}


def compute_feature_norm_from_attention(attention: np.ndarray, n_visual_tokens: int) -> np.ndarray:
    """
    Compute feature norm proxy from attention patterns.
    
    For visual tokens, we can use the sum of attention received from all text tokens
    as a proxy for feature importance/norm.
    
    Args:
        attention: (num_heads, seq_len, seq_len)
        n_visual_tokens: Number of visual tokens
    
    Returns:
        feature_norms: (n_visual_tokens,) - norm value for each visual token
    """
    num_heads, seq_len, _ = attention.shape
    
    # Extract text->visual attention: text tokens attend to visual tokens
    text_start_idx = n_visual_tokens
    text_to_visual = attention[:, text_start_idx:, :n_visual_tokens]  # (num_heads, n_text, n_visual)
    
    # Sum attention received by each visual token across all text tokens and heads
    # This gives a measure of how much each visual token is "attended to"
    feature_norms = text_to_visual.sum(axis=(0, 1))  # (n_visual,)
    
    return feature_norms


def compute_true_feature_norm(vision_features: np.ndarray) -> np.ndarray:
    """
    Compute true L2 norm of vision features.
    
    Args:
        vision_features: (n_visual_tokens, hidden_dim) - vision feature embeddings
    
    Returns:
        feature_norms: (n_visual_tokens,) - L2 norm for each visual token
    """
    # Compute L2 norm: ||z_i||_2 = sqrt(sum(z_i^2))
    feature_norms = np.linalg.norm(vision_features, ord=2, axis=-1)  # (n_visual,)
    return feature_norms


def analyze_foreground_background_distribution(
    sample_data: dict,
    layer_idx: int,
    n_bins: int = 50
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Analyze foreground vs background token distribution.
    
    Args:
        sample_data: Loaded .npz sample data
        layer_idx: Layer index to analyze
        n_bins: Number of bins for histogram
    
    Returns:
        bin_centers: (n_bins,) - center of each bin for x-axis
        fg_counts: (n_bins,) - foreground token count per bin
        bg_counts: (n_bins,) - background token count per bin
        fg_attn_mean: (n_bins,) - mean attention for foreground tokens per bin
        bg_attn_mean: (n_bins,) - mean attention for background tokens per bin
    """
    attn_key = f'attention_layer_{layer_idx}'
    if attn_key not in sample_data:
        return None, None, None, None, None
    
    attention = sample_data[attn_key]  # (num_heads, seq_len, seq_len)
    is_hr_token = sample_data['is_hr_token']  # (n_tokens,) bool mask
    n_visual_merged = sample_data.get('n_visual_merged', None)
    
    if n_visual_merged is None:
        # Estimate from grid shape
        lr_grid_shape = tuple(sample_data['lr_grid_shape'])
        t, h, w = lr_grid_shape
        spatial_merge_size = 2
        n_visual_merged = (t * h * w) // (spatial_merge_size ** 2)
    
    n_visual_tokens = int(n_visual_merged)
    
    # Compute feature norms
    # Try to use true L2 norm from vision features if available
    vision_features = sample_data.get('vision_features', None)
    if vision_features is not None and len(vision_features) >= n_visual_tokens:
        # Use true L2 norm
        feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens])
        norm_type = "L2 Norm"
    else:
        # Fallback: use attention-based proxy
        feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens)
        norm_type = "Attention Received (proxy)"
        print(f"    Warning: No vision_features found, using attention proxy for feature norm")
    
    # Extract attention weights for each visual token (mean over text tokens and heads)
    text_start_idx = n_visual_tokens
    text_to_visual = attention[:, text_start_idx:, :n_visual_tokens]  # (num_heads, n_text, n_visual)
    attn_weights = text_to_visual.mean(axis=(0, 1))  # (n_visual,)
    
    # Classify tokens as foreground (HR) or background (LR)
    # KEY: is_hr_token corresponds to UNMERGED tokens (before 2x2 spatial merge)
    # We need to map unmerged->merged to get foreground mask for the final sequence
    
    merge_unit = 4  # 2x2 spatial merge = 4 tokens per merged token
    is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool)
    n_mts_unmerged = len(is_hr_token)
    
    # Map: unmerged HR tokens → merged positions
    for unmerged_idx in range(n_mts_unmerged):
        if is_hr_token[unmerged_idx]:
            merged_idx = unmerged_idx // merge_unit
            if merged_idx < n_visual_tokens:
                is_hr_token_merged[merged_idx] = True
    
    # Foreground = positions with HR tokens, Background = positions with only LR
    foreground_mask = is_hr_token_merged
    background_mask = ~is_hr_token_merged
    
    print(f"    Visual tokens: {n_visual_tokens}, HR: {foreground_mask.sum()}, LR: {background_mask.sum()}")
    
    # Get foreground and background data
    fg_norms = feature_norms[foreground_mask]
    bg_norms = feature_norms[background_mask]
    fg_attn = attn_weights[foreground_mask]
    bg_attn = attn_weights[background_mask]
    
    # Create bins based on full range
    norm_min = 0
    norm_max = feature_norms.max()
    bins = np.linspace(norm_min, norm_max, n_bins + 1)
    bin_centers = (bins[:-1] + bins[1:]) / 2
    
    # Compute histograms for token counts
    fg_counts, _ = np.histogram(fg_norms, bins=bins)
    bg_counts, _ = np.histogram(bg_norms, bins=bins)
    
    # Compute mean attention per bin
    fg_attn_mean = np.full(n_bins, np.nan)  # Initialize with NaN
    bg_attn_mean = np.full(n_bins, np.nan)  # Initialize with NaN
    
    for i in range(n_bins):
        # Foreground
        mask_fg = (fg_norms >= bins[i]) & (fg_norms < bins[i+1])
        if mask_fg.sum() > 0:
            fg_attn_mean[i] = fg_attn[mask_fg].mean()
        
        # Background
        mask_bg = (bg_norms >= bins[i]) & (bg_norms < bins[i+1])
        if mask_bg.sum() > 0:
            bg_attn_mean[i] = bg_attn[mask_bg].mean()
    
    return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean


def visualize_foreground_background_analysis(
    bin_centers: np.ndarray,
    fg_counts: np.ndarray,
    bg_counts: np.ndarray,
    fg_attn_mean: np.ndarray,
    bg_attn_mean: np.ndarray,
    layer_idx: int,
    output_path: str,
    sample_ids: List[int] = None
):
    """
    Create dual Y-axis plot for foreground vs background analysis.
    
    Args:
        bin_centers: (n_bins,) - x-axis values
        fg_counts: (n_bins,) - foreground token counts
        bg_counts: (n_bins,) - background token counts
        fg_attn_mean: (n_bins,) - mean attention for foreground
        bg_attn_mean: (n_bins,) - mean attention for background
        layer_idx: Layer index for title
        output_path: Where to save the plot
        sample_ids: List of sample IDs used (for title)
    """
    # Set font (use DejaVu Sans which is commonly available on Linux)
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Helvetica', 'sans-serif']
    
    fig, ax1 = plt.subplots(figsize=(14, 8))
    
    # Left Y-axis: Token counts
    ax1.set_xlabel('L2 Norm of Visual Token', fontsize=28, fontweight='bold')
    ax1.set_ylabel('Token Count', fontsize=28, fontweight='bold', color='black')
    
    # Convert 0 counts to NaN so lines don't connect through 0
    fg_counts_masked = fg_counts.astype(float)
    fg_counts_masked[fg_counts_masked == 0] = np.nan
    
    bg_counts_masked = bg_counts.astype(float)
    bg_counts_masked[bg_counts_masked == 0] = np.nan
    
    # Plot token counts as lines (will skip NaN values in connections)
    line_fg_count = ax1.plot(bin_centers, fg_counts_masked, color='dodgerblue', linewidth=2.5, 
                             label='Foreground Tokens', marker='o', markersize=4, alpha=0.8)
    line_bg_count = ax1.plot(bin_centers, bg_counts_masked, color='lightcoral', linewidth=2.5, 
                             label='Background Tokens', marker='s', markersize=4, alpha=0.8)
    
    ax1.tick_params(axis='y', labelcolor='black', labelsize=24)
    ax1.tick_params(axis='x', labelsize=24)
    ax1.grid(True, alpha=0.3, linestyle='--')
    
    # Right Y-axis: Attention weights (as bar chart/histogram)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Average Attention Weight', fontsize=28, fontweight='bold', color='darkgreen')
    
    # Calculate bar width - use full bin width so bars touch each other
    if len(bin_centers) > 1:
        bar_width = (bin_centers[1] - bin_centers[0]) / 2  # Half bin width for each bar (no gaps)
    else:
        bar_width = 0.5
    
    # Plot attention weights as bars (histograms) - bars touch each other
    # Foreground: blue bars (left half of bin)
    ax2.bar(bin_centers - bar_width/2, fg_attn_mean, 
           width=bar_width, color='blue', alpha=0.3, 
           label='Foreground Attention', edgecolor='darkblue', linewidth=0.5, align='center')
    
    # Background: red bars (right half of bin)
    ax2.bar(bin_centers + bar_width/2, bg_attn_mean, 
           width=bar_width, color='red', alpha=0.3, 
           label='Background Attention', edgecolor='darkred', linewidth=0.5, align='center')
    
    # Create dummy objects for legend (bar patches)
    from matplotlib.patches import Patch
    line_fg_attn = [Patch(facecolor='blue', alpha=0.3, edgecolor='darkblue', linewidth=0.5, label='Foreground Attention')]
    line_bg_attn = [Patch(facecolor='red', alpha=0.3, edgecolor='darkred', linewidth=0.5, label='Background Attention')]
    
    ax2.tick_params(axis='y', labelcolor='darkgreen', labelsize=24)
    
    # Title
    sample_str = f"Samples {sample_ids}" if sample_ids else "All Samples"
    plt.title(f'Foreground vs Background Token Analysis - Layer {layer_idx}\n{sample_str}', 
              fontsize=32, fontweight='bold', pad=20)
    
    # Combined legend
    lines1 = line_fg_count + line_bg_count
    lines2 = line_fg_attn + line_bg_attn
    labels1 = [l.get_label() for l in lines1]
    labels2 = [l.get_label() for l in lines2]
    
    # Create legend with both axes
    ax1.legend(lines1, labels1, loc='upper left', fontsize=22, framealpha=0.9)
    ax2.legend(handles=lines2, loc='upper right', fontsize=22, framealpha=0.9)
    
    # Add info box with token statistics
    fg_total = fg_counts.sum()
    bg_total = bg_counts.sum()
    fg_attn_overall = np.nanmean(fg_attn_mean)  # Use nanmean to ignore NaN values
    bg_attn_overall = np.nanmean(bg_attn_mean)  # Use nanmean to ignore NaN values
    
    info_text = f"Avg Attention:\n"
    info_text += f"  Foreground: {fg_attn_overall:.4f}\n"
    info_text += f"  Background: {bg_attn_overall:.4f}"
    
    # Place info box at middle left of the plot (to avoid overlap with legend)
    ax1.text(0.02, 0.5, info_text, transform=ax1.transAxes, 
            fontsize=20, verticalalignment='center',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8, edgecolor='black', linewidth=1.5))
    
    # Add total token statistics at the bottom (below x-axis label)
    total_text = f"Total Tokens: Foreground = {int(fg_total)}, Background = {int(bg_total)}"
    fig.text(0.5, 0.02, total_text, ha='center', fontsize=20, fontweight='bold')
    
    plt.tight_layout(rect=[0, 0.04, 1, 1])  # Leave space at bottom for total tokens text
    plt.savefig(output_path, dpi=200, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved: {output_path}")


def aggregate_multiple_samples(
    data_dir: str,
    layer_idx: int,
    sample_ids: List[int] = None,
    n_bins: int = 50
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[int]]:
    """
    Aggregate foreground/background analysis across multiple samples.
    
    Args:
        data_dir: Directory with .npz files
        layer_idx: Layer to analyze
        sample_ids: List of sample IDs to use (None = all)
        n_bins: Number of bins
    
    Returns:
        Aggregated bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_sample_ids
    """
    # Find all .npz files
    npz_files = sorted(glob.glob(os.path.join(data_dir, 'sample_*.npz')))
    
    if not npz_files:
        print(f"Error: No .npz files found in {data_dir}")
        return None, None, None, None, None, []
    
    # Filter by sample_ids if provided
    if sample_ids:
        npz_files = [f for f in npz_files 
                    if any(f"sample_{sid:04d}.npz" in f for sid in sample_ids)]
    
    print(f"Processing {len(npz_files)} samples for layer {layer_idx}...")
    
    # Aggregate data
    all_fg_norms = []
    all_bg_norms = []
    all_fg_attn = []
    all_bg_attn = []
    processed_ids = []
    
    for npz_path in npz_files:
        sample_id = int(os.path.basename(npz_path).split('_')[1].split('.')[0])
        sample_data = load_mts_sample(npz_path)
        
        attn_key = f'attention_layer_{layer_idx}'
        if attn_key not in sample_data:
            print(f"  Warning: Layer {layer_idx} not found in sample {sample_id}")
            continue
        
        attention = sample_data[attn_key]
        is_hr_token = sample_data['is_hr_token']
        n_visual_merged = sample_data.get('n_visual_merged', None)
        
        if n_visual_merged is None:
            lr_grid_shape = tuple(sample_data['lr_grid_shape'])
            t, h, w = lr_grid_shape
            spatial_merge_size = 2
            n_visual_merged = (t * h * w) // (spatial_merge_size ** 2)
        
        n_visual_tokens = int(n_visual_merged)
        
        # Compute feature norms
        vision_features = sample_data.get('vision_features', None)
        if vision_features is not None and len(vision_features) >= n_visual_tokens:
            feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens])
        else:
            feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens)
        
        text_start_idx = n_visual_tokens
        text_to_visual = attention[:, text_start_idx:, :n_visual_tokens]
        attn_weights = text_to_visual.mean(axis=(0, 1))
        
        # Classify tokens: map unmerged HR tokens to merged positions
        merge_unit = 4  # 2x2 spatial merge
        is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool)
        for unmerged_idx in range(len(is_hr_token)):
            if is_hr_token[unmerged_idx]:
                merged_idx = unmerged_idx // merge_unit
                if merged_idx < n_visual_tokens:
                    is_hr_token_merged[merged_idx] = True
        
        foreground_mask = is_hr_token_merged
        background_mask = ~is_hr_token_merged
        
        # Collect data
        all_fg_norms.append(feature_norms[foreground_mask])
        all_bg_norms.append(feature_norms[background_mask])
        all_fg_attn.append(attn_weights[foreground_mask])
        all_bg_attn.append(attn_weights[background_mask])
        processed_ids.append(sample_id)
        
        print(f"  ✓ Sample {sample_id}: n_visual={n_visual_tokens}, FG(HR)={foreground_mask.sum()}, BG(LR)={background_mask.sum()}")
    
    if not all_fg_norms:
        print(f"Error: No valid samples found")
        return None, None, None, None, None, []
    
    # Concatenate all samples
    all_fg_norms = np.concatenate(all_fg_norms)
    all_bg_norms = np.concatenate(all_bg_norms)
    all_fg_attn = np.concatenate(all_fg_attn)
    all_bg_attn = np.concatenate(all_bg_attn)
    
    print(f"\nAggregated: FG={len(all_fg_norms)}, BG={len(all_bg_norms)} tokens")
    
    # Create bins
    norm_max = max(all_fg_norms.max(), all_bg_norms.max())
    bins = np.linspace(0, norm_max, n_bins + 1)
    bin_centers = (bins[:-1] + bins[1:]) / 2
    
    # Compute histograms
    fg_counts, _ = np.histogram(all_fg_norms, bins=bins)
    bg_counts, _ = np.histogram(all_bg_norms, bins=bins)
    
    # Compute mean attention per bin
    fg_attn_mean = np.full(n_bins, np.nan)  # Initialize with NaN
    bg_attn_mean = np.full(n_bins, np.nan)  # Initialize with NaN
    
    for i in range(n_bins):
        # Foreground
        mask_fg = (all_fg_norms >= bins[i]) & (all_fg_norms < bins[i+1])
        if mask_fg.sum() > 0:
            fg_attn_mean[i] = all_fg_attn[mask_fg].mean()
        
        # Background
        mask_bg = (all_bg_norms >= bins[i]) & (all_bg_norms < bins[i+1])
        if mask_bg.sum() > 0:
            bg_attn_mean[i] = all_bg_attn[mask_bg].mean()
    
    return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_ids


def main():
    parser = argparse.ArgumentParser(
        description='Analyze Foreground vs Background Token Distribution'
    )
    parser.add_argument('--data_dir', type=str, 
                       default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877',
                       help='Directory containing .npz files')
    parser.add_argument('--output_dir', type=str, 
                       default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877_analysis',
                       help='Output directory')
    parser.add_argument('--layers', type=int, nargs='+', default=[2, 6, 15, 27],
                       help='Layers to analyze')
    parser.add_argument('--samples', type=int, nargs='+', default=None,
                       help='Specific sample IDs (default: all samples)')
    parser.add_argument('--n_bins', type=int, default=50,
                       help='Number of bins for histogram')
    
    args = parser.parse_args()
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    print("="*70)
    print("Foreground vs Background Token Analysis")
    print("="*70)
    print(f"Data directory: {args.data_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Layers: {args.layers}")
    print(f"Bins: {args.n_bins}")
    print("="*70)
    
    # Process each layer
    for layer_idx in args.layers:
        print(f"\n{'='*70}")
        print(f"Processing Layer {layer_idx}")
        print(f"{'='*70}")
        
        # Aggregate across samples
        bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, sample_ids = \
            aggregate_multiple_samples(
                data_dir=args.data_dir,
                layer_idx=layer_idx,
                sample_ids=args.samples,
                n_bins=args.n_bins
            )
        
        if bin_centers is None:
            print(f"Skipping layer {layer_idx} due to errors")
            continue
        
        # Visualize
        output_path = os.path.join(
            args.output_dir, 
            f"layer_{layer_idx}_foreground_background_analysis.png"
        )
        visualize_foreground_background_analysis(
            bin_centers=bin_centers,
            fg_counts=fg_counts,
            bg_counts=bg_counts,
            fg_attn_mean=fg_attn_mean,
            bg_attn_mean=bg_attn_mean,
            layer_idx=layer_idx,
            output_path=output_path,
            sample_ids=sample_ids
        )
    
    print(f"\n{'='*70}")
    print(f"✓ Analysis complete! Results saved to: {args.output_dir}")
    print(f"{'='*70}")


if __name__ == '__main__':
    main()