#!/usr/bin/env python3 """ BASE TIER DEEP MODEL ANALYSIS =============================== Tuned specifically for the top 3 reasonable targets with high enough similarity to geometrically align. This helps align the procrustes experiment to save time. Three models, all 768-d output, all patch-based ViTs: 1. clip_l14_openai — CLIP ViT-L/14 (text-supervised, semantic) 2. dinov2_b14 — DINOv2 ViT-B/14 (self-supervised, structural) 3. siglip_b16_384 — SigLIP ViT-B/16 (sigmoid contrastive, semantic) Analyze: - Full architecture comparison (layers, heads, dims, patch size) - Weight statistics per layer (norms, spectral radius, sparsity) - Attention head geometry (Q/K/V weight structure) - Layer-by-layer representation similarity (CKA, Procrustes) - Patch embedding weight comparison (the actual patchwork) - MLP weight spectrum analysis - Where do they converge internally vs diverge? """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import json import gc DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("=" * 65) print("BASE TIER DEEP MODEL ANALYSIS") print("=" * 65) print(f" Device: {DEVICE}") # ══════════════════════════════════════════════════════════════════ # LOAD MODELS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("LOADING MODELS") print(f"{'='*65}") from transformers import ( CLIPVisionModel, CLIPVisionConfig, Dinov2Model, Dinov2Config, SiglipVisionModel, SiglipVisionConfig, ) models = {} configs = {} # CLIP ViT-L/14 print(f"\n Loading CLIP ViT-L/14...") clip = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").eval() models["clip_l14"] = clip configs["clip_l14"] = clip.config print(f" Loaded: {sum(p.numel() for p in clip.parameters()):,} params") # DINOv2 ViT-B/14 print(f" Loading DINOv2 ViT-B/14...") dino = Dinov2Model.from_pretrained("facebook/dinov2-base").eval() models["dinov2_b14"] = dino configs["dinov2_b14"] = dino.config print(f" Loaded: {sum(p.numel() for p in dino.parameters()):,} params") # SigLIP ViT-B/16 print(f" Loading SigLIP ViT-B/16-384...") siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-384").eval() models["siglip_b16"] = siglip configs["siglip_b16"] = siglip.config print(f" Loaded: {sum(p.numel() for p in siglip.parameters()):,} params") # ══════════════════════════════════════════════════════════════════ # SCAN 1: ARCHITECTURE COMPARISON # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 1: ARCHITECTURE COMPARISON") print(f"{'='*65}") def get_arch_info(name, model, config): info = {"name": name} c = config if hasattr(c, 'hidden_size'): info["hidden_size"] = c.hidden_size if hasattr(c, 'intermediate_size'): info["intermediate_size"] = c.intermediate_size if hasattr(c, 'num_hidden_layers'): info["num_layers"] = c.num_hidden_layers if hasattr(c, 'num_attention_heads'): info["num_heads"] = c.num_attention_heads if hasattr(c, 'patch_size'): info["patch_size"] = c.patch_size if hasattr(c, 'image_size'): info["image_size"] = c.image_size info["total_params"] = sum(p.numel() for p in model.parameters()) info["head_dim"] = info.get("hidden_size", 0) // max(info.get("num_heads", 1), 1) return info for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: info = get_arch_info(name, models[name], configs[name]) print(f"\n {name}:") for k, v in info.items(): if k != "name": print(f" {k:<20}: {v:>12,}" if isinstance(v, int) else f" {k:<20}: {v}") # ══════════════════════════════════════════════════════════════════ # SCAN 2: NAMED PARAMETER INVENTORY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 2: PARAMETER INVENTORY") print(f"{'='*65}") for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: model = models[name] print(f"\n {name}:") # Group by layer type groups = {} for pname, p in model.named_parameters(): # Extract layer category parts = pname.split(".") if "embeddings" in pname: cat = "embeddings" elif "encoder" in pname and "layer" in pname: # Find layer number for part in parts: if part.startswith("layer"): break # Categorize within layer if "attention" in pname: if "query" in pname or "q_proj" in pname or "k_proj" in pname or "v_proj" in pname: cat = "attn_qkv" elif "out" in pname or "o_proj" in pname: cat = "attn_out" else: cat = "attn_other" elif "mlp" in pname or "intermediate" in pname or "output" in pname: cat = "mlp" elif "norm" in pname or "layer_norm" in pname: cat = "layernorm" else: cat = "encoder_other" elif "layernorm" in pname.lower() or "layer_norm" in pname.lower(): cat = "final_norm" elif "head" in pname or "pooler" in pname: cat = "head" else: cat = "other" groups.setdefault(cat, {"count": 0, "params": 0, "shapes": []}) groups[cat]["count"] += 1 groups[cat]["params"] += p.numel() if len(groups[cat]["shapes"]) < 3: groups[cat]["shapes"].append(f"{pname.split('.')[-2]}.{pname.split('.')[-1]}: {list(p.shape)}") for cat in sorted(groups.keys()): g = groups[cat] print(f" {cat:<15}: {g['params']:>12,} ({g['count']:2d} tensors)") for s in g["shapes"]: print(f" {s}") # ══════════════════════════════════════════════════════════════════ # SCAN 3: WEIGHT STATISTICS PER LAYER # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 3: WEIGHT STATISTICS") print(f"{'='*65}") def weight_stats(param): p = param.float().detach() stats = { "shape": list(p.shape), "norm": p.norm().item(), "mean": p.mean().item(), "std": p.std().item(), "abs_max": p.abs().max().item(), "sparsity": (p.abs() < 1e-6).float().mean().item(), } # Spectral radius for 2D weights if p.dim() == 2 and min(p.shape) > 1: sv = torch.linalg.svdvals(p) stats["sv_max"] = sv[0].item() stats["sv_min"] = sv[-1].item() stats["sv_ratio"] = (sv[0] / (sv[-1] + 1e-10)).item() stats["eff_rank"] = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() return stats for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: model = models[name] print(f"\n {name} — key weight matrices:") print(f" {'param':<50} {'shape':<20} {'norm':>8} {'std':>8} {'sv_max':>8} {'eff_rank':>9}") print(f" {'-'*105}") for pname, p in model.named_parameters(): if p.dim() < 2: continue if p.numel() < 1000: continue # Only show interesting layers show = False for keyword in ["patch", "embed", "position", "cls", "layer.0.", "layer.5.", "layer.11.", "layer.23.", "q_proj", "k_proj", "v_proj", "query", "key", "value", "fc1", "fc2", "dense", "out_proj", "layernorm", "head"]: if keyword in pname.lower(): show = True; break if not show: continue s = weight_stats(p) sv_max = f"{s.get('sv_max', 0):.4f}" if 'sv_max' in s else " N/A" eff_rank = f"{s.get('eff_rank', 0):.1f}" if 'eff_rank' in s else " N/A" short_name = pname[-50:] if len(pname) > 50 else pname shape_str = str(s["shape"]) print(f" {short_name:<50} {shape_str:<20} {s['norm']:>8.4f} " f"{s['std']:>8.5f} {sv_max:>8} {eff_rank:>9}") # ══════════════════════════════════════════════════════════════════ # SCAN 4: PATCH EMBEDDING ANALYSIS (the actual patchwork) # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 4: PATCH EMBEDDING WEIGHTS") print(f"{'='*65}") patch_embeddings = {} for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: model = models[name] for pname, p in model.named_parameters(): if "patch" in pname.lower() and "embed" in pname.lower() and p.dim() == 4: patch_embeddings[name] = p.detach().float() print(f"\n {name}: {pname}") print(f" Shape: {list(p.shape)}") # (out_channels, in_channels, kernel_h, kernel_w) print(f" = {p.shape[0]} filters × {p.shape[1]} channels × {p.shape[2]}×{p.shape[3]} kernel") # Reshape to 2D for spectral analysis w2d = p.detach().float().reshape(p.shape[0], -1) # (out, in*h*w) sv = torch.linalg.svdvals(w2d) eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() print(f" Spectral: sv_max={sv[0]:.4f} sv_min={sv[-1]:.6f} " f"eff_rank={eff_rank:.1f}/{min(w2d.shape)}") print(f" Norm: {p.norm():.4f} Mean: {p.mean():.6f} Std: {p.std():.6f}") # Per-filter analysis filter_norms = p.detach().float().reshape(p.shape[0], -1).norm(dim=1) print(f" Filter norms: mean={filter_norms.mean():.4f} " f"std={filter_norms.std():.4f} " f"min={filter_norms.min():.4f} max={filter_norms.max():.4f}") break # Compare patch embeddings pairwise (Procrustes on flattened filters) if len(patch_embeddings) >= 2: print(f"\n Patch embedding Procrustes alignment:") names_list = list(patch_embeddings.keys()) for i in range(len(names_list)): for j in range(i+1, len(names_list)): n1, n2 = names_list[i], names_list[j] p1 = patch_embeddings[n1].reshape(patch_embeddings[n1].shape[0], -1) p2 = patch_embeddings[n2].reshape(patch_embeddings[n2].shape[0], -1) # Truncate to common dim d_min = min(p1.shape[0], p2.shape[0]) d_feat = min(p1.shape[1], p2.shape[1]) a = p1[:d_min, :d_feat]; b = p2[:d_min, :d_feat] # Raw cosine (mean over filters) cos = F.cosine_similarity( F.normalize(a, dim=1), F.normalize(b, dim=1), dim=1).mean().item() print(f" {n1} × {n2}: raw_cos={cos:.4f} (d_min={d_min}, d_feat={d_feat})") # ══════════════════════════════════════════════════════════════════ # SCAN 5: ATTENTION HEAD GEOMETRY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 5: ATTENTION HEAD GEOMETRY") print(f"{'='*65}") def extract_qkv_weights(model, name): """Extract Q, K, V weight matrices from each layer.""" layers_qkv = [] for pname, p in model.named_parameters(): if p.dim() != 2: continue plow = pname.lower() if ("query" in plow or "q_proj" in plow) and "weight" in plow: layers_qkv.append({"layer": pname, "type": "Q", "weight": p.detach().float()}) elif ("key" in plow or "k_proj" in plow) and "weight" in plow: layers_qkv.append({"layer": pname, "type": "K", "weight": p.detach().float()}) elif ("value" in plow or "v_proj" in plow) and "weight" in plow: layers_qkv.append({"layer": pname, "type": "V", "weight": p.detach().float()}) return layers_qkv for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: qkv = extract_qkv_weights(models[name], name) n_layers = len(qkv) // 3 print(f"\n {name} ({n_layers} layers):") print(f" {'layer':>6} {'Q_norm':>8} {'K_norm':>8} {'V_norm':>8} " f"{'QK_cos':>8} {'QV_cos':>8} {'KV_cos':>8}") for layer_idx in range(n_layers): q = qkv[layer_idx * 3]["weight"] k = qkv[layer_idx * 3 + 1]["weight"] v = qkv[layer_idx * 3 + 2]["weight"] q_norm = q.norm().item() k_norm = k.norm().item() v_norm = v.norm().item() # Flatten and compute cosine between Q/K, Q/V, K/V qf = q.reshape(-1); kf = k.reshape(-1); vf = v.reshape(-1) d = min(qf.shape[0], kf.shape[0], vf.shape[0]) qk_cos = F.cosine_similarity(qf[:d].unsqueeze(0), kf[:d].unsqueeze(0)).item() qv_cos = F.cosine_similarity(qf[:d].unsqueeze(0), vf[:d].unsqueeze(0)).item() kv_cos = F.cosine_similarity(kf[:d].unsqueeze(0), vf[:d].unsqueeze(0)).item() if layer_idx < 3 or layer_idx >= n_layers - 2 or layer_idx == n_layers // 2: print(f" {layer_idx:>6} {q_norm:>8.3f} {k_norm:>8.3f} {v_norm:>8.3f} " f"{qk_cos:>8.4f} {qv_cos:>8.4f} {kv_cos:>8.4f}") elif layer_idx == 3: print(f" {'...':>6}") # ══════════════════════════════════════════════════════════════════ # SCAN 6: CROSS-MODEL QK ALIGNMENT # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 6: CROSS-MODEL WEIGHT ALIGNMENT") print(f"{'='*65}") # Compare equivalent layers across models # Use common dimension (768) — all three output 768-d # Compare Q weights, K weights, V weights at equivalent depth fractions model_qkv = {} for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: model_qkv[name] = extract_qkv_weights(models[name], name) print(f"\n Cross-model Q weight cosine at equivalent depth fractions:") print(f" {'depth':>6} {'clip×dino':>10} {'clip×siglip':>12} {'dino×siglip':>12}") for name in model_qkv: n = len(model_qkv[name]) // 3 print(f" {name}: {n} layers") # Compare at 0%, 25%, 50%, 75%, 100% depth for frac in [0.0, 0.25, 0.5, 0.75, 1.0]: vals = {} for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: qkv = model_qkv[name] n = len(qkv) // 3 idx = min(int(frac * (n - 1)), n - 1) q = qkv[idx * 3]["weight"].reshape(-1) vals[name] = q # Truncate to common length min_len = min(v.shape[0] for v in vals.values()) cos_cd = F.cosine_similarity( vals["clip_l14"][:min_len].unsqueeze(0), vals["dinov2_b14"][:min_len].unsqueeze(0)).item() cos_cs = F.cosine_similarity( vals["clip_l14"][:min_len].unsqueeze(0), vals["siglip_b16"][:min_len].unsqueeze(0)).item() cos_ds = F.cosine_similarity( vals["dinov2_b14"][:min_len].unsqueeze(0), vals["siglip_b16"][:min_len].unsqueeze(0)).item() print(f" {frac:>5.0%} {cos_cd:>10.4f} {cos_cs:>12.4f} {cos_ds:>12.4f}") # ══════════════════════════════════════════════════════════════════ # SCAN 7: MLP WEIGHT SPECTRUM # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 7: MLP WEIGHT SPECTRUM") print(f"{'='*65}") for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: model = models[name] mlp_weights = [] for pname, p in model.named_parameters(): if p.dim() == 2 and ("fc1" in pname or "fc2" in pname or ("intermediate" in pname and "dense" in pname and "weight" in pname) or ("output" in pname and "dense" in pname and "weight" in pname and "attention" not in pname)): mlp_weights.append((pname, p.detach().float())) print(f"\n {name} MLPs ({len(mlp_weights)} weight matrices):") for pname, w in mlp_weights[:6]: # first 3 layers sv = torch.linalg.svdvals(w) eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() short = pname.split(".")[-3] + "." + pname.split(".")[-2] + "." + pname.split(".")[-1] print(f" {short:<40} {str(list(w.shape)):<20} " f"eff_rank={eff_rank:>6.1f}/{min(w.shape)} " f"sv_max={sv[0]:.3f} sv_10={sv[min(9,len(sv)-1)]:.4f}") if len(mlp_weights) > 6: print(f" ... ({len(mlp_weights) - 6} more)") # ══════════════════════════════════════════════════════════════════ # SCAN 8: POSITION EMBEDDING ANALYSIS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 8: POSITION EMBEDDINGS") print(f"{'='*65}") for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: model = models[name] for pname, p in model.named_parameters(): if "position" in pname.lower() and "embed" in pname.lower(): pe = p.detach().float() print(f"\n {name}: {pname}") print(f" Shape: {list(pe.shape)}") print(f" Norm: {pe.norm():.4f} Mean: {pe.mean():.6f} Std: {pe.std():.6f}") if pe.dim() >= 2: # Self-similarity of position embeddings if pe.dim() == 3: pe2d = pe.squeeze(0) else: pe2d = pe sim = F.cosine_similarity(pe2d.unsqueeze(0), pe2d.unsqueeze(1), dim=-1) print(f" Self-sim: diag_mean={sim.diag().mean():.4f} " f"off_diag_mean={(sim.sum()-sim.diag().sum()).item()/(sim.numel()-sim.shape[0]):.4f}") print(f" Adjacent pos cos: mean={F.cosine_similarity(pe2d[:-1], pe2d[1:], dim=-1).mean():.4f}") # SVD of position embeddings sv = torch.linalg.svdvals(pe2d) eff_rank = ((sv.sum()**2) / (sv.pow(2).sum() + 1e-12)).item() print(f" Spectral: eff_rank={eff_rank:.1f}/{min(pe2d.shape)} " f"sv1%={sv[0].pow(2).item()/sv.pow(2).sum().item()*100:.1f}%") break # ══════════════════════════════════════════════════════════════════ # SCAN 9: LAYERNORM ANALYSIS # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("SCAN 9: LAYERNORM WEIGHT/BIAS PATTERNS") print(f"{'='*65}") for name in ["clip_l14", "dinov2_b14", "siglip_b16"]: model = models[name] ln_weights = [] ln_biases = [] for pname, p in model.named_parameters(): if ("norm" in pname.lower() or "layer_norm" in pname.lower()): if "weight" in pname: ln_weights.append((pname, p.detach().float())) elif "bias" in pname: ln_biases.append((pname, p.detach().float())) print(f"\n {name} ({len(ln_weights)} LayerNorms):") for (wn, w), (bn, b) in zip(ln_weights[:4], ln_biases[:4]): short = wn.split(".")[-3] + "." + wn.split(".")[-2] print(f" {short:<30} w: mean={w.mean():.4f} std={w.std():.4f} " f"b: mean={b.mean():.5f} std={b.std():.4f}") # Final LayerNorm if ln_weights: wn, w = ln_weights[-1] bn, b = ln_biases[-1] if ln_biases else ("", torch.zeros_like(w)) print(f" FINAL: {wn}") print(f" weight: mean={w.mean():.4f} std={w.std():.4f} " f"min={w.min():.4f} max={w.max():.4f}") if ln_biases: print(f" bias: mean={b.mean():.5f} std={b.std():.4f}") # ══════════════════════════════════════════════════════════════════ # SUMMARY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("ANALYSIS COMPLETE") print(f"{'='*65}") # Clean up del models, configs gc.collect() torch.cuda.empty_cache()