AbstractPhil commited on
Commit
8fb9dff
·
verified ·
1 Parent(s): 8c3a929

Update cell4_vae_pipeline.py

Browse files
Files changed (1) hide show
  1. cell4_vae_pipeline.py +190 -138
cell4_vae_pipeline.py CHANGED
@@ -1,11 +1,10 @@
1
  """
2
- Cell 4: Multi-Scale Geometric Extraction Pipeline
3
- ===================================================
4
  Run after Cells 1-3. Uses globals from prior cells.
5
 
6
- Updated for PatchCrossAttentionClassifier (no Conv3d).
7
- Defines extraction functions and the MultiScaleExtractor class.
8
- Does NOT execute anything — Cell 5 uses these.
9
  """
10
 
11
  import numpy as np
@@ -20,8 +19,8 @@ import math
20
  class ExtractionConfig:
21
  canonical_shape: Tuple[int, int, int] = (8, 16, 16)
22
  scales: List[Tuple[int, int, int]] = field(default_factory=lambda: [
23
- (32, 64, 64), # L0: full latent
24
- (16, 32, 32), # L1: regional
25
  (8, 16, 16), # L2: native patch
26
  (4, 8, 8), # L3: fine detail
27
  ])
@@ -30,6 +29,7 @@ class ExtractionConfig:
30
  min_occupancy: float = 0.005
31
  binarize_percentiles: List[float] = field(default_factory=lambda: [75, 90, 95])
32
  n_channel_groups: int = 8
 
33
  device: str = 'cuda'
34
 
35
 
@@ -48,53 +48,116 @@ class GeometricAnnotation:
48
  channel_group_pair: Optional[Tuple[int, int]] = None
49
 
50
 
51
- def extract_patches_sliding(volume, patch_size, overlap=0.5):
52
- """Extract overlapping patches from a 3D volume."""
 
 
 
 
 
 
53
  D, H, W = volume.shape
54
  pz, py, px = patch_size
55
 
56
- if D < pz or H < py or W < px:
57
- pad_d = max(pz - D, 0)
58
- pad_h = max(py - H, 0)
59
- pad_w = max(px - W, 0)
 
60
  volume = F.pad(volume, (0, pad_w, 0, pad_h, 0, pad_d))
61
  D, H, W = volume.shape
62
 
63
- stride_z = max(1, int(pz * (1 - overlap)))
64
- stride_y = max(1, int(py * (1 - overlap)))
65
- stride_x = max(1, int(px * (1 - overlap)))
66
 
67
- patches = []
68
- for z in range(0, max(1, D - pz + 1), stride_z):
69
- for y in range(0, max(1, H - py + 1), stride_y):
70
- for x in range(0, max(1, W - px + 1), stride_x):
71
- patch = volume[z:z+pz, y:y+py, x:x+px]
72
- patches.append((patch, (z, y, x)))
73
 
74
- return patches
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- def resize_to_canonical(patch, target=(8, 16, 16)):
78
- """Resize 3D patch to canonical resolution via trilinear interpolation."""
79
- x = patch.unsqueeze(0).unsqueeze(0).float()
80
- x = F.interpolate(x, size=target, mode='trilinear', align_corners=False)
81
- return x.squeeze(0).squeeze(0)
82
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- def binarize_continuous(patch, percentiles=[75, 90, 95]):
85
- """Binarize continuous patch at multiple percentile thresholds."""
86
- flat = patch.flatten()
87
- nonzero = flat[flat.abs() > 1e-8]
88
- if len(nonzero) < 10:
89
- return [torch.zeros_like(patch)] * len(percentiles)
90
- thresholds = [torch.quantile(nonzero.abs(), p / 100.0).item() for p in percentiles]
91
- return [(patch.abs() >= t).float() for t in thresholds]
92
 
 
93
 
94
  def cluster_channels(latents, n_groups=8):
95
  """
96
  Cluster VAE channels by correlation.
97
- latents: (N, C, H, W) batch
98
  Returns: (groups, corr_matrix)
99
  """
100
  N, C, H, W = latents.shape
@@ -153,164 +216,153 @@ def cluster_channels(latents, n_groups=8):
153
 
