File size: 20,251 Bytes
c535f7b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 | #!/usr/bin/env python3
"""
Foreground vs Background Token Analysis
Analyzes the distribution of tokens based on whether they are in GT box (foreground) or not (background).
- Foreground: HR tokens (have parent_lr_idx) - inside GT box - Light Blue
- Background: LR tokens (no parent_lr_idx replacement) - outside GT box - Light Red
Visualization:
- X-axis: Feature Norm (0 to max)
- Left Y-axis: Token Count (blue line for foreground, red line for background)
- Right Y-axis: Average Attention Weight
"""
import os
import sys
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import glob
from typing import List, Dict, Tuple
import argparse
# Add custom models path
sys.path.insert(0, '/mnt/rdata4_6/huixin/LLaMA-Factory-main/custom_models/qwen2_5_vl')
def load_mts_sample(npz_path: str) -> dict:
"""Load a single MTS attention sample from .npz file."""
data = np.load(npz_path, allow_pickle=True)
return {key: data[key] for key in data.files}
def compute_feature_norm_from_attention(attention: np.ndarray, n_visual_tokens: int) -> np.ndarray:
"""
Compute feature norm proxy from attention patterns.
For visual tokens, we can use the sum of attention received from all text tokens
as a proxy for feature importance/norm.
Args:
attention: (num_heads, seq_len, seq_len)
n_visual_tokens: Number of visual tokens
Returns:
feature_norms: (n_visual_tokens,) - norm value for each visual token
"""
num_heads, seq_len, _ = attention.shape
# Extract text->visual attention: text tokens attend to visual tokens
text_start_idx = n_visual_tokens
text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] # (num_heads, n_text, n_visual)
# Sum attention received by each visual token across all text tokens and heads
# This gives a measure of how much each visual token is "attended to"
feature_norms = text_to_visual.sum(axis=(0, 1)) # (n_visual,)
return feature_norms
def compute_true_feature_norm(vision_features: np.ndarray) -> np.ndarray:
"""
Compute true L2 norm of vision features.
Args:
vision_features: (n_visual_tokens, hidden_dim) - vision feature embeddings
Returns:
feature_norms: (n_visual_tokens,) - L2 norm for each visual token
"""
# Compute L2 norm: ||z_i||_2 = sqrt(sum(z_i^2))
feature_norms = np.linalg.norm(vision_features, ord=2, axis=-1) # (n_visual,)
return feature_norms
def analyze_foreground_background_distribution(
sample_data: dict,
layer_idx: int,
n_bins: int = 50
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Analyze foreground vs background token distribution.
Args:
sample_data: Loaded .npz sample data
layer_idx: Layer index to analyze
n_bins: Number of bins for histogram
Returns:
bin_centers: (n_bins,) - center of each bin for x-axis
fg_counts: (n_bins,) - foreground token count per bin
bg_counts: (n_bins,) - background token count per bin
fg_attn_mean: (n_bins,) - mean attention for foreground tokens per bin
bg_attn_mean: (n_bins,) - mean attention for background tokens per bin
"""
attn_key = f'attention_layer_{layer_idx}'
if attn_key not in sample_data:
return None, None, None, None, None
attention = sample_data[attn_key] # (num_heads, seq_len, seq_len)
is_hr_token = sample_data['is_hr_token'] # (n_tokens,) bool mask
n_visual_merged = sample_data.get('n_visual_merged', None)
if n_visual_merged is None:
# Estimate from grid shape
lr_grid_shape = tuple(sample_data['lr_grid_shape'])
t, h, w = lr_grid_shape
spatial_merge_size = 2
n_visual_merged = (t * h * w) // (spatial_merge_size ** 2)
n_visual_tokens = int(n_visual_merged)
# Compute feature norms
# Try to use true L2 norm from vision features if available
vision_features = sample_data.get('vision_features', None)
if vision_features is not None and len(vision_features) >= n_visual_tokens:
# Use true L2 norm
feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens])
norm_type = "L2 Norm"
else:
# Fallback: use attention-based proxy
feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens)
norm_type = "Attention Received (proxy)"
print(f" Warning: No vision_features found, using attention proxy for feature norm")
# Extract attention weights for each visual token (mean over text tokens and heads)
text_start_idx = n_visual_tokens
text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] # (num_heads, n_text, n_visual)
attn_weights = text_to_visual.mean(axis=(0, 1)) # (n_visual,)
# Classify tokens as foreground (HR) or background (LR)
# KEY: is_hr_token corresponds to UNMERGED tokens (before 2x2 spatial merge)
# We need to map unmerged->merged to get foreground mask for the final sequence
merge_unit = 4 # 2x2 spatial merge = 4 tokens per merged token
is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool)
n_mts_unmerged = len(is_hr_token)
# Map: unmerged HR tokens → merged positions
for unmerged_idx in range(n_mts_unmerged):
if is_hr_token[unmerged_idx]:
merged_idx = unmerged_idx // merge_unit
if merged_idx < n_visual_tokens:
is_hr_token_merged[merged_idx] = True
# Foreground = positions with HR tokens, Background = positions with only LR
foreground_mask = is_hr_token_merged
background_mask = ~is_hr_token_merged
print(f" Visual tokens: {n_visual_tokens}, HR: {foreground_mask.sum()}, LR: {background_mask.sum()}")
# Get foreground and background data
fg_norms = feature_norms[foreground_mask]
bg_norms = feature_norms[background_mask]
fg_attn = attn_weights[foreground_mask]
bg_attn = attn_weights[background_mask]
# Create bins based on full range
norm_min = 0
norm_max = feature_norms.max()
bins = np.linspace(norm_min, norm_max, n_bins + 1)
bin_centers = (bins[:-1] + bins[1:]) / 2
# Compute histograms for token counts
fg_counts, _ = np.histogram(fg_norms, bins=bins)
bg_counts, _ = np.histogram(bg_norms, bins=bins)
# Compute mean attention per bin
fg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
bg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
for i in range(n_bins):
# Foreground
mask_fg = (fg_norms >= bins[i]) & (fg_norms < bins[i+1])
if mask_fg.sum() > 0:
fg_attn_mean[i] = fg_attn[mask_fg].mean()
# Background
mask_bg = (bg_norms >= bins[i]) & (bg_norms < bins[i+1])
if mask_bg.sum() > 0:
bg_attn_mean[i] = bg_attn[mask_bg].mean()
return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean
def visualize_foreground_background_analysis(
bin_centers: np.ndarray,
fg_counts: np.ndarray,
bg_counts: np.ndarray,
fg_attn_mean: np.ndarray,
bg_attn_mean: np.ndarray,
layer_idx: int,
output_path: str,
sample_ids: List[int] = None
):
"""
Create dual Y-axis plot for foreground vs background analysis.
Args:
bin_centers: (n_bins,) - x-axis values
fg_counts: (n_bins,) - foreground token counts
bg_counts: (n_bins,) - background token counts
fg_attn_mean: (n_bins,) - mean attention for foreground
bg_attn_mean: (n_bins,) - mean attention for background
layer_idx: Layer index for title
output_path: Where to save the plot
sample_ids: List of sample IDs used (for title)
"""
# Set font (use DejaVu Sans which is commonly available on Linux)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Helvetica', 'sans-serif']
fig, ax1 = plt.subplots(figsize=(14, 8))
# Left Y-axis: Token counts
ax1.set_xlabel('L2 Norm of Visual Token', fontsize=28, fontweight='bold')
ax1.set_ylabel('Token Count', fontsize=28, fontweight='bold', color='black')
# Convert 0 counts to NaN so lines don't connect through 0
fg_counts_masked = fg_counts.astype(float)
fg_counts_masked[fg_counts_masked == 0] = np.nan
bg_counts_masked = bg_counts.astype(float)
bg_counts_masked[bg_counts_masked == 0] = np.nan
# Plot token counts as lines (will skip NaN values in connections)
line_fg_count = ax1.plot(bin_centers, fg_counts_masked, color='dodgerblue', linewidth=2.5,
label='Foreground Tokens', marker='o', markersize=4, alpha=0.8)
line_bg_count = ax1.plot(bin_centers, bg_counts_masked, color='lightcoral', linewidth=2.5,
label='Background Tokens', marker='s', markersize=4, alpha=0.8)
ax1.tick_params(axis='y', labelcolor='black', labelsize=24)
ax1.tick_params(axis='x', labelsize=24)
ax1.grid(True, alpha=0.3, linestyle='--')
# Right Y-axis: Attention weights (as bar chart/histogram)
ax2 = ax1.twinx()
ax2.set_ylabel('Average Attention Weight', fontsize=28, fontweight='bold', color='darkgreen')
# Calculate bar width - use full bin width so bars touch each other
if len(bin_centers) > 1:
bar_width = (bin_centers[1] - bin_centers[0]) / 2 # Half bin width for each bar (no gaps)
else:
bar_width = 0.5
# Plot attention weights as bars (histograms) - bars touch each other
# Foreground: blue bars (left half of bin)
ax2.bar(bin_centers - bar_width/2, fg_attn_mean,
width=bar_width, color='blue', alpha=0.3,
label='Foreground Attention', edgecolor='darkblue', linewidth=0.5, align='center')
# Background: red bars (right half of bin)
ax2.bar(bin_centers + bar_width/2, bg_attn_mean,
width=bar_width, color='red', alpha=0.3,
label='Background Attention', edgecolor='darkred', linewidth=0.5, align='center')
# Create dummy objects for legend (bar patches)
from matplotlib.patches import Patch
line_fg_attn = [Patch(facecolor='blue', alpha=0.3, edgecolor='darkblue', linewidth=0.5, label='Foreground Attention')]
line_bg_attn = [Patch(facecolor='red', alpha=0.3, edgecolor='darkred', linewidth=0.5, label='Background Attention')]
ax2.tick_params(axis='y', labelcolor='darkgreen', labelsize=24)
# Title
sample_str = f"Samples {sample_ids}" if sample_ids else "All Samples"
plt.title(f'Foreground vs Background Token Analysis - Layer {layer_idx}\n{sample_str}',
fontsize=32, fontweight='bold', pad=20)
# Combined legend
lines1 = line_fg_count + line_bg_count
lines2 = line_fg_attn + line_bg_attn
labels1 = [l.get_label() for l in lines1]
labels2 = [l.get_label() for l in lines2]
# Create legend with both axes
ax1.legend(lines1, labels1, loc='upper left', fontsize=22, framealpha=0.9)
ax2.legend(handles=lines2, loc='upper right', fontsize=22, framealpha=0.9)
# Add info box with token statistics
fg_total = fg_counts.sum()
bg_total = bg_counts.sum()
fg_attn_overall = np.nanmean(fg_attn_mean) # Use nanmean to ignore NaN values
bg_attn_overall = np.nanmean(bg_attn_mean) # Use nanmean to ignore NaN values
info_text = f"Avg Attention:\n"
info_text += f" Foreground: {fg_attn_overall:.4f}\n"
info_text += f" Background: {bg_attn_overall:.4f}"
# Place info box at middle left of the plot (to avoid overlap with legend)
ax1.text(0.02, 0.5, info_text, transform=ax1.transAxes,
fontsize=20, verticalalignment='center',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8, edgecolor='black', linewidth=1.5))
# Add total token statistics at the bottom (below x-axis label)
total_text = f"Total Tokens: Foreground = {int(fg_total)}, Background = {int(bg_total)}"
fig.text(0.5, 0.02, total_text, ha='center', fontsize=20, fontweight='bold')
plt.tight_layout(rect=[0, 0.04, 1, 1]) # Leave space at bottom for total tokens text
plt.savefig(output_path, dpi=200, bbox_inches='tight')
plt.close()
print(f"✓ Saved: {output_path}")
def aggregate_multiple_samples(
data_dir: str,
layer_idx: int,
sample_ids: List[int] = None,
n_bins: int = 50
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[int]]:
"""
Aggregate foreground/background analysis across multiple samples.
Args:
data_dir: Directory with .npz files
layer_idx: Layer to analyze
sample_ids: List of sample IDs to use (None = all)
n_bins: Number of bins
Returns:
Aggregated bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_sample_ids
"""
# Find all .npz files
npz_files = sorted(glob.glob(os.path.join(data_dir, 'sample_*.npz')))
if not npz_files:
print(f"Error: No .npz files found in {data_dir}")
return None, None, None, None, None, []
# Filter by sample_ids if provided
if sample_ids:
npz_files = [f for f in npz_files
if any(f"sample_{sid:04d}.npz" in f for sid in sample_ids)]
print(f"Processing {len(npz_files)} samples for layer {layer_idx}...")
# Aggregate data
all_fg_norms = []
all_bg_norms = []
all_fg_attn = []
all_bg_attn = []
processed_ids = []
for npz_path in npz_files:
sample_id = int(os.path.basename(npz_path).split('_')[1].split('.')[0])
sample_data = load_mts_sample(npz_path)
attn_key = f'attention_layer_{layer_idx}'
if attn_key not in sample_data:
print(f" Warning: Layer {layer_idx} not found in sample {sample_id}")
continue
attention = sample_data[attn_key]
is_hr_token = sample_data['is_hr_token']
n_visual_merged = sample_data.get('n_visual_merged', None)
if n_visual_merged is None:
lr_grid_shape = tuple(sample_data['lr_grid_shape'])
t, h, w = lr_grid_shape
spatial_merge_size = 2
n_visual_merged = (t * h * w) // (spatial_merge_size ** 2)
n_visual_tokens = int(n_visual_merged)
# Compute feature norms
vision_features = sample_data.get('vision_features', None)
if vision_features is not None and len(vision_features) >= n_visual_tokens:
feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens])
else:
feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens)
text_start_idx = n_visual_tokens
text_to_visual = attention[:, text_start_idx:, :n_visual_tokens]
attn_weights = text_to_visual.mean(axis=(0, 1))
# Classify tokens: map unmerged HR tokens to merged positions
merge_unit = 4 # 2x2 spatial merge
is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool)
for unmerged_idx in range(len(is_hr_token)):
if is_hr_token[unmerged_idx]:
merged_idx = unmerged_idx // merge_unit
if merged_idx < n_visual_tokens:
is_hr_token_merged[merged_idx] = True
foreground_mask = is_hr_token_merged
background_mask = ~is_hr_token_merged
# Collect data
all_fg_norms.append(feature_norms[foreground_mask])
all_bg_norms.append(feature_norms[background_mask])
all_fg_attn.append(attn_weights[foreground_mask])
all_bg_attn.append(attn_weights[background_mask])
processed_ids.append(sample_id)
print(f" ✓ Sample {sample_id}: n_visual={n_visual_tokens}, FG(HR)={foreground_mask.sum()}, BG(LR)={background_mask.sum()}")
if not all_fg_norms:
print(f"Error: No valid samples found")
return None, None, None, None, None, []
# Concatenate all samples
all_fg_norms = np.concatenate(all_fg_norms)
all_bg_norms = np.concatenate(all_bg_norms)
all_fg_attn = np.concatenate(all_fg_attn)
all_bg_attn = np.concatenate(all_bg_attn)
print(f"\nAggregated: FG={len(all_fg_norms)}, BG={len(all_bg_norms)} tokens")
# Create bins
norm_max = max(all_fg_norms.max(), all_bg_norms.max())
bins = np.linspace(0, norm_max, n_bins + 1)
bin_centers = (bins[:-1] + bins[1:]) / 2
# Compute histograms
fg_counts, _ = np.histogram(all_fg_norms, bins=bins)
bg_counts, _ = np.histogram(all_bg_norms, bins=bins)
# Compute mean attention per bin
fg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
bg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
for i in range(n_bins):
# Foreground
mask_fg = (all_fg_norms >= bins[i]) & (all_fg_norms < bins[i+1])
if mask_fg.sum() > 0:
fg_attn_mean[i] = all_fg_attn[mask_fg].mean()
# Background
mask_bg = (all_bg_norms >= bins[i]) & (all_bg_norms < bins[i+1])
if mask_bg.sum() > 0:
bg_attn_mean[i] = all_bg_attn[mask_bg].mean()
return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_ids
def main():
parser = argparse.ArgumentParser(
description='Analyze Foreground vs Background Token Distribution'
)
parser.add_argument('--data_dir', type=str,
default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877',
help='Directory containing .npz files')
parser.add_argument('--output_dir', type=str,
default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877_analysis',
help='Output directory')
parser.add_argument('--layers', type=int, nargs='+', default=[2, 6, 15, 27],
help='Layers to analyze')
parser.add_argument('--samples', type=int, nargs='+', default=None,
help='Specific sample IDs (default: all samples)')
parser.add_argument('--n_bins', type=int, default=50,
help='Number of bins for histogram')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
print("="*70)
print("Foreground vs Background Token Analysis")
print("="*70)
print(f"Data directory: {args.data_dir}")
print(f"Output directory: {args.output_dir}")
print(f"Layers: {args.layers}")
print(f"Bins: {args.n_bins}")
print("="*70)
# Process each layer
for layer_idx in args.layers:
print(f"\n{'='*70}")
print(f"Processing Layer {layer_idx}")
print(f"{'='*70}")
# Aggregate across samples
bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, sample_ids = \
aggregate_multiple_samples(
data_dir=args.data_dir,
layer_idx=layer_idx,
sample_ids=args.samples,
n_bins=args.n_bins
)
if bin_centers is None:
print(f"Skipping layer {layer_idx} due to errors")
continue
# Visualize
output_path = os.path.join(
args.output_dir,
f"layer_{layer_idx}_foreground_background_analysis.png"
)
visualize_foreground_background_analysis(
bin_centers=bin_centers,
fg_counts=fg_counts,
bg_counts=bg_counts,
fg_attn_mean=fg_attn_mean,
bg_attn_mean=bg_attn_mean,
layer_idx=layer_idx,
output_path=output_path,
sample_ids=sample_ids
)
print(f"\n{'='*70}")
print(f"✓ Analysis complete! Results saved to: {args.output_dir}")
print(f"{'='*70}")
if __name__ == '__main__':
main()
|