LLaMA-Factory-main / scripts /visualize_foreground_background_analysis.py
hxxxxxx's picture
Upload scripts/visualize_foreground_background_analysis.py with huggingface_hub
c535f7b verified
#!/usr/bin/env python3
"""
Foreground vs Background Token Analysis
Analyzes the distribution of tokens based on whether they are in GT box (foreground) or not (background).
- Foreground: HR tokens (have parent_lr_idx) - inside GT box - Light Blue
- Background: LR tokens (no parent_lr_idx replacement) - outside GT box - Light Red
Visualization:
- X-axis: Feature Norm (0 to max)
- Left Y-axis: Token Count (blue line for foreground, red line for background)
- Right Y-axis: Average Attention Weight
"""
import os
import sys
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import glob
from typing import List, Dict, Tuple
import argparse
# Add custom models path
sys.path.insert(0, '/mnt/rdata4_6/huixin/LLaMA-Factory-main/custom_models/qwen2_5_vl')
def load_mts_sample(npz_path: str) -> dict:
"""Load a single MTS attention sample from .npz file."""
data = np.load(npz_path, allow_pickle=True)
return {key: data[key] for key in data.files}
def compute_feature_norm_from_attention(attention: np.ndarray, n_visual_tokens: int) -> np.ndarray:
"""
Compute feature norm proxy from attention patterns.
For visual tokens, we can use the sum of attention received from all text tokens
as a proxy for feature importance/norm.
Args:
attention: (num_heads, seq_len, seq_len)
n_visual_tokens: Number of visual tokens
Returns:
feature_norms: (n_visual_tokens,) - norm value for each visual token
"""
num_heads, seq_len, _ = attention.shape
# Extract text->visual attention: text tokens attend to visual tokens
text_start_idx = n_visual_tokens
text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] # (num_heads, n_text, n_visual)
# Sum attention received by each visual token across all text tokens and heads
# This gives a measure of how much each visual token is "attended to"
feature_norms = text_to_visual.sum(axis=(0, 1)) # (n_visual,)
return feature_norms
def compute_true_feature_norm(vision_features: np.ndarray) -> np.ndarray:
"""
Compute true L2 norm of vision features.
Args:
vision_features: (n_visual_tokens, hidden_dim) - vision feature embeddings
Returns:
feature_norms: (n_visual_tokens,) - L2 norm for each visual token
"""
# Compute L2 norm: ||z_i||_2 = sqrt(sum(z_i^2))
feature_norms = np.linalg.norm(vision_features, ord=2, axis=-1) # (n_visual,)
return feature_norms
def analyze_foreground_background_distribution(
sample_data: dict,
layer_idx: int,
n_bins: int = 50
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Analyze foreground vs background token distribution.
Args:
sample_data: Loaded .npz sample data
layer_idx: Layer index to analyze
n_bins: Number of bins for histogram
Returns:
bin_centers: (n_bins,) - center of each bin for x-axis
fg_counts: (n_bins,) - foreground token count per bin
bg_counts: (n_bins,) - background token count per bin
fg_attn_mean: (n_bins,) - mean attention for foreground tokens per bin
bg_attn_mean: (n_bins,) - mean attention for background tokens per bin
"""
attn_key = f'attention_layer_{layer_idx}'
if attn_key not in sample_data:
return None, None, None, None, None
attention = sample_data[attn_key] # (num_heads, seq_len, seq_len)
is_hr_token = sample_data['is_hr_token'] # (n_tokens,) bool mask
n_visual_merged = sample_data.get('n_visual_merged', None)
if n_visual_merged is None:
# Estimate from grid shape
lr_grid_shape = tuple(sample_data['lr_grid_shape'])
t, h, w = lr_grid_shape
spatial_merge_size = 2
n_visual_merged = (t * h * w) // (spatial_merge_size ** 2)
n_visual_tokens = int(n_visual_merged)
# Compute feature norms
# Try to use true L2 norm from vision features if available
vision_features = sample_data.get('vision_features', None)
if vision_features is not None and len(vision_features) >= n_visual_tokens:
# Use true L2 norm
feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens])
norm_type = "L2 Norm"
else:
# Fallback: use attention-based proxy
feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens)
norm_type = "Attention Received (proxy)"
print(f" Warning: No vision_features found, using attention proxy for feature norm")
# Extract attention weights for each visual token (mean over text tokens and heads)
text_start_idx = n_visual_tokens
text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] # (num_heads, n_text, n_visual)
attn_weights = text_to_visual.mean(axis=(0, 1)) # (n_visual,)
# Classify tokens as foreground (HR) or background (LR)
# KEY: is_hr_token corresponds to UNMERGED tokens (before 2x2 spatial merge)
# We need to map unmerged->merged to get foreground mask for the final sequence
merge_unit = 4 # 2x2 spatial merge = 4 tokens per merged token
is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool)
n_mts_unmerged = len(is_hr_token)
# Map: unmerged HR tokens → merged positions
for unmerged_idx in range(n_mts_unmerged):
if is_hr_token[unmerged_idx]:
merged_idx = unmerged_idx // merge_unit
if merged_idx < n_visual_tokens:
is_hr_token_merged[merged_idx] = True
# Foreground = positions with HR tokens, Background = positions with only LR
foreground_mask = is_hr_token_merged
background_mask = ~is_hr_token_merged
print(f" Visual tokens: {n_visual_tokens}, HR: {foreground_mask.sum()}, LR: {background_mask.sum()}")
# Get foreground and background data
fg_norms = feature_norms[foreground_mask]
bg_norms = feature_norms[background_mask]
fg_attn = attn_weights[foreground_mask]
bg_attn = attn_weights[background_mask]
# Create bins based on full range
norm_min = 0
norm_max = feature_norms.max()
bins = np.linspace(norm_min, norm_max, n_bins + 1)
bin_centers = (bins[:-1] + bins[1:]) / 2
# Compute histograms for token counts
fg_counts, _ = np.histogram(fg_norms, bins=bins)
bg_counts, _ = np.histogram(bg_norms, bins=bins)
# Compute mean attention per bin
fg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
bg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
for i in range(n_bins):
# Foreground
mask_fg = (fg_norms >= bins[i]) & (fg_norms < bins[i+1])
if mask_fg.sum() > 0:
fg_attn_mean[i] = fg_attn[mask_fg].mean()
# Background
mask_bg = (bg_norms >= bins[i]) & (bg_norms < bins[i+1])
if mask_bg.sum() > 0:
bg_attn_mean[i] = bg_attn[mask_bg].mean()
return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean
def visualize_foreground_background_analysis(
bin_centers: np.ndarray,
fg_counts: np.ndarray,
bg_counts: np.ndarray,
fg_attn_mean: np.ndarray,
bg_attn_mean: np.ndarray,
layer_idx: int,
output_path: str,
sample_ids: List[int] = None
):
"""
Create dual Y-axis plot for foreground vs background analysis.
Args:
bin_centers: (n_bins,) - x-axis values
fg_counts: (n_bins,) - foreground token counts
bg_counts: (n_bins,) - background token counts
fg_attn_mean: (n_bins,) - mean attention for foreground
bg_attn_mean: (n_bins,) - mean attention for background
layer_idx: Layer index for title
output_path: Where to save the plot
sample_ids: List of sample IDs used (for title)
"""
# Set font (use DejaVu Sans which is commonly available on Linux)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Helvetica', 'sans-serif']
fig, ax1 = plt.subplots(figsize=(14, 8))
# Left Y-axis: Token counts
ax1.set_xlabel('L2 Norm of Visual Token', fontsize=28, fontweight='bold')
ax1.set_ylabel('Token Count', fontsize=28, fontweight='bold', color='black')
# Convert 0 counts to NaN so lines don't connect through 0
fg_counts_masked = fg_counts.astype(float)
fg_counts_masked[fg_counts_masked == 0] = np.nan
bg_counts_masked = bg_counts.astype(float)
bg_counts_masked[bg_counts_masked == 0] = np.nan
# Plot token counts as lines (will skip NaN values in connections)
line_fg_count = ax1.plot(bin_centers, fg_counts_masked, color='dodgerblue', linewidth=2.5,
label='Foreground Tokens', marker='o', markersize=4, alpha=0.8)
line_bg_count = ax1.plot(bin_centers, bg_counts_masked, color='lightcoral', linewidth=2.5,
label='Background Tokens', marker='s', markersize=4, alpha=0.8)
ax1.tick_params(axis='y', labelcolor='black', labelsize=24)
ax1.tick_params(axis='x', labelsize=24)
ax1.grid(True, alpha=0.3, linestyle='--')
# Right Y-axis: Attention weights (as bar chart/histogram)
ax2 = ax1.twinx()
ax2.set_ylabel('Average Attention Weight', fontsize=28, fontweight='bold', color='darkgreen')
# Calculate bar width - use full bin width so bars touch each other
if len(bin_centers) > 1:
bar_width = (bin_centers[1] - bin_centers[0]) / 2 # Half bin width for each bar (no gaps)
else:
bar_width = 0.5
# Plot attention weights as bars (histograms) - bars touch each other
# Foreground: blue bars (left half of bin)
ax2.bar(bin_centers - bar_width/2, fg_attn_mean,
width=bar_width, color='blue', alpha=0.3,
label='Foreground Attention', edgecolor='darkblue', linewidth=0.5, align='center')
# Background: red bars (right half of bin)
ax2.bar(bin_centers + bar_width/2, bg_attn_mean,
width=bar_width, color='red', alpha=0.3,
label='Background Attention', edgecolor='darkred', linewidth=0.5, align='center')
# Create dummy objects for legend (bar patches)
from matplotlib.patches import Patch
line_fg_attn = [Patch(facecolor='blue', alpha=0.3, edgecolor='darkblue', linewidth=0.5, label='Foreground Attention')]
line_bg_attn = [Patch(facecolor='red', alpha=0.3, edgecolor='darkred', linewidth=0.5, label='Background Attention')]
ax2.tick_params(axis='y', labelcolor='darkgreen', labelsize=24)
# Title
sample_str = f"Samples {sample_ids}" if sample_ids else "All Samples"
plt.title(f'Foreground vs Background Token Analysis - Layer {layer_idx}\n{sample_str}',
fontsize=32, fontweight='bold', pad=20)
# Combined legend
lines1 = line_fg_count + line_bg_count
lines2 = line_fg_attn + line_bg_attn
labels1 = [l.get_label() for l in lines1]
labels2 = [l.get_label() for l in lines2]
# Create legend with both axes
ax1.legend(lines1, labels1, loc='upper left', fontsize=22, framealpha=0.9)
ax2.legend(handles=lines2, loc='upper right', fontsize=22, framealpha=0.9)
# Add info box with token statistics
fg_total = fg_counts.sum()
bg_total = bg_counts.sum()
fg_attn_overall = np.nanmean(fg_attn_mean) # Use nanmean to ignore NaN values
bg_attn_overall = np.nanmean(bg_attn_mean) # Use nanmean to ignore NaN values
info_text = f"Avg Attention:\n"
info_text += f" Foreground: {fg_attn_overall:.4f}\n"
info_text += f" Background: {bg_attn_overall:.4f}"
# Place info box at middle left of the plot (to avoid overlap with legend)
ax1.text(0.02, 0.5, info_text, transform=ax1.transAxes,
fontsize=20, verticalalignment='center',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8, edgecolor='black', linewidth=1.5))
# Add total token statistics at the bottom (below x-axis label)
total_text = f"Total Tokens: Foreground = {int(fg_total)}, Background = {int(bg_total)}"
fig.text(0.5, 0.02, total_text, ha='center', fontsize=20, fontweight='bold')
plt.tight_layout(rect=[0, 0.04, 1, 1]) # Leave space at bottom for total tokens text
plt.savefig(output_path, dpi=200, bbox_inches='tight')
plt.close()
print(f"✓ Saved: {output_path}")
def aggregate_multiple_samples(
data_dir: str,
layer_idx: int,
sample_ids: List[int] = None,
n_bins: int = 50
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[int]]:
"""
Aggregate foreground/background analysis across multiple samples.
Args:
data_dir: Directory with .npz files
layer_idx: Layer to analyze
sample_ids: List of sample IDs to use (None = all)
n_bins: Number of bins
Returns:
Aggregated bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_sample_ids
"""
# Find all .npz files
npz_files = sorted(glob.glob(os.path.join(data_dir, 'sample_*.npz')))
if not npz_files:
print(f"Error: No .npz files found in {data_dir}")
return None, None, None, None, None, []
# Filter by sample_ids if provided
if sample_ids:
npz_files = [f for f in npz_files
if any(f"sample_{sid:04d}.npz" in f for sid in sample_ids)]
print(f"Processing {len(npz_files)} samples for layer {layer_idx}...")
# Aggregate data
all_fg_norms = []
all_bg_norms = []
all_fg_attn = []
all_bg_attn = []
processed_ids = []
for npz_path in npz_files:
sample_id = int(os.path.basename(npz_path).split('_')[1].split('.')[0])
sample_data = load_mts_sample(npz_path)
attn_key = f'attention_layer_{layer_idx}'
if attn_key not in sample_data:
print(f" Warning: Layer {layer_idx} not found in sample {sample_id}")
continue
attention = sample_data[attn_key]
is_hr_token = sample_data['is_hr_token']
n_visual_merged = sample_data.get('n_visual_merged', None)
if n_visual_merged is None:
lr_grid_shape = tuple(sample_data['lr_grid_shape'])
t, h, w = lr_grid_shape
spatial_merge_size = 2
n_visual_merged = (t * h * w) // (spatial_merge_size ** 2)
n_visual_tokens = int(n_visual_merged)
# Compute feature norms
vision_features = sample_data.get('vision_features', None)
if vision_features is not None and len(vision_features) >= n_visual_tokens:
feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens])
else:
feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens)
text_start_idx = n_visual_tokens
text_to_visual = attention[:, text_start_idx:, :n_visual_tokens]
attn_weights = text_to_visual.mean(axis=(0, 1))
# Classify tokens: map unmerged HR tokens to merged positions
merge_unit = 4 # 2x2 spatial merge
is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool)
for unmerged_idx in range(len(is_hr_token)):
if is_hr_token[unmerged_idx]:
merged_idx = unmerged_idx // merge_unit
if merged_idx < n_visual_tokens:
is_hr_token_merged[merged_idx] = True
foreground_mask = is_hr_token_merged
background_mask = ~is_hr_token_merged
# Collect data
all_fg_norms.append(feature_norms[foreground_mask])
all_bg_norms.append(feature_norms[background_mask])
all_fg_attn.append(attn_weights[foreground_mask])
all_bg_attn.append(attn_weights[background_mask])
processed_ids.append(sample_id)
print(f" ✓ Sample {sample_id}: n_visual={n_visual_tokens}, FG(HR)={foreground_mask.sum()}, BG(LR)={background_mask.sum()}")
if not all_fg_norms:
print(f"Error: No valid samples found")
return None, None, None, None, None, []
# Concatenate all samples
all_fg_norms = np.concatenate(all_fg_norms)
all_bg_norms = np.concatenate(all_bg_norms)
all_fg_attn = np.concatenate(all_fg_attn)
all_bg_attn = np.concatenate(all_bg_attn)
print(f"\nAggregated: FG={len(all_fg_norms)}, BG={len(all_bg_norms)} tokens")
# Create bins
norm_max = max(all_fg_norms.max(), all_bg_norms.max())
bins = np.linspace(0, norm_max, n_bins + 1)
bin_centers = (bins[:-1] + bins[1:]) / 2
# Compute histograms
fg_counts, _ = np.histogram(all_fg_norms, bins=bins)
bg_counts, _ = np.histogram(all_bg_norms, bins=bins)
# Compute mean attention per bin
fg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
bg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN
for i in range(n_bins):
# Foreground
mask_fg = (all_fg_norms >= bins[i]) & (all_fg_norms < bins[i+1])
if mask_fg.sum() > 0:
fg_attn_mean[i] = all_fg_attn[mask_fg].mean()
# Background
mask_bg = (all_bg_norms >= bins[i]) & (all_bg_norms < bins[i+1])
if mask_bg.sum() > 0:
bg_attn_mean[i] = all_bg_attn[mask_bg].mean()
return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_ids
def main():
parser = argparse.ArgumentParser(
description='Analyze Foreground vs Background Token Distribution'
)
parser.add_argument('--data_dir', type=str,
default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877',
help='Directory containing .npz files')
parser.add_argument('--output_dir', type=str,
default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877_analysis',
help='Output directory')
parser.add_argument('--layers', type=int, nargs='+', default=[2, 6, 15, 27],
help='Layers to analyze')
parser.add_argument('--samples', type=int, nargs='+', default=None,
help='Specific sample IDs (default: all samples)')
parser.add_argument('--n_bins', type=int, default=50,
help='Number of bins for histogram')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
print("="*70)
print("Foreground vs Background Token Analysis")
print("="*70)
print(f"Data directory: {args.data_dir}")
print(f"Output directory: {args.output_dir}")
print(f"Layers: {args.layers}")
print(f"Bins: {args.n_bins}")
print("="*70)
# Process each layer
for layer_idx in args.layers:
print(f"\n{'='*70}")
print(f"Processing Layer {layer_idx}")
print(f"{'='*70}")
# Aggregate across samples
bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, sample_ids = \
aggregate_multiple_samples(
data_dir=args.data_dir,
layer_idx=layer_idx,
sample_ids=args.samples,
n_bins=args.n_bins
)
if bin_centers is None:
print(f"Skipping layer {layer_idx} due to errors")
continue
# Visualize
output_path = os.path.join(
args.output_dir,
f"layer_{layer_idx}_foreground_background_analysis.png"
)
visualize_foreground_background_analysis(
bin_centers=bin_centers,
fg_counts=fg_counts,
bg_counts=bg_counts,
fg_attn_mean=fg_attn_mean,
bg_attn_mean=bg_attn_mean,
layer_idx=layer_idx,
output_path=output_path,
sample_ids=sample_ids
)
print(f"\n{'='*70}")
print(f"✓ Analysis complete! Results saved to: {args.output_dir}")
print(f"{'='*70}")
if __name__ == '__main__':
main()