| | """ |
| | Cell 5: Multi-VAE Geometric Comparison |
| | ======================================== |
| | Run after Cells 1-4. Reuses existing pipeline. |
| | |
| | Processes 4 VAEs sequentially: |
| | SD 1.5 β 4ch Γ 64Γ64 (512px input) |
| | SDXL β 4ch Γ 128Γ128 (1024px input) |
| | Flux.1 β 16ch Γ 128Γ128 (1024px input) |
| | Flux.2 β 16ch Γ 128Γ128 (1024px input) |
| | |
| | Each: load VAE β encode β free β cluster β extract β store |
| | Then: comparative diagnostics |
| | """ |
| |
|
| | import os, time, json, zipfile, math |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from pathlib import Path |
| | from tqdm.auto import tqdm |
| | from collections import Counter |
| |
|
| | |
| | print("=" * 70) |
| | print("Multi-VAE Geometric Comparison Pipeline") |
| | print("=" * 70) |
| |
|
| | IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp'} |
| | HF_REPO = "AbstractPhil/grid-geometric-classifier-sliding-proto" |
| | HF_ZIP = "mega_liminal_captioned.zip" |
| | SKIP_DIRS = {'.config', 'sample_data', 'drive', '__pycache__', |
| | 'checkpoints_vae_ca', |
| | 'latent_cache_sd15', 'latent_cache_sdxl', |
| | 'latent_cache_flux1', 'latent_cache_flux2', |
| | 'latent_cache_mega_sd15', 'latent_cache_mega_sdxl', |
| | 'latent_cache_mega_flux1', 'latent_cache_mega_flux2'} |
| |
|
| | def find_images(directory): |
| | return sorted([ |
| | os.path.join(r, f) for r, _, fs in os.walk(directory) |
| | for f in fs if Path(f).suffix.lower() in IMAGE_EXTENSIONS |
| | ]) if os.path.exists(directory) else [] |
| |
|
| | def scan_content_for_images(min_count=100): |
| | """Scan /content/ for the largest image directory.""" |
| | best_dir, best_imgs = None, [] |
| | for d in sorted(os.listdir('/content/')): |
| | full = f'/content/{d}' |
| | if os.path.isdir(full) and d not in SKIP_DIRS: |
| | found = find_images(full) |
| | if len(found) > len(best_imgs): |
| | best_dir, best_imgs = full, found |
| | if len(best_imgs) >= min_count: |
| | return best_dir, best_imgs |
| | return None, [] |
| |
|
| | LIMINAL_DIR, image_paths = scan_content_for_images() |
| |
|
| | if LIMINAL_DIR is None: |
| | try: |
| | from huggingface_hub import hf_hub_download |
| | except ImportError: |
| | os.system('pip install -q huggingface_hub') |
| | from huggingface_hub import hf_hub_download |
| | print(f"Downloading {HF_ZIP} from {HF_REPO}...") |
| | zip_path = hf_hub_download(repo_id=HF_REPO, filename=HF_ZIP) |
| | with zipfile.ZipFile(zip_path, 'r') as z: |
| | z.extractall('/content/') |
| | LIMINAL_DIR, image_paths = scan_content_for_images() |
| |
|
| | assert LIMINAL_DIR and len(image_paths) > 0, "No images found in /content/" |
| | print(f"Found {len(image_paths)} images in {LIMINAL_DIR}") |
| |
|
| | |
| | print("\n" + "=" * 70) |
| | print("Step 1: Classifier") |
| | print("=" * 70) |
| |
|
| | ckpt = '/content/best_vae_ca_classifier.pt' |
| | if not os.path.exists(ckpt): |
| | ckpt = '/content/checkpoints_vae_ca/best.pt' |
| |
|
| | model = PatchCrossAttentionClassifier(n_classes=NUM_CLASSES) |
| | model.load_state_dict(torch.load(ckpt, map_location='cpu', weights_only=True)) |
| | device = torch.device('cuda') |
| | model = model.to(device).eval() |
| | print(f"Loaded {sum(p.numel() for p in model.parameters()):,} params") |
| |
|
| | |
| |
|
| | try: |
| | from diffusers import AutoencoderKL |
| | except ImportError: |
| | os.system('pip install -q diffusers transformers accelerate') |
| | from diffusers import AutoencoderKL |
| |
|
| | from torchvision import transforms |
| | from PIL import Image |
| |
|
| | VAE_CONFIGS = [ |
| | { |
| | 'name': 'SD 1.5', |
| | 'model_id': 'stable-diffusion-v1-5/stable-diffusion-v1-5', |
| | 'subfolder': 'vae', |
| | 'input_res': 512, |
| | 'dtype': torch.float16, |
| | 'cache_dir': '/content/latent_cache_mega_sd15', |
| | }, |
| | { |
| | 'name': 'SDXL', |
| | 'model_id': 'madebyollin/sdxl-vae-fp16-fix', |
| | 'subfolder': None, |
| | 'input_res': 1024, |
| | 'dtype': torch.float16, |
| | 'cache_dir': '/content/latent_cache_mega_sdxl', |
| | }, |
| | { |
| | 'name': 'Flux.1', |
| | 'model_id': 'black-forest-labs/FLUX.1-dev', |
| | 'subfolder': 'vae', |
| | 'input_res': 1024, |
| | 'dtype': torch.bfloat16, |
| | 'cache_dir': '/content/latent_cache_mega_flux1', |
| | }, |
| | { |
| | 'name': 'Flux.2', |
| | 'model_id': 'black-forest-labs/FLUX.2-dev', |
| | 'subfolder': 'vae', |
| | 'input_res': 1024, |
| | 'dtype': torch.bfloat16, |
| | 'cache_dir': '/content/latent_cache_mega_flux2', |
| | }, |
| | ] |
| |
|
| |
|
| | def get_scales_for_latent(C, H, W): |
| | """Compute appropriate scales based on actual latent dimensions. No L3 (noise).""" |
| | scales = [] |
| | |
| | scales.append((min(C, 16), min(H, 64), min(W, 64))) |
| | |
| | scales.append((min(C, 8), min(H, 32), min(W, 32))) |
| | |
| | scales.append((min(C, 8), 16, 16)) |
| | return scales |
| |
|
| |
|
| | def encode_dataset(vae, image_paths, cache_dir, input_res, dtype, batch_size=4): |
| | """Encode images, return list of cache paths.""" |
| | os.makedirs(cache_dir, exist_ok=True) |
| |
|
| | transform = transforms.Compose([ |
| | transforms.Resize((input_res, input_res)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5]*3, [0.5]*3), |
| | ]) |
| |
|
| | latent_paths = [] |
| | need_encode = [] |
| |
|
| | for p in image_paths: |
| | name = Path(p).stem |
| | cp = os.path.join(cache_dir, f'{name}.pt') |
| | latent_paths.append(cp) |
| | if not os.path.exists(cp): |
| | need_encode.append((p, cp)) |
| |
|
| | if not need_encode: |
| | return latent_paths |
| |
|
| | for i in tqdm(range(0, len(need_encode), batch_size), desc="Encoding", unit="batch"): |
| | batch_items = need_encode[i:i+batch_size] |
| | imgs = [] |
| | for p, cp in batch_items: |
| | try: |
| | imgs.append((transform(Image.open(p).convert('RGB')), cp)) |
| | except Exception: |
| | pass |
| |
|
| | if imgs: |
| | batch = torch.stack([im[0] for im in imgs]).to(device, dtype=dtype) |
| | with torch.no_grad(): |
| | latents = vae.encode(batch).latent_dist.mean |
| | for j, (_, cp) in enumerate(imgs): |
| | torch.save(latents[j].cpu().float(), cp) |
| | del batch, latents |
| |
|
| | return latent_paths |
| |
|
| |
|
| | |
| |
|
| | all_vae_results = {} |
| |
|
| | for vconf in VAE_CONFIGS: |
| | vname = vconf['name'] |
| | print("\n" + "=" * 70) |
| | print(f"Processing: {vname}") |
| | print(f" Model: {vconf['model_id']}") |
| | print("=" * 70) |
| |
|
| | |
| | try: |
| | load_kwargs = dict(torch_dtype=vconf['dtype']) |
| | if vconf['subfolder']: |
| | load_kwargs['subfolder'] = vconf['subfolder'] |
| | vae = AutoencoderKL.from_pretrained( |
| | vconf['model_id'], **load_kwargs |
| | ).to(device).eval() |
| | except Exception as e: |
| | print(f" β Failed to load {vname}: {e}") |
| | print(f" Skipping...") |
| | continue |
| |
|
| | latent_paths = encode_dataset( |
| | vae, image_paths, vconf['cache_dir'], |
| | vconf['input_res'], vconf['dtype']) |
| |
|
| | if not latent_paths: |
| | print(f" β No latents produced for {vname}, skipping") |
| | del vae |
| | torch.cuda.empty_cache() |
| | continue |
| |
|
| | |
| | sample = torch.load(latent_paths[0]) |
| | if torch.isnan(sample).any(): |
| | print(f" β NaN detected in cached latents β purging {vconf['cache_dir']}") |
| | import shutil |
| | shutil.rmtree(vconf['cache_dir']) |
| | latent_paths = encode_dataset( |
| | vae, image_paths, vconf['cache_dir'], |
| | vconf['input_res'], vconf['dtype']) |
| | sample = torch.load(latent_paths[0]) |
| |
|
| | C, H, W = sample.shape |
| | print(f" Latent: ({C}, {H}, {W}) " |
| | f"mean={sample.mean():.3f} std={sample.std():.3f} " |
| | f"[{sample.min():.3f}, {sample.max():.3f}]") |
| | del sample |
| |
|
| | |
| | del vae |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | N_CL = min(100, len(latent_paths)) |
| | sample_batch = torch.stack([torch.load(latent_paths[i]) for i in range(N_CL)]).to(device) |
| | channel_groups, corr = cluster_channels_gpu(sample_batch, n_groups=min(8, C)) |
| | print(f" Groups: {channel_groups}") |
| | del sample_batch |
| |
|
| | |
| | scales = get_scales_for_latent(C, H, W) |
| | print(f" Scales: {scales}") |
| |
|
| | config = ExtractionConfig( |
| | confidence_threshold=0.6, |
| | min_occupancy=0.01, |
| | image_batch_size=32, |
| | ) |
| | config.scales = scales |
| |
|
| | extractor = MultiScaleExtractor(model, config) |
| |
|
| | IMG_BATCH = config.image_batch_size |
| | vae_annotations = [] |
| | vae_records = [] |
| |
|
| | for batch_start in tqdm(range(0, len(latent_paths), IMG_BATCH), |
| | desc=f"{vname} extract", unit=f"Γ{IMG_BATCH}"): |
| | bp = latent_paths[batch_start:batch_start + IMG_BATCH] |
| | names = [Path(p).stem for p in bp] |
| | latents = [torch.load(p).to(device) for p in bp] |
| |
|
| | batch_results = extractor.extract_batch(latents, channel_groups) |
| | del latents |
| |
|
| | for b_idx, result in enumerate(batch_results): |
| | raw = result['raw_annotations'] |
| | dev = result['deviance_annotations'] |
| | all_anns = raw + dev |
| |
|
| | vae_records.append({ |
| | 'name': names[b_idx], |
| | 'n_total': len(all_anns), |
| | 'n_raw': len(raw), |
| | 'n_deviance': len(dev), |
| | 'classes': Counter(a.class_name for a in all_anns), |
| | 'confidences': [a.confidence for a in all_anns], |
| | 'scales': Counter(a.scale_level for a in all_anns), |
| | 'dimensions': Counter(a.dimension for a in all_anns), |
| | 'curved': sum(1 for a in all_anns if a.is_curved), |
| | }) |
| |
|
| | for a in all_anns: |
| | vae_annotations.append({ |
| | 'class': a.class_name, 'confidence': a.confidence, |
| | 'scale': a.scale_level, 'dimension': a.dimension, |
| | 'curved': a.is_curved, 'curvature': a.curvature_type, |
| | 'source': a.source, |
| | }) |
| |
|
| | |
| | total = len(vae_annotations) |
| | cls_counts = Counter(a['class'] for a in vae_annotations) |
| | confs = [a['confidence'] for a in vae_annotations] |
| |
|
| | all_vae_results[vname] = { |
| | 'latent_shape': (C, H, W), |
| | 'n_images': len(vae_records), |
| | 'total_annotations': total, |
| | 'class_counts': cls_counts, |
| | 'mean_confidence': float(np.mean(confs)) if confs else 0, |
| | 'std_confidence': float(np.std(confs)) if confs else 0, |
| | 'records': vae_records, |
| | 'channel_groups': channel_groups, |
| | 'scales': scales, |
| | } |
| |
|
| | if confs: |
| | print(f" {vname}: {total:,} annotations, conf={np.mean(confs):.3f}") |
| | else: |
| | print(f" {vname}: 0 annotations (check latent stats above)") |
| | top5 = cls_counts.most_common(5) |
| | for cls, cnt in top5: |
| | print(f" {cls:20s} {cnt:>10,} ({cnt/max(total,1)*100:5.1f}%)") |
| |
|
| | del vae_annotations, vae_records |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | |
| | |
| | |
| | print("\n" + "=" * 70) |
| | print("COMPARATIVE ANALYSIS: VAE Geometric Structures") |
| | print("=" * 70) |
| |
|
| | vae_names = list(all_vae_results.keys()) |
| |
|
| | |
| | print(f"\n{'β'*70}") |
| | print(f" {'VAE':12s} {'Latent':>14s} {'Ann':>12s} {'Ann/img':>8s} " |
| | f"{'Conf':>6s} {'Classes':>7s}") |
| | print(f"{'β'*70}") |
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | sh = f"{r['latent_shape'][0]}Γ{r['latent_shape'][1]}Γ{r['latent_shape'][2]}" |
| | n_cls = len(r['class_counts']) |
| | ann_per = r['total_annotations'] / max(r['n_images'], 1) |
| | print(f" {vn:12s} {sh:>14s} {r['total_annotations']:>12,} {ann_per:>8.0f} " |
| | f"{r['mean_confidence']:>6.3f} {n_cls:>4d}/38") |
| |
|
| | |
| | print(f"\n{'β'*70}") |
| | print("TOP-5 CLASSES PER VAE") |
| | print(f"{'β'*70}") |
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | total = r['total_annotations'] |
| | top5 = r['class_counts'].most_common(5) |
| | classes_str = " ".join(f"{c}:{cnt/max(total,1)*100:.0f}%" for c, cnt in top5) |
| | print(f" {vn:12s} {classes_str}") |
| |
|
| | |
| | print(f"\n{'β'*70}") |
| | print("CLASS PRESENCE ACROSS VAEs (>0.5% of annotations)") |
| | print(f"{'β'*70}") |
| |
|
| | all_classes_seen = set() |
| | for vn in vae_names: |
| | for cls in all_vae_results[vn]['class_counts']: |
| | all_classes_seen.add(cls) |
| |
|
| | |
| | class_totals = Counter() |
| | for vn in vae_names: |
| | class_totals.update(all_vae_results[vn]['class_counts']) |
| |
|
| | print(f" {'Class':20s}", end="") |
| | for vn in vae_names: |
| | print(f" {vn:>10s}", end="") |
| | print() |
| |
|
| | for cls, _ in class_totals.most_common(): |
| | row_vals = [] |
| | any_significant = False |
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | total = max(r['total_annotations'], 1) |
| | cnt = r['class_counts'].get(cls, 0) |
| | pct = cnt / total * 100 |
| | row_vals.append(pct) |
| | if pct >= 0.5: |
| | any_significant = True |
| |
|
| | if any_significant: |
| | print(f" {cls:20s}", end="") |
| | for pct in row_vals: |
| | if pct >= 5: |
| | print(f" {pct:>9.1f}%", end="") |
| | elif pct >= 0.5: |
| | print(f" {pct:>9.1f}%", end="") |
| | elif pct > 0: |
| | print(f" trace", end="") |
| | else: |
| | print(f" β", end="") |
| | print() |
| |
|
| | |
| | print(f"\n{'β'*70}") |
| | print("GEOMETRIC FINGERPRINT SIMILARITY (cosine between class distributions)") |
| | print(f"{'β'*70}") |
| |
|
| | if len(vae_names) >= 2: |
| | vecs = {} |
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | total = max(r['total_annotations'], 1) |
| | vec = np.zeros(len(CLASS_NAMES)) |
| | for cls, cnt in r['class_counts'].items(): |
| | vec[CLASS_NAMES.index(cls)] = cnt / total |
| | vecs[vn] = vec |
| |
|
| | print(f" {'':12s}", end="") |
| | for vn in vae_names: |
| | print(f" {vn:>10s}", end="") |
| | print() |
| |
|
| | for vn1 in vae_names: |
| | print(f" {vn1:12s}", end="") |
| | for vn2 in vae_names: |
| | v1, v2 = vecs[vn1], vecs[vn2] |
| | n1, n2 = np.linalg.norm(v1), np.linalg.norm(v2) |
| | if n1 > 0 and n2 > 0: |
| | sim = np.dot(v1, v2) / (n1 * n2) |
| | else: |
| | sim = 0 |
| | print(f" {sim:>10.3f}", end="") |
| | print() |
| |
|
| | |
| | print(f"\n{'β'*70}") |
| | print("DIMENSIONAL DISTRIBUTION") |
| | print(f"{'β'*70}") |
| | print(f" {'VAE':12s} {'0D':>8s} {'1D':>8s} {'2D':>8s} {'3D':>8s} {'Curved':>8s}") |
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | total = max(r['total_annotations'], 1) |
| | dim_c = Counter() |
| | curved_n = 0 |
| | for rec in r['records']: |
| | dim_c.update(rec['dimensions']) |
| | curved_n += rec['curved'] |
| | print(f" {vn:12s}", end="") |
| | for d in range(4): |
| | pct = dim_c.get(d, 0) / total * 100 |
| | print(f" {pct:>7.1f}%", end="") |
| | print(f" {curved_n/total*100:>7.1f}%") |
| |
|
| | |
| | print(f"\n{'β'*70}") |
| | print("CHANNEL GROUPS") |
| | print(f"{'β'*70}") |
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | sh = r['latent_shape'] |
| | print(f" {vn:12s} ({sh[0]}ch): {r['channel_groups']}") |
| |
|
| | |
| | if len(vae_names) >= 2: |
| | print(f"\n{'β'*70}") |
| | print("PER-IMAGE CROSS-VAE CONSISTENCY") |
| | print(f"{'β'*70}") |
| | print(" Do images that are geometrically distinct in one VAE stay distinct in another?") |
| |
|
| | |
| | per_vae_vectors = {} |
| | common_names = None |
| |
|
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | name_to_vec = {} |
| | for rec in r['records']: |
| | vec = np.zeros(len(CLASS_NAMES)) |
| | total = max(rec['n_total'], 1) |
| | for cls, cnt in rec['classes'].items(): |
| | vec[CLASS_NAMES.index(cls)] = cnt / total |
| | name_to_vec[rec['name']] = vec |
| |
|
| | per_vae_vectors[vn] = name_to_vec |
| | names_set = set(name_to_vec.keys()) |
| | common_names = names_set if common_names is None else common_names & names_set |
| |
|
| | common_names = sorted(common_names)[:200] |
| |
|
| | if len(common_names) >= 10: |
| | for vn1_idx, vn1 in enumerate(vae_names): |
| | for vn2 in vae_names[vn1_idx+1:]: |
| | v1_mat = np.stack([per_vae_vectors[vn1][n] for n in common_names]) |
| | v2_mat = np.stack([per_vae_vectors[vn2][n] for n in common_names]) |
| |
|
| | |
| | norms1 = np.linalg.norm(v1_mat, axis=1, keepdims=True) |
| | norms2 = np.linalg.norm(v2_mat, axis=1, keepdims=True) |
| | norms1 = np.clip(norms1, 1e-8, None) |
| | norms2 = np.clip(norms2, 1e-8, None) |
| | cos = np.sum((v1_mat / norms1) * (v2_mat / norms2), axis=1) |
| |
|
| | print(f" {vn1:12s} β {vn2:12s}: " |
| | f"mean={cos.mean():.3f} std={cos.std():.3f} " |
| | f"[{cos.min():.3f}, {cos.max():.3f}]") |
| | if cos.mean() > 0.9: |
| | print(f" β Same geometric structure") |
| | elif cos.mean() > 0.7: |
| | print(f" β Similar structure") |
| | elif cos.mean() > 0.4: |
| | print(f" β Different structures") |
| | else: |
| | print(f" β Very different geometric encoding") |
| |
|
| | |
| | save_data = {} |
| | for vn in vae_names: |
| | r = all_vae_results[vn] |
| | save_data[vn] = { |
| | 'latent_shape': r['latent_shape'], |
| | 'n_images': r['n_images'], |
| | 'total_annotations': r['total_annotations'], |
| | 'class_counts': dict(r['class_counts']), |
| | 'mean_confidence': r['mean_confidence'], |
| | 'scales': [list(s) for s in r['scales']], |
| | 'channel_groups': r['channel_groups'], |
| | } |
| |
|
| | with open('/content/multi_vae_comparison_mega.json', 'w') as f: |
| | json.dump(save_data, f, indent=2) |
| |
|
| | print(f"\nSaved to /content/multi_vae_comparison_mega.json") |
| | print("=" * 70) |
| | print("β Multi-VAE comparison complete!") |
| | print("=" * 70) |