VLM-Lens / src /concepts /pca_knn.py
marstin's picture
[martin-dev] add demo v1 test
d425e71
"""Instance-based k-NN extension for VLM concept analysis.
This module extends the existing VLM concept analysis with nearest-neighbor
prototype-based classification. It reuses the existing functions and adds
instance-based readout capabilities.
"""
from __future__ import annotations
from collections import defaultdict
from typing import Any, Optional
import numpy as np
# Import from the existing analysis module
from pca import (analyze_concept_trends, cosine_similarity_numpy,
extract_concept_from_filename, group_tensors_by_concept,
load_tensors_by_layer)
def _build_normalized_prototype_bank(
concept_tensors: list[tuple[np.ndarray, Any, int, str]]
) -> tuple[Optional[np.ndarray], list[dict[str, Any]]]:
"""Build an (N,d) bank of L2-normalized prototype vectors and metadata.
Args:
concept_tensors: List of tuples (vec, label, row_id, image_path)
Returns:
Tuple of (X matrix (N,d), meta list of dicts with concept/row_id/image_path)
"""
X_list, meta = [], []
for vec, label, row_id, image_path in concept_tensors:
if vec is None:
continue
norm = np.linalg.norm(vec)
if not np.isfinite(norm) or norm == 0:
continue
X_list.append(vec / norm)
meta.append({
'concept': extract_concept_from_filename(image_path),
'row_id': row_id,
'image_path': image_path,
'label': label
})
if not X_list:
return None, []
X = np.vstack(X_list)
return X, meta
def _nearest_prototypes(
target_vec: np.ndarray,
X_bank: Optional[np.ndarray],
meta: list[dict[str, Any]],
topk: int = 5
) -> list[dict[str, Any]]:
"""Compute cosine similarities target vs all prototypes (already normalized).
Args:
target_vec: Target vector (d,), will be L2-normalized here
X_bank: Prototype bank matrix (N, d), already normalized
meta: List of metadata dicts for each prototype
topk: Number of top neighbors to return
Returns:
Top list of dicts sorted by similarity with keys:
['concept', 'row_id', 'image_path', 'label', 'sim']
"""
if X_bank is None or len(meta) == 0:
return []
# L2-normalize target
t = target_vec
t_norm = np.linalg.norm(t)
if not np.isfinite(t_norm) or t_norm == 0:
return []
t = t / t_norm
sims = X_bank @ t # cosine since both normalized
k = min(topk, sims.shape[0])
# argpartition is O(N); then sort the small top-k slice
idx = np.argpartition(-sims, k - 1)[:k]
idx = idx[np.argsort(-sims[idx])]
out = []
for i in idx:
m = meta[i]
out.append({
'concept': m['concept'],
'row_id': m['row_id'],
'image_path': m['image_path'],
'label': m['label'],
'sim': float(sims[i]),
})
return out
def _knn_weighted_vote(
neighbors: list[dict[str, Any]],
p: float = 1.0
) -> tuple[Optional[str], dict[str, float]]:
"""Weighted majority vote over top-k neighbors.
Args:
neighbors: List of neighbor dicts with 'concept' and 'sim' keys
p: Power for weighting (weight = sim^p, negatives clipped to 0)
Returns:
Tuple of (winner_concept, score_dict)
"""
wsum = defaultdict(float)
for nb in neighbors:
w = max(0.0, nb['sim']) ** p
wsum[nb['concept']] += w
if not wsum:
return None, {}
winner = max(wsum.items(), key=lambda kv: kv[1])[0]
return winner, dict(wsum)
def analyze_target_vs_concepts_with_knn(
target_tensors: list[tuple[np.ndarray, Any, int, str]],
concept_tensors: list[tuple[np.ndarray, Any, int, str]],
layer_name: str,
knn_topk: int = 5,
knn_power: float = 1.0
) -> list[dict[str, Any]]:
"""Analyze similarity between targets and concepts with k-NN instance-based prediction.
Keeps existing per-prototype stats and centroid metrics.
Adds instance-based nearest-neighbor prediction (1-NN + k-NN vote).
Args:
target_tensors: List of target tensor data
concept_tensors: List of concept tensor data
layer_name: Name of the current layer
knn_topk: Number of nearest neighbors to consider
knn_power: Power for weighted voting (weight = sim^p)
Returns:
List of analysis results with added 'instance_knn' section
"""
# Group by concept (existing behavior)
concept_groups = group_tensors_by_concept(concept_tensors)
print(f'Found {len(concept_groups)} concepts: {list(concept_groups.keys())}')
for concept, tensors in concept_groups.items():
print(f' {concept}: {len(tensors)} images')
# Precompute centroids (as before)
concept_centroids = {}
for concept_name, tensor_list in concept_groups.items():
vecs = [t[0] for t in tensor_list]
if len(vecs) > 0:
concept_centroids[concept_name] = np.mean(np.vstack(vecs), axis=0)
else:
concept_centroids[concept_name] = None
# NEW: build prototype bank once for this layer
X_bank, bank_meta = _build_normalized_prototype_bank(concept_tensors)
if X_bank is None:
print('Warning: prototype bank is empty for this layer; skipping instance-NN.')
results = []
for target_data in target_tensors:
target_vec, target_label, target_row_id, target_image_filename = target_data
target_result = {
'layer': layer_name,
'target_row_id': target_row_id,
'target_label': target_label,
'target_image_filename': target_image_filename,
'concept_analysis': {}, # existing per-concept stats live here
'instance_knn': {} # NEW: instance-based readout lives here
}
# --- Existing per-concept stats (unchanged) ---
for concept_name, concept_tensor_list in concept_groups.items():
similarities = []
for concept_data in concept_tensor_list:
concept_vec, concept_label, concept_row_id, concept_image_filename = concept_data
if target_vec.shape != concept_vec.shape:
print(f'Warning: Shape mismatch between target {target_row_id} and concept {concept_row_id}')
continue
sim = cosine_similarity_numpy(target_vec, concept_vec)
similarities.append(sim)
concept_stats = {}
if similarities:
similarities = np.array(similarities)
distances = 1.0 - similarities
concept_stats.update({
'min_similarity': float(np.min(similarities)),
'max_similarity': float(np.max(similarities)),
'mean_similarity': float(np.mean(similarities)),
'min_distance': float(np.min(distances)),
'mean_distance': float(np.mean(distances)),
'num_comparisons': int(len(similarities)),
})
centroid = concept_centroids.get(concept_name, None)
if centroid is not None and centroid.shape == target_vec.shape:
cen_sim = cosine_similarity_numpy(target_vec, centroid)
cen_ang = float(np.degrees(np.arccos(np.clip(cen_sim, -1.0, 1.0))))
concept_stats.update({
'centroid_similarity': float(cen_sim),
'centroid_angular_deg': cen_ang
})
if concept_stats:
target_result['concept_analysis'][concept_name] = concept_stats
# --- NEW: instance-based nearest neighbor prediction ---
if X_bank is not None:
nbs = _nearest_prototypes(target_vec, X_bank, bank_meta, topk=knn_topk)
winner_1nn = nbs[0]['concept'] if nbs else None
voted, vote_scores = _knn_weighted_vote(nbs, p=knn_power) if nbs else (None, {})
target_result['instance_knn'] = {
'top1_concept': winner_1nn,
'top1_similarity': nbs[0]['sim'] if nbs else None,
'topk_neighbors': nbs, # list with concept,row_id,image_path,sim
'topk_voted_concept': voted, # weighted by sim^p over topk (non-negative)
'vote_scores': vote_scores, # dict concept->weight
'topk': knn_topk,
'vote_power': knn_power
}
results.append(target_result)
target_display = target_image_filename if target_image_filename else f'Target_{target_row_id}'
print(f'Analyzed {target_display} against {len(concept_groups)} concepts')
return results
def concept_similarity_analysis_with_knn(
target_db_path: str,
concept_db_path: str,
layer_names: Optional[list[str]] = None,
n_pca_components: Optional[int] = None,
knn_topk: int = 5,
knn_power: float = 1.0,
device: str = 'cpu'
) -> dict[str, dict[str, Any]]:
"""Main function for concept-based similarity analysis with k-NN prediction.
Args:
target_db_path: Path to target images database
concept_db_path: Path to concept images database
layer_names: List of layer names to analyze (None for all common layers)
n_pca_components: Number of PCA components (None to skip PCA)
knn_topk: Number of nearest neighbors for k-NN prediction
knn_power: Power for weighted voting in k-NN
device: PyTorch device
Returns:
Dictionary of analysis results by layer with k-NN predictions
"""
print('Starting concept-based similarity analysis with k-NN...')
print(f'Target DB: {target_db_path}')
print(f'Concept DB: {concept_db_path}')
print(f'PCA components: {n_pca_components}')
print(f'k-NN parameters: topk={knn_topk}, power={knn_power}')
# Load tensors from both databases (reuse existing function)
print(f'\nLoading tensors from {target_db_path}...')
target_tensors = load_tensors_by_layer(target_db_path, device)
print(f'Loading tensors from {concept_db_path}...')
concept_tensors = load_tensors_by_layer(concept_db_path, device)
# Find common layers
common_layers = set(target_tensors.keys()) & set(concept_tensors.keys())
print(f'\nFound {len(common_layers)} common layers: {sorted(common_layers)}')
if not common_layers:
print('No common layers found between databases!')
return {}
# Determine which layers to analyze
if layer_names is None:
layers_to_analyze = sorted(common_layers)
print('Analyzing all common layers')
else:
if isinstance(layer_names, str):
layer_names = [layer_names]
layers_to_analyze = [layer for layer in layer_names if layer in common_layers]
print(f'Analyzing specified layers: {layers_to_analyze}')
# Warn about missing layers
missing_layers = set(layer_names) - common_layers
if missing_layers:
print(f'Warning: Requested layers not found: {missing_layers}')
if not layers_to_analyze:
print('No valid layers to analyze!')
return {}
all_results = {}
# Process each layer
for layer in layers_to_analyze:
print(f'\n{"=" * 50}')
print(f'Processing Layer: {layer}')
print(f'{"=" * 50}')
target_layer_tensors = target_tensors[layer]
concept_layer_tensors = concept_tensors[layer]
print(f'Target tensors: {len(target_layer_tensors)}')
print(f'Concept tensors: {len(concept_layer_tensors)}')
# Apply PCA if requested (reuse existing function)
if n_pca_components is not None:
# Import the PCA function
from pca import apply_pca_to_layer
target_layer_tensors, concept_layer_tensors, pca_model = apply_pca_to_layer(
target_layer_tensors, concept_layer_tensors, n_pca_components
)
else:
pca_model = None
# Analyze similarities with k-NN
layer_results = analyze_target_vs_concepts_with_knn(
target_layer_tensors, concept_layer_tensors, layer,
knn_topk=knn_topk, knn_power=knn_power
)
all_results[layer] = {
'results': layer_results,
'pca_model': pca_model,
'n_pca_components': n_pca_components,
'knn_topk': knn_topk,
'knn_power': knn_power
}
# Print layer summary
if layer_results:
print(f"\nLayer \'{layer}\' Summary:")
print(f' Analyzed {len(layer_results)} target images')
# Get all concept names from first result
if layer_results[0]['concept_analysis']:
concept_names = list(layer_results[0]['concept_analysis'].keys())
print(f' Against {len(concept_names)} concepts: {concept_names}')
# Print k-NN summary
knn_predictions = []
for result in layer_results:
ik = result.get('instance_knn', {})
if ik.get('top1_concept'):
knn_predictions.append(ik['top1_concept'])
if knn_predictions:
from collections import Counter
pred_counts = Counter(knn_predictions)
print(f' k-NN Predictions: {dict(pred_counts)}')
return all_results
def save_knn_analysis_results(
results: dict[str, dict[str, Any]],
output_file: str = 'output/knn_similarity_analysis.txt'
) -> None:
"""Save k-NN analysis results to a text file.
Args:
results: Dictionary of analysis results by layer
output_file: Output filename
"""
import os
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
f.write('VLM Concept Analysis with Instance-based k-NN Prediction\n')
f.write('=' * 60 + '\n\n')
for layer, layer_data in results.items():
layer_results = layer_data['results']
n_pca_components = layer_data['n_pca_components']
knn_topk = layer_data.get('knn_topk', 5)
knn_power = layer_data.get('knn_power', 1.0)
f.write(f'Layer: {layer}\n')
if n_pca_components:
f.write(f'PCA Components: {n_pca_components}\n')
f.write(f'k-NN Parameters: topk={knn_topk}, power={knn_power}\n')
f.write('-' * 40 + '\n\n')
for result in layer_results:
target_display = result['target_image_filename'] or f'Target_{result["target_row_id"]}'
f.write(f'Target: {target_display}\n')
# k-NN predictions
ik = result.get('instance_knn', {})
if ik:
f.write(f' 1-NN Concept: {ik.get("top1_concept")} (sim={ik.get("top1_similarity", 0):.4f})\n')
if ik.get('topk_voted_concept') is not None and ik.get('topk', 1) > 1:
f.write(f' k-NN Vote (k={ik["topk"]}, p={ik["vote_power"]}): {ik["topk_voted_concept"]}\n')
# Show top neighbors
neighbors = ik.get('topk_neighbors', [])
if neighbors:
f.write(' Top Neighbors:\n')
for i, nb in enumerate(neighbors[:3], 1): # Show top 3
f.write(f' {i}. {nb["concept"]} (sim={nb["sim"]:.4f})\n')
# Original concept analysis
for concept_name, stats in result['concept_analysis'].items():
f.write(f' Concept {concept_name}:\n')
if 'centroid_similarity' in stats:
f.write(f' Centroid Similarity: {stats["centroid_similarity"]:.4f}\n')
if 'mean_similarity' in stats:
f.write(f' Mean Similarity: {stats["mean_similarity"]:.4f}\n')
f.write('\n')
f.write('\n')
print(f'k-NN results saved to {output_file}')
def analyze_knn_accuracy(
results: dict[str, dict[str, Any]],
ground_truth_concept_extractor: Optional[callable] = None
) -> None:
"""Analyze k-NN prediction accuracy if ground truth is available.
Args:
results: Dictionary of analysis results by layer
ground_truth_concept_extractor: Function to extract true concept from target filename
"""
if ground_truth_concept_extractor is None:
ground_truth_concept_extractor = extract_concept_from_filename
print(f'\n{"=" * 50}')
print('k-NN PREDICTION ACCURACY ANALYSIS')
print(f'{"=" * 50}')
for layer, layer_data in results.items():
layer_results = layer_data['results']
knn_topk = layer_data.get('knn_topk', 5)
print(f'\nLayer: {layer}')
print('-' * 30)
if not layer_results:
print('No results for this layer')
continue
correct_1nn = 0
correct_knn = 0
total = 0
for result in layer_results:
# Extract ground truth
true_concept = ground_truth_concept_extractor(result['target_image_filename'])
if not true_concept:
continue
ik = result.get('instance_knn', {})
if not ik:
continue
total += 1
# Check 1-NN accuracy
pred_1nn = ik.get('top1_concept')
if pred_1nn == true_concept:
correct_1nn += 1
# Check k-NN vote accuracy
pred_knn = ik.get('topk_voted_concept')
if pred_knn == true_concept:
correct_knn += 1
if total > 0:
acc_1nn = correct_1nn / total
acc_knn = correct_knn / total
print(f' 1-NN Accuracy: {correct_1nn}/{total} = {acc_1nn:.3f}')
print(f' k-NN Accuracy (k={knn_topk}): {correct_knn}/{total} = {acc_knn:.3f}')
else:
print(' No valid predictions to evaluate')
if __name__ == '__main__':
# Configuration
target_db_path = 'output/llava.db'
concept_db_path = 'output/llava-concepts-colors.db'
# Analysis parameters
layer_names = None # None for all layers
n_pca_components = 5 # None for raw embeddings
knn_topk = 5
knn_power = 1.0
print('=' * 60)
print('VLM CONCEPT ANALYSIS WITH INSTANCE-BASED k-NN')
print('=' * 60)
try:
# Run k-NN analysis
results = concept_similarity_analysis_with_knn(
target_db_path=target_db_path,
concept_db_path=concept_db_path,
layer_names=layer_names,
n_pca_components=n_pca_components,
knn_topk=knn_topk,
knn_power=knn_power,
device='cpu'
)
if results:
# Save detailed results
output_file = 'output/knn_similarity_analysis.txt'
save_knn_analysis_results(results, output_file)
# Analyze k-NN accuracy
analyze_knn_accuracy(results)
# Show aggregate trends (reuse existing function)
analyze_concept_trends(results)
print(f'\n{"=" * 50}')
print('k-NN ANALYSIS COMPLETE')
print(f'{"=" * 50}')
print(f'Processed {len(results)} layers')
print(f'Results saved to: {output_file}')
else:
print('No results generated. Check database compatibility and parameters.')
except Exception as e:
print(f'Error during analysis: {e}')
import traceback
traceback.print_exc()