| | |
| | """ |
| | Foreground vs Background Token Analysis |
| | |
| | Analyzes the distribution of tokens based on whether they are in GT box (foreground) or not (background). |
| | - Foreground: HR tokens (have parent_lr_idx) - inside GT box - Light Blue |
| | - Background: LR tokens (no parent_lr_idx replacement) - outside GT box - Light Red |
| | |
| | Visualization: |
| | - X-axis: Feature Norm (0 to max) |
| | - Left Y-axis: Token Count (blue line for foreground, red line for background) |
| | - Right Y-axis: Average Attention Weight |
| | """ |
| |
|
| | import os |
| | import sys |
| | import numpy as np |
| | import matplotlib |
| | matplotlib.use('Agg') |
| | import matplotlib.pyplot as plt |
| | from matplotlib.patches import Patch |
| | import glob |
| | from typing import List, Dict, Tuple |
| | import argparse |
| |
|
| | |
| | sys.path.insert(0, '/mnt/rdata4_6/huixin/LLaMA-Factory-main/custom_models/qwen2_5_vl') |
| |
|
| |
|
| | def load_mts_sample(npz_path: str) -> dict: |
| | """Load a single MTS attention sample from .npz file.""" |
| | data = np.load(npz_path, allow_pickle=True) |
| | return {key: data[key] for key in data.files} |
| |
|
| |
|
| | def compute_feature_norm_from_attention(attention: np.ndarray, n_visual_tokens: int) -> np.ndarray: |
| | """ |
| | Compute feature norm proxy from attention patterns. |
| | |
| | For visual tokens, we can use the sum of attention received from all text tokens |
| | as a proxy for feature importance/norm. |
| | |
| | Args: |
| | attention: (num_heads, seq_len, seq_len) |
| | n_visual_tokens: Number of visual tokens |
| | |
| | Returns: |
| | feature_norms: (n_visual_tokens,) - norm value for each visual token |
| | """ |
| | num_heads, seq_len, _ = attention.shape |
| | |
| | |
| | text_start_idx = n_visual_tokens |
| | text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] |
| | |
| | |
| | |
| | feature_norms = text_to_visual.sum(axis=(0, 1)) |
| | |
| | return feature_norms |
| |
|
| |
|
| | def compute_true_feature_norm(vision_features: np.ndarray) -> np.ndarray: |
| | """ |
| | Compute true L2 norm of vision features. |
| | |
| | Args: |
| | vision_features: (n_visual_tokens, hidden_dim) - vision feature embeddings |
| | |
| | Returns: |
| | feature_norms: (n_visual_tokens,) - L2 norm for each visual token |
| | """ |
| | |
| | feature_norms = np.linalg.norm(vision_features, ord=2, axis=-1) |
| | return feature_norms |
| |
|
| |
|
| | def analyze_foreground_background_distribution( |
| | sample_data: dict, |
| | layer_idx: int, |
| | n_bins: int = 50 |
| | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| | """ |
| | Analyze foreground vs background token distribution. |
| | |
| | Args: |
| | sample_data: Loaded .npz sample data |
| | layer_idx: Layer index to analyze |
| | n_bins: Number of bins for histogram |
| | |
| | Returns: |
| | bin_centers: (n_bins,) - center of each bin for x-axis |
| | fg_counts: (n_bins,) - foreground token count per bin |
| | bg_counts: (n_bins,) - background token count per bin |
| | fg_attn_mean: (n_bins,) - mean attention for foreground tokens per bin |
| | bg_attn_mean: (n_bins,) - mean attention for background tokens per bin |
| | """ |
| | attn_key = f'attention_layer_{layer_idx}' |
| | if attn_key not in sample_data: |
| | return None, None, None, None, None |
| | |
| | attention = sample_data[attn_key] |
| | is_hr_token = sample_data['is_hr_token'] |
| | n_visual_merged = sample_data.get('n_visual_merged', None) |
| | |
| | if n_visual_merged is None: |
| | |
| | lr_grid_shape = tuple(sample_data['lr_grid_shape']) |
| | t, h, w = lr_grid_shape |
| | spatial_merge_size = 2 |
| | n_visual_merged = (t * h * w) // (spatial_merge_size ** 2) |
| | |
| | n_visual_tokens = int(n_visual_merged) |
| | |
| | |
| | |
| | vision_features = sample_data.get('vision_features', None) |
| | if vision_features is not None and len(vision_features) >= n_visual_tokens: |
| | |
| | feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens]) |
| | norm_type = "L2 Norm" |
| | else: |
| | |
| | feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens) |
| | norm_type = "Attention Received (proxy)" |
| | print(f" Warning: No vision_features found, using attention proxy for feature norm") |
| | |
| | |
| | text_start_idx = n_visual_tokens |
| | text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] |
| | attn_weights = text_to_visual.mean(axis=(0, 1)) |
| | |
| | |
| | |
| | |
| | |
| | merge_unit = 4 |
| | is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool) |
| | n_mts_unmerged = len(is_hr_token) |
| | |
| | |
| | for unmerged_idx in range(n_mts_unmerged): |
| | if is_hr_token[unmerged_idx]: |
| | merged_idx = unmerged_idx // merge_unit |
| | if merged_idx < n_visual_tokens: |
| | is_hr_token_merged[merged_idx] = True |
| | |
| | |
| | foreground_mask = is_hr_token_merged |
| | background_mask = ~is_hr_token_merged |
| | |
| | print(f" Visual tokens: {n_visual_tokens}, HR: {foreground_mask.sum()}, LR: {background_mask.sum()}") |
| | |
| | |
| | fg_norms = feature_norms[foreground_mask] |
| | bg_norms = feature_norms[background_mask] |
| | fg_attn = attn_weights[foreground_mask] |
| | bg_attn = attn_weights[background_mask] |
| | |
| | |
| | norm_min = 0 |
| | norm_max = feature_norms.max() |
| | bins = np.linspace(norm_min, norm_max, n_bins + 1) |
| | bin_centers = (bins[:-1] + bins[1:]) / 2 |
| | |
| | |
| | fg_counts, _ = np.histogram(fg_norms, bins=bins) |
| | bg_counts, _ = np.histogram(bg_norms, bins=bins) |
| | |
| | |
| | fg_attn_mean = np.full(n_bins, np.nan) |
| | bg_attn_mean = np.full(n_bins, np.nan) |
| | |
| | for i in range(n_bins): |
| | |
| | mask_fg = (fg_norms >= bins[i]) & (fg_norms < bins[i+1]) |
| | if mask_fg.sum() > 0: |
| | fg_attn_mean[i] = fg_attn[mask_fg].mean() |
| | |
| | |
| | mask_bg = (bg_norms >= bins[i]) & (bg_norms < bins[i+1]) |
| | if mask_bg.sum() > 0: |
| | bg_attn_mean[i] = bg_attn[mask_bg].mean() |
| | |
| | return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean |
| |
|
| |
|
| | def visualize_foreground_background_analysis( |
| | bin_centers: np.ndarray, |
| | fg_counts: np.ndarray, |
| | bg_counts: np.ndarray, |
| | fg_attn_mean: np.ndarray, |
| | bg_attn_mean: np.ndarray, |
| | layer_idx: int, |
| | output_path: str, |
| | sample_ids: List[int] = None |
| | ): |
| | """ |
| | Create dual Y-axis plot for foreground vs background analysis. |
| | |
| | Args: |
| | bin_centers: (n_bins,) - x-axis values |
| | fg_counts: (n_bins,) - foreground token counts |
| | bg_counts: (n_bins,) - background token counts |
| | fg_attn_mean: (n_bins,) - mean attention for foreground |
| | bg_attn_mean: (n_bins,) - mean attention for background |
| | layer_idx: Layer index for title |
| | output_path: Where to save the plot |
| | sample_ids: List of sample IDs used (for title) |
| | """ |
| | |
| | plt.rcParams['font.family'] = 'sans-serif' |
| | plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Helvetica', 'sans-serif'] |
| | |
| | fig, ax1 = plt.subplots(figsize=(14, 8)) |
| | |
| | |
| | ax1.set_xlabel('L2 Norm of Visual Token', fontsize=28, fontweight='bold') |
| | ax1.set_ylabel('Token Count', fontsize=28, fontweight='bold', color='black') |
| | |
| | |
| | fg_counts_masked = fg_counts.astype(float) |
| | fg_counts_masked[fg_counts_masked == 0] = np.nan |
| | |
| | bg_counts_masked = bg_counts.astype(float) |
| | bg_counts_masked[bg_counts_masked == 0] = np.nan |
| | |
| | |
| | line_fg_count = ax1.plot(bin_centers, fg_counts_masked, color='dodgerblue', linewidth=2.5, |
| | label='Foreground Tokens', marker='o', markersize=4, alpha=0.8) |
| | line_bg_count = ax1.plot(bin_centers, bg_counts_masked, color='lightcoral', linewidth=2.5, |
| | label='Background Tokens', marker='s', markersize=4, alpha=0.8) |
| | |
| | ax1.tick_params(axis='y', labelcolor='black', labelsize=24) |
| | ax1.tick_params(axis='x', labelsize=24) |
| | ax1.grid(True, alpha=0.3, linestyle='--') |
| | |
| | |
| | ax2 = ax1.twinx() |
| | ax2.set_ylabel('Average Attention Weight', fontsize=28, fontweight='bold', color='darkgreen') |
| | |
| | |
| | if len(bin_centers) > 1: |
| | bar_width = (bin_centers[1] - bin_centers[0]) / 2 |
| | else: |
| | bar_width = 0.5 |
| | |
| | |
| | |
| | ax2.bar(bin_centers - bar_width/2, fg_attn_mean, |
| | width=bar_width, color='blue', alpha=0.3, |
| | label='Foreground Attention', edgecolor='darkblue', linewidth=0.5, align='center') |
| | |
| | |
| | ax2.bar(bin_centers + bar_width/2, bg_attn_mean, |
| | width=bar_width, color='red', alpha=0.3, |
| | label='Background Attention', edgecolor='darkred', linewidth=0.5, align='center') |
| | |
| | |
| | from matplotlib.patches import Patch |
| | line_fg_attn = [Patch(facecolor='blue', alpha=0.3, edgecolor='darkblue', linewidth=0.5, label='Foreground Attention')] |
| | line_bg_attn = [Patch(facecolor='red', alpha=0.3, edgecolor='darkred', linewidth=0.5, label='Background Attention')] |
| | |
| | ax2.tick_params(axis='y', labelcolor='darkgreen', labelsize=24) |
| | |
| | |
| | sample_str = f"Samples {sample_ids}" if sample_ids else "All Samples" |
| | plt.title(f'Foreground vs Background Token Analysis - Layer {layer_idx}\n{sample_str}', |
| | fontsize=32, fontweight='bold', pad=20) |
| | |
| | |
| | lines1 = line_fg_count + line_bg_count |
| | lines2 = line_fg_attn + line_bg_attn |
| | labels1 = [l.get_label() for l in lines1] |
| | labels2 = [l.get_label() for l in lines2] |
| | |
| | |
| | ax1.legend(lines1, labels1, loc='upper left', fontsize=22, framealpha=0.9) |
| | ax2.legend(handles=lines2, loc='upper right', fontsize=22, framealpha=0.9) |
| | |
| | |
| | fg_total = fg_counts.sum() |
| | bg_total = bg_counts.sum() |
| | fg_attn_overall = np.nanmean(fg_attn_mean) |
| | bg_attn_overall = np.nanmean(bg_attn_mean) |
| | |
| | info_text = f"Avg Attention:\n" |
| | info_text += f" Foreground: {fg_attn_overall:.4f}\n" |
| | info_text += f" Background: {bg_attn_overall:.4f}" |
| | |
| | |
| | ax1.text(0.02, 0.5, info_text, transform=ax1.transAxes, |
| | fontsize=20, verticalalignment='center', |
| | bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8, edgecolor='black', linewidth=1.5)) |
| | |
| | |
| | total_text = f"Total Tokens: Foreground = {int(fg_total)}, Background = {int(bg_total)}" |
| | fig.text(0.5, 0.02, total_text, ha='center', fontsize=20, fontweight='bold') |
| | |
| | plt.tight_layout(rect=[0, 0.04, 1, 1]) |
| | plt.savefig(output_path, dpi=200, bbox_inches='tight') |
| | plt.close() |
| | print(f"✓ Saved: {output_path}") |
| |
|
| |
|
| | def aggregate_multiple_samples( |
| | data_dir: str, |
| | layer_idx: int, |
| | sample_ids: List[int] = None, |
| | n_bins: int = 50 |
| | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[int]]: |
| | """ |
| | Aggregate foreground/background analysis across multiple samples. |
| | |
| | Args: |
| | data_dir: Directory with .npz files |
| | layer_idx: Layer to analyze |
| | sample_ids: List of sample IDs to use (None = all) |
| | n_bins: Number of bins |
| | |
| | Returns: |
| | Aggregated bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_sample_ids |
| | """ |
| | |
| | npz_files = sorted(glob.glob(os.path.join(data_dir, 'sample_*.npz'))) |
| | |
| | if not npz_files: |
| | print(f"Error: No .npz files found in {data_dir}") |
| | return None, None, None, None, None, [] |
| | |
| | |
| | if sample_ids: |
| | npz_files = [f for f in npz_files |
| | if any(f"sample_{sid:04d}.npz" in f for sid in sample_ids)] |
| | |
| | print(f"Processing {len(npz_files)} samples for layer {layer_idx}...") |
| | |
| | |
| | all_fg_norms = [] |
| | all_bg_norms = [] |
| | all_fg_attn = [] |
| | all_bg_attn = [] |
| | processed_ids = [] |
| | |
| | for npz_path in npz_files: |
| | sample_id = int(os.path.basename(npz_path).split('_')[1].split('.')[0]) |
| | sample_data = load_mts_sample(npz_path) |
| | |
| | attn_key = f'attention_layer_{layer_idx}' |
| | if attn_key not in sample_data: |
| | print(f" Warning: Layer {layer_idx} not found in sample {sample_id}") |
| | continue |
| | |
| | attention = sample_data[attn_key] |
| | is_hr_token = sample_data['is_hr_token'] |
| | n_visual_merged = sample_data.get('n_visual_merged', None) |
| | |
| | if n_visual_merged is None: |
| | lr_grid_shape = tuple(sample_data['lr_grid_shape']) |
| | t, h, w = lr_grid_shape |
| | spatial_merge_size = 2 |
| | n_visual_merged = (t * h * w) // (spatial_merge_size ** 2) |
| | |
| | n_visual_tokens = int(n_visual_merged) |
| | |
| | |
| | vision_features = sample_data.get('vision_features', None) |
| | if vision_features is not None and len(vision_features) >= n_visual_tokens: |
| | feature_norms = compute_true_feature_norm(vision_features[:n_visual_tokens]) |
| | else: |
| | feature_norms = compute_feature_norm_from_attention(attention, n_visual_tokens) |
| | |
| | text_start_idx = n_visual_tokens |
| | text_to_visual = attention[:, text_start_idx:, :n_visual_tokens] |
| | attn_weights = text_to_visual.mean(axis=(0, 1)) |
| | |
| | |
| | merge_unit = 4 |
| | is_hr_token_merged = np.zeros(n_visual_tokens, dtype=bool) |
| | for unmerged_idx in range(len(is_hr_token)): |
| | if is_hr_token[unmerged_idx]: |
| | merged_idx = unmerged_idx // merge_unit |
| | if merged_idx < n_visual_tokens: |
| | is_hr_token_merged[merged_idx] = True |
| | |
| | foreground_mask = is_hr_token_merged |
| | background_mask = ~is_hr_token_merged |
| | |
| | |
| | all_fg_norms.append(feature_norms[foreground_mask]) |
| | all_bg_norms.append(feature_norms[background_mask]) |
| | all_fg_attn.append(attn_weights[foreground_mask]) |
| | all_bg_attn.append(attn_weights[background_mask]) |
| | processed_ids.append(sample_id) |
| | |
| | print(f" ✓ Sample {sample_id}: n_visual={n_visual_tokens}, FG(HR)={foreground_mask.sum()}, BG(LR)={background_mask.sum()}") |
| | |
| | if not all_fg_norms: |
| | print(f"Error: No valid samples found") |
| | return None, None, None, None, None, [] |
| | |
| | |
| | all_fg_norms = np.concatenate(all_fg_norms) |
| | all_bg_norms = np.concatenate(all_bg_norms) |
| | all_fg_attn = np.concatenate(all_fg_attn) |
| | all_bg_attn = np.concatenate(all_bg_attn) |
| | |
| | print(f"\nAggregated: FG={len(all_fg_norms)}, BG={len(all_bg_norms)} tokens") |
| | |
| | |
| | norm_max = max(all_fg_norms.max(), all_bg_norms.max()) |
| | bins = np.linspace(0, norm_max, n_bins + 1) |
| | bin_centers = (bins[:-1] + bins[1:]) / 2 |
| | |
| | |
| | fg_counts, _ = np.histogram(all_fg_norms, bins=bins) |
| | bg_counts, _ = np.histogram(all_bg_norms, bins=bins) |
| | |
| | |
| | fg_attn_mean = np.full(n_bins, np.nan) |
| | bg_attn_mean = np.full(n_bins, np.nan) |
| | |
| | for i in range(n_bins): |
| | |
| | mask_fg = (all_fg_norms >= bins[i]) & (all_fg_norms < bins[i+1]) |
| | if mask_fg.sum() > 0: |
| | fg_attn_mean[i] = all_fg_attn[mask_fg].mean() |
| | |
| | |
| | mask_bg = (all_bg_norms >= bins[i]) & (all_bg_norms < bins[i+1]) |
| | if mask_bg.sum() > 0: |
| | bg_attn_mean[i] = all_bg_attn[mask_bg].mean() |
| | |
| | return bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, processed_ids |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description='Analyze Foreground vs Background Token Distribution' |
| | ) |
| | parser.add_argument('--data_dir', type=str, |
| | default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877', |
| | help='Directory containing .npz files') |
| | parser.add_argument('--output_dir', type=str, |
| | default='/mnt/rdata4_6/huixin/LLaMA-Factory-main/attn_vis/paco_attention_data_checkpoint5877_analysis', |
| | help='Output directory') |
| | parser.add_argument('--layers', type=int, nargs='+', default=[2, 6, 15, 27], |
| | help='Layers to analyze') |
| | parser.add_argument('--samples', type=int, nargs='+', default=None, |
| | help='Specific sample IDs (default: all samples)') |
| | parser.add_argument('--n_bins', type=int, default=50, |
| | help='Number of bins for histogram') |
| | |
| | args = parser.parse_args() |
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | |
| | print("="*70) |
| | print("Foreground vs Background Token Analysis") |
| | print("="*70) |
| | print(f"Data directory: {args.data_dir}") |
| | print(f"Output directory: {args.output_dir}") |
| | print(f"Layers: {args.layers}") |
| | print(f"Bins: {args.n_bins}") |
| | print("="*70) |
| | |
| | |
| | for layer_idx in args.layers: |
| | print(f"\n{'='*70}") |
| | print(f"Processing Layer {layer_idx}") |
| | print(f"{'='*70}") |
| | |
| | |
| | bin_centers, fg_counts, bg_counts, fg_attn_mean, bg_attn_mean, sample_ids = \ |
| | aggregate_multiple_samples( |
| | data_dir=args.data_dir, |
| | layer_idx=layer_idx, |
| | sample_ids=args.samples, |
| | n_bins=args.n_bins |
| | ) |
| | |
| | if bin_centers is None: |
| | print(f"Skipping layer {layer_idx} due to errors") |
| | continue |
| | |
| | |
| | output_path = os.path.join( |
| | args.output_dir, |
| | f"layer_{layer_idx}_foreground_background_analysis.png" |
| | ) |
| | visualize_foreground_background_analysis( |
| | bin_centers=bin_centers, |
| | fg_counts=fg_counts, |
| | bg_counts=bg_counts, |
| | fg_attn_mean=fg_attn_mean, |
| | bg_attn_mean=bg_attn_mean, |
| | layer_idx=layer_idx, |
| | output_path=output_path, |
| | sample_ids=sample_ids |
| | ) |
| | |
| | print(f"\n{'='*70}") |
| | print(f"✓ Analysis complete! Results saved to: {args.output_dir}") |
| | print(f"{'='*70}") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|
| |
|