""" Cell 5: Multi-VAE Geometric Comparison ======================================== Run after Cells 1-4. Reuses existing pipeline. Processes 4 VAEs sequentially: SD 1.5 → 4ch × 64×64 (512px input) SDXL → 4ch × 128×128 (1024px input) Flux.1 → 16ch × 128×128 (1024px input) Flux.2 → 16ch × 128×128 (1024px input) Each: load VAE → encode → free → cluster → extract → store Then: comparative diagnostics """ import os, time, json, zipfile, math import numpy as np import torch import torch.nn.functional as F from pathlib import Path from tqdm.auto import tqdm from collections import Counter # === 0. Images ================================================================ print("=" * 70) print("Multi-VAE Geometric Comparison Pipeline") print("=" * 70) IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'} HF_REPO = "AbstractPhil/grid-geometric-classifier-sliding-proto" HF_ZIP = "mega_liminal_captioned.zip" SKIP_DIRS = {'.config', 'sample_data', 'drive', '__pycache__', 'checkpoints_vae_ca', 'latent_cache_sd15', 'latent_cache_sdxl', 'latent_cache_flux1', 'latent_cache_flux2', 'latent_cache_mega_sd15', 'latent_cache_mega_sdxl', 'latent_cache_mega_flux1', 'latent_cache_mega_flux2'} def find_images(directory): return sorted([ os.path.join(r, f) for r, _, fs in os.walk(directory) for f in fs if Path(f).suffix.lower() in IMAGE_EXTENSIONS ]) if os.path.exists(directory) else [] def scan_content_for_images(min_count=100): """Scan /content/ for the largest image directory.""" best_dir, best_imgs = None, [] for d in sorted(os.listdir('/content/')): full = f'/content/{d}' if os.path.isdir(full) and d not in SKIP_DIRS: found = find_images(full) if len(found) > len(best_imgs): best_dir, best_imgs = full, found if len(best_imgs) >= min_count: return best_dir, best_imgs return None, [] LIMINAL_DIR, image_paths = scan_content_for_images() if LIMINAL_DIR is None: try: from huggingface_hub import hf_hub_download except ImportError: os.system('pip install -q huggingface_hub') from huggingface_hub import hf_hub_download print(f"Downloading {HF_ZIP} from {HF_REPO}...") zip_path = hf_hub_download(repo_id=HF_REPO, filename=HF_ZIP) with zipfile.ZipFile(zip_path, 'r') as z: z.extractall('/content/') LIMINAL_DIR, image_paths = scan_content_for_images() assert LIMINAL_DIR and len(image_paths) > 0, "No images found in /content/" print(f"Found {len(image_paths)} images in {LIMINAL_DIR}") # === 1. Classifier ============================================================ print("\n" + "=" * 70) print("Step 1: Classifier") print("=" * 70) ckpt = '/content/best_vae_ca_classifier.pt' if not os.path.exists(ckpt): ckpt = '/content/checkpoints_vae_ca/best.pt' model = PatchCrossAttentionClassifier(n_classes=NUM_CLASSES) model.load_state_dict(torch.load(ckpt, map_location='cpu', weights_only=True)) device = torch.device('cuda') model = model.to(device).eval() print(f"Loaded {sum(p.numel() for p in model.parameters()):,} params") # === 2. VAE Definitions ====================================================== try: from diffusers import AutoencoderKL except ImportError: os.system('pip install -q diffusers transformers accelerate') from diffusers import AutoencoderKL from torchvision import transforms from PIL import Image VAE_CONFIGS = [ { 'name': 'SD 1.5', 'model_id': 'stable-diffusion-v1-5/stable-diffusion-v1-5', 'subfolder': 'vae', 'input_res': 512, 'dtype': torch.float16, 'cache_dir': '/content/latent_cache_mega_sd15', }, { 'name': 'SDXL', 'model_id': 'madebyollin/sdxl-vae-fp16-fix', 'subfolder': None, 'input_res': 1024, 'dtype': torch.float16, 'cache_dir': '/content/latent_cache_mega_sdxl', }, { 'name': 'Flux.1', 'model_id': 'black-forest-labs/FLUX.1-dev', 'subfolder': 'vae', 'input_res': 1024, 'dtype': torch.bfloat16, 'cache_dir': '/content/latent_cache_mega_flux1', }, { 'name': 'Flux.2', 'model_id': 'black-forest-labs/FLUX.2-dev', 'subfolder': 'vae', 'input_res': 1024, 'dtype': torch.bfloat16, 'cache_dir': '/content/latent_cache_mega_flux2', }, ] def get_scales_for_latent(C, H, W): """Compute appropriate scales based on actual latent dimensions. No L3 (noise).""" scales = [] # L0: full latent (or capped) scales.append((min(C, 16), min(H, 64), min(W, 64))) # L1: regional scales.append((min(C, 8), min(H, 32), min(W, 32))) # L2: native patch (classifier resolution) scales.append((min(C, 8), 16, 16)) return scales def encode_dataset(vae, image_paths, cache_dir, input_res, dtype, batch_size=4): """Encode images, return list of cache paths.""" os.makedirs(cache_dir, exist_ok=True) transform = transforms.Compose([ transforms.Resize((input_res, input_res)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3), ]) latent_paths = [] need_encode = [] for p in image_paths: name = Path(p).stem cp = os.path.join(cache_dir, f'{name}.pt') latent_paths.append(cp) if not os.path.exists(cp): need_encode.append((p, cp)) if not need_encode: return latent_paths for i in tqdm(range(0, len(need_encode), batch_size), desc="Encoding", unit="batch"): batch_items = need_encode[i:i+batch_size] imgs = [] for p, cp in batch_items: try: imgs.append((transform(Image.open(p).convert('RGB')), cp)) except Exception: pass if imgs: batch = torch.stack([im[0] for im in imgs]).to(device, dtype=dtype) with torch.no_grad(): latents = vae.encode(batch).latent_dist.mean for j, (_, cp) in enumerate(imgs): torch.save(latents[j].cpu().float(), cp) del batch, latents return latent_paths # === 3. Process Each VAE ===================================================== all_vae_results = {} for vconf in VAE_CONFIGS: vname = vconf['name'] print("\n" + "=" * 70) print(f"Processing: {vname}") print(f" Model: {vconf['model_id']}") print("=" * 70) # --- Encode --- try: load_kwargs = dict(torch_dtype=vconf['dtype']) if vconf['subfolder']: load_kwargs['subfolder'] = vconf['subfolder'] vae = AutoencoderKL.from_pretrained( vconf['model_id'], **load_kwargs ).to(device).eval() except Exception as e: print(f" ⚠ Failed to load {vname}: {e}") print(f" Skipping...") continue latent_paths = encode_dataset( vae, image_paths, vconf['cache_dir'], vconf['input_res'], vconf['dtype']) if not latent_paths: print(f" ⚠ No latents produced for {vname}, skipping") del vae torch.cuda.empty_cache() continue # Sanity check: if first latent is NaN, purge cache and re-encode sample = torch.load(latent_paths[0]) if torch.isnan(sample).any(): print(f" ⚠ NaN detected in cached latents — purging {vconf['cache_dir']}") import shutil shutil.rmtree(vconf['cache_dir']) latent_paths = encode_dataset( vae, image_paths, vconf['cache_dir'], vconf['input_res'], vconf['dtype']) sample = torch.load(latent_paths[0]) C, H, W = sample.shape print(f" Latent: ({C}, {H}, {W}) " f"mean={sample.mean():.3f} std={sample.std():.3f} " f"[{sample.min():.3f}, {sample.max():.3f}]") del sample # Free VAE immediately del vae torch.cuda.empty_cache() # --- Cluster --- N_CL = min(100, len(latent_paths)) sample_batch = torch.stack([torch.load(latent_paths[i]) for i in range(N_CL)]).to(device) channel_groups, corr = cluster_channels_gpu(sample_batch, n_groups=min(8, C)) print(f" Groups: {channel_groups}") del sample_batch # --- Extract --- scales = get_scales_for_latent(C, H, W) print(f" Scales: {scales}") config = ExtractionConfig( confidence_threshold=0.6, min_occupancy=0.01, image_batch_size=32, ) config.scales = scales extractor = MultiScaleExtractor(model, config) IMG_BATCH = config.image_batch_size vae_annotations = [] vae_records = [] for batch_start in tqdm(range(0, len(latent_paths), IMG_BATCH), desc=f"{vname} extract", unit=f"×{IMG_BATCH}"): bp = latent_paths[batch_start:batch_start + IMG_BATCH] names = [Path(p).stem for p in bp] latents = [torch.load(p).to(device) for p in bp] batch_results = extractor.extract_batch(latents, channel_groups) del latents for b_idx, result in enumerate(batch_results): raw = result['raw_annotations'] dev = result['deviance_annotations'] all_anns = raw + dev vae_records.append({ 'name': names[b_idx], 'n_total': len(all_anns), 'n_raw': len(raw), 'n_deviance': len(dev), 'classes': Counter(a.class_name for a in all_anns), 'confidences': [a.confidence for a in all_anns], 'scales': Counter(a.scale_level for a in all_anns), 'dimensions': Counter(a.dimension for a in all_anns), 'curved': sum(1 for a in all_anns if a.is_curved), }) for a in all_anns: vae_annotations.append({ 'class': a.class_name, 'confidence': a.confidence, 'scale': a.scale_level, 'dimension': a.dimension, 'curved': a.is_curved, 'curvature': a.curvature_type, 'source': a.source, }) # Summarize total = len(vae_annotations) cls_counts = Counter(a['class'] for a in vae_annotations) confs = [a['confidence'] for a in vae_annotations] all_vae_results[vname] = { 'latent_shape': (C, H, W), 'n_images': len(vae_records), 'total_annotations': total, 'class_counts': cls_counts, 'mean_confidence': float(np.mean(confs)) if confs else 0, 'std_confidence': float(np.std(confs)) if confs else 0, 'records': vae_records, 'channel_groups': channel_groups, 'scales': scales, } if confs: print(f" {vname}: {total:,} annotations, conf={np.mean(confs):.3f}") else: print(f" {vname}: 0 annotations (check latent stats above)") top5 = cls_counts.most_common(5) for cls, cnt in top5: print(f" {cls:20s} {cnt:>10,} ({cnt/max(total,1)*100:5.1f}%)") del vae_annotations, vae_records torch.cuda.empty_cache() # ============================================================================= # COMPARATIVE ANALYSIS # ============================================================================= print("\n" + "=" * 70) print("COMPARATIVE ANALYSIS: VAE Geometric Structures") print("=" * 70) vae_names = list(all_vae_results.keys()) # --- Table: Overview --- print(f"\n{'─'*70}") print(f" {'VAE':12s} {'Latent':>14s} {'Ann':>12s} {'Ann/img':>8s} " f"{'Conf':>6s} {'Classes':>7s}") print(f"{'─'*70}") for vn in vae_names: r = all_vae_results[vn] sh = f"{r['latent_shape'][0]}×{r['latent_shape'][1]}×{r['latent_shape'][2]}" n_cls = len(r['class_counts']) ann_per = r['total_annotations'] / max(r['n_images'], 1) print(f" {vn:12s} {sh:>14s} {r['total_annotations']:>12,} {ann_per:>8.0f} " f"{r['mean_confidence']:>6.3f} {n_cls:>4d}/38") # --- Table: Top-5 classes per VAE --- print(f"\n{'─'*70}") print("TOP-5 CLASSES PER VAE") print(f"{'─'*70}") for vn in vae_names: r = all_vae_results[vn] total = r['total_annotations'] top5 = r['class_counts'].most_common(5) classes_str = " ".join(f"{c}:{cnt/max(total,1)*100:.0f}%" for c, cnt in top5) print(f" {vn:12s} {classes_str}") # --- Class presence heatmap (which classes appear in which VAEs) --- print(f"\n{'─'*70}") print("CLASS PRESENCE ACROSS VAEs (>0.5% of annotations)") print(f"{'─'*70}") all_classes_seen = set() for vn in vae_names: for cls in all_vae_results[vn]['class_counts']: all_classes_seen.add(cls) # Sort by total frequency class_totals = Counter() for vn in vae_names: class_totals.update(all_vae_results[vn]['class_counts']) print(f" {'Class':20s}", end="") for vn in vae_names: print(f" {vn:>10s}", end="") print() for cls, _ in class_totals.most_common(): row_vals = [] any_significant = False for vn in vae_names: r = all_vae_results[vn] total = max(r['total_annotations'], 1) cnt = r['class_counts'].get(cls, 0) pct = cnt / total * 100 row_vals.append(pct) if pct >= 0.5: any_significant = True if any_significant: print(f" {cls:20s}", end="") for pct in row_vals: if pct >= 5: print(f" {pct:>9.1f}%", end="") elif pct >= 0.5: print(f" {pct:>9.1f}%", end="") elif pct > 0: print(f" trace", end="") else: print(f" —", end="") print() # --- Geometric fingerprint comparison --- print(f"\n{'─'*70}") print("GEOMETRIC FINGERPRINT SIMILARITY (cosine between class distributions)") print(f"{'─'*70}") if len(vae_names) >= 2: vecs = {} for vn in vae_names: r = all_vae_results[vn] total = max(r['total_annotations'], 1) vec = np.zeros(len(CLASS_NAMES)) for cls, cnt in r['class_counts'].items(): vec[CLASS_NAMES.index(cls)] = cnt / total vecs[vn] = vec print(f" {'':12s}", end="") for vn in vae_names: print(f" {vn:>10s}", end="") print() for vn1 in vae_names: print(f" {vn1:12s}", end="") for vn2 in vae_names: v1, v2 = vecs[vn1], vecs[vn2] n1, n2 = np.linalg.norm(v1), np.linalg.norm(v2) if n1 > 0 and n2 > 0: sim = np.dot(v1, v2) / (n1 * n2) else: sim = 0 print(f" {sim:>10.3f}", end="") print() # --- Dimensional comparison --- print(f"\n{'─'*70}") print("DIMENSIONAL DISTRIBUTION") print(f"{'─'*70}") print(f" {'VAE':12s} {'0D':>8s} {'1D':>8s} {'2D':>8s} {'3D':>8s} {'Curved':>8s}") for vn in vae_names: r = all_vae_results[vn] total = max(r['total_annotations'], 1) dim_c = Counter() curved_n = 0 for rec in r['records']: dim_c.update(rec['dimensions']) curved_n += rec['curved'] print(f" {vn:12s}", end="") for d in range(4): pct = dim_c.get(d, 0) / total * 100 print(f" {pct:>7.1f}%", end="") print(f" {curved_n/total*100:>7.1f}%") # --- Channel group comparison --- print(f"\n{'─'*70}") print("CHANNEL GROUPS") print(f"{'─'*70}") for vn in vae_names: r = all_vae_results[vn] sh = r['latent_shape'] print(f" {vn:12s} ({sh[0]}ch): {r['channel_groups']}") # --- Per-image cross-VAE consistency --- if len(vae_names) >= 2: print(f"\n{'─'*70}") print("PER-IMAGE CROSS-VAE CONSISTENCY") print(f"{'─'*70}") print(" Do images that are geometrically distinct in one VAE stay distinct in another?") # Build per-image class vectors for each VAE per_vae_vectors = {} common_names = None for vn in vae_names: r = all_vae_results[vn] name_to_vec = {} for rec in r['records']: vec = np.zeros(len(CLASS_NAMES)) total = max(rec['n_total'], 1) for cls, cnt in rec['classes'].items(): vec[CLASS_NAMES.index(cls)] = cnt / total name_to_vec[rec['name']] = vec per_vae_vectors[vn] = name_to_vec names_set = set(name_to_vec.keys()) common_names = names_set if common_names is None else common_names & names_set common_names = sorted(common_names)[:200] # sample for speed if len(common_names) >= 10: for vn1_idx, vn1 in enumerate(vae_names): for vn2 in vae_names[vn1_idx+1:]: v1_mat = np.stack([per_vae_vectors[vn1][n] for n in common_names]) v2_mat = np.stack([per_vae_vectors[vn2][n] for n in common_names]) # Per-image cosine between VAEs norms1 = np.linalg.norm(v1_mat, axis=1, keepdims=True) norms2 = np.linalg.norm(v2_mat, axis=1, keepdims=True) norms1 = np.clip(norms1, 1e-8, None) norms2 = np.clip(norms2, 1e-8, None) cos = np.sum((v1_mat / norms1) * (v2_mat / norms2), axis=1) print(f" {vn1:12s} ↔ {vn2:12s}: " f"mean={cos.mean():.3f} std={cos.std():.3f} " f"[{cos.min():.3f}, {cos.max():.3f}]") if cos.mean() > 0.9: print(f" → Same geometric structure") elif cos.mean() > 0.7: print(f" → Similar structure") elif cos.mean() > 0.4: print(f" → Different structures") else: print(f" → Very different geometric encoding") # === Save ===================================================================== save_data = {} for vn in vae_names: r = all_vae_results[vn] save_data[vn] = { 'latent_shape': r['latent_shape'], 'n_images': r['n_images'], 'total_annotations': r['total_annotations'], 'class_counts': dict(r['class_counts']), 'mean_confidence': r['mean_confidence'], 'scales': [list(s) for s in r['scales']], 'channel_groups': r['channel_groups'], } with open('/content/multi_vae_comparison_mega.json', 'w') as f: json.dump(save_data, f, indent=2) print(f"\nSaved to /content/multi_vae_comparison_mega.json") print("=" * 70) print("✓ Multi-VAE comparison complete!") print("=" * 70)