| """ |
| Cell 4: Multi-Scale Geometric Extraction (Fully Batched) |
| ========================================================= |
| Optimizations: |
| - Multi-image batching: N images → single mega classify call |
| - Fused raw + deviance extraction per image |
| - GPU-only channel clustering (no numpy round-trip) |
| - torch.kthvalue replaces torch.quantile |
| - No torch.cuda.empty_cache() in hot path |
| - All GPU-resident until annotation construction |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass, field |
| from typing import List, Optional, Tuple |
| import math |
|
|
|
|
| @dataclass |
| class ExtractionConfig: |
| canonical_shape: Tuple[int, int, int] = (8, 16, 16) |
| scales: List[Tuple[int, int, int]] = field(default_factory=lambda: [ |
| (16, 64, 64), |
| (8, 32, 32), |
| (8, 16, 16), |
| (4, 8, 8), |
| ]) |
| overlap: float = 0.5 |
| confidence_threshold: float = 0.7 |
| min_occupancy: float = 0.01 |
| binarize_percentile: float = 90.0 |
| n_channel_groups: int = 8 |
| max_classify_batch: int = 16384 |
| image_batch_size: int = 32 |
| device: str = 'cuda' |
|
|
|
|
| @dataclass |
| class GeometricAnnotation: |
| class_name: str |
| class_idx: int |
| confidence: float |
| scale_level: int |
| location: Tuple[int, int, int] |
| patch_size: Tuple[int, int, int] |
| dimension: int = -1 |
| is_curved: bool = False |
| curvature_type: str = "none" |
| source: str = "raw" |
| channel_group_pair: Optional[Tuple[int, int]] = None |
|
|
|
|
| |
|
|
| def extract_patches_gpu(volume, patch_size, overlap=0.5): |
| """Vectorized patch extraction. Returns (N, pz, py, px), (N, 3) locations.""" |
| D, H, W = volume.shape |
| pz, py, px = patch_size |
| dev = volume.device |
|
|
| if D < pz or H < py or W < px: |
| volume = F.pad(volume, (0, max(px-W,0), 0, max(py-H,0), 0, max(pz-D,0))) |
| D, H, W = volume.shape |
|
|
| sz = max(1, int(pz * (1 - overlap))) |
| sy = max(1, int(py * (1 - overlap))) |
| sx = max(1, int(px * (1 - overlap))) |
|
|
| z_s = torch.arange(0, max(1, D - pz + 1), sz, device=dev) |
| y_s = torch.arange(0, max(1, H - py + 1), sy, device=dev) |
| x_s = torch.arange(0, max(1, W - px + 1), sx, device=dev) |
|
|
| if len(z_s) == 0: z_s = torch.tensor([0], device=dev) |
| if len(y_s) == 0: y_s = torch.tensor([0], device=dev) |
| if len(x_s) == 0: x_s = torch.tensor([0], device=dev) |
|
|
| gz, gy, gx = torch.meshgrid(z_s, y_s, x_s, indexing='ij') |
| locs = torch.stack([gz.flatten(), gy.flatten(), gx.flatten()], dim=1) |
| N = locs.shape[0] |
|
|
| oz = torch.arange(pz, device=dev) |
| oy = torch.arange(py, device=dev) |
| ox = torch.arange(px, device=dev) |
|
|
| z_idx = (locs[:, 0:1] + oz.unsqueeze(0))[:, :, None, None].expand(N, pz, py, px) |
| y_idx = (locs[:, 1:2] + oy.unsqueeze(0))[:, None, :, None].expand(N, pz, py, px) |
| x_idx = (locs[:, 2:3] + ox.unsqueeze(0))[:, None, None, :].expand(N, pz, py, px) |
|
|
| return volume[z_idx, y_idx, x_idx], locs |
|
|
|
|
| def binarize_fast(patches, percentile=90.0, min_occ=0.01): |
| """Fast binarization using kthvalue (not quantile).""" |
| N = patches.shape[0] |
| V = patches[0].numel() |
| flat = patches.reshape(N, V).abs() |
|
|
| k = max(1, int(V * (1.0 - percentile / 100.0))) |
| thresholds = flat.kthvalue(V - k + 1, dim=1, keepdim=True).values |
|
|
| binary = (flat >= thresholds).float() |
| occ = binary.mean(dim=1) |
| keep = (occ >= min_occ) & (occ <= 0.95) |
| keep_idx = keep.nonzero(as_tuple=True)[0] |
|
|
| return binary.reshape(N, *patches.shape[1:])[keep_idx], keep_idx |
|
|
|
|
| def extract_and_prepare_volume(volume, config): |
| """ |
| Extract patches at ALL scales from a single volume. |
| Returns (canonical_patches, meta_list) all on GPU. |
| """ |
| canonical = config.canonical_shape |
| all_canonical = [] |
| all_meta = [] |
|
|
| for level, scale in enumerate(config.scales): |
| pz, py, px = scale |
| D, H, W = volume.shape |
| if D < pz or H < py or W < px: |
| continue |
|
|
| patches, locations = extract_patches_gpu(volume, scale, config.overlap) |
| binary, keep_idx = binarize_fast( |
| patches, config.binarize_percentile, config.min_occupancy) |
|
|
| if binary.shape[0] == 0: |
| continue |
|
|
| kept_locs = locations[keep_idx] |
|
|
| if binary.shape[1:] != tuple(canonical): |
| resized = F.interpolate( |
| binary.unsqueeze(1), size=canonical, |
| mode='trilinear', align_corners=False).squeeze(1) |
| else: |
| resized = binary |
|
|
| all_canonical.append(resized) |
| all_meta.append((level, kept_locs, scale, resized.shape[0])) |
|
|
| if not all_canonical: |
| return None, [] |
|
|
| return torch.cat(all_canonical, dim=0), all_meta |
|
|
|
|
| |
|
|
| def cluster_channels_gpu(latents, n_groups=8): |
| """Fully GPU channel clustering, no numpy.""" |
| N, C, H, W = latents.shape |
| flat = latents.reshape(N, C, -1) |
| flat = flat - flat.mean(dim=-1, keepdim=True) |
| flat = F.normalize(flat, dim=-1) |
| corr = torch.bmm(flat, flat.transpose(1, 2)).mean(dim=0) |
|
|
| dist = 1.0 - corr.abs() |
| remaining = torch.ones(C, dtype=torch.bool, device=latents.device) |
| target_size = max(1, C // n_groups) |
| groups = [] |
|
|
| for g in range(n_groups): |
| if not remaining.any(): |
| break |
| avail = remaining.nonzero(as_tuple=True)[0] |
| if g == 0: |
| seed = avail[0].item() |
| else: |
| |
| min_dists = torch.full((C,), float('inf'), device=latents.device) |
| for grp in groups: |
| grp_t = torch.tensor(grp, device=latents.device) |
| d = dist[:, grp_t].min(dim=1).values |
| min_dists = torch.min(min_dists, d) |
| min_dists[~remaining] = -1 |
| seed = min_dists.argmax().item() |
|
|
| group = [seed] |
| remaining[seed] = False |
|
|
| dists_from_seed = dist[seed].clone() |
| dists_from_seed[~remaining] = float('inf') |
| _, nearest = dists_from_seed.topk(min(target_size - 1, remaining.sum().item()), largest=False) |
| nearest = nearest[dists_from_seed[nearest] < float('inf')] |
|
|
| for c in nearest.tolist(): |
| group.append(c) |
| remaining[c] = False |
| groups.append(group) |
|
|
| |
| for c in remaining.nonzero(as_tuple=True)[0].tolist(): |
| grp_dists = [] |
| for gi, grp in enumerate(groups): |
| grp_t = torch.tensor(grp, device=latents.device) |
| grp_dists.append(dist[c, grp_t].min().item()) |
| groups[min(range(len(groups)), key=lambda i: grp_dists[i])].append(c) |
|
|
| return groups, corr |
|
|
|
|
| def compute_deviance_volume(latent, groups): |
| """Compute inter-group deviance. Returns (n_pairs, H, W), pair_list.""" |
| group_means = torch.stack([latent[grp].mean(dim=0) for grp in groups]) |
| n = len(groups) |
| i_idx, j_idx = torch.triu_indices(n, n, offset=1, device=latent.device) |
| deviances = (group_means[i_idx] - group_means[j_idx]).abs() |
| pairs = list(zip(i_idx.cpu().tolist(), j_idx.cpu().tolist())) |
| return deviances, pairs |
|
|
|
|
| |
|
|
| class MultiScaleExtractor: |
| def __init__(self, classifier, config=None): |
| self.classifier = classifier |
| self.config = config or ExtractionConfig() |
| self.classifier.eval() |
| self.device = next(classifier.parameters()).device |
| self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
| @torch.no_grad() |
| def classify_batch(self, patches): |
| """Classify mega-batch with amp.""" |
| N = patches.shape[0] |
| if N == 0: |
| return None |
|
|
| max_b = self.config.max_classify_batch |
| all_cls, all_conf, all_dim, all_curved, all_curv_type = [], [], [], [], [] |
|
|
| for start in range(0, N, max_b): |
| chunk = patches[start:start+max_b] |
| with torch.amp.autocast('cuda', dtype=self.amp_dtype): |
| out = self.classifier(chunk) |
|
|
| probs = F.softmax(out["class_logits"].float(), dim=-1) |
| top2 = probs.topk(2, dim=-1).values |
| all_cls.append(probs.argmax(dim=-1)) |
| all_conf.append(top2[:, 0] - top2[:, 1]) |
| all_dim.append(out["dim_logits"].argmax(dim=-1)) |
| all_curved.append(out["is_curved_pred"].squeeze(-1) > 0.0) |
| all_curv_type.append(out["curv_type_logits"].argmax(dim=-1)) |
|
|
| return { |
| "pred_class": torch.cat(all_cls), |
| "confidence": torch.cat(all_conf), |
| "dim_pred": torch.cat(all_dim), |
| "curved_pred": torch.cat(all_curved), |
| "curv_type_pred": torch.cat(all_curv_type), |
| } |
|
|
| def _results_to_annotations(self, results, meta, conf_thresh, source="raw", pair_indices=None): |
| """Convert batched results + meta into annotation list.""" |
| annotations = [] |
| offset = 0 |
|
|
| for level, kept_locs, scale, count in meta: |
| chunk_conf = results["confidence"][offset:offset+count] |
| mask = chunk_conf >= conf_thresh |
| local_idx = mask.nonzero(as_tuple=True)[0] |
|
|
| if len(local_idx) > 0: |
| gi = local_idx + offset |
| cls = results["pred_class"][gi].cpu() |
| conf = results["confidence"][gi].cpu() |
| dim = results["dim_pred"][gi].cpu() |
| curved = results["curved_pred"][gi].cpu() |
| curv = results["curv_type_pred"][gi].cpu() |
| locs = kept_locs[local_idx].cpu() |
|
|
| for i in range(len(local_idx)): |
| ann = GeometricAnnotation( |
| class_name=CLASS_NAMES[cls[i].item()], |
| class_idx=cls[i].item(), |
| confidence=conf[i].item(), |
| scale_level=level, |
| location=tuple(int(x) for x in locs[i].tolist()), |
| patch_size=scale, |
| dimension=dim[i].item(), |
| is_curved=bool(curved[i].item()), |
| curvature_type=CURVATURE_NAMES[curv[i].item()], |
| source=source, |
| ) |
| if source == "deviance" and pair_indices is not None: |
| pair_idx = locs[i][0].item() |
| if pair_idx < len(pair_indices): |
| ann.channel_group_pair = pair_indices[pair_idx] |
| annotations.append(ann) |
|
|
| offset += count |
| return annotations |
|
|
| def extract_batch(self, latents, channel_groups): |
| """ |
| Process multiple latents simultaneously. |
| latents: list of (C, H, W) tensors on GPU |
| Returns: list of per-image result dicts |
| """ |
| conf_thresh = self.config.confidence_threshold |
|
|
| |
| all_patches = [] |
| image_segments = [] |
|
|
| for img_idx, latent in enumerate(latents): |
| |
| raw_patches, raw_meta = extract_and_prepare_volume(latent, self.config) |
| if raw_patches is not None: |
| n = raw_patches.shape[0] |
| all_patches.append(raw_patches) |
| image_segments.append((img_idx, "raw", raw_meta, None, n)) |
|
|
| |
| if channel_groups is not None: |
| dev_vol, pair_indices = compute_deviance_volume(latent, channel_groups) |
| dev_patches, dev_meta = extract_and_prepare_volume(dev_vol, self.config) |
| if dev_patches is not None: |
| n = dev_patches.shape[0] |
| all_patches.append(dev_patches) |
| image_segments.append((img_idx, "deviance", dev_meta, pair_indices, n)) |
|
|
| if not all_patches: |
| return [{ |
| 'raw_annotations': [], 'deviance_annotations': [], |
| 'n_raw': 0, 'n_deviance': 0, |
| } for _ in latents] |
|
|
| |
| mega_batch = torch.cat(all_patches, dim=0) |
| del all_patches |
|
|
| results = self.classify_batch(mega_batch) |
| del mega_batch |
|
|
| if results is None: |
| return [{ |
| 'raw_annotations': [], 'deviance_annotations': [], |
| 'n_raw': 0, 'n_deviance': 0, |
| } for _ in latents] |
|
|
| |
| per_image = {i: {'raw': [], 'deviance': []} for i in range(len(latents))} |
| global_offset = 0 |
|
|
| for img_idx, source, meta, pair_indices, total_count in image_segments: |
| |
| seg_results = { |
| k: v[global_offset:global_offset+total_count] |
| for k, v in results.items() |
| } |
| anns = self._results_to_annotations( |
| seg_results, meta, conf_thresh, source, pair_indices) |
| per_image[img_idx][source].extend(anns) |
| global_offset += total_count |
|
|
| del results |
|
|
| |
| output = [] |
| for i in range(len(latents)): |
| raw = per_image[i]['raw'] |
| dev = per_image[i]['deviance'] |
| output.append({ |
| 'raw_annotations': raw, |
| 'deviance_annotations': dev, |
| 'n_raw': len(raw), |
| 'n_deviance': len(dev), |
| }) |
|
|
| return output |
|
|
| |
| def extract_from_latent(self, latent, channel_groups=None): |
| return self.extract_batch([latent], channel_groups)[0] |
|
|
|
|
| print("✓ Cell 4: Fully batched multi-image extraction") |
| print(f" Scales: {ExtractionConfig().scales}") |
| print(f" Canonical: {ExtractionConfig().canonical_shape}") |
| print(f" Image batch: {ExtractionConfig().image_batch_size}") |
| print(f" Classify batch: {ExtractionConfig().max_classify_batch}") |
| print(f" Percentile: {ExtractionConfig().binarize_percentile}th (kthvalue)") |