154
  def compute_inter_group_deviances(latent, groups):
155
  """
156
- Compute (H, W) deviance maps between channel groups.
157
  latent: (C, H, W), groups: list of channel index lists
158
- Returns: list of ((group_i, group_j), deviance_map)
159
  """
160
- group_means = torch.stack([latent[grp].mean(dim=0) for grp in groups])
161
  n = len(groups)
162
- deviances = []
163
- for i in range(n):
164
- for j in range(i + 1, n):
165
- dev = (group_means[i] - group_means[j]).abs()
166
- deviances.append(((i, j), dev))
167
- return deviances
168
 
169
 
170
- def deviance_maps_to_3d(deviances, n_groups):
171
- """Stack deviance maps into (n_pairs, H, W) volume."""
172
- return torch.stack([dev for (_, dev) in deviances], dim=0)
173
-
174
 
175
  class MultiScaleExtractor:
176
  """
177
- Confidence-cascaded multi-scale geometric extractor.
178
- Uses trained PatchCrossAttentionClassifier (from Cell 2 globals).
179
  """
180
 
181
  def __init__(self, classifier, config=None):
182
  self.classifier = classifier
183
  self.config = config or ExtractionConfig()
184
  self.classifier.eval()
 
185
 
186
  @torch.no_grad()
187
- def classify_patches(self, patches, max_batch=512):
188
- """Classify batch of (B, 8, 16, 16) patches with chunking to avoid OOM."""
189
- device = next(self.classifier.parameters()).device
190
  N = patches.shape[0]
191
-
192
  all_results = []
193
- for start in range(0, N, max_batch):
194
- chunk = patches[start:start+max_batch].to(device)
 
195
  out = self.classifier(chunk)
196
  probs = F.softmax(out["class_logits"], dim=-1)
197
  max_prob, pred_class = probs.max(dim=-1)
198
  top2 = probs.topk(2, dim=-1).values
199
  margin = top2[:, 0] - top2[:, 1]
200
- dim_pred = out["dim_logits"].argmax(dim=-1)
201
- curved_pred = (out["is_curved_pred"].squeeze(-1) > 0.0)
202
- curv_type_pred = out["curv_type_logits"].argmax(dim=-1)
203
 
204
  all_results.append({
205
  "pred_class": pred_class.cpu(),
206
  "confidence": margin.cpu(),
207
  "max_prob": max_prob.cpu(),
208
- "dim_pred": dim_pred.cpu(),
209
- "curved_pred": curved_pred.cpu(),
210
- "curv_type_pred": curv_type_pred.cpu(),
211
- "features": out["features"].cpu(),
212
  })
213
  del chunk, out, probs
214
- torch.cuda.empty_cache()
215
 
 
 
216
  return {k: torch.cat([r[k] for r in all_results], dim=0)
217
  for k in all_results[0]}
218
 
219
  def extract_from_volume(self, volume, min_confidence=None):
220
  """
221
- Extract annotations via confidence cascade.
222
- volume: (D, H, W) continuous or binary tensor
223
  """
224
  conf_thresh = min_confidence or self.config.confidence_threshold
 
225
  annotations = []
226
 
227
- regions_to_process = [(0, 0, 0, 0,
228
- volume.shape[0], volume.shape[1], volume.shape[2])]
229
 
230
  for level, scale in enumerate(self.config.scales):
231
- if not regions_to_process:
232
- break
233
-
234
- next_regions = []
235
  pz, py, px = scale
 
 
 
 
236
 
237
- all_patches = []
238
-
239
- for ridx, (lvl, rz0, ry0, rx0, rz1, ry1, rx1) in enumerate(regions_to_process):
240
- if lvl != level:
241
- next_regions.append((lvl, rz0, ry0, rx0, rz1, ry1, rx1))
242
- continue
243
-
244
- subvol = volume[rz0:rz1, ry0:ry1, rx0:rx1]
245
- patches = extract_patches_sliding(subvol, (pz, py, px), self.config.overlap)
246
-
247
- for patch, (lz, ly, lx) in patches:
248
- binary_patches = binarize_continuous(patch, self.config.binarize_percentiles)
249
- for bp in binary_patches:
250
- occ = bp.mean().item()
251
- if occ < self.config.min_occupancy:
252
- continue
253
- canonical = resize_to_canonical(bp, self.config.canonical_shape)
254
- all_patches.append((
255
- canonical,
256
- (rz0 + lz, ry0 + ly, rx0 + lx),
257
- ridx, scale))
258
-
259
- if not all_patches:
260
- regions_to_process = next_regions
261
  continue
262
 
263
- # Batch classify
264
- batch = torch.stack([p[0] for p in all_patches])
265
- results = self.classify_patches(batch)
266
-
267
- for i, (_, loc, ridx, sc) in enumerate(all_patches):
268
- conf = results["confidence"][i].item()
269
- cls_idx = results["pred_class"][i].item()
270
-
271
- if conf >= conf_thresh:
272
- ann = GeometricAnnotation(
273
- class_name=CLASS_NAMES[cls_idx],
274
- class_idx=cls_idx,
275
- confidence=conf,
276
- scale_level=level,
277
- location=loc,
278
- patch_size=sc,
279
- dimension=results["dim_pred"][i].item(),
280
- is_curved=bool(results["curved_pred"][i].item()),
281
- curvature_type=CURVATURE_NAMES[results["curv_type_pred"][i].item()],
282
- )
283
- annotations.append(ann)
284
- elif level < len(self.config.scales) - 1:
285
- z0, y0, x0 = loc
286
- next_regions.append((
287
- level + 1,
288
- z0, y0, x0,
289
- min(z0 + sc[0], volume.shape[0]),
290
- min(y0 + sc[1], volume.shape[1]),
291
- min(x0 + sc[2], volume.shape[2]),
292
- ))
293
-
294
- regions_to_process = next_regions
295
 
296
  return annotations
297
 
298
  def extract_from_latent(self, latent, channel_groups=None):
299
  """
300
- Full extraction for a single Flux 2 VAE latent.
301
  latent: (C, H, W) tensor
302
  """
303
- raw_annotations = self.extract_from_volume(latent)
 
 
 
304
 
 
305
  deviance_annotations = []
306
  if channel_groups is not None:
307
- deviances = compute_inter_group_deviances(latent, channel_groups)
308
- dev_volume = deviance_maps_to_3d(deviances, len(channel_groups))
309
- deviance_annotations = self.extract_from_volume(dev_volume)
310
  for ann in deviance_annotations:
311
  pair_idx = ann.location[0]
312
- if pair_idx < len(deviances):
313
- ann.channel_group_pair = deviances[pair_idx][0]
314
 
315
  return {
316
  'raw_annotations': raw_annotations,
@@ -320,6 +372,6 @@ class MultiScaleExtractor:
320
  }
321
 
322
 
323
- print("✓ Cell 4: Extraction pipeline defined (PatchCrossAttention)")
324
  print(f" Scales: {ExtractionConfig().scales}")
325
  print(f" Canonical: {ExtractionConfig().canonical_shape}")
 
1
  """
2
+ Cell 4: Multi-Scale Geometric Extraction Pipeline (Vectorized)
3
+ ===============================================================
4
  Run after Cells 1-3. Uses globals from prior cells.
5
 
6
+ Fully vectorized no Python loops over patches.
7
+ Uses unfold for extraction, batched binarization, batched resize.
 
8
  """
9
 
10
  import numpy as np
 
19
  class ExtractionConfig:
20
  canonical_shape: Tuple[int, int, int] = (8, 16, 16)
21
  scales: List[Tuple[int, int, int]] = field(default_factory=lambda: [
22
+ (16, 64, 64), # L0: full latent
23
+ (8, 32, 32), # L1: regional
24
  (8, 16, 16), # L2: native patch
25
  (4, 8, 8), # L3: fine detail
26
  ])
 
29
  min_occupancy: float = 0.005
30
  binarize_percentiles: List[float] = field(default_factory=lambda: [75, 90, 95])
31
  n_channel_groups: int = 8
32
+ max_classify_batch: int = 512
33
  device: str = 'cuda'
34
 
35
 
 
48
  channel_group_pair: Optional[Tuple[int, int]] = None
49
 
50
 
51
+ # === Vectorized Extraction ====================================================
52
+
53
+ def extract_patches_unfold(volume, patch_size, overlap=0.5):
54
+ """
55
+ Extract all patches from volume using unfold. Fully vectorized.
56
+ volume: (D, H, W)
57
+ Returns: (patches: (N, pz, py, px), locations: (N, 3))
58
+ """
59
  D, H, W = volume.shape
60
  pz, py, px = patch_size
61
 
62
+ # Pad if needed
63
+ pad_d = max(pz - D, 0)
64
+ pad_h = max(py - H, 0)
65
+ pad_w = max(px - W, 0)
66
+ if pad_d > 0 or pad_h > 0 or pad_w > 0:
67
  volume = F.pad(volume, (0, pad_w, 0, pad_h, 0, pad_d))
68
  D, H, W = volume.shape
69
 
70
+ sz = max(1, int(pz * (1 - overlap)))
71
+ sy = max(1, int(py * (1 - overlap)))
72
+ sx = max(1, int(px * (1 - overlap)))
73
 
74
+ # unfold each dim: (D, H, W) → patches
75
+ # Use as_strided for 3D unfold
76
+ nz = max(1, (D - pz) // sz + 1)
77
+ ny = max(1, (H - py) // sy + 1)
78
+ nx = max(1, (W - px) // sx + 1)
 
79
 
80
+ # Build index grids
81
+ z_starts = torch.arange(nz, device=volume.device) * sz
82
+ y_starts = torch.arange(ny, device=volume.device) * sy
83
+ x_starts = torch.arange(nx, device=volume.device) * sx
84
 
85
+ # Clamp to valid range
86
+ z_starts = z_starts.clamp(max=D - pz)
87
+ y_starts = y_starts.clamp(max=H - py)
88
+ x_starts = x_starts.clamp(max=W - px)
89
+
90
+ # Meshgrid of all patch origins
91
+ gz, gy, gx = torch.meshgrid(z_starts, y_starts, x_starts, indexing='ij')
92
+ locations = torch.stack([gz.flatten(), gy.flatten(), gx.flatten()], dim=1) # (N, 3)
93
+ N = locations.shape[0]
94
+
95
+ # Extract using advanced indexing
96
+ # Build (N, pz, py, px) index tensors
97
+ oz = torch.arange(pz, device=volume.device)
98
+ oy = torch.arange(py, device=volume.device)
99
+ ox = torch.arange(px, device=volume.device)
100
+
101
+ # (N, pz)
102
+ z_idx = locations[:, 0:1] + oz.unsqueeze(0) # (N, pz)
103
+ y_idx = locations[:, 1:2] + oy.unsqueeze(0) # (N, py)
104
+ x_idx = locations[:, 2:3] + ox.unsqueeze(0) # (N, px)
105
+
106
+ # Expand to (N, pz, py, px)
107
+ z_idx = z_idx[:, :, None, None].expand(N, pz, py, px)
108
+ y_idx = y_idx[:, None, :, None].expand(N, pz, py, px)
109
+ x_idx = x_idx[:, None, None, :].expand(N, pz, py, px)
110
+
111
+ patches = volume[z_idx, y_idx, x_idx] # (N, pz, py, px)
112
+
113
+ return patches, locations
114
+
115
+
116
+ def binarize_batch(patches, percentiles=[75, 90, 95]):
117
+ """
118
+ Binarize N patches at multiple thresholds. Vectorized.
119
+ patches: (N, pz, py, px)
120
+ Returns: (N * len(percentiles), pz, py, px), repeat_indices
121
+ """
122
+ N = patches.shape[0]
123
+ flat = patches.reshape(N, -1)
124
+ abs_flat = flat.abs()
125
+
126
+ results = []
127
+ for p in percentiles:
128
+ # Per-patch percentile threshold
129
+ thresholds = torch.quantile(abs_flat, p / 100.0, dim=1, keepdim=True) # (N, 1)
130
+ binary = (abs_flat >= thresholds).float().reshape(patches.shape)
131
+ results.append(binary)
132
+
133
+ # Stack: (n_thresh, N, pz, py, px) → (N * n_thresh, pz, py, px)
134
+ stacked = torch.cat(results, dim=0) # (N*n_thresh, pz, py, px)
135
+
136
+ # Location indices: each original patch repeated n_thresh times
137
+ repeat_idx = torch.arange(N, device=patches.device).repeat(len(percentiles))
138
+
139
+ return stacked, repeat_idx
140
 
 
 
 
 
 
141
 
142
+ def resize_batch(patches, target=(8, 16, 16)):
143
+ """
144
+ Resize batch of 3D patches to canonical. Vectorized.
145
+ patches: (N, pz, py, px)
146
+ Returns: (N, tz, ty, tx)
147
+ """
148
+ if patches.shape[1:] == target:
149
+ return patches
150
+ x = patches.unsqueeze(1) # (N, 1, pz, py, px)
151
+ x = F.interpolate(x, size=target, mode='trilinear', align_corners=False)
152
+ return x.squeeze(1)
153
 
 
 
 
 
 
 
 
 
154
 
155
+ # === Channel Clustering =======================================================
156
 
157
  def cluster_channels(latents, n_groups=8):
158
  """
159
  Cluster VAE channels by correlation.
160
+ latents: (N, C, H, W)
161
  Returns: (groups, corr_matrix)
162
  """
163
  N, C, H, W = latents.shape
 
216
 
217
  def compute_inter_group_deviances(latent, groups):
218
  """
219
+ Compute deviance maps between channel groups. Vectorized.
220
  latent: (C, H, W), groups: list of channel index lists
221
+ Returns: (n_pairs, H, W)
222
  """
223
+ group_means = torch.stack([latent[grp].mean(dim=0) for grp in groups]) # (G, H, W)
224
  n = len(groups)
225
+ # All pairs via broadcasting
226
+ i_idx, j_idx = torch.triu_indices(n, n, offset=1)
227
+ deviances = (group_means[i_idx] - group_means[j_idx]).abs() # (n_pairs, H, W)
228
+ pair_indices = list(zip(i_idx.tolist(), j_idx.tolist()))
229
+ return deviances, pair_indices
 
230
 
231
 
232
+ # === Extractor ================================================================
 
 
 
233
 
234
  class MultiScaleExtractor:
235
  """
236
+ Vectorized multi-scale geometric extractor.
237
+ No Python loops over individual patches.
238
  """
239
 
240
  def __init__(self, classifier, config=None):
241
  self.classifier = classifier
242
  self.config = config or ExtractionConfig()
243
  self.classifier.eval()
244
+ self.device = next(classifier.parameters()).device
245
 
246
  @torch.no_grad()
247
+ def classify_patches(self, patches):
248
+ """Classify (N, 8, 16, 16) patches in chunks."""
 
249
  N = patches.shape[0]
250
+ max_b = self.config.max_classify_batch
251
  all_results = []
252
+
253
+ for start in range(0, N, max_b):
254
+ chunk = patches[start:start+max_b].to(self.device)
255
  out = self.classifier(chunk)
256
  probs = F.softmax(out["class_logits"], dim=-1)
257
  max_prob, pred_class = probs.max(dim=-1)
258
  top2 = probs.topk(2, dim=-1).values
259
  margin = top2[:, 0] - top2[:, 1]
 
 
 
260
 
261
  all_results.append({
262
  "pred_class": pred_class.cpu(),
263
  "confidence": margin.cpu(),
264
  "max_prob": max_prob.cpu(),
265
+ "dim_pred": out["dim_logits"].argmax(dim=-1).cpu(),
266
+ "curved_pred": (out["is_curved_pred"].squeeze(-1) > 0.0).cpu(),
267
+ "curv_type_pred": out["curv_type_logits"].argmax(dim=-1).cpu(),
 
268
  })
269
  del chunk, out, probs
 
270
 
271
+ if not all_results:
272
+ return None
273
  return {k: torch.cat([r[k] for r in all_results], dim=0)
274
  for k in all_results[0]}
275
 
276
  def extract_from_volume(self, volume, min_confidence=None):
277
  """
278
+ Vectorized extraction over all scales.
279
+ volume: (D, H, W) tensor on any device
280
  """
281
  conf_thresh = min_confidence or self.config.confidence_threshold
282
+ canonical = self.config.canonical_shape
283
  annotations = []
284
 
285
+ volume = volume.float().cpu()
 
286
 
287
  for level, scale in enumerate(self.config.scales):
 
 
 
 
288
  pz, py, px = scale
289
+ D, H, W = volume.shape
290
+
291
+ if D < pz or H < py or W < px:
292
+ continue
293
 
294
+ # 1. Extract all patches — vectorized
295
+ patches, locations = extract_patches_unfold(volume, scale, self.config.overlap)
296
+ # patches: (N, pz, py, px), locations: (N, 3)
297
+
298
+ if patches.shape[0] == 0:
299
+ continue
300
+
301
+ # 2. Binarize at all thresholds — vectorized
302
+ binary, repeat_idx = binarize_batch(patches, self.config.binarize_percentiles)
303
+ # binary: (N*n_thresh, pz, py, px), repeat_idx: (N*n_thresh,)
304
+
305
+ # 3. Filter by occupancy — vectorized
306
+ occ = binary.reshape(binary.shape[0], -1).mean(dim=1)
307
+ keep = occ >= self.config.min_occupancy
308
+ binary = binary[keep]
309
+ loc_idx = repeat_idx[keep]
310
+
311
+ if binary.shape[0] == 0:
 
 
 
 
 
 
312
  continue
313
 
314
+ # 4. Resize to canonical — vectorized
315
+ canonical_patches = resize_batch(binary, canonical)
316
+
317
+ # 5. Classify in chunks
318
+ results = self.classify_patches(canonical_patches)
319
+ if results is None:
320
+ continue
321
+
322
+ # 6. Filter by confidence and build annotations
323
+ conf_mask = results["confidence"] >= conf_thresh
324
+ indices = conf_mask.nonzero(as_tuple=True)[0]
325
+
326
+ for i in indices.tolist():
327
+ orig_idx = loc_idx[i].item()
328
+ loc = locations[orig_idx].tolist()
329
+ ann = GeometricAnnotation(
330
+ class_name=CLASS_NAMES[results["pred_class"][i].item()],
331
+ class_idx=results["pred_class"][i].item(),
332
+ confidence=results["confidence"][i].item(),
333
+ scale_level=level,
334
+ location=tuple(int(x) for x in loc),
335
+ patch_size=scale,
336
+ dimension=results["dim_pred"][i].item(),
337
+ is_curved=bool(results["curved_pred"][i].item()),
338
+ curvature_type=CURVATURE_NAMES[results["curv_type_pred"][i].item()],
339
+ )
340
+ annotations.append(ann)
341
+
342
+ del patches, locations, binary, canonical_patches, results
 
 
 
343
 
344
  return annotations
345
 
346
  def extract_from_latent(self, latent, channel_groups=None):
347
  """
348
+ Full extraction for one Flux 2 VAE latent.
349
  latent: (C, H, W) tensor
350
  """
351
+ latent_cpu = latent.cpu().float()
352
+
353
+ # Raw volume: treat channels as depth
354
+ raw_annotations = self.extract_from_volume(latent_cpu)
355
 
356
+ # Deviance volume
357
  deviance_annotations = []
358
  if channel_groups is not None:
359
+ dev_maps, pair_indices = compute_inter_group_deviances(latent_cpu, channel_groups)
360
+ # dev_maps: (n_pairs, H, W) — treat as (D, H, W)
361
+ deviance_annotations = self.extract_from_volume(dev_maps)
362
  for ann in deviance_annotations:
363
  pair_idx = ann.location[0]
364
+ if pair_idx < len(pair_indices):
365
+ ann.channel_group_pair = pair_indices[pair_idx]
366
 
367
  return {
368
  'raw_annotations': raw_annotations,
 
372
  }
373
 
374
 
375
+ print("✓ Cell 4: Vectorized extraction pipeline defined")
376
  print(f" Scales: {ExtractionConfig().scales}")
377
  print(f" Canonical: {ExtractionConfig().canonical_shape}")