Rename cell5_flux_vae_geometric_analysis.py to cell5_quad_vae_geometric_analysis.py
1f97150 verified | """ | |
| Cell 5: Multi-VAE Geometric Comparison | |
| ======================================== | |
| Run after Cells 1-4. Reuses existing pipeline. | |
| Processes 4 VAEs sequentially: | |
| SD 1.5 β 4ch Γ 64Γ64 (512px input) | |
| SDXL β 4ch Γ 128Γ128 (1024px input) | |
| Flux.1 β 16ch Γ 128Γ128 (1024px input) | |
| Flux.2 β 16ch Γ 128Γ128 (1024px input) | |
| Each: load VAE β encode β free β cluster β extract β store | |
| Then: comparative diagnostics | |
| """ | |
| import os, time, json, zipfile, math | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| from collections import Counter | |
| # === 0. Images ================================================================ | |
| print("=" * 70) | |
| print("Multi-VAE Geometric Comparison Pipeline") | |
| print("=" * 70) | |
| LIMINAL_DIR = '/content/liminal' | |
| HF_REPO = "AbstractPhil/grid-geometric-classifier-sliding-proto" | |
| IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'} | |
| def find_images(directory): | |
| return sorted([ | |
| os.path.join(r, f) for r, _, fs in os.walk(directory) | |
| for f in fs if Path(f).suffix.lower() in IMAGE_EXTENSIONS | |
| ]) if os.path.exists(directory) else [] | |
| image_paths = find_images(LIMINAL_DIR) | |
| if len(image_paths) == 0: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| except ImportError: | |
| os.system('pip install -q huggingface_hub') | |
| from huggingface_hub import hf_hub_download | |
| print(f"Downloading liminal.zip from {HF_REPO}...") | |
| zip_path = hf_hub_download(repo_id=HF_REPO, filename="liminal.zip") | |
| with zipfile.ZipFile(zip_path, 'r') as z: | |
| z.extractall('/content/') | |
| image_paths = find_images(LIMINAL_DIR) | |
| if len(image_paths) == 0: | |
| # Try finding extracted dir with a different name | |
| for d in os.listdir('/content/'): | |
| full = f'/content/{d}' | |
| if os.path.isdir(full) and d not in ['.config', 'sample_data', 'drive', '__pycache__']: | |
| found = find_images(full) | |
| if len(found) > 0: | |
| LIMINAL_DIR = full | |
| image_paths = found | |
| break | |
| assert len(image_paths) > 0, f"No images found. Check {LIMINAL_DIR}" | |
| print(f"Found {len(image_paths)} images") | |
| # === 1. Classifier ============================================================ | |
| print("\n" + "=" * 70) | |
| print("Step 1: Classifier") | |
| print("=" * 70) | |
| ckpt = '/content/best_vae_ca_classifier.pt' | |
| if not os.path.exists(ckpt): | |
| ckpt = '/content/checkpoints_vae_ca/best.pt' | |
| model = PatchCrossAttentionClassifier(n_classes=NUM_CLASSES) | |
| model.load_state_dict(torch.load(ckpt, map_location='cpu', weights_only=True)) | |
| device = torch.device('cuda') | |
| model = model.to(device).eval() | |
| print(f"Loaded {sum(p.numel() for p in model.parameters()):,} params") | |
| # === 2. VAE Definitions ====================================================== | |
| try: | |
| from diffusers import AutoencoderKL | |
| except ImportError: | |
| os.system('pip install -q diffusers transformers accelerate') | |
| from diffusers import AutoencoderKL | |
| from torchvision import transforms | |
| from PIL import Image | |
| VAE_CONFIGS = [ | |
| { | |
| 'name': 'SD 1.5', | |
| 'model_id': 'stable-diffusion-v1-5/stable-diffusion-v1-5', | |
| 'subfolder': 'vae', | |
| 'input_res': 512, | |
| 'dtype': torch.float16, | |
| 'cache_dir': '/content/latent_cache_sd15', | |
| }, | |
| { | |
| 'name': 'SDXL', | |
| 'model_id': 'madebyollin/sdxl-vae-fp16-fix', | |
| 'subfolder': None, | |
| 'input_res': 1024, | |
| 'dtype': torch.float16, | |
| 'cache_dir': '/content/latent_cache_sdxl', | |
| }, | |
| { | |
| 'name': 'Flux.1', | |
| 'model_id': 'black-forest-labs/FLUX.1-dev', | |
| 'subfolder': 'vae', | |
| 'input_res': 1024, | |
| 'dtype': torch.bfloat16, | |
| 'cache_dir': '/content/latent_cache_flux1', | |
| }, | |
| { | |
| 'name': 'Flux.2', | |
| 'model_id': 'black-forest-labs/FLUX.2-dev', | |
| 'subfolder': 'vae', | |
| 'input_res': 1024, | |
| 'dtype': torch.bfloat16, | |
| 'cache_dir': '/content/latent_cache_flux2', | |
| }, | |
| ] | |
| def get_scales_for_latent(C, H, W): | |
| """Compute appropriate scales based on actual latent dimensions. No L3 (noise).""" | |
| scales = [] | |
| # L0: full latent (or capped) | |
| scales.append((min(C, 16), min(H, 64), min(W, 64))) | |
| # L1: regional | |
| scales.append((min(C, 8), min(H, 32), min(W, 32))) | |
| # L2: native patch (classifier resolution) | |
| scales.append((min(C, 8), 16, 16)) | |
| return scales | |
| def encode_dataset(vae, image_paths, cache_dir, input_res, dtype, batch_size=4): | |
| """Encode images, return list of cache paths.""" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| transform = transforms.Compose([ | |
| transforms.Resize((input_res, input_res)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5]*3, [0.5]*3), | |
| ]) | |
| latent_paths = [] | |
| need_encode = [] | |
| for p in image_paths: | |
| name = Path(p).stem | |
| cp = os.path.join(cache_dir, f'{name}.pt') | |
| latent_paths.append(cp) | |
| if not os.path.exists(cp): | |
| need_encode.append((p, cp)) | |
| if not need_encode: | |
| return latent_paths | |
| for i in tqdm(range(0, len(need_encode), batch_size), desc="Encoding", unit="batch"): | |
| batch_items = need_encode[i:i+batch_size] | |
| imgs = [] | |
| for p, cp in batch_items: | |
| try: | |
| imgs.append((transform(Image.open(p).convert('RGB')), cp)) | |
| except Exception: | |
| pass | |
| if imgs: | |
| batch = torch.stack([im[0] for im in imgs]).to(device, dtype=dtype) | |
| with torch.no_grad(): | |
| latents = vae.encode(batch).latent_dist.mean | |
| for j, (_, cp) in enumerate(imgs): | |
| torch.save(latents[j].cpu().float(), cp) | |
| del batch, latents | |
| return latent_paths | |
| # === 3. Process Each VAE ===================================================== | |
| all_vae_results = {} | |
| for vconf in VAE_CONFIGS: | |
| vname = vconf['name'] | |
| print("\n" + "=" * 70) | |
| print(f"Processing: {vname}") | |
| print(f" Model: {vconf['model_id']}") | |
| print("=" * 70) | |
| # --- Encode --- | |
| try: | |
| load_kwargs = dict(torch_dtype=vconf['dtype']) | |
| if vconf['subfolder']: | |
| load_kwargs['subfolder'] = vconf['subfolder'] | |
| vae = AutoencoderKL.from_pretrained( | |
| vconf['model_id'], **load_kwargs | |
| ).to(device).eval() | |
| except Exception as e: | |
| print(f" β Failed to load {vname}: {e}") | |
| print(f" Skipping...") | |
| continue | |
| latent_paths = encode_dataset( | |
| vae, image_paths, vconf['cache_dir'], | |
| vconf['input_res'], vconf['dtype']) | |
| if not latent_paths: | |
| print(f" β No latents produced for {vname}, skipping") | |
| del vae | |
| torch.cuda.empty_cache() | |
| continue | |
| # Sanity check: if first latent is NaN, purge cache and re-encode | |
| sample = torch.load(latent_paths[0]) | |
| if torch.isnan(sample).any(): | |
| print(f" β NaN detected in cached latents β purging {vconf['cache_dir']}") | |
| import shutil | |
| shutil.rmtree(vconf['cache_dir']) | |
| latent_paths = encode_dataset( | |
| vae, image_paths, vconf['cache_dir'], | |
| vconf['input_res'], vconf['dtype']) | |
| sample = torch.load(latent_paths[0]) | |
| C, H, W = sample.shape | |
| print(f" Latent: ({C}, {H}, {W}) " | |
| f"mean={sample.mean():.3f} std={sample.std():.3f} " | |
| f"[{sample.min():.3f}, {sample.max():.3f}]") | |
| del sample | |
| # Free VAE immediately | |
| del vae | |
| torch.cuda.empty_cache() | |
| # --- Cluster --- | |
| N_CL = min(100, len(latent_paths)) | |
| sample_batch = torch.stack([torch.load(latent_paths[i]) for i in range(N_CL)]).to(device) | |
| channel_groups, corr = cluster_channels_gpu(sample_batch, n_groups=min(8, C)) | |
| print(f" Groups: {channel_groups}") | |
| del sample_batch | |
| # --- Extract --- | |
| scales = get_scales_for_latent(C, H, W) | |
| print(f" Scales: {scales}") | |
| config = ExtractionConfig( | |
| confidence_threshold=0.6, | |
| min_occupancy=0.01, | |
| image_batch_size=32, | |
| ) | |
| config.scales = scales | |
| extractor = MultiScaleExtractor(model, config) | |
| IMG_BATCH = config.image_batch_size | |
| vae_annotations = [] | |
| vae_records = [] | |
| for batch_start in tqdm(range(0, len(latent_paths), IMG_BATCH), | |
| desc=f"{vname} extract", unit=f"Γ{IMG_BATCH}"): | |
| bp = latent_paths[batch_start:batch_start + IMG_BATCH] | |
| names = [Path(p).stem for p in bp] | |
| latents = [torch.load(p).to(device) for p in bp] | |
| batch_results = extractor.extract_batch(latents, channel_groups) | |
| del latents | |
| for b_idx, result in enumerate(batch_results): | |
| raw = result['raw_annotations'] | |
| dev = result['deviance_annotations'] | |
| all_anns = raw + dev | |
| vae_records.append({ | |
| 'name': names[b_idx], | |
| 'n_total': len(all_anns), | |
| 'n_raw': len(raw), | |
| 'n_deviance': len(dev), | |
| 'classes': Counter(a.class_name for a in all_anns), | |
| 'confidences': [a.confidence for a in all_anns], | |
| 'scales': Counter(a.scale_level for a in all_anns), | |
| 'dimensions': Counter(a.dimension for a in all_anns), | |
| 'curved': sum(1 for a in all_anns if a.is_curved), | |
| }) | |
| for a in all_anns: | |
| vae_annotations.append({ | |
| 'class': a.class_name, 'confidence': a.confidence, | |
| 'scale': a.scale_level, 'dimension': a.dimension, | |
| 'curved': a.is_curved, 'curvature': a.curvature_type, | |
| 'source': a.source, | |
| }) | |
| # Summarize | |
| total = len(vae_annotations) | |
| cls_counts = Counter(a['class'] for a in vae_annotations) | |
| confs = [a['confidence'] for a in vae_annotations] | |
| all_vae_results[vname] = { | |
| 'latent_shape': (C, H, W), | |
| 'n_images': len(vae_records), | |
| 'total_annotations': total, | |
| 'class_counts': cls_counts, | |
| 'mean_confidence': float(np.mean(confs)) if confs else 0, | |
| 'std_confidence': float(np.std(confs)) if confs else 0, | |
| 'records': vae_records, | |
| 'channel_groups': channel_groups, | |
| 'scales': scales, | |
| } | |
| if confs: | |
| print(f" {vname}: {total:,} annotations, conf={np.mean(confs):.3f}") | |
| else: | |
| print(f" {vname}: 0 annotations (check latent stats above)") | |
| top5 = cls_counts.most_common(5) | |
| for cls, cnt in top5: | |
| print(f" {cls:20s} {cnt:>10,} ({cnt/max(total,1)*100:5.1f}%)") | |
| del vae_annotations, vae_records | |
| torch.cuda.empty_cache() | |
| # ============================================================================= | |
| # COMPARATIVE ANALYSIS | |
| # ============================================================================= | |
| print("\n" + "=" * 70) | |
| print("COMPARATIVE ANALYSIS: VAE Geometric Structures") | |
| print("=" * 70) | |
| vae_names = list(all_vae_results.keys()) | |
| # --- Table: Overview --- | |
| print(f"\n{'β'*70}") | |
| print(f" {'VAE':12s} {'Latent':>14s} {'Ann':>12s} {'Ann/img':>8s} " | |
| f"{'Conf':>6s} {'Classes':>7s}") | |
| print(f"{'β'*70}") | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| sh = f"{r['latent_shape'][0]}Γ{r['latent_shape'][1]}Γ{r['latent_shape'][2]}" | |
| n_cls = len(r['class_counts']) | |
| ann_per = r['total_annotations'] / max(r['n_images'], 1) | |
| print(f" {vn:12s} {sh:>14s} {r['total_annotations']:>12,} {ann_per:>8.0f} " | |
| f"{r['mean_confidence']:>6.3f} {n_cls:>4d}/38") | |
| # --- Table: Top-5 classes per VAE --- | |
| print(f"\n{'β'*70}") | |
| print("TOP-5 CLASSES PER VAE") | |
| print(f"{'β'*70}") | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| total = r['total_annotations'] | |
| top5 = r['class_counts'].most_common(5) | |
| classes_str = " ".join(f"{c}:{cnt/max(total,1)*100:.0f}%" for c, cnt in top5) | |
| print(f" {vn:12s} {classes_str}") | |
| # --- Class presence heatmap (which classes appear in which VAEs) --- | |
| print(f"\n{'β'*70}") | |
| print("CLASS PRESENCE ACROSS VAEs (>0.5% of annotations)") | |
| print(f"{'β'*70}") | |
| all_classes_seen = set() | |
| for vn in vae_names: | |
| for cls in all_vae_results[vn]['class_counts']: | |
| all_classes_seen.add(cls) | |
| # Sort by total frequency | |
| class_totals = Counter() | |
| for vn in vae_names: | |
| class_totals.update(all_vae_results[vn]['class_counts']) | |
| print(f" {'Class':20s}", end="") | |
| for vn in vae_names: | |
| print(f" {vn:>10s}", end="") | |
| print() | |
| for cls, _ in class_totals.most_common(): | |
| row_vals = [] | |
| any_significant = False | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| total = max(r['total_annotations'], 1) | |
| cnt = r['class_counts'].get(cls, 0) | |
| pct = cnt / total * 100 | |
| row_vals.append(pct) | |
| if pct >= 0.5: | |
| any_significant = True | |
| if any_significant: | |
| print(f" {cls:20s}", end="") | |
| for pct in row_vals: | |
| if pct >= 5: | |
| print(f" {pct:>9.1f}%", end="") | |
| elif pct >= 0.5: | |
| print(f" {pct:>9.1f}%", end="") | |
| elif pct > 0: | |
| print(f" trace", end="") | |
| else: | |
| print(f" β", end="") | |
| print() | |
| # --- Geometric fingerprint comparison --- | |
| print(f"\n{'β'*70}") | |
| print("GEOMETRIC FINGERPRINT SIMILARITY (cosine between class distributions)") | |
| print(f"{'β'*70}") | |
| if len(vae_names) >= 2: | |
| vecs = {} | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| total = max(r['total_annotations'], 1) | |
| vec = np.zeros(len(CLASS_NAMES)) | |
| for cls, cnt in r['class_counts'].items(): | |
| vec[CLASS_NAMES.index(cls)] = cnt / total | |
| vecs[vn] = vec | |
| print(f" {'':12s}", end="") | |
| for vn in vae_names: | |
| print(f" {vn:>10s}", end="") | |
| print() | |
| for vn1 in vae_names: | |
| print(f" {vn1:12s}", end="") | |
| for vn2 in vae_names: | |
| v1, v2 = vecs[vn1], vecs[vn2] | |
| n1, n2 = np.linalg.norm(v1), np.linalg.norm(v2) | |
| if n1 > 0 and n2 > 0: | |
| sim = np.dot(v1, v2) / (n1 * n2) | |
| else: | |
| sim = 0 | |
| print(f" {sim:>10.3f}", end="") | |
| print() | |
| # --- Dimensional comparison --- | |
| print(f"\n{'β'*70}") | |
| print("DIMENSIONAL DISTRIBUTION") | |
| print(f"{'β'*70}") | |
| print(f" {'VAE':12s} {'0D':>8s} {'1D':>8s} {'2D':>8s} {'3D':>8s} {'Curved':>8s}") | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| total = max(r['total_annotations'], 1) | |
| dim_c = Counter() | |
| curved_n = 0 | |
| for rec in r['records']: | |
| dim_c.update(rec['dimensions']) | |
| curved_n += rec['curved'] | |
| print(f" {vn:12s}", end="") | |
| for d in range(4): | |
| pct = dim_c.get(d, 0) / total * 100 | |
| print(f" {pct:>7.1f}%", end="") | |
| print(f" {curved_n/total*100:>7.1f}%") | |
| # --- Channel group comparison --- | |
| print(f"\n{'β'*70}") | |
| print("CHANNEL GROUPS") | |
| print(f"{'β'*70}") | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| sh = r['latent_shape'] | |
| print(f" {vn:12s} ({sh[0]}ch): {r['channel_groups']}") | |
| # --- Per-image cross-VAE consistency --- | |
| if len(vae_names) >= 2: | |
| print(f"\n{'β'*70}") | |
| print("PER-IMAGE CROSS-VAE CONSISTENCY") | |
| print(f"{'β'*70}") | |
| print(" Do images that are geometrically distinct in one VAE stay distinct in another?") | |
| # Build per-image class vectors for each VAE | |
| per_vae_vectors = {} | |
| common_names = None | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| name_to_vec = {} | |
| for rec in r['records']: | |
| vec = np.zeros(len(CLASS_NAMES)) | |
| total = max(rec['n_total'], 1) | |
| for cls, cnt in rec['classes'].items(): | |
| vec[CLASS_NAMES.index(cls)] = cnt / total | |
| name_to_vec[rec['name']] = vec | |
| per_vae_vectors[vn] = name_to_vec | |
| names_set = set(name_to_vec.keys()) | |
| common_names = names_set if common_names is None else common_names & names_set | |
| common_names = sorted(common_names)[:200] # sample for speed | |
| if len(common_names) >= 10: | |
| for vn1_idx, vn1 in enumerate(vae_names): | |
| for vn2 in vae_names[vn1_idx+1:]: | |
| v1_mat = np.stack([per_vae_vectors[vn1][n] for n in common_names]) | |
| v2_mat = np.stack([per_vae_vectors[vn2][n] for n in common_names]) | |
| # Per-image cosine between VAEs | |
| norms1 = np.linalg.norm(v1_mat, axis=1, keepdims=True) | |
| norms2 = np.linalg.norm(v2_mat, axis=1, keepdims=True) | |
| norms1 = np.clip(norms1, 1e-8, None) | |
| norms2 = np.clip(norms2, 1e-8, None) | |
| cos = np.sum((v1_mat / norms1) * (v2_mat / norms2), axis=1) | |
| print(f" {vn1:12s} β {vn2:12s}: " | |
| f"mean={cos.mean():.3f} std={cos.std():.3f} " | |
| f"[{cos.min():.3f}, {cos.max():.3f}]") | |
| if cos.mean() > 0.9: | |
| print(f" β Same geometric structure") | |
| elif cos.mean() > 0.7: | |
| print(f" β Similar structure") | |
| elif cos.mean() > 0.4: | |
| print(f" β Different structures") | |
| else: | |
| print(f" β Very different geometric encoding") | |
| # === Save ===================================================================== | |
| save_data = {} | |
| for vn in vae_names: | |
| r = all_vae_results[vn] | |
| save_data[vn] = { | |
| 'latent_shape': r['latent_shape'], | |
| 'n_images': r['n_images'], | |
| 'total_annotations': r['total_annotations'], | |
| 'class_counts': dict(r['class_counts']), | |
| 'mean_confidence': r['mean_confidence'], | |
| 'scales': [list(s) for s in r['scales']], | |
| 'channel_groups': r['channel_groups'], | |
| } | |
| with open('/content/multi_vae_comparison.json', 'w') as f: | |
| json.dump(save_data, f, indent=2) | |
| print(f"\nSaved to /content/multi_vae_comparison.json") | |
| print("=" * 70) | |
| print("β Multi-VAE comparison complete!") | |
| print("=" * 70) |