#!/usr/bin/env python3 """ Foreground vs Background Token Analysis Analyzes the distribution of tokens based on whether they are in GT box (foreground) or not (background). - Foreground: HR tokens (have parent_lr_idx) - inside GT box - Light Blue - Background: LR tokens (no parent_lr_idx replacement) - outside GT box - Light Red Visualization: - X-axis: Feature Norm (0 to max) - Left Y-axis: Token Count (blue line for foreground, red line for background) - Right Y-axis: Average Attention Weight """ import os import sys import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.patches import Patch import glob from typing import List, Dict, Tuple import argparse # Add custom models path sys.path.insert(0, '/mnt/rdata4_6/huixin/LLaMA-Factory-main/custom_models/qwen2_5_vl') def load_mts_sample(npz_path: str) -> dict: """Load a single MTS attention sample from .npz file.""" data = np.load(npz_path, allow_pickle=True) return {key: data[key] for key in data.files} def compute_feature_norm_from_attention(attention: np.ndarray, n_visual_tokens: int) -> np.ndarray: """ Compute feature norm proxy from attention patterns. For visual tokens, we can use the sum of attention received from all text tokens as a proxy for feature importance/norm. Args: attention: (num_heads, seq_len, seq_len) n_visual_tokens: Number of visual tokens Returns: feature_norms: (n_visual_tokens,) - norm value for each visual token """ num_heads, seq_len, _ = attention.shape # Extract text->visual attention: text tokens attend to visual tokens text_start_idx = n_visual_tokens text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] # (num_heads, n_text, n_visual) # Sum attention received by each visual token across all text tokens and heads # This gives a measure of how much each visual token is "attended to" feature_norms = text_to_visual.sum(axis=(0, 1)) # (n_visual,) return feature_norms def compute_true_feature_norm(vision_features: np.ndarray) -> np.ndarray: """ Compute true L2 norm of vision features. Args: vision_features: (n_visual_tokens, hidden_dim) - vision feature embeddings Returns: feature_norms: (n_visual_tokens,) - L2 norm for each visual token """ # Compute L2 norm: ||z_i||_2 = sqrt(sum(z_i^2)) feature_norms = np.linalg.norm(vision_features, ord=2, axis=-1) # (n_visual,) return feature_norms def analyze_foreground_background_distribution( sample_data: dict, layer_idx: int, n_bins: int = 50 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Analyze foreground vs background token distribution. Args: sample_data: Loaded .npz sample data layer_idx: Layer index to analyze n_bins: Number of bins for histogram Returns: bin_centers: (n_bins,) - center of each bin for x-axis fg_counts: (n_bins,) - foreground token count per bin bg_counts: (n_bins,) - background token count per bin fg_attn_mean: (n_bins,) - mean attention for foreground tokens per bin bg_attn_mean: (n_bins,) - mean attention for background tokens per bin """ attn_key = f'attention_layer_{layer_idx}' if attn_key not in sample_data: return None, None, None, None, None attention = sample_data[attn_key] # (num_heads, seq_len, seq_len) is_hr_token = sample_data['is_hr_token'] # (n_tokens,) bool mask n_visual_merged = sample_data.get('n_visual_merged', None) if n_visual_merged is None: # Estimate from grid shape lr_grid_shape = tuple(sample_data['lr_grid_shape']) t, h, w = lr_grid_shape spatial_merge_size = 2 n_visual_merged = (t * h * w) // (spatial_merge_size ** 2) n_visual_tokens = int(n_visual_merged) # Compute feature norms # Try to use true L2 norm from vision features if available vision_features = sample_data.get('vision_features', None) if vision_features is not None and len(vision_features) >= n_visual_tokens: # Use true L2 norm feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens]) norm_type = "L2 Norm" else: # Fallback: use attention-based proxy feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens) norm_type = "Attention Received (proxy)" print(f" Warning: No vision_features found, using attention proxy for feature norm") # Extract attention weights for each visual token (mean over text tokens and heads) text_start_idx = n_visual_tokens text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] # (num_heads, n_text, n_visual) attn_weights = text_to_visual.mean(axis=(0, 1)) # (n_visual,) # Classify tokens as foreground (HR) or background (LR) # KEY: is_hr_token corresponds to UNMERGED tokens (before 2x2 spatial merge) # We need to map unmerged->merged to get foreground mask for the final sequence merge_unit = 4 # 2x2 spatial merge = 4 tokens per merged token is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool) n_mts_unmerged = len(is_hr_token) # Map: unmerged HR tokens → merged positions for unmerged_idx in range(n_mts_unmerged): if is_hr_token[unmerged_idx]: merged_idx = unmerged_idx // merge_unit if merged_idx < n_visual_tokens: is_hr_token_merged[merged_idx] = True # Foreground = positions with HR tokens, Background = positions with only LR foreground_mask = is_hr_token_merged background_mask = ~is_hr_token_merged print(f" Visual tokens: {n_visual_tokens}, HR: {foreground_mask.sum()}, LR: {background_mask.sum()}") # Get foreground and background data fg_norms = feature_norms[foreground_mask] bg_norms = feature_norms[background_mask] fg_attn = attn_weights[foreground_mask] bg_attn = attn_weights[background_mask] # Create bins based on full range norm_min = 0 norm_max = feature_norms.max() bins = np.linspace(norm_min, norm_max, n_bins + 1) bin_centers = (bins[:-1] + bins[1:]) / 2 # Compute histograms for token counts fg_counts, _ = np.histogram(fg_norms, bins=bins) bg_counts, _ = np.histogram(bg_norms, bins=bins) # Compute mean attention per bin fg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN bg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN for i in range(n_bins): # Foreground mask_fg = (fg_norms >= bins[i]) & (fg_norms < bins[i+1]) if mask_fg.sum() > 0: fg_attn_mean[i] = fg_attn[mask_fg].mean() # Background mask_bg = (bg_norms >= bins[i]) & (bg_norms < bins[i+1]) if mask_bg.sum() > 0: bg_attn_mean[i] = bg_attn[mask_bg].mean() return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean def visualize_foreground_background_analysis( bin_centers: np.ndarray, fg_counts: np.ndarray, bg_counts: np.ndarray, fg_attn_mean: np.ndarray, bg_attn_mean: np.ndarray, layer_idx: int, output_path: str, sample_ids: List[int] = None ): """ Create dual Y-axis plot for foreground vs background analysis. Args: bin_centers: (n_bins,) - x-axis values fg_counts: (n_bins,) - foreground token counts bg_counts: (n_bins,) - background token counts fg_attn_mean: (n_bins,) - mean attention for foreground bg_attn_mean: (n_bins,) - mean attention for background layer_idx: Layer index for title output_path: Where to save the plot sample_ids: List of sample IDs used (for title) """ # Set font (use DejaVu Sans which is commonly available on Linux) plt.rcParams['font.family'] = 'sans-serif' plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Helvetica', 'sans-serif'] fig, ax1 = plt.subplots(figsize=(14, 8)) # Left Y-axis: Token counts ax1.set_xlabel('L2 Norm of Visual Token', fontsize=28, fontweight='bold') ax1.set_ylabel('Token Count', fontsize=28, fontweight='bold', color='black') # Convert 0 counts to NaN so lines don't connect through 0 fg_counts_masked = fg_counts.astype(float) fg_counts_masked[fg_counts_masked == 0] = np.nan bg_counts_masked = bg_counts.astype(float) bg_counts_masked[bg_counts_masked == 0] = np.nan # Plot token counts as lines (will skip NaN values in connections) line_fg_count = ax1.plot(bin_centers, fg_counts_masked, color='dodgerblue', linewidth=2.5, label='Foreground Tokens', marker='o', markersize=4, alpha=0.8) line_bg_count = ax1.plot(bin_centers, bg_counts_masked, color='lightcoral', linewidth=2.5, label='Background Tokens', marker='s', markersize=4, alpha=0.8) ax1.tick_params(axis='y', labelcolor='black', labelsize=24) ax1.tick_params(axis='x', labelsize=24) ax1.grid(True, alpha=0.3, linestyle='--') # Right Y-axis: Attention weights (as bar chart/histogram) ax2 = ax1.twinx() ax2.set_ylabel('Average Attention Weight', fontsize=28, fontweight='bold', color='darkgreen') # Calculate bar width - use full bin width so bars touch each other if len(bin_centers) > 1: bar_width = (bin_centers[1] - bin_centers[0]) / 2 # Half bin width for each bar (no gaps) else: bar_width = 0.5 # Plot attention weights as bars (histograms) - bars touch each other # Foreground: blue bars (left half of bin) ax2.bar(bin_centers - bar_width/2, fg_attn_mean, width=bar_width, color='blue', alpha=0.3, label='Foreground Attention', edgecolor='darkblue', linewidth=0.5, align='center') # Background: red bars (right half of bin) ax2.bar(bin_centers + bar_width/2, bg_attn_mean, width=bar_width, color='red', alpha=0.3, label='Background Attention', edgecolor='darkred', linewidth=0.5, align='center') # Create dummy objects for legend (bar patches) from matplotlib.patches import Patch line_fg_attn = [Patch(facecolor='blue', alpha=0.3, edgecolor='darkblue', linewidth=0.5, label='Foreground Attention')] line_bg_attn = [Patch(facecolor='red', alpha=0.3, edgecolor='darkred', linewidth=0.5, label='Background Attention')] ax2.tick_params(axis='y', labelcolor='darkgreen', labelsize=24) # Title sample_str = f"Samples {sample_ids}" if sample_ids else "All Samples" plt.title(f'Foreground vs Background Token Analysis - Layer {layer_idx}\n{sample_str}', fontsize=32, fontweight='bold', pad=20) # Combined legend lines1 = line_fg_count + line_bg_count lines2 = line_fg_attn + line_bg_attn labels1 = [l.get_label() for l in lines1] labels2 = [l.get_label() for l in lines2] # Create legend with both axes ax1.legend(lines1, labels1, loc='upper left', fontsize=22, framealpha=0.9) ax2.legend(handles=lines2, loc='upper right', fontsize=22, framealpha=0.9) # Add info box with token statistics fg_total = fg_counts.sum() bg_total = bg_counts.sum() fg_attn_overall = np.nanmean(fg_attn_mean) # Use nanmean to ignore NaN values bg_attn_overall = np.nanmean(bg_attn_mean) # Use nanmean to ignore NaN values info_text = f"Avg Attention:\n" info_text += f" Foreground: {fg_attn_overall:.4f}\n" info_text += f" Background: {bg_attn_overall:.4f}" # Place info box at middle left of the plot (to avoid overlap with legend) ax1.text(0.02, 0.5, info_text, transform=ax1.transAxes, fontsize=20, verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8, edgecolor='black', linewidth=1.5)) # Add total token statistics at the bottom (below x-axis label) total_text = f"Total Tokens: Foreground = {int(fg_total)}, Background = {int(bg_total)}" fig.text(0.5, 0.02, total_text, ha='center', fontsize=20, fontweight='bold') plt.tight_layout(rect=[0, 0.04, 1, 1]) # Leave space at bottom for total tokens text plt.savefig(output_path, dpi=200, bbox_inches='tight') plt.close() print(f"✓ Saved: {output_path}") def aggregate_multiple_samples( data_dir: str, layer_idx: int, sample_ids: List[int] = None, n_bins: int = 50 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[int]]: """ Aggregate foreground/background analysis across multiple samples. Args: data_dir: Directory with .npz files layer_idx: Layer to analyze sample_ids: List of sample IDs to use (None = all) n_bins: Number of bins Returns: Aggregated bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_sample_ids """ # Find all .npz files npz_files = sorted(glob.glob(os.path.join(data_dir, 'sample_*.npz'))) if not npz_files: print(f"Error: No .npz files found in {data_dir}") return None, None, None, None, None, [] # Filter by sample_ids if provided if sample_ids: npz_files = [f for f in npz_files if any(f"sample_{sid:04d}.npz" in f for sid in sample_ids)] print(f"Processing {len(npz_files)} samples for layer {layer_idx}...") # Aggregate data all_fg_norms = [] all_bg_norms = [] all_fg_attn = [] all_bg_attn = [] processed_ids = [] for npz_path in npz_files: sample_id = int(os.path.basename(npz_path).split('_')[1].split('.')[0]) sample_data = load_mts_sample(npz_path) attn_key = f'attention_layer_{layer_idx}' if attn_key not in sample_data: print(f" Warning: Layer {layer_idx} not found in sample {sample_id}") continue attention = sample_data[attn_key] is_hr_token = sample_data['is_hr_token'] n_visual_merged = sample_data.get('n_visual_merged', None) if n_visual_merged is None: lr_grid_shape = tuple(sample_data['lr_grid_shape']) t, h, w = lr_grid_shape spatial_merge_size = 2 n_visual_merged = (t * h * w) // (spatial_merge_size ** 2) n_visual_tokens = int(n_visual_merged) # Compute feature norms vision_features = sample_data.get('vision_features', None) if vision_features is not None and len(vision_features) >= n_visual_tokens: feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens]) else: feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens) text_start_idx = n_visual_tokens text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] attn_weights = text_to_visual.mean(axis=(0, 1)) # Classify tokens: map unmerged HR tokens to merged positions merge_unit = 4 # 2x2 spatial merge is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool) for unmerged_idx in range(len(is_hr_token)): if is_hr_token[unmerged_idx]: merged_idx = unmerged_idx // merge_unit if merged_idx < n_visual_tokens: is_hr_token_merged[merged_idx] = True foreground_mask = is_hr_token_merged background_mask = ~is_hr_token_merged # Collect data all_fg_norms.append(feature_norms[foreground_mask]) all_bg_norms.append(feature_norms[background_mask]) all_fg_attn.append(attn_weights[foreground_mask]) all_bg_attn.append(attn_weights[background_mask]) processed_ids.append(sample_id) print(f" ✓ Sample {sample_id}: n_visual={n_visual_tokens}, FG(HR)={foreground_mask.sum()}, BG(LR)={background_mask.sum()}") if not all_fg_norms: print(f"Error: No valid samples found") return None, None, None, None, None, [] # Concatenate all samples all_fg_norms = np.concatenate(all_fg_norms) all_bg_norms = np.concatenate(all_bg_norms) all_fg_attn = np.concatenate(all_fg_attn) all_bg_attn = np.concatenate(all_bg_attn) print(f"\nAggregated: FG={len(all_fg_norms)}, BG={len(all_bg_norms)} tokens") # Create bins norm_max = max(all_fg_norms.max(), all_bg_norms.max()) bins = np.linspace(0, norm_max, n_bins + 1) bin_centers = (bins[:-1] + bins[1:]) / 2 # Compute histograms fg_counts, _ = np.histogram(all_fg_norms, bins=bins) bg_counts, _ = np.histogram(all_bg_norms, bins=bins) # Compute mean attention per bin fg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN bg_attn_mean = np.full(n_bins, np.nan) # Initialize with NaN for i in range(n_bins): # Foreground mask_fg = (all_fg_norms >= bins[i]) & (all_fg_norms < bins[i+1]) if mask_fg.sum() > 0: fg_attn_mean[i] = all_fg_attn[mask_fg].mean() # Background mask_bg = (all_bg_norms >= bins[i]) & (all_bg_norms < bins[i+1]) if mask_bg.sum() > 0: bg_attn_mean[i] = all_bg_attn[mask_bg].mean() return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_ids def main(): parser = argparse.ArgumentParser( description='Analyze Foreground vs Background Token Distribution' ) parser.add_argument('--data_dir', type=str, default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877', help='Directory containing .npz files') parser.add_argument('--output_dir', type=str, default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877_analysis', help='Output directory') parser.add_argument('--layers', type=int, nargs='+', default=[2, 6, 15, 27], help='Layers to analyze') parser.add_argument('--samples', type=int, nargs='+', default=None, help='Specific sample IDs (default: all samples)') parser.add_argument('--n_bins', type=int, default=50, help='Number of bins for histogram') args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) print("="*70) print("Foreground vs Background Token Analysis") print("="*70) print(f"Data directory: {args.data_dir}") print(f"Output directory: {args.output_dir}") print(f"Layers: {args.layers}") print(f"Bins: {args.n_bins}") print("="*70) # Process each layer for layer_idx in args.layers: print(f"\n{'='*70}") print(f"Processing Layer {layer_idx}") print(f"{'='*70}") # Aggregate across samples bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, sample_ids = \ aggregate_multiple_samples( data_dir=args.data_dir, layer_idx=layer_idx, sample_ids=args.samples, n_bins=args.n_bins ) if bin_centers is None: print(f"Skipping layer {layer_idx} due to errors") continue # Visualize output_path = os.path.join( args.output_dir, f"layer_{layer_idx}_foreground_background_analysis.png" ) visualize_foreground_background_analysis( bin_centers=bin_centers, fg_counts=fg_counts, bg_counts=bg_counts, fg_attn_mean=fg_attn_mean, bg_attn_mean=bg_attn_mean, layer_idx=layer_idx, output_path=output_path, sample_ids=sample_ids ) print(f"\n{'='*70}") print(f"✓ Analysis complete! Results saved to: {args.output_dir}") print(f"{'='*70}") if __name__ == '__main__': main()