VLM-Lens / src /concepts /pca_separation.py
marstin's picture
[martin-dev] add demo v1 test
d425e71
"""PCA scatter plot visualization for VLM concept analysis.
Creates 2D scatter plots of concepts and targets in PCA space for interpretability.
"""
from __future__ import annotations
import os
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
from pca import (apply_pca_to_layer, extract_concept_from_filename,
group_tensors_by_concept, load_tensors_by_layer)
def create_pca_scatter_plots(
target_db_path: str,
concept_db_path: str,
layer_names: Optional[list[str]] = None,
output_dir: str = 'output',
figsize: tuple[int, int] = (12, 8),
alpha: float = 0.7,
target_marker_size: int = 100,
concept_marker_size: int = 50
) -> None:
"""Create 2D PCA scatter plots for concepts and targets.
Args:
target_db_path: Path to target images database
concept_db_path: Path to concept images database
layer_names: List of layer names to visualize (None for all layers)
output_dir: Directory to save plots
figsize: Figure size (width, height)
alpha: Transparency for concept points
target_marker_size: Size of target markers
concept_marker_size: Size of concept markers
"""
print('Creating PCA scatter plots...')
# Load tensors from both databases
print(f'Loading tensors from {target_db_path}...')
target_tensors = load_tensors_by_layer(target_db_path, 'cpu')
print(f'Loading tensors from {concept_db_path}...')
concept_tensors = load_tensors_by_layer(concept_db_path, 'cpu')
# Find common layers
common_layers = set(target_tensors.keys()) & set(concept_tensors.keys())
print(f'Found {len(common_layers)} common layers: {sorted(common_layers)}')
if not common_layers:
print('No common layers found between databases!')
return
# Determine which layers to visualize
if layer_names is None:
layers_to_analyze = sorted(common_layers)
else:
if isinstance(layer_names, str):
layer_names = [layer_names]
layers_to_analyze = [layer for layer in layer_names if layer in common_layers]
os.makedirs(output_dir, exist_ok=True)
# Create plots for each layer
for layer in layers_to_analyze:
print(f'\nProcessing layer: {layer}')
target_layer_tensors = target_tensors[layer]
concept_layer_tensors = concept_tensors[layer]
if not target_layer_tensors or not concept_layer_tensors:
print(f'Skipping layer {layer} - insufficient data')
continue
# Apply PCA with 2 components
print(' Applying PCA with 2 components...')
transformed_targets, transformed_concepts, pca_model = apply_pca_to_layer(
target_layer_tensors, concept_layer_tensors, n_components=2
)
if pca_model is None:
print(f' Failed to apply PCA for layer {layer}')
continue
# Group concepts for coloring
concept_groups = group_tensors_by_concept(transformed_concepts)
# Create the plot
fig, ax = plt.subplots(figsize=figsize)
# Define colors for concepts (use a colormap)
concept_names = sorted(concept_groups.keys())
colors = plt.cm.Set3(np.linspace(0, 1, len(concept_names)))
color_map = dict(zip(concept_names, colors))
# Plot concept prototypes
for concept_name, concept_data in concept_groups.items():
concept_coords = np.array([data[0] for data in concept_data])
ax.scatter(
concept_coords[:, 0],
concept_coords[:, 1],
c=[color_map[concept_name]],
s=concept_marker_size,
alpha=alpha,
label=f'{concept_name} (prototypes)',
marker='o',
edgecolors='white',
linewidth=0.5
)
# Plot targets
target_coords = np.array([data[0] for data in transformed_targets])
target_concepts = []
# Extract target concepts for coloring
for data in transformed_targets:
target_concept = extract_concept_from_filename(data[3]) # data[3] is image_filename
target_concepts.append(target_concept)
# Plot targets with concept-based coloring
for i, (coord, target_concept) in enumerate(zip(target_coords, target_concepts)):
if target_concept in color_map:
color = color_map[target_concept]
label = f'{target_concept} (target)' if i == 0 or target_concept != target_concepts[i-1] else None
else:
color = 'black'
label = 'Unknown (target)' if i == 0 else None
ax.scatter(
coord[0],
coord[1],
c=[color],
s=target_marker_size,
alpha=0.9,
marker='^', # Triangle for targets
edgecolors='black',
linewidth=1.0,
label=label
)
# Customize the plot
ax.set_xlabel(f'PC1 ({pca_model.explained_variance_ratio_[0]:.3f} variance explained)')
ax.set_ylabel(f'PC2 ({pca_model.explained_variance_ratio_[1]:.3f} variance explained)')
ax.set_title(f'PCA Visualization: Concepts vs Targets\nLayer: {layer}')
ax.grid(True, alpha=0.3)
# Create legend with better organization
handles, labels = ax.get_legend_handles_labels()
# Separate prototype and target entries
prototype_handles, prototype_labels = [], []
target_handles, target_labels = [], []
for handle, label in zip(handles, labels):
if '(prototypes)' in label:
prototype_handles.append(handle)
prototype_labels.append(label.replace(' (prototypes)', ''))
elif '(target)' in label:
target_handles.append(handle)
target_labels.append(label.replace(' (target)', ''))
# Create two-column legend
if prototype_handles and target_handles:
legend1 = ax.legend(
prototype_handles,
[f'{label} (○)' for label in prototype_labels],
title='Concept Prototypes',
loc='upper left',
bbox_to_anchor=(1.02, 1.0),
fontsize=9
)
ax.add_artist(legend1)
ax.legend(
target_handles,
[f'{label} (△)' for label in target_labels],
title='Target Images',
loc='upper left',
bbox_to_anchor=(1.02, 0.6),
fontsize=9
)
else:
ax.legend(bbox_to_anchor=(1.02, 1.0), loc='upper left', fontsize=9)
# Add statistics text
stats_text = (
f'Total variance explained: {pca_model.explained_variance_ratio_.sum():.3f}\n'
f'Concepts: {len(concept_groups)}\n'
f'Prototypes: {len(transformed_concepts)}\n'
f'Targets: {len(transformed_targets)}'
)
ax.text(
0.02, 0.98,
stats_text,
transform=ax.transAxes,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
fontsize=9
)
plt.tight_layout()
# Save plot
plot_filename = f'{output_dir}/pca_scatter_layer_{layer.replace("/", "_")}.png'
plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
plt.close()
print(f' Plot saved: {plot_filename}')
# Print summary statistics
print(f' Variance explained: PC1={pca_model.explained_variance_ratio_[0]:.3f}, '
f'PC2={pca_model.explained_variance_ratio_[1]:.3f}, '
f'Total={pca_model.explained_variance_ratio_.sum():.3f}')
print(f' Plotted {len(concept_groups)} concept groups with {len(transformed_concepts)} prototypes')
print(f' Plotted {len(transformed_targets)} target images')
print(f'\nPCA scatter plots complete. Plots saved in {output_dir}/')
def create_concept_separation_analysis(
target_db_path: str,
concept_db_path: str,
layer_names: Optional[list[str]] = None,
output_dir: str = 'output'
) -> None:
"""Analyze concept separation in PCA space.
Args:
target_db_path: Path to target images database
concept_db_path: Path to concept images database
layer_names: List of layer names to analyze (None for all layers)
output_dir: Directory to save analysis
"""
print('\nAnalyzing concept separation in PCA space...')
# Load tensors
target_tensors = load_tensors_by_layer(target_db_path, 'cpu')
concept_tensors = load_tensors_by_layer(concept_db_path, 'cpu')
common_layers = set(target_tensors.keys()) & set(concept_tensors.keys())
if layer_names is None:
layers_to_analyze = sorted(common_layers)
else:
if isinstance(layer_names, str):
layer_names = [layer_names]
layers_to_analyze = [layer for layer in layer_names if layer in common_layers]
os.makedirs(output_dir, exist_ok=True)
with open(f'{output_dir}/pca_separation_analysis.txt', 'w') as f:
f.write('PCA Concept Separation Analysis\n')
f.write('=' * 40 + '\n\n')
for layer in layers_to_analyze:
target_layer_tensors = target_tensors[layer]
concept_layer_tensors = concept_tensors[layer]
if not concept_layer_tensors:
continue
# Apply PCA
_, transformed_concepts, pca_model = apply_pca_to_layer(
target_layer_tensors, concept_layer_tensors, n_components=2
)
if pca_model is None:
continue
f.write(f'Layer: {layer}\n')
f.write('-' * 20 + '\n')
# Group concepts
concept_groups = group_tensors_by_concept(transformed_concepts)
# Calculate concept centroids in PCA space
concept_centroids = {}
for concept_name, concept_data in concept_groups.items():
coords = np.array([data[0] for data in concept_data])
concept_centroids[concept_name] = np.mean(coords, axis=0)
# Calculate pairwise distances between concept centroids
concept_names = list(concept_centroids.keys())
f.write('Concept centroid distances in PC1-PC2 space:\n')
for i, concept1 in enumerate(concept_names):
for j, concept2 in enumerate(concept_names[i+1:], i+1):
centroid1 = concept_centroids[concept1]
centroid2 = concept_centroids[concept2]
distance = np.linalg.norm(centroid1 - centroid2)
f.write(f' {concept1} - {concept2}: {distance:.3f}\n')
# Calculate within-concept scatter
f.write('\nWithin-concept scatter (std dev):\n')
for concept_name, concept_data in concept_groups.items():
coords = np.array([data[0] for data in concept_data])
if len(coords) > 1:
std_pc1 = np.std(coords[:, 0])
std_pc2 = np.std(coords[:, 1])
f.write(f' {concept_name}: PC1={std_pc1:.3f}, PC2={std_pc2:.3f}\n')
f.write('\nPCA Statistics:\n')
f.write(f' PC1 variance explained: {pca_model.explained_variance_ratio_[0]:.3f}\n')
f.write(f' PC2 variance explained: {pca_model.explained_variance_ratio_[1]:.3f}\n')
f.write(f' Total variance explained: {pca_model.explained_variance_ratio_.sum():.3f}\n')
f.write('\n\n')
print(f'Separation analysis saved to {output_dir}/pca_separation_analysis.txt')
if __name__ == '__main__':
# Configuration
target_db_path = 'output/llava.db'
concept_db_path = 'output/llava-concepts-colors.db'
# Visualization parameters
layer_names = None # None for all layers, or specify: ['layer_name1', 'layer_name2']
print('=' * 60)
print('VLM PCA VISUALIZATION')
print('=' * 60)
try:
# Create scatter plots
create_pca_scatter_plots(
target_db_path=target_db_path,
concept_db_path=concept_db_path,
layer_names=layer_names,
output_dir='output',
figsize=(12, 8),
alpha=0.7,
target_marker_size=100,
concept_marker_size=50
)
# Analyze concept separation
create_concept_separation_analysis(
target_db_path=target_db_path,
concept_db_path=concept_db_path,
layer_names=layer_names,
output_dir='output'
)
print('\nVisualization complete!')
except Exception as e:
print(f'Error during visualization: {e}')
import traceback
traceback.print_exc()