geolip-vit-base-x3 / analysis /core_geometric_analyzer.py
AbstractPhil's picture
Rename core_geometric_analyzer.py to analysis/core_geometric_analyzer.py
4257bf4 verified
#!/usr/bin/env python3
"""
BASE TIER DEEP MODEL ANALYSIS
===============================
Tuned specifically for the top 3 reasonable targets with high enough similarity to geometrically align.
This helps align the procrustes experiment to save time.
Three models, all 768-d output, all patch-based ViTs:
1. clip_l14_openai β€” CLIP ViT-L/14 (text-supervised, semantic)
2. dinov2_b14 β€” DINOv2 ViT-B/14 (self-supervised, structural)
3. siglip_b16_384 β€” SigLIP ViT-B/16 (sigmoid contrastive, semantic)
Analyze:
- Full architecture comparison (layers, heads, dims, patch size)
- Weight statistics per layer (norms, spectral radius, sparsity)
- Attention head geometry (Q/K/V weight structure)
- Layer-by-layer representation similarity (CKA, Procrustes)
- Patch embedding weight comparison (the actual patchwork)
- MLP weight spectrum analysis
- Where do they converge internally vs diverge?
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import gc
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("=" * 65)
print("BASE TIER DEEP MODEL ANALYSIS")
print("=" * 65)
print(f" Device: {DEVICE}")
# ══════════════════════════════════════════════════════════════════
# LOAD MODELS
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("LOADING MODELS")
print(f"{'='*65}")
from transformers import (
CLIPVisionModel, CLIPVisionConfig,
Dinov2Model, Dinov2Config,
SiglipVisionModel, SiglipVisionConfig,
)
models = {}
configs = {}
# CLIP ViT-L/14
print(f"\n Loading CLIP ViT-L/14...")
clip = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").eval()
models["clip_l14"] = clip
configs["clip_l14"] = clip.config
print(f" Loaded: {sum(p.numel() for p in clip.parameters()):,} params")
# DINOv2 ViT-B/14
print(f" Loading DINOv2 ViT-B/14...")
dino = Dinov2Model.from_pretrained("facebook/dinov2-base").eval()
models["dinov2_b14"] = dino
configs["dinov2_b14"] = dino.config
print(f" Loaded: {sum(p.numel() for p in dino.parameters()):,} params")
# SigLIP ViT-B/16
print(f" Loading SigLIP ViT-B/16-384...")
siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-384").eval()
models["siglip_b16"] = siglip
configs["siglip_b16"] = siglip.config
print(f" Loaded: {sum(p.numel() for p in siglip.parameters()):,} params")
# ══════════════════════════════════════════════════════════════════
# SCAN 1: ARCHITECTURE COMPARISON
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 1: ARCHITECTURE COMPARISON")
print(f"{'='*65}")
def get_arch_info(name, model, config):
info = {"name": name}
c = config
if hasattr(c, 'hidden_size'):
info["hidden_size"] = c.hidden_size
if hasattr(c, 'intermediate_size'):
info["intermediate_size"] = c.intermediate_size
if hasattr(c, 'num_hidden_layers'):
info["num_layers"] = c.num_hidden_layers
if hasattr(c, 'num_attention_heads'):
info["num_heads"] = c.num_attention_heads
if hasattr(c, 'patch_size'):
info["patch_size"] = c.patch_size
if hasattr(c, 'image_size'):
info["image_size"] = c.image_size
info["total_params"] = sum(p.numel() for p in model.parameters())
info["head_dim"] = info.get("hidden_size", 0) // max(info.get("num_heads", 1), 1)
return info
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
info = get_arch_info(name, models[name], configs[name])
print(f"\n {name}:")
for k, v in info.items():
if k != "name":
print(f" {k:<20}: {v:>12,}" if isinstance(v, int) else f" {k:<20}: {v}")
# ══════════════════════════════════════════════════════════════════
# SCAN 2: NAMED PARAMETER INVENTORY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 2: PARAMETER INVENTORY")
print(f"{'='*65}")
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
model = models[name]
print(f"\n {name}:")
# Group by layer type
groups = {}
for pname, p in model.named_parameters():
# Extract layer category
parts = pname.split(".")
if "embeddings" in pname:
cat = "embeddings"
elif "encoder" in pname and "layer" in pname:
# Find layer number
for part in parts:
if part.startswith("layer"):
break
# Categorize within layer
if "attention" in pname:
if "query" in pname or "q_proj" in pname or "k_proj" in pname or "v_proj" in pname:
cat = "attn_qkv"
elif "out" in pname or "o_proj" in pname:
cat = "attn_out"
else:
cat = "attn_other"
elif "mlp" in pname or "intermediate" in pname or "output" in pname:
cat = "mlp"
elif "norm" in pname or "layer_norm" in pname:
cat = "layernorm"
else:
cat = "encoder_other"
elif "layernorm" in pname.lower() or "layer_norm" in pname.lower():
cat = "final_norm"
elif "head" in pname or "pooler" in pname:
cat = "head"
else:
cat = "other"
groups.setdefault(cat, {"count": 0, "params": 0, "shapes": []})
groups[cat]["count"] += 1
groups[cat]["params"] += p.numel()
if len(groups[cat]["shapes"]) < 3:
groups[cat]["shapes"].append(f"{pname.split('.')[-2]}.{pname.split('.')[-1]}: {list(p.shape)}")
for cat in sorted(groups.keys()):
g = groups[cat]
print(f" {cat:<15}: {g['params']:>12,} ({g['count']:2d} tensors)")
for s in g["shapes"]:
print(f" {s}")
# ══════════════════════════════════════════════════════════════════
# SCAN 3: WEIGHT STATISTICS PER LAYER
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 3: WEIGHT STATISTICS")
print(f"{'='*65}")
def weight_stats(param):
p = param.float().detach()
stats = {
"shape": list(p.shape),
"norm": p.norm().item(),
"mean": p.mean().item(),
"std": p.std().item(),
"abs_max": p.abs().max().item(),
"sparsity": (p.abs() < 1e-6).float().mean().item(),
}
# Spectral radius for 2D weights
if p.dim() == 2 and min(p.shape) > 1:
sv = torch.linalg.svdvals(p)
stats["sv_max"] = sv[0].item()
stats["sv_min"] = sv[-1].item()
stats["sv_ratio"] = (sv[0] / (sv[-1] + 1e-10)).item()
stats["eff_rank"] = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
return stats
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
model = models[name]
print(f"\n {name} β€” key weight matrices:")
print(f" {'param':<50} {'shape':<20} {'norm':>8} {'std':>8} {'sv_max':>8} {'eff_rank':>9}")
print(f" {'-'*105}")
for pname, p in model.named_parameters():
if p.dim() < 2: continue
if p.numel() < 1000: continue
# Only show interesting layers
show = False
for keyword in ["patch", "embed", "position", "cls",
"layer.0.", "layer.5.", "layer.11.",
"layer.23.", "q_proj", "k_proj", "v_proj",
"query", "key", "value",
"fc1", "fc2", "dense", "out_proj",
"layernorm", "head"]:
if keyword in pname.lower():
show = True; break
if not show: continue
s = weight_stats(p)
sv_max = f"{s.get('sv_max', 0):.4f}" if 'sv_max' in s else " N/A"
eff_rank = f"{s.get('eff_rank', 0):.1f}" if 'eff_rank' in s else " N/A"
short_name = pname[-50:] if len(pname) > 50 else pname
shape_str = str(s["shape"])
print(f" {short_name:<50} {shape_str:<20} {s['norm']:>8.4f} "
f"{s['std']:>8.5f} {sv_max:>8} {eff_rank:>9}")
# ══════════════════════════════════════════════════════════════════
# SCAN 4: PATCH EMBEDDING ANALYSIS (the actual patchwork)
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 4: PATCH EMBEDDING WEIGHTS")
print(f"{'='*65}")
patch_embeddings = {}
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
model = models[name]
for pname, p in model.named_parameters():
if "patch" in pname.lower() and "embed" in pname.lower() and p.dim() == 4:
patch_embeddings[name] = p.detach().float()
print(f"\n {name}: {pname}")
print(f" Shape: {list(p.shape)}")
# (out_channels, in_channels, kernel_h, kernel_w)
print(f" = {p.shape[0]} filters Γ— {p.shape[1]} channels Γ— {p.shape[2]}Γ—{p.shape[3]} kernel")
# Reshape to 2D for spectral analysis
w2d = p.detach().float().reshape(p.shape[0], -1) # (out, in*h*w)
sv = torch.linalg.svdvals(w2d)
eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
print(f" Spectral: sv_max={sv[0]:.4f} sv_min={sv[-1]:.6f} "
f"eff_rank={eff_rank:.1f}/{min(w2d.shape)}")
print(f" Norm: {p.norm():.4f} Mean: {p.mean():.6f} Std: {p.std():.6f}")
# Per-filter analysis
filter_norms = p.detach().float().reshape(p.shape[0], -1).norm(dim=1)
print(f" Filter norms: mean={filter_norms.mean():.4f} "
f"std={filter_norms.std():.4f} "
f"min={filter_norms.min():.4f} max={filter_norms.max():.4f}")
break
# Compare patch embeddings pairwise (Procrustes on flattened filters)
if len(patch_embeddings) >= 2:
print(f"\n Patch embedding Procrustes alignment:")
names_list = list(patch_embeddings.keys())
for i in range(len(names_list)):
for j in range(i+1, len(names_list)):
n1, n2 = names_list[i], names_list[j]
p1 = patch_embeddings[n1].reshape(patch_embeddings[n1].shape[0], -1)
p2 = patch_embeddings[n2].reshape(patch_embeddings[n2].shape[0], -1)
# Truncate to common dim
d_min = min(p1.shape[0], p2.shape[0])
d_feat = min(p1.shape[1], p2.shape[1])
a = p1[:d_min, :d_feat]; b = p2[:d_min, :d_feat]
# Raw cosine (mean over filters)
cos = F.cosine_similarity(
F.normalize(a, dim=1), F.normalize(b, dim=1), dim=1).mean().item()
print(f" {n1} Γ— {n2}: raw_cos={cos:.4f} (d_min={d_min}, d_feat={d_feat})")
# ══════════════════════════════════════════════════════════════════
# SCAN 5: ATTENTION HEAD GEOMETRY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 5: ATTENTION HEAD GEOMETRY")
print(f"{'='*65}")
def extract_qkv_weights(model, name):
"""Extract Q, K, V weight matrices from each layer."""
layers_qkv = []
for pname, p in model.named_parameters():
if p.dim() != 2: continue
plow = pname.lower()
if ("query" in plow or "q_proj" in plow) and "weight" in plow:
layers_qkv.append({"layer": pname, "type": "Q", "weight": p.detach().float()})
elif ("key" in plow or "k_proj" in plow) and "weight" in plow:
layers_qkv.append({"layer": pname, "type": "K", "weight": p.detach().float()})
elif ("value" in plow or "v_proj" in plow) and "weight" in plow:
layers_qkv.append({"layer": pname, "type": "V", "weight": p.detach().float()})
return layers_qkv
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
qkv = extract_qkv_weights(models[name], name)
n_layers = len(qkv) // 3
print(f"\n {name} ({n_layers} layers):")
print(f" {'layer':>6} {'Q_norm':>8} {'K_norm':>8} {'V_norm':>8} "
f"{'QK_cos':>8} {'QV_cos':>8} {'KV_cos':>8}")
for layer_idx in range(n_layers):
q = qkv[layer_idx * 3]["weight"]
k = qkv[layer_idx * 3 + 1]["weight"]
v = qkv[layer_idx * 3 + 2]["weight"]
q_norm = q.norm().item()
k_norm = k.norm().item()
v_norm = v.norm().item()
# Flatten and compute cosine between Q/K, Q/V, K/V
qf = q.reshape(-1); kf = k.reshape(-1); vf = v.reshape(-1)
d = min(qf.shape[0], kf.shape[0], vf.shape[0])
qk_cos = F.cosine_similarity(qf[:d].unsqueeze(0), kf[:d].unsqueeze(0)).item()
qv_cos = F.cosine_similarity(qf[:d].unsqueeze(0), vf[:d].unsqueeze(0)).item()
kv_cos = F.cosine_similarity(kf[:d].unsqueeze(0), vf[:d].unsqueeze(0)).item()
if layer_idx < 3 or layer_idx >= n_layers - 2 or layer_idx == n_layers // 2:
print(f" {layer_idx:>6} {q_norm:>8.3f} {k_norm:>8.3f} {v_norm:>8.3f} "
f"{qk_cos:>8.4f} {qv_cos:>8.4f} {kv_cos:>8.4f}")
elif layer_idx == 3:
print(f" {'...':>6}")
# ══════════════════════════════════════════════════════════════════
# SCAN 6: CROSS-MODEL QK ALIGNMENT
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 6: CROSS-MODEL WEIGHT ALIGNMENT")
print(f"{'='*65}")
# Compare equivalent layers across models
# Use common dimension (768) β€” all three output 768-d
# Compare Q weights, K weights, V weights at equivalent depth fractions
model_qkv = {}
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
model_qkv[name] = extract_qkv_weights(models[name], name)
print(f"\n Cross-model Q weight cosine at equivalent depth fractions:")
print(f" {'depth':>6} {'clipΓ—dino':>10} {'clipΓ—siglip':>12} {'dinoΓ—siglip':>12}")
for name in model_qkv:
n = len(model_qkv[name]) // 3
print(f" {name}: {n} layers")
# Compare at 0%, 25%, 50%, 75%, 100% depth
for frac in [0.0, 0.25, 0.5, 0.75, 1.0]:
vals = {}
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
qkv = model_qkv[name]
n = len(qkv) // 3
idx = min(int(frac * (n - 1)), n - 1)
q = qkv[idx * 3]["weight"].reshape(-1)
vals[name] = q
# Truncate to common length
min_len = min(v.shape[0] for v in vals.values())
cos_cd = F.cosine_similarity(
vals["clip_l14"][:min_len].unsqueeze(0),
vals["dinov2_b14"][:min_len].unsqueeze(0)).item()
cos_cs = F.cosine_similarity(
vals["clip_l14"][:min_len].unsqueeze(0),
vals["siglip_b16"][:min_len].unsqueeze(0)).item()
cos_ds = F.cosine_similarity(
vals["dinov2_b14"][:min_len].unsqueeze(0),
vals["siglip_b16"][:min_len].unsqueeze(0)).item()
print(f" {frac:>5.0%} {cos_cd:>10.4f} {cos_cs:>12.4f} {cos_ds:>12.4f}")
# ══════════════════════════════════════════════════════════════════
# SCAN 7: MLP WEIGHT SPECTRUM
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 7: MLP WEIGHT SPECTRUM")
print(f"{'='*65}")
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
model = models[name]
mlp_weights = []
for pname, p in model.named_parameters():
if p.dim() == 2 and ("fc1" in pname or "fc2" in pname or
("intermediate" in pname and "dense" in pname and "weight" in pname) or
("output" in pname and "dense" in pname and "weight" in pname and "attention" not in pname)):
mlp_weights.append((pname, p.detach().float()))
print(f"\n {name} MLPs ({len(mlp_weights)} weight matrices):")
for pname, w in mlp_weights[:6]: # first 3 layers
sv = torch.linalg.svdvals(w)
eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
short = pname.split(".")[-3] + "." + pname.split(".")[-2] + "." + pname.split(".")[-1]
print(f" {short:<40} {str(list(w.shape)):<20} "
f"eff_rank={eff_rank:>6.1f}/{min(w.shape)} "
f"sv_max={sv[0]:.3f} sv_10={sv[min(9,len(sv)-1)]:.4f}")
if len(mlp_weights) > 6:
print(f" ... ({len(mlp_weights) - 6} more)")
# ══════════════════════════════════════════════════════════════════
# SCAN 8: POSITION EMBEDDING ANALYSIS
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 8: POSITION EMBEDDINGS")
print(f"{'='*65}")
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
model = models[name]
for pname, p in model.named_parameters():
if "position" in pname.lower() and "embed" in pname.lower():
pe = p.detach().float()
print(f"\n {name}: {pname}")
print(f" Shape: {list(pe.shape)}")
print(f" Norm: {pe.norm():.4f} Mean: {pe.mean():.6f} Std: {pe.std():.6f}")
if pe.dim() >= 2:
# Self-similarity of position embeddings
if pe.dim() == 3:
pe2d = pe.squeeze(0)
else:
pe2d = pe
sim = F.cosine_similarity(pe2d.unsqueeze(0), pe2d.unsqueeze(1), dim=-1)
print(f" Self-sim: diag_mean={sim.diag().mean():.4f} "
f"off_diag_mean={(sim.sum()-sim.diag().sum()).item()/(sim.numel()-sim.shape[0]):.4f}")
print(f" Adjacent pos cos: mean={F.cosine_similarity(pe2d[:-1], pe2d[1:], dim=-1).mean():.4f}")
# SVD of position embeddings
sv = torch.linalg.svdvals(pe2d)
eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item()
print(f" Spectral: eff_rank={eff_rank:.1f}/{min(pe2d.shape)} "
f"sv1%={sv[0].pow(2).item()/sv.pow(2).sum().item()*100:.1f}%")
break
# ══════════════════════════════════════════════════════════════════
# SCAN 9: LAYERNORM ANALYSIS
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("SCAN 9: LAYERNORM WEIGHT/BIAS PATTERNS")
print(f"{'='*65}")
for name in ["clip_l14", "dinov2_b14", "siglip_b16"]:
model = models[name]
ln_weights = []
ln_biases = []
for pname, p in model.named_parameters():
if ("norm" in pname.lower() or "layer_norm" in pname.lower()):
if "weight" in pname:
ln_weights.append((pname, p.detach().float()))
elif "bias" in pname:
ln_biases.append((pname, p.detach().float()))
print(f"\n {name} ({len(ln_weights)} LayerNorms):")
for (wn, w), (bn, b) in zip(ln_weights[:4], ln_biases[:4]):
short = wn.split(".")[-3] + "." + wn.split(".")[-2]
print(f" {short:<30} w: mean={w.mean():.4f} std={w.std():.4f} "
f"b: mean={b.mean():.5f} std={b.std():.4f}")
# Final LayerNorm
if ln_weights:
wn, w = ln_weights[-1]
bn, b = ln_biases[-1] if ln_biases else ("", torch.zeros_like(w))
print(f" FINAL: {wn}")
print(f" weight: mean={w.mean():.4f} std={w.std():.4f} "
f"min={w.min():.4f} max={w.max():.4f}")
if ln_biases:
print(f" bias: mean={b.mean():.5f} std={b.std():.4f}")
# ══════════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("ANALYSIS COMPLETE")
print(f"{'='*65}")
# Clean up
del models, configs
gc.collect()
torch.cuda.empty_cache()