AbstractPhil's picture
Update cell4_vae_pipeline.py
bc985f5 verified
"""
Cell 4: Multi-Scale Geometric Extraction (Fully Batched)
=========================================================
Optimizations:
- Multi-image batching: N images → single mega classify call
- Fused raw + deviance extraction per image
- GPU-only channel clustering (no numpy round-trip)
- torch.kthvalue replaces torch.quantile
- No torch.cuda.empty_cache() in hot path
- All GPU-resident until annotation construction
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import math
@dataclass
class ExtractionConfig:
canonical_shape: Tuple[int, int, int] = (8, 16, 16)
scales: List[Tuple[int, int, int]] = field(default_factory=lambda: [
(16, 64, 64),
(8, 32, 32),
(8, 16, 16),
(4, 8, 8),
])
overlap: float = 0.5
confidence_threshold: float = 0.7
min_occupancy: float = 0.01
binarize_percentile: float = 90.0
n_channel_groups: int = 8
max_classify_batch: int = 16384
image_batch_size: int = 32 # process N images simultaneously
device: str = 'cuda'
@dataclass
class GeometricAnnotation:
class_name: str
class_idx: int
confidence: float
scale_level: int
location: Tuple[int, int, int]
patch_size: Tuple[int, int, int]
dimension: int = -1
is_curved: bool = False
curvature_type: str = "none"
source: str = "raw"
channel_group_pair: Optional[Tuple[int, int]] = None
# === GPU Primitives ===========================================================
def extract_patches_gpu(volume, patch_size, overlap=0.5):
"""Vectorized patch extraction. Returns (N, pz, py, px), (N, 3) locations."""
D, H, W = volume.shape
pz, py, px = patch_size
dev = volume.device
if D < pz or H < py or W < px:
volume = F.pad(volume, (0, max(px-W,0), 0, max(py-H,0), 0, max(pz-D,0)))
D, H, W = volume.shape
sz = max(1, int(pz * (1 - overlap)))
sy = max(1, int(py * (1 - overlap)))
sx = max(1, int(px * (1 - overlap)))
z_s = torch.arange(0, max(1, D - pz + 1), sz, device=dev)
y_s = torch.arange(0, max(1, H - py + 1), sy, device=dev)
x_s = torch.arange(0, max(1, W - px + 1), sx, device=dev)
if len(z_s) == 0: z_s = torch.tensor([0], device=dev)
if len(y_s) == 0: y_s = torch.tensor([0], device=dev)
if len(x_s) == 0: x_s = torch.tensor([0], device=dev)
gz, gy, gx = torch.meshgrid(z_s, y_s, x_s, indexing='ij')
locs = torch.stack([gz.flatten(), gy.flatten(), gx.flatten()], dim=1)
N = locs.shape[0]
oz = torch.arange(pz, device=dev)
oy = torch.arange(py, device=dev)
ox = torch.arange(px, device=dev)
z_idx = (locs[:, 0:1] + oz.unsqueeze(0))[:, :, None, None].expand(N, pz, py, px)
y_idx = (locs[:, 1:2] + oy.unsqueeze(0))[:, None, :, None].expand(N, pz, py, px)
x_idx = (locs[:, 2:3] + ox.unsqueeze(0))[:, None, None, :].expand(N, pz, py, px)
return volume[z_idx, y_idx, x_idx], locs
def binarize_fast(patches, percentile=90.0, min_occ=0.01):
"""Fast binarization using kthvalue (not quantile)."""
N = patches.shape[0]
V = patches[0].numel()
flat = patches.reshape(N, V).abs()
k = max(1, int(V * (1.0 - percentile / 100.0)))
thresholds = flat.kthvalue(V - k + 1, dim=1, keepdim=True).values
binary = (flat >= thresholds).float()
occ = binary.mean(dim=1)
keep = (occ >= min_occ) & (occ <= 0.95)
keep_idx = keep.nonzero(as_tuple=True)[0]
return binary.reshape(N, *patches.shape[1:])[keep_idx], keep_idx
def extract_and_prepare_volume(volume, config):
"""
Extract patches at ALL scales from a single volume.
Returns (canonical_patches, meta_list) all on GPU.
"""
canonical = config.canonical_shape
all_canonical = []
all_meta = [] # (level, kept_locs, scale, count)
for level, scale in enumerate(config.scales):
pz, py, px = scale
D, H, W = volume.shape
if D < pz or H < py or W < px:
continue
patches, locations = extract_patches_gpu(volume, scale, config.overlap)
binary, keep_idx = binarize_fast(
patches, config.binarize_percentile, config.min_occupancy)
if binary.shape[0] == 0:
continue
kept_locs = locations[keep_idx]
if binary.shape[1:] != tuple(canonical):
resized = F.interpolate(
binary.unsqueeze(1), size=canonical,
mode='trilinear', align_corners=False).squeeze(1)
else:
resized = binary
all_canonical.append(resized)
all_meta.append((level, kept_locs, scale, resized.shape[0]))
if not all_canonical:
return None, []
return torch.cat(all_canonical, dim=0), all_meta
# === GPU Channel Clustering ===================================================
def cluster_channels_gpu(latents, n_groups=8):
"""Fully GPU channel clustering, no numpy."""
N, C, H, W = latents.shape
flat = latents.reshape(N, C, -1)
flat = flat - flat.mean(dim=-1, keepdim=True)
flat = F.normalize(flat, dim=-1)
corr = torch.bmm(flat, flat.transpose(1, 2)).mean(dim=0)
dist = 1.0 - corr.abs()
remaining = torch.ones(C, dtype=torch.bool, device=latents.device)
target_size = max(1, C // n_groups)
groups = []
for g in range(n_groups):
if not remaining.any():
break
avail = remaining.nonzero(as_tuple=True)[0]
if g == 0:
seed = avail[0].item()
else:
# Farthest from existing groups
min_dists = torch.full((C,), float('inf'), device=latents.device)
for grp in groups:
grp_t = torch.tensor(grp, device=latents.device)
d = dist[:, grp_t].min(dim=1).values
min_dists = torch.min(min_dists, d)
min_dists[~remaining] = -1
seed = min_dists.argmax().item()
group = [seed]
remaining[seed] = False
dists_from_seed = dist[seed].clone()
dists_from_seed[~remaining] = float('inf')
_, nearest = dists_from_seed.topk(min(target_size - 1, remaining.sum().item()), largest=False)
nearest = nearest[dists_from_seed[nearest] < float('inf')]
for c in nearest.tolist():
group.append(c)
remaining[c] = False
groups.append(group)
# Assign stragglers
for c in remaining.nonzero(as_tuple=True)[0].tolist():
grp_dists = []
for gi, grp in enumerate(groups):
grp_t = torch.tensor(grp, device=latents.device)
grp_dists.append(dist[c, grp_t].min().item())
groups[min(range(len(groups)), key=lambda i: grp_dists[i])].append(c)
return groups, corr
def compute_deviance_volume(latent, groups):
"""Compute inter-group deviance. Returns (n_pairs, H, W), pair_list."""
group_means = torch.stack([latent[grp].mean(dim=0) for grp in groups])
n = len(groups)
i_idx, j_idx = torch.triu_indices(n, n, offset=1, device=latent.device)
deviances = (group_means[i_idx] - group_means[j_idx]).abs()
pairs = list(zip(i_idx.cpu().tolist(), j_idx.cpu().tolist()))
return deviances, pairs
# === Batched Multi-Image Extractor ============================================
class MultiScaleExtractor:
def __init__(self, classifier, config=None):
self.classifier = classifier
self.config = config or ExtractionConfig()
self.classifier.eval()
self.device = next(classifier.parameters()).device
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
@torch.no_grad()
def classify_batch(self, patches):
"""Classify mega-batch with amp."""
N = patches.shape[0]
if N == 0:
return None
max_b = self.config.max_classify_batch
all_cls, all_conf, all_dim, all_curved, all_curv_type = [], [], [], [], []
for start in range(0, N, max_b):
chunk = patches[start:start+max_b]
with torch.amp.autocast('cuda', dtype=self.amp_dtype):
out = self.classifier(chunk)
probs = F.softmax(out["class_logits"].float(), dim=-1)
top2 = probs.topk(2, dim=-1).values
all_cls.append(probs.argmax(dim=-1))
all_conf.append(top2[:, 0] - top2[:, 1])
all_dim.append(out["dim_logits"].argmax(dim=-1))
all_curved.append(out["is_curved_pred"].squeeze(-1) > 0.0)
all_curv_type.append(out["curv_type_logits"].argmax(dim=-1))
return {
"pred_class": torch.cat(all_cls),
"confidence": torch.cat(all_conf),
"dim_pred": torch.cat(all_dim),
"curved_pred": torch.cat(all_curved),
"curv_type_pred": torch.cat(all_curv_type),
}
def _results_to_annotations(self, results, meta, conf_thresh, source="raw", pair_indices=None):
"""Convert batched results + meta into annotation list."""
annotations = []
offset = 0
for level, kept_locs, scale, count in meta:
chunk_conf = results["confidence"][offset:offset+count]
mask = chunk_conf >= conf_thresh
local_idx = mask.nonzero(as_tuple=True)[0]
if len(local_idx) > 0:
gi = local_idx + offset
cls = results["pred_class"][gi].cpu()
conf = results["confidence"][gi].cpu()
dim = results["dim_pred"][gi].cpu()
curved = results["curved_pred"][gi].cpu()
curv = results["curv_type_pred"][gi].cpu()
locs = kept_locs[local_idx].cpu()
for i in range(len(local_idx)):
ann = GeometricAnnotation(
class_name=CLASS_NAMES[cls[i].item()],
class_idx=cls[i].item(),
confidence=conf[i].item(),
scale_level=level,
location=tuple(int(x) for x in locs[i].tolist()),
patch_size=scale,
dimension=dim[i].item(),
is_curved=bool(curved[i].item()),
curvature_type=CURVATURE_NAMES[curv[i].item()],
source=source,
)
if source == "deviance" and pair_indices is not None:
pair_idx = locs[i][0].item()
if pair_idx < len(pair_indices):
ann.channel_group_pair = pair_indices[pair_idx]
annotations.append(ann)
offset += count
return annotations
def extract_batch(self, latents, channel_groups):
"""
Process multiple latents simultaneously.
latents: list of (C, H, W) tensors on GPU
Returns: list of per-image result dicts
"""
conf_thresh = self.config.confidence_threshold
# Phase 1: extract patches from ALL images, both raw + deviance
all_patches = []
image_segments = [] # (img_idx, source, meta, pair_indices, patch_count)
for img_idx, latent in enumerate(latents):
# Raw volume: channels as depth
raw_patches, raw_meta = extract_and_prepare_volume(latent, self.config)
if raw_patches is not None:
n = raw_patches.shape[0]
all_patches.append(raw_patches)
image_segments.append((img_idx, "raw", raw_meta, None, n))
# Deviance volume
if channel_groups is not None:
dev_vol, pair_indices = compute_deviance_volume(latent, channel_groups)
dev_patches, dev_meta = extract_and_prepare_volume(dev_vol, self.config)
if dev_patches is not None:
n = dev_patches.shape[0]
all_patches.append(dev_patches)
image_segments.append((img_idx, "deviance", dev_meta, pair_indices, n))
if not all_patches:
return [{
'raw_annotations': [], 'deviance_annotations': [],
'n_raw': 0, 'n_deviance': 0,
} for _ in latents]
# Phase 2: SINGLE classify call for ALL images × ALL scales × raw+deviance
mega_batch = torch.cat(all_patches, dim=0)
del all_patches
results = self.classify_batch(mega_batch)
del mega_batch
if results is None:
return [{
'raw_annotations': [], 'deviance_annotations': [],
'n_raw': 0, 'n_deviance': 0,
} for _ in latents]
# Phase 3: distribute results back to per-image annotations
per_image = {i: {'raw': [], 'deviance': []} for i in range(len(latents))}
global_offset = 0
for img_idx, source, meta, pair_indices, total_count in image_segments:
# Slice this segment's results
seg_results = {
k: v[global_offset:global_offset+total_count]
for k, v in results.items()
}
anns = self._results_to_annotations(
seg_results, meta, conf_thresh, source, pair_indices)
per_image[img_idx][source].extend(anns)
global_offset += total_count
del results
# Build output
output = []
for i in range(len(latents)):
raw = per_image[i]['raw']
dev = per_image[i]['deviance']
output.append({
'raw_annotations': raw,
'deviance_annotations': dev,
'n_raw': len(raw),
'n_deviance': len(dev),
})
return output
# Single-image compat
def extract_from_latent(self, latent, channel_groups=None):
return self.extract_batch([latent], channel_groups)[0]
print("✓ Cell 4: Fully batched multi-image extraction")
print(f" Scales: {ExtractionConfig().scales}")
print(f" Canonical: {ExtractionConfig().canonical_shape}")
print(f" Image batch: {ExtractionConfig().image_batch_size}")
print(f" Classify batch: {ExtractionConfig().max_classify_batch}")
print(f" Percentile: {ExtractionConfig().binarize_percentile}th (kthvalue)")