mnm-matin commited on
Commit
90cc3c3
·
verified ·
1 Parent(s): 4e18ed4

Update manufacturing Space to latest HyperView public API

Browse files
.hyperview/extensions/manufacturing-readout/panel.js CHANGED
@@ -39,7 +39,6 @@ function normalizeModels(value) {
39
  displayName: String(model.displayName || model.display_name || model.key || `Model ${index + 1}`),
40
  buttonLabel: String(model.buttonLabel || model.button_label || `${model.key || "Model"} query`),
41
  layoutKey: model.layoutKey || model.layout_key || null,
42
- spaceKey: model.spaceKey || model.space_key || null,
43
  }))
44
  .filter((model) => model.layoutKey);
45
  }
@@ -178,6 +177,77 @@ function CompactEvidence({ item, models }) {
178
  );
179
  }
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  function StepBlock({ number, title, children }) {
182
  return React.createElement(
183
  "div",
@@ -320,8 +390,7 @@ function choiceFromSimilarity(similarity, examples, models) {
320
  const sourceKey = source.includes(":") ? source.split(":").pop() : null;
321
  const model =
322
  models.find((candidate) => candidate.key === sourceKey) ||
323
- models.find((candidate) => candidate.layoutKey === similarity.layout_key) ||
324
- models.find((candidate) => candidate.spaceKey === similarity.space_key);
325
  if (!model) return null;
326
  const metric = advantageMetric(model.key, item.id);
327
  return {
@@ -332,18 +401,6 @@ function choiceFromSimilarity(similarity, examples, models) {
332
  };
333
  }
334
 
335
- async function sendJson(path, payload, method = "POST") {
336
- const response = await fetch(path, {
337
- method,
338
- headers: { "Content-Type": "application/json" },
339
- body: JSON.stringify(payload),
340
- });
341
- if (!response.ok) {
342
- throw new Error(await response.text());
343
- }
344
- return response.json();
345
- }
346
-
347
  function buttonText(model) {
348
  if (model.key === "candidate") return "Hyper3";
349
  if (model.key === "clip") return "CLIP";
@@ -416,6 +473,7 @@ function WalkthroughCard({ item, models, onSelectQuery, loadingKey, activeModelK
416
  { style: { color: colors.mutedText, fontSize: 10, lineHeight: 1.3 } },
417
  "Wrong-line references send operators to the wrong golden sample.",
418
  ),
 
419
  React.createElement(CompactEvidence, { item, models }),
420
  );
421
  }
@@ -424,7 +482,6 @@ export default function ManufacturingPanel() {
424
  const props = usePanelProps() || {};
425
  const commands = usePanelCommands();
426
  const runtimeState = usePanelRuntimeState ? usePanelRuntimeState() : {};
427
- const workspaceId = String(props.workspaceId || props.workspace_id || "manufacturing-visa-reference-clip-hyper3clip");
428
  const models = normalizeModels(props.models);
429
  const examples = Array.isArray(props.examples) ? props.examples : [];
430
  const primaryExample = examples.find((item) => item.id === "fryum") || examples[0] || null;
@@ -449,6 +506,7 @@ export default function ManufacturingPanel() {
449
  const item = examples.find((example) => example.queryId === sampleId);
450
  const metric = advantageMetric(model.key, item?.id);
451
  const nextChoice = {
 
452
  modelName: model.displayName,
453
  queryLabel: title(item?.queryLabel || "fryum"),
454
  metricLine: metric?.line || null,
@@ -458,63 +516,25 @@ export default function ManufacturingPanel() {
458
  setActiveChoice(nextChoice);
459
  setLoadingKey(choiceKey);
460
  try {
461
- if (commands.setActiveLayout) {
462
- await commands.setActiveLayout(model.layoutKey, { persist: "none" });
463
- }
464
- if (commands.showSimilar) {
465
- await commands.showSimilar({
466
- sampleId,
467
- layoutKey: model.layoutKey,
468
- spaceKey: model.spaceKey,
469
- k: 10,
470
- source: `manufacturing-demo:${model.key}`,
471
- focus: "samples",
472
- persist: "none",
473
- });
474
- }
475
- await sendJson("/api/control/ui/state", {
476
- workspace_id: workspaceId,
477
- set_active_layout: true,
478
- active_layout_key: model.layoutKey,
479
- set_selection: true,
480
- selected_ids: [sampleId],
481
- set_similarity_query: true,
482
- similarity_query: {
483
- sample_id: sampleId,
484
- layout_key: model.layoutKey,
485
- space_key: model.spaceKey,
486
- k: 10,
487
- source: `manufacturing-demo:${model.key}`,
488
- },
489
- }, "PATCH");
490
  setActiveChoice(nextChoice);
491
  setActiveModelKey(model.key);
492
  } catch (error) {
493
- try {
494
- await sendJson("/api/control/ui/layout", {
495
- workspace_id: workspaceId,
496
- layout_key: model.layoutKey,
497
- });
498
- await sendJson("/api/control/ui/similarity", {
499
- workspace_id: workspaceId,
500
- sample_id: sampleId,
501
- layout_key: model.layoutKey,
502
- space_key: model.spaceKey,
503
- k: 10,
504
- source: `manufacturing-demo:${model.key}`,
505
- });
506
- setActiveChoice(nextChoice);
507
- setActiveModelKey(model.key);
508
- return;
509
- } catch (fallbackError) {
510
- const message = fallbackError instanceof Error ? fallbackError.message : String(fallbackError);
511
- setPanelError(`Could not show neighbors: ${message}`);
512
- }
513
  } finally {
514
  setLoadingKey(null);
515
  }
516
  },
517
- [commands, examples, workspaceId],
518
  );
519
 
520
  return React.createElement(
 
39
  displayName: String(model.displayName || model.display_name || model.key || `Model ${index + 1}`),
40
  buttonLabel: String(model.buttonLabel || model.button_label || `${model.key || "Model"} query`),
41
  layoutKey: model.layoutKey || model.layout_key || null,
 
42
  }))
43
  .filter((model) => model.layoutKey);
44
  }
 
177
  );
178
  }
179
 
180
+ function ActiveNeighbors({ item, modelKey }) {
181
+ if (!item || !modelKey) return null;
182
+ const summary = item.summaries?.[modelKey] || {};
183
+ const neighbors = Array.isArray(summary.neighbors) ? summary.neighbors.slice(0, 5) : [];
184
+ if (!neighbors.length) return null;
185
+ const cell = {
186
+ padding: "4px 3px",
187
+ borderBottom: `1px solid ${colors.border}`,
188
+ fontSize: 10,
189
+ color: colors.bodyText,
190
+ };
191
+ const head = { ...cell, color: colors.mutedText, fontSize: 9, textTransform: "uppercase" };
192
+ return React.createElement(
193
+ "div",
194
+ {
195
+ style: {
196
+ borderTop: `1px solid ${colors.border}`,
197
+ paddingTop: 7,
198
+ display: "flex",
199
+ flexDirection: "column",
200
+ gap: 4,
201
+ },
202
+ },
203
+ React.createElement(
204
+ "div",
205
+ { style: { color: colors.strongText, fontSize: 11.5, fontWeight: 900 } },
206
+ modelKey === "candidate" ? "Hyper3 Top Refs" : "CLIP Top Refs",
207
+ ),
208
+ React.createElement(
209
+ "table",
210
+ { style: { width: "100%", borderCollapse: "collapse" } },
211
+ React.createElement(
212
+ "thead",
213
+ null,
214
+ React.createElement(
215
+ "tr",
216
+ null,
217
+ React.createElement("th", { style: head, align: "left" }, "Rank"),
218
+ React.createElement("th", { style: head, align: "left" }, "SKU"),
219
+ React.createElement("th", { style: head, align: "right" }, "Signal"),
220
+ ),
221
+ ),
222
+ React.createElement(
223
+ "tbody",
224
+ null,
225
+ neighbors.map((neighbor) => {
226
+ const signal = neighbor.sameSkuNormal
227
+ ? "correct normal"
228
+ : neighbor.pipeFryumConfusion
229
+ ? "wrong line"
230
+ : neighbor.sameSku
231
+ ? "same SKU"
232
+ : "other";
233
+ const signalColor = neighbor.sameSkuNormal
234
+ ? colors.good
235
+ : neighbor.pipeFryumConfusion
236
+ ? colors.error
237
+ : colors.bodyText;
238
+ return React.createElement(
239
+ "tr",
240
+ { key: `${modelKey}-${neighbor.rank}-${neighbor.id}` },
241
+ React.createElement("td", { style: { ...cell, color: colors.strongText, fontWeight: 800 } }, `#${neighbor.rank}`),
242
+ React.createElement("td", { style: cell }, pretty(neighbor.sku)),
243
+ React.createElement("td", { style: { ...cell, color: signalColor, fontWeight: 800 }, align: "right" }, signal),
244
+ );
245
+ }),
246
+ ),
247
+ ),
248
+ );
249
+ }
250
+
251
  function StepBlock({ number, title, children }) {
252
  return React.createElement(
253
  "div",
 
390
  const sourceKey = source.includes(":") ? source.split(":").pop() : null;
391
  const model =
392
  models.find((candidate) => candidate.key === sourceKey) ||
393
+ models.find((candidate) => candidate.layoutKey === similarity.layout_key);
 
394
  if (!model) return null;
395
  const metric = advantageMetric(model.key, item.id);
396
  return {
 
401
  };
402
  }
403
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  function buttonText(model) {
405
  if (model.key === "candidate") return "Hyper3";
406
  if (model.key === "clip") return "CLIP";
 
473
  { style: { color: colors.mutedText, fontSize: 10, lineHeight: 1.3 } },
474
  "Wrong-line references send operators to the wrong golden sample.",
475
  ),
476
+ React.createElement(ActiveNeighbors, { item, modelKey: activeModelKey }),
477
  React.createElement(CompactEvidence, { item, models }),
478
  );
479
  }
 
482
  const props = usePanelProps() || {};
483
  const commands = usePanelCommands();
484
  const runtimeState = usePanelRuntimeState ? usePanelRuntimeState() : {};
 
485
  const models = normalizeModels(props.models);
486
  const examples = Array.isArray(props.examples) ? props.examples : [];
487
  const primaryExample = examples.find((item) => item.id === "fryum") || examples[0] || null;
 
506
  const item = examples.find((example) => example.queryId === sampleId);
507
  const metric = advantageMetric(model.key, item?.id);
508
  const nextChoice = {
509
+ modelKey: model.key,
510
  modelName: model.displayName,
511
  queryLabel: title(item?.queryLabel || "fryum"),
512
  metricLine: metric?.line || null,
 
516
  setActiveChoice(nextChoice);
517
  setLoadingKey(choiceKey);
518
  try {
519
+ await commands.setActiveLayout(model.layoutKey, { persist: true });
520
+ await commands.showSimilar({
521
+ sampleId,
522
+ layoutKey: model.layoutKey,
523
+ k: 10,
524
+ source: `manufacturing-demo:${model.key}`,
525
+ focus: "samples",
526
+ persist: true,
527
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  setActiveChoice(nextChoice);
529
  setActiveModelKey(model.key);
530
  } catch (error) {
531
+ const message = error instanceof Error ? error.message : String(error);
532
+ setPanelError(`Could not show neighbors: ${message}`);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  } finally {
534
  setLoadingKey(null);
535
  }
536
  },
537
+ [commands, examples],
538
  );
539
 
540
  return React.createElement(
Dockerfile CHANGED
@@ -20,7 +20,8 @@ WORKDIR $HOME/app
20
 
21
  RUN pip install --upgrade pip
22
 
23
- ARG HYPERVIEW_VERSION=0.6.0
 
24
 
25
  # Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
26
  RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
@@ -33,6 +34,7 @@ import hyperview as hv
33
  print("hyperview", hv.__version__, inspect.signature(hv.launch))
34
  PY
35
  RUN pip install \
 
36
  "datasets>=4.5.0" \
37
  "Pillow>=12.0.0" \
38
  "timm>=1.0.0" \
 
20
 
21
  RUN pip install --upgrade pip
22
 
23
+ ARG HYPERVIEW_VERSION=0.6.2
24
+ ARG HYPER_MODELS_VERSION=0.3.0
25
 
26
  # Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
27
  RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
 
34
  print("hyperview", hv.__version__, inspect.signature(hv.launch))
35
  PY
36
  RUN pip install \
37
+ "hyper-models[ml]==${HYPER_MODELS_VERSION}" \
38
  "datasets>=4.5.0" \
39
  "Pillow>=12.0.0" \
40
  "timm>=1.0.0" \
README.md CHANGED
@@ -14,7 +14,7 @@ This Space builds a balanced subset of the VisA industrial visual anomaly
14
  dataset and opens HyperView with two side-by-side embedding spaces:
15
 
16
  - CLIP ViT-B/32 in a Euclidean 2D layout
17
- - Hyper3-CLIP `hyper3labs/hyper3-clip-v0.5` in a Poincare 2D layout
18
 
19
  The workflow is inspection reference retrieval: given a production-line
20
  inspection image, retrieve the right normal references for the same SKU or
@@ -59,8 +59,8 @@ VISA_SAMPLES_PER_CATEGORY=12 HYPERVIEW_PORT=6265 \
59
  uv run python hyperview-spaces/spaces/manufacturing-visa-reference-clip-hyper3clip/demo.py
60
  ```
61
 
62
- Hyper3-CLIP weights are loaded from the gated
63
- `hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs an
64
- `HF_TOKEN` secret with access to that model. If unavailable, the Space can start
65
- with a clearly labeled CLIP fallback unless `HYPERVIEW_ALLOW_CANDIDATE_FALLBACK=0`
66
- is set.
 
14
  dataset and opens HyperView with two side-by-side embedding spaces:
15
 
16
  - CLIP ViT-B/32 in a Euclidean 2D layout
17
+ - Hyper3-CLIP `hyper3-clip-v0.5` from `hyper-models` in a Poincare 2D layout
18
 
19
  The workflow is inspection reference retrieval: given a production-line
20
  inspection image, retrieve the right normal references for the same SKU or
 
59
  uv run python hyperview-spaces/spaces/manufacturing-visa-reference-clip-hyper3clip/demo.py
60
  ```
61
 
62
+ Hyper3-CLIP weights are loaded through the `hyper-models` catalog entry for the
63
+ gated `hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs
64
+ an `HF_TOKEN` secret with access to that model. If unavailable, the Space can
65
+ start with a clearly labeled CLIP fallback unless
66
+ `HYPERVIEW_ALLOW_CANDIDATE_FALLBACK=0` is set.
demo.py CHANGED
@@ -18,7 +18,6 @@ from PIL import Image, ImageOps
18
 
19
  import hyperview as hv
20
 
21
-
22
  SPACE_DIR = Path(__file__).resolve().parent
23
  SPACE_HOST = os.environ.get("HYPERVIEW_HOST", "127.0.0.1")
24
  SPACE_PORT = int(os.environ.get("HYPERVIEW_PORT", "6265"))
@@ -29,6 +28,11 @@ EXTENSION_DIR = SPACE_DIR / ".hyperview" / "extensions" / "manufacturing-readout
29
  SAMPLES_PER_CATEGORY = int(os.environ.get("VISA_SAMPLES_PER_CATEGORY", "4"))
30
  TRAIN_FRACTION = float(os.environ.get("VISA_TRAIN_FRACTION", "0.5"))
31
  IMAGE_MAX_SIZE = (640, 640)
 
 
 
 
 
32
  ALLOW_CANDIDATE_FALLBACK = os.environ.get("HYPERVIEW_ALLOW_CANDIDATE_FALLBACK", "1").lower() in {
33
  "1",
34
  "true",
@@ -90,8 +94,8 @@ MODEL_SPECS = [
90
  "key": "candidate",
91
  "display_name": os.environ.get("VISA_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
92
  "button_label": os.environ.get("VISA_CANDIDATE_BUTTON_LABEL", "Show Hyper3 neighbors"),
93
- "provider": os.environ.get("VISA_CANDIDATE_PROVIDER", "hyper3-clip"),
94
- "model": os.environ.get("VISA_CANDIDATE_MODEL", "hyper3labs/hyper3-clip-v0.5"),
95
  "layout": os.environ.get("VISA_CANDIDATE_LAYOUT", "poincare:2d"),
96
  "geometry": os.environ.get("VISA_CANDIDATE_GEOMETRY", "poincare"),
97
  "layout_dimension": int(os.environ.get("VISA_CANDIDATE_LAYOUT_DIMENSION", "2")),
@@ -219,6 +223,7 @@ def add_visa_samples(dataset: hv.Dataset) -> None:
219
  media_dir = media_root()
220
  added = 0
221
  updated = 0
 
222
  for record in select_visa_records():
223
  sample_id = safe_sample_id(record["category"], record["split_name"], record["row_index"], record["defect_label"])
224
  destination = Path(record["local_path"]) if record.get("local_path") else media_dir / f"{sample_id}.jpg"
@@ -234,12 +239,17 @@ def add_visa_samples(dataset: hv.Dataset) -> None:
234
  "source_dataset": "BrachioLab/visa",
235
  }
236
  existed = sample_id in existing_ids
 
 
 
237
  dataset.add_image(str(destination), label=record["category"], metadata=metadata, sample_id=sample_id)
238
  if existed:
239
  updated += 1
240
  else:
241
  added += 1
242
  existing_ids.add(sample_id)
 
 
243
  print(f"Prepared VisA samples ({added} added, {updated} updated).", flush=True)
244
 
245
 
@@ -262,6 +272,7 @@ def ensure_layouts(dataset: hv.Dataset) -> dict[str, str]:
262
  )
263
  print(warning, flush=True)
264
  RUNTIME_WARNINGS.append(warning)
 
265
  spec.update(
266
  {
267
  "display_name": "Hyper3-CLIP unavailable (CLIP fallback)",
@@ -270,21 +281,22 @@ def ensure_layouts(dataset: hv.Dataset) -> dict[str, str]:
270
  "layout_dimension": MODEL_SPECS[0]["layout_dimension"],
271
  "panel_title": "Hyper3-CLIP unavailable - showing CLIP fallback",
272
  "fallback": True,
273
- "space_key": MODEL_SPECS[0].get("space_key"),
274
  }
275
  )
276
- layouts[spec["key"]] = layouts["clip"]
277
  continue
278
  raise
279
- spec["space_key"] = space_key
280
  print(f"Ensuring {spec['display_name']} layout...", flush=True)
281
- layouts[spec["key"]] = dataset.compute_visualization(
282
  space_key=space_key,
283
  layout=spec["layout"],
284
  n_neighbors=20,
285
  min_dist=0.08,
286
  metric=spec["metric"],
287
  )
 
 
288
  return layouts
289
 
290
 
@@ -305,24 +317,19 @@ def model_panel_props(layouts: dict[str, str]) -> list[dict[str, Any]]:
305
  "displayName": spec["display_name"],
306
  "buttonLabel": spec["button_label"],
307
  "layoutKey": layout_key,
308
- "spaceKey": spec.get("space_key") or space_key_from_layout(layout_key),
309
  }
310
  )
311
  return props
312
 
313
 
314
- def space_key_from_layout(layout_key: str) -> str:
315
- return layout_key.split("__euclidean_umap", 1)[0].split("__poincare_umap", 1)[0]
316
-
317
-
318
  def reference_summary(dataset: hv.Dataset, sample_id: str, model_key: str) -> dict[str, Any]:
319
  spec = next((item for item in MODEL_SPECS if item["key"] == model_key), None)
320
- if spec is None or spec.get("space_key") is None:
321
  return {}
322
  query = dataset[sample_id]
323
  query_sku = query.metadata.get("sku")
324
  query_family = query.metadata.get("product_family")
325
- neighbors = dataset.find_similar(sample_id, k=10, space_key=str(spec["space_key"]))
326
  sku_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("sku") == query_sku)
327
  family_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("product_family") == query_family)
328
  normal_refs = sum(1 for sample, _distance in neighbors if sample.metadata.get("workflow_role") == "normal_reference")
@@ -431,17 +438,6 @@ def category_strength_rows(dataset: hv.Dataset) -> list[dict[str, str]]:
431
  return sorted(rows, key=lambda row: float(row["delta"]), reverse=True)[:3]
432
 
433
 
434
- def register_hyper3_clip_provider() -> None:
435
- from hyperview.runtime import ProviderRegistry
436
-
437
- ProviderRegistry().register_python(
438
- "hyper3-clip",
439
- "hyper3_clip_provider:Hyper3ClipEmbeddings",
440
- description="Hyper3-CLIP v0.5 image embeddings from hyper3labs/hyper3-clip-v0.5",
441
- overwrite=True,
442
- )
443
-
444
-
445
  def build_demo_view(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.ui.View:
446
  scatter_panels = [
447
  hv.ui.Scatter(
@@ -460,14 +456,20 @@ def build_demo_view(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.ui.View:
460
  extension="manufacturing-readout",
461
  panel="manufacturing-comparison",
462
  position="right",
 
463
  props={
464
- "workspaceId": WORKSPACE_ID,
465
  "models": model_panel_props(layouts),
466
  "examples": build_examples(dataset),
467
  "strengthRows": category_strength_rows(dataset),
468
  "warnings": RUNTIME_WARNINGS,
469
  },
470
  ),
 
 
 
 
 
 
471
  )
472
 
473
 
@@ -491,7 +493,6 @@ def launch_demo(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.Session:
491
 
492
 
493
  def main() -> None:
494
- register_hyper3_clip_provider()
495
  dataset, layouts = build_dataset()
496
  print("Layouts:", flush=True)
497
  for spec in MODEL_SPECS:
 
18
 
19
  import hyperview as hv
20
 
 
21
  SPACE_DIR = Path(__file__).resolve().parent
22
  SPACE_HOST = os.environ.get("HYPERVIEW_HOST", "127.0.0.1")
23
  SPACE_PORT = int(os.environ.get("HYPERVIEW_PORT", "6265"))
 
28
  SAMPLES_PER_CATEGORY = int(os.environ.get("VISA_SAMPLES_PER_CATEGORY", "4"))
29
  TRAIN_FRACTION = float(os.environ.get("VISA_TRAIN_FRACTION", "0.5"))
30
  IMAGE_MAX_SIZE = (640, 640)
31
+ FORCE_SAMPLE_REFRESH = os.environ.get("HYPERVIEW_VISA_FORCE_REFRESH", "").lower() in {
32
+ "1",
33
+ "true",
34
+ "yes",
35
+ }
36
  ALLOW_CANDIDATE_FALLBACK = os.environ.get("HYPERVIEW_ALLOW_CANDIDATE_FALLBACK", "1").lower() in {
37
  "1",
38
  "true",
 
94
  "key": "candidate",
95
  "display_name": os.environ.get("VISA_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
96
  "button_label": os.environ.get("VISA_CANDIDATE_BUTTON_LABEL", "Show Hyper3 neighbors"),
97
+ "provider": os.environ.get("VISA_CANDIDATE_PROVIDER", "hyper-models"),
98
+ "model": os.environ.get("VISA_CANDIDATE_MODEL", "hyper3-clip-v0.5"),
99
  "layout": os.environ.get("VISA_CANDIDATE_LAYOUT", "poincare:2d"),
100
  "geometry": os.environ.get("VISA_CANDIDATE_GEOMETRY", "poincare"),
101
  "layout_dimension": int(os.environ.get("VISA_CANDIDATE_LAYOUT_DIMENSION", "2")),
 
223
  media_dir = media_root()
224
  added = 0
225
  updated = 0
226
+ skipped = 0
227
  for record in select_visa_records():
228
  sample_id = safe_sample_id(record["category"], record["split_name"], record["row_index"], record["defect_label"])
229
  destination = Path(record["local_path"]) if record.get("local_path") else media_dir / f"{sample_id}.jpg"
 
239
  "source_dataset": "BrachioLab/visa",
240
  }
241
  existed = sample_id in existing_ids
242
+ if existed and not FORCE_SAMPLE_REFRESH:
243
+ skipped += 1
244
+ continue
245
  dataset.add_image(str(destination), label=record["category"], metadata=metadata, sample_id=sample_id)
246
  if existed:
247
  updated += 1
248
  else:
249
  added += 1
250
  existing_ids.add(sample_id)
251
+ if skipped:
252
+ print(f"Skipped {skipped} existing VisA sample rows.", flush=True)
253
  print(f"Prepared VisA samples ({added} added, {updated} updated).", flush=True)
254
 
255
 
 
272
  )
273
  print(warning, flush=True)
274
  RUNTIME_WARNINGS.append(warning)
275
+ fallback_layout_key = layouts["clip"]
276
  spec.update(
277
  {
278
  "display_name": "Hyper3-CLIP unavailable (CLIP fallback)",
 
281
  "layout_dimension": MODEL_SPECS[0]["layout_dimension"],
282
  "panel_title": "Hyper3-CLIP unavailable - showing CLIP fallback",
283
  "fallback": True,
284
+ "layout_key": fallback_layout_key,
285
  }
286
  )
287
+ layouts[spec["key"]] = fallback_layout_key
288
  continue
289
  raise
 
290
  print(f"Ensuring {spec['display_name']} layout...", flush=True)
291
+ layout_key = dataset.compute_visualization(
292
  space_key=space_key,
293
  layout=spec["layout"],
294
  n_neighbors=20,
295
  min_dist=0.08,
296
  metric=spec["metric"],
297
  )
298
+ spec["layout_key"] = layout_key
299
+ layouts[spec["key"]] = layout_key
300
  return layouts
301
 
302
 
 
317
  "displayName": spec["display_name"],
318
  "buttonLabel": spec["button_label"],
319
  "layoutKey": layout_key,
 
320
  }
321
  )
322
  return props
323
 
324
 
 
 
 
 
325
  def reference_summary(dataset: hv.Dataset, sample_id: str, model_key: str) -> dict[str, Any]:
326
  spec = next((item for item in MODEL_SPECS if item["key"] == model_key), None)
327
+ if spec is None or spec.get("layout_key") is None:
328
  return {}
329
  query = dataset[sample_id]
330
  query_sku = query.metadata.get("sku")
331
  query_family = query.metadata.get("product_family")
332
+ neighbors = dataset.find_similar(sample_id, k=10, layout_key=str(spec["layout_key"]))
333
  sku_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("sku") == query_sku)
334
  family_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("product_family") == query_family)
335
  normal_refs = sum(1 for sample, _distance in neighbors if sample.metadata.get("workflow_role") == "normal_reference")
 
438
  return sorted(rows, key=lambda row: float(row["delta"]), reverse=True)[:3]
439
 
440
 
 
 
 
 
 
 
 
 
 
 
 
441
  def build_demo_view(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.ui.View:
442
  scatter_panels = [
443
  hv.ui.Scatter(
 
456
  extension="manufacturing-readout",
457
  panel="manufacturing-comparison",
458
  position="right",
459
+ layout=hv.ui.PanelLayout(width=340, min_width=300),
460
  props={
 
461
  "models": model_panel_props(layouts),
462
  "examples": build_examples(dataset),
463
  "strengthRows": category_strength_rows(dataset),
464
  "warnings": RUNTIME_WARNINGS,
465
  },
466
  ),
467
+ hv.ui.Samples(
468
+ id="manufacturing-neighbors",
469
+ title="Step 2 - Retrieved References",
470
+ position="bottom",
471
+ layout=hv.ui.PanelLayout(height=220, min_height=180),
472
+ ),
473
  )
474
 
475
 
 
493
 
494
 
495
  def main() -> None:
 
496
  dataset, layouts = build_dataset()
497
  print("Layouts:", flush=True)
498
  for spec in MODEL_SPECS:
hyper3_clip/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
2
-
3
- __all__ = ["Hyper3CLIP"]
 
 
 
 
hyper3_clip/data/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- from hyper3_clip.data.collators import collate_grounded
2
- from hyper3_clip.data.grit_webdataset import ProcessedGritDataset
3
- from hyper3_clip.data.manifest_dataset import GroundedManifestDataset
4
- from hyper3_clip.data.mixed_dataset import MixedGroundedIterableDataset
5
- from hyper3_clip.data.types import GroundedParent, GroundedRecord
6
-
7
- __all__ = [
8
- "GroundedManifestDataset",
9
- "GroundedParent",
10
- "GroundedRecord",
11
- "MixedGroundedIterableDataset",
12
- "ProcessedGritDataset",
13
- "collate_grounded",
14
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/data/collators.py DELETED
@@ -1,209 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import re
4
- from typing import Any
5
-
6
- import torch
7
-
8
-
9
- def _attention_mask(tokens) -> torch.Tensor:
10
- if "attention_mask" in tokens:
11
- return tokens["attention_mask"]
12
- return torch.ones_like(tokens["input_ids"])
13
-
14
-
15
- def collate_grounded(
16
- batch: list[dict[str, Any]],
17
- tokenizer,
18
- max_text_length: int,
19
- *,
20
- beta_clip_queries: bool = False,
21
- beta_clip_max_sentences: int = 5,
22
- beta_clip_max_phrases: int = 30,
23
- beta_clip_max_queries_per_image: int | None = None,
24
- beta_clip_use_part_texts: bool = True,
25
- ) -> dict[str, torch.Tensor]:
26
- images = torch.stack([item["image"] for item in batch])
27
- captions = [item["caption"] for item in batch]
28
- part_image_rows: list[torch.Tensor] = []
29
- part_texts: list[str] = []
30
- part_owner: list[int] = []
31
-
32
- for batch_index, item in enumerate(batch):
33
- for part_index, part_image in enumerate(item["part_images"]):
34
- part_image_rows.append(part_image)
35
- part_texts.append(item["part_texts"][part_index])
36
- part_owner.append(batch_index)
37
-
38
- text = tokenizer(captions, padding=True, truncation=True, max_length=max_text_length, return_tensors="pt")
39
- text_attention_mask = _attention_mask(text)
40
- if part_image_rows:
41
- part_images = torch.stack(part_image_rows)
42
- part_text = tokenizer(part_texts, padding=True, truncation=True, max_length=max_text_length, return_tensors="pt")
43
- part_text_input_ids = part_text["input_ids"]
44
- part_text_attention_mask = _attention_mask(part_text)
45
- else:
46
- part_images = images.new_zeros((0, *images.shape[1:]))
47
- empty_text_shape = (0, text["input_ids"].shape[1])
48
- part_text_input_ids = text["input_ids"].new_zeros(empty_text_shape)
49
- part_text_attention_mask = text_attention_mask.new_zeros(empty_text_shape)
50
-
51
- collated = {
52
- "image": images,
53
- "part_images": part_images,
54
- "part_owner": torch.tensor(part_owner, dtype=torch.long),
55
- "text_input_ids": text["input_ids"],
56
- "text_attention_mask": text_attention_mask,
57
- "part_text_input_ids": part_text_input_ids,
58
- "part_text_attention_mask": part_text_attention_mask,
59
- }
60
-
61
- if beta_clip_queries:
62
- query_texts: list[str] = []
63
- query_owner: list[int] = []
64
- query_type: list[int] = []
65
- query_parent: list[int] = []
66
- query_weight: list[float] = []
67
- query_source_part: list[int] = []
68
- part_offsets = []
69
- cursor = 0
70
- for item in batch:
71
- part_offsets.append(cursor)
72
- cursor += len(item["part_images"])
73
- for batch_index, item in enumerate(batch):
74
- image_queries = _beta_clip_query_items_for_item(
75
- caption=item["caption"],
76
- part_texts=item["part_texts"],
77
- max_sentences=beta_clip_max_sentences,
78
- max_phrases=beta_clip_max_phrases,
79
- max_queries=beta_clip_max_queries_per_image,
80
- use_part_texts=beta_clip_use_part_texts,
81
- )
82
- query_offset = len(query_texts)
83
- for query in image_queries:
84
- query_texts.append(query["text"])
85
- query_owner.append(batch_index)
86
- query_type.append(query["type"])
87
- local_parent = query["parent"]
88
- query_parent.append(-1 if local_parent < 0 else query_offset + local_parent)
89
- query_weight.append(query["weight"])
90
- local_part = query["source_part"]
91
- query_source_part.append(-1 if local_part < 0 else part_offsets[batch_index] + local_part)
92
- query_tokens = tokenizer(query_texts, padding=True, truncation=True, max_length=max_text_length, return_tensors="pt")
93
- collated.update(
94
- {
95
- "beta_query_input_ids": query_tokens["input_ids"],
96
- "beta_query_attention_mask": _attention_mask(query_tokens),
97
- "beta_query_owner": torch.tensor(query_owner, dtype=torch.long),
98
- "beta_query_type": torch.tensor(query_type, dtype=torch.long),
99
- "beta_query_parent": torch.tensor(query_parent, dtype=torch.long),
100
- "beta_query_weight": torch.tensor(query_weight, dtype=torch.float32),
101
- "beta_query_source_part": torch.tensor(query_source_part, dtype=torch.long),
102
- }
103
- )
104
-
105
- return collated
106
-
107
-
108
- def _beta_clip_queries_for_item(
109
- *,
110
- caption: str,
111
- part_texts: list[str],
112
- max_sentences: int,
113
- max_phrases: int,
114
- max_queries: int | None,
115
- use_part_texts: bool,
116
- ) -> list[str]:
117
- return [
118
- query["text"]
119
- for query in _beta_clip_query_items_for_item(
120
- caption=caption,
121
- part_texts=part_texts,
122
- max_sentences=max_sentences,
123
- max_phrases=max_phrases,
124
- max_queries=max_queries,
125
- use_part_texts=use_part_texts,
126
- )
127
- ]
128
-
129
-
130
- def _beta_clip_query_items_for_item(
131
- *,
132
- caption: str,
133
- part_texts: list[str],
134
- max_sentences: int,
135
- max_phrases: int,
136
- max_queries: int | None,
137
- use_part_texts: bool,
138
- ) -> list[dict[str, str | int | float]]:
139
- queries: list[str] = []
140
- query_items: list[dict[str, str | int | float]] = []
141
- seen: set[str] = set()
142
-
143
- def add_query(text: str, *, query_type: int, parent: int = 0, weight: float = 1.0, source_part: int = -1) -> int:
144
- normalized = " ".join(str(text).strip().split())
145
- key = normalized.casefold()
146
- if len(normalized) >= 3 and key not in seen:
147
- seen.add(key)
148
- queries.append(normalized)
149
- query_items.append(
150
- {
151
- "text": normalized,
152
- "type": query_type,
153
- "parent": parent,
154
- "weight": weight,
155
- "source_part": source_part,
156
- }
157
- )
158
- return len(queries) - 1
159
- return -1
160
-
161
- caption_index = add_query(caption, query_type=0, parent=-1, weight=0.0)
162
- if caption_index < 0:
163
- caption_index = 0
164
- sentence_indices: list[int] = []
165
- for sentence in _split_sentences(caption)[: max(0, max_sentences)]:
166
- sentence_index = add_query(sentence, query_type=1, parent=caption_index, weight=1.0)
167
- if sentence_index >= 0:
168
- sentence_indices.append(sentence_index)
169
- phrase_parent = sentence_indices[0] if sentence_indices else caption_index
170
-
171
- phrase_count = 0
172
- if use_part_texts:
173
- for part_index, part_text in enumerate(part_texts):
174
- if phrase_count >= max_phrases:
175
- break
176
- before = len(queries)
177
- add_query(part_text, query_type=2, parent=caption_index, weight=0.75, source_part=part_index)
178
- phrase_count += int(len(queries) > before)
179
-
180
- if phrase_count < max_phrases:
181
- for phrase in _extract_lightweight_phrases(caption):
182
- if phrase_count >= max_phrases:
183
- break
184
- before = len(queries)
185
- add_query(phrase, query_type=3, parent=phrase_parent, weight=0.5)
186
- phrase_count += int(len(queries) > before)
187
-
188
- if max_queries is not None:
189
- return query_items[: max(1, int(max_queries))]
190
- return query_items
191
-
192
-
193
- def _split_sentences(text: str) -> list[str]:
194
- return [part.strip() for part in re.split(r"(?<=[.!?;])\s+|\n+", text) if part.strip()]
195
-
196
-
197
- def _extract_lightweight_phrases(text: str) -> list[str]:
198
- chunks = re.split(r"[,;:()]|\s+(?:and|with|near|beside|behind|under|above|around|next to)\s+", text, flags=re.I)
199
- phrases: list[str] = []
200
- for chunk in chunks:
201
- words = re.findall(r"[A-Za-z0-9]+(?:[-'][A-Za-z0-9]+)?", chunk)
202
- if 2 <= len(words) <= 8:
203
- phrases.append(" ".join(words))
204
- elif len(words) > 8:
205
- for start in range(0, len(words) - 1, 4):
206
- phrase = " ".join(words[start : start + 6])
207
- if len(phrase.split()) >= 2:
208
- phrases.append(phrase)
209
- return phrases
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/data/grit_cleaning.py DELETED
@@ -1,554 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
- import re
5
- import unicodedata
6
- from dataclasses import dataclass
7
- from typing import Any
8
-
9
- import numpy as np
10
- from PIL import Image
11
-
12
-
13
- SPACE_RE = re.compile(r"\s+")
14
- URL_RE = re.compile(r"(https?://|www\.|\.com\b|\.net\b|\.org\b)", re.IGNORECASE)
15
- EMAIL_RE = re.compile(r"\b[\w.+-]+@[\w.-]+\.\w+\b")
16
- HTML_RE = re.compile(r"<[^>]+>")
17
- TOKEN_RE = re.compile(r"[a-z0-9]+(?:[-'][a-z0-9]+)?")
18
-
19
- LEADING_DETERMINERS = {
20
- "a",
21
- "an",
22
- "the",
23
- "this",
24
- "that",
25
- "these",
26
- "those",
27
- "his",
28
- "her",
29
- "its",
30
- "their",
31
- "my",
32
- "our",
33
- "your",
34
- }
35
-
36
- QUANTITY_WORDS = {
37
- "one",
38
- "two",
39
- "three",
40
- "four",
41
- "five",
42
- "six",
43
- "many",
44
- "several",
45
- "some",
46
- "few",
47
- "group",
48
- "pair",
49
- }
50
-
51
- VISUAL_MODIFIERS = {
52
- "big",
53
- "small",
54
- "large",
55
- "little",
56
- "old",
57
- "young",
58
- "new",
59
- "red",
60
- "blue",
61
- "green",
62
- "yellow",
63
- "white",
64
- "black",
65
- "brown",
66
- "gray",
67
- "grey",
68
- "orange",
69
- "pink",
70
- "purple",
71
- "colorful",
72
- "colourful",
73
- "wooden",
74
- "metal",
75
- "plastic",
76
- "striped",
77
- }
78
-
79
- NON_VISUAL_HEADS = {
80
- "background",
81
- "foreground",
82
- "caption",
83
- "copyright",
84
- "credit",
85
- "item",
86
- "edge",
87
- "image",
88
- "left",
89
- "logo",
90
- "method",
91
- "middle",
92
- "number",
93
- "photo",
94
- "photograph",
95
- "picture",
96
- "place",
97
- "right",
98
- "scene",
99
- "side",
100
- "statement",
101
- "stock",
102
- "text",
103
- "thing",
104
- "view",
105
- "watermark",
106
- }
107
-
108
- NON_VISUAL_PHRASES = {
109
- "available at",
110
- "all rights reserved",
111
- "click here",
112
- "copyright",
113
- "getty images",
114
- "istock",
115
- "shutterstock",
116
- "stock photo",
117
- }
118
-
119
- ACTION_SPLITS = (
120
- " standing ",
121
- " sitting ",
122
- " lying ",
123
- " walking ",
124
- " running ",
125
- " flying ",
126
- " eating ",
127
- " holding ",
128
- " wearing ",
129
- " playing ",
130
- )
131
-
132
- PREPOSITION_SPLITS = (
133
- " next to ",
134
- " in front of ",
135
- " on top of ",
136
- " inside ",
137
- " outside ",
138
- " with ",
139
- " without ",
140
- " near ",
141
- " beside ",
142
- " behind ",
143
- " under ",
144
- " over ",
145
- " from ",
146
- " into ",
147
- " across ",
148
- " around ",
149
- " at ",
150
- " on ",
151
- " in ",
152
- " of ",
153
- )
154
-
155
- CANONICAL_REWRITES = {
156
- "aeroplane": "airplane",
157
- "aircraft": "airplane",
158
- "bike": "bicycle",
159
- "cell phone": "phone",
160
- "mobile phone": "phone",
161
- "motorbike": "motorcycle",
162
- "plant pot": "potted plant",
163
- "tv": "television",
164
- }
165
-
166
- TOKEN_SYNONYMS = {
167
- "airplane": {"airplane", "aeroplane", "aircraft", "plane"},
168
- "bicycle": {"bicycle", "bike"},
169
- "motorcycle": {"motorcycle", "motorbike"},
170
- "person": {"person", "people", "man", "woman", "boy", "girl", "teenager", "teenagers"},
171
- "people": {"person", "people", "man", "woman", "men", "women", "children", "teenager", "teenagers"},
172
- "phone": {"phone", "cell", "mobile", "telephone"},
173
- "television": {"television", "tv"},
174
- }
175
-
176
- HUMAN_GROUP_WORDS = {
177
- "adults",
178
- "boys",
179
- "children",
180
- "crowd",
181
- "girls",
182
- "kids",
183
- "men",
184
- "people",
185
- "teenagers",
186
- "teens",
187
- "women",
188
- }
189
-
190
- HUMAN_SINGULAR_WORDS = {
191
- "adult",
192
- "baby",
193
- "boy",
194
- "child",
195
- "girl",
196
- "kid",
197
- "man",
198
- "person",
199
- "teenager",
200
- "woman",
201
- }
202
-
203
- HUMAN_ROLE_WORDS = {
204
- "actor",
205
- "actress",
206
- "artist",
207
- "athlete",
208
- "boss",
209
- "coach",
210
- "doctor",
211
- "lawyer",
212
- "manager",
213
- "minister",
214
- "musician",
215
- "player",
216
- "politician",
217
- "president",
218
- "singer",
219
- "teacher",
220
- }
221
-
222
- DEFAULT_HYPERNYMS = {
223
- "airplane": ("aircraft", "vehicle"),
224
- "apple": ("fruit", "food"),
225
- "backpack": ("bag", "accessory"),
226
- "baseball bat": ("bat", "sports equipment"),
227
- "bear": ("mammal", "animal"),
228
- "bicycle": ("vehicle",),
229
- "bird": ("animal",),
230
- "boat": ("vehicle",),
231
- "bottle": ("container",),
232
- "bus": ("vehicle",),
233
- "car": ("vehicle",),
234
- "cat": ("mammal", "animal"),
235
- "chair": ("furniture",),
236
- "cup": ("container",),
237
- "dog": ("mammal", "animal"),
238
- "flower": ("plant",),
239
- "fork": ("utensil",),
240
- "horse": ("mammal", "animal"),
241
- "knife": ("utensil",),
242
- "lamp": ("light", "furniture"),
243
- "laptop": ("computer", "electronic device"),
244
- "person": ("human", "animal"),
245
- "phone": ("electronic device",),
246
- "potted plant": ("plant",),
247
- "shirt": ("clothing",),
248
- "shoe": ("footwear", "clothing"),
249
- "skis": ("sports equipment",),
250
- "spoon": ("utensil",),
251
- "sports ball": ("ball", "sports equipment"),
252
- "table": ("furniture",),
253
- "television": ("electronic device",),
254
- "train": ("vehicle",),
255
- "tree": ("plant",),
256
- "truck": ("vehicle",),
257
- }
258
-
259
-
260
- @dataclass(frozen=True)
261
- class ImageQuality:
262
- width: int
263
- height: int
264
- brightness: float
265
- contrast: float
266
- entropy: float
267
- black_border_fraction: float
268
-
269
-
270
- @dataclass(frozen=True)
271
- class ParentCleanDecision:
272
- original_text: str
273
- canonical_text: str
274
- keep: bool
275
- quality_score: float
276
- reasons: tuple[str, ...]
277
- hypernyms: tuple[str, ...]
278
- image_quality: ImageQuality | None = None
279
-
280
-
281
- def clean_parent(
282
- parent_text: str,
283
- caption: str = "",
284
- parent_image: Image.Image | None = None,
285
- min_score: float = 0.45,
286
- hypernym_map: dict[str, tuple[str, ...]] | None = None,
287
- ) -> ParentCleanDecision:
288
- canonical = canonicalize_parent_text(parent_text)
289
- reasons: list[str] = []
290
- fatal = False
291
-
292
- if not canonical:
293
- reasons.append("empty_after_canonicalization")
294
- fatal = True
295
- if looks_like_boilerplate(parent_text):
296
- reasons.append("boilerplate_or_url")
297
- fatal = True
298
- if canonical and is_non_visual_parent(canonical):
299
- reasons.append("non_visual_parent")
300
- fatal = True
301
- if canonical and len(canonical.split()) > 6:
302
- reasons.append("too_long_for_clean_parent")
303
- if canonical and caption_duplicates_parent(caption, canonical):
304
- reasons.append("duplicates_caption")
305
- fatal = True
306
- if canonical and not caption_mentions_parent(caption, canonical):
307
- reasons.append("caption_does_not_mention_parent")
308
-
309
- image_quality = image_quality_stats(parent_image) if parent_image is not None else None
310
- if image_quality is not None:
311
- if image_quality.entropy < 1.0 or image_quality.contrast < 3.0:
312
- reasons.append("low_information_crop")
313
- if image_quality.black_border_fraction > 0.65:
314
- reasons.append("mostly_black_border")
315
- if (
316
- "caption_does_not_mention_parent" in reasons
317
- and "low_information_crop" in reasons
318
- and "mostly_black_border" in reasons
319
- ):
320
- reasons.append("text_slide_or_bad_crop")
321
- fatal = True
322
-
323
- score = parent_quality_score(canonical, reasons, image_quality)
324
- hmap = DEFAULT_HYPERNYMS if hypernym_map is None else hypernym_map
325
- hypernyms = tuple(hmap.get(canonical, ()))
326
- keep = not fatal and score >= min_score
327
- return ParentCleanDecision(
328
- original_text=parent_text,
329
- canonical_text=canonical,
330
- keep=keep,
331
- quality_score=score,
332
- reasons=tuple(reasons),
333
- hypernyms=hypernyms,
334
- image_quality=image_quality,
335
- )
336
-
337
-
338
- def canonicalize_parent_text(text: str) -> str:
339
- text = normalize_text(text)
340
- if not text:
341
- return ""
342
- text = strip_boilerplate_tail(text)
343
- for marker in ACTION_SPLITS:
344
- if marker in text:
345
- text = text.split(marker, maxsplit=1)[0].strip()
346
- break
347
- human = canonicalize_human_text(text)
348
- if human:
349
- return human
350
- for marker in PREPOSITION_SPLITS:
351
- if marker in text:
352
- text = text.split(marker, maxsplit=1)[0].strip()
353
- break
354
- tokens = TOKEN_RE.findall(text)
355
- while tokens and (tokens[0] in LEADING_DETERMINERS or tokens[0] in QUANTITY_WORDS):
356
- tokens.pop(0)
357
- if len(tokens) > 2:
358
- while tokens and tokens[0] in VISUAL_MODIFIERS:
359
- tokens.pop(0)
360
- candidate = " ".join(tokens).strip()
361
- candidate = CANONICAL_REWRITES.get(candidate, candidate)
362
- if candidate.endswith("s") and candidate[:-1] in DEFAULT_HYPERNYMS:
363
- candidate = candidate[:-1]
364
- return candidate
365
-
366
-
367
- def canonicalize_human_text(text: str) -> str:
368
- tokens = TOKEN_RE.findall(text)
369
- if not tokens:
370
- return ""
371
- token_set = set(tokens)
372
- if token_set.intersection(HUMAN_GROUP_WORDS):
373
- return "people"
374
- for word in ("baby", "woman", "man", "girl", "boy", "child", "teenager", "person"):
375
- if word in token_set:
376
- return word
377
- if token_set.intersection(HUMAN_ROLE_WORDS):
378
- return "person"
379
- return ""
380
-
381
-
382
- def normalize_text(text: str) -> str:
383
- text = unicodedata.normalize("NFKC", str(text))
384
- text = HTML_RE.sub(" ", text)
385
- text = text.replace("_", " ").replace("/", " ")
386
- text = text.strip().lower()
387
- text = text.strip(" \t\r\n\"'.,;:!?()[]{}")
388
- return SPACE_RE.sub(" ", text)
389
-
390
-
391
- def strip_boilerplate_tail(text: str) -> str:
392
- for marker in (" - available at ", " available at ", " | ", " © ", " copyright "):
393
- if marker in text:
394
- text = text.split(marker, maxsplit=1)[0]
395
- return text.strip()
396
-
397
-
398
- def looks_like_boilerplate(text: str) -> bool:
399
- normalized = normalize_text(text)
400
- if URL_RE.search(normalized) or EMAIL_RE.search(normalized):
401
- return True
402
- return any(phrase in normalized for phrase in NON_VISUAL_PHRASES)
403
-
404
-
405
- def is_non_visual_parent(canonical: str) -> bool:
406
- tokens = canonical.split()
407
- if not tokens:
408
- return True
409
- if canonical in NON_VISUAL_HEADS:
410
- return True
411
- if tokens[-1] in NON_VISUAL_HEADS:
412
- return True
413
- if all(token.isdigit() for token in tokens):
414
- return True
415
- return False
416
-
417
-
418
- def caption_mentions_parent(caption: str, canonical_parent: str) -> bool:
419
- if not caption or not canonical_parent:
420
- return True
421
- caption_tokens = set(TOKEN_RE.findall(normalize_text(caption)))
422
- parent_tokens = TOKEN_RE.findall(canonical_parent)
423
- if not parent_tokens:
424
- return False
425
- for token in parent_tokens:
426
- synonyms = TOKEN_SYNONYMS.get(token, {token})
427
- if not caption_tokens.intersection(synonyms):
428
- return False
429
- return True
430
-
431
-
432
- def caption_duplicates_parent(caption: str, canonical_parent: str) -> bool:
433
- if not caption or not canonical_parent:
434
- return False
435
- caption_tokens = TOKEN_RE.findall(normalize_text(caption))
436
- parent_tokens = TOKEN_RE.findall(canonical_parent)
437
- if len(parent_tokens) < 6 or not caption_tokens:
438
- return False
439
- caption_set = set(caption_tokens)
440
- parent_set = set(parent_tokens)
441
- overlap = len(caption_set.intersection(parent_set))
442
- return overlap / max(len(parent_set), 1) >= 0.85 and overlap / max(len(caption_set), 1) >= 0.65
443
-
444
-
445
- def image_quality_stats(image: Image.Image | None) -> ImageQuality | None:
446
- if image is None:
447
- return None
448
- rgb = image.convert("RGB")
449
- width, height = rgb.size
450
- gray = np.asarray(rgb.convert("L"), dtype=np.float32)
451
- brightness = float(gray.mean())
452
- contrast = float(gray.std())
453
- hist, _ = np.histogram(gray, bins=64, range=(0, 256), density=True)
454
- hist = hist[hist > 0]
455
- entropy = float(-(hist * np.log2(hist)).sum())
456
- border = _border_pixels(gray)
457
- black_border_fraction = float((border < 8).mean()) if border.size else 0.0
458
- return ImageQuality(
459
- width=width,
460
- height=height,
461
- brightness=brightness,
462
- contrast=contrast,
463
- entropy=entropy,
464
- black_border_fraction=black_border_fraction,
465
- )
466
-
467
-
468
- def parent_quality_score(canonical: str, reasons: list[str], image_quality: ImageQuality | None) -> float:
469
- if not canonical:
470
- return 0.0
471
- score = 1.0
472
- penalties = {
473
- "caption_does_not_mention_parent": 0.20,
474
- "too_long_for_clean_parent": 0.20,
475
- "low_information_crop": 0.25,
476
- "mostly_black_border": 0.15,
477
- "non_visual_parent": 0.60,
478
- "boilerplate_or_url": 0.80,
479
- "duplicates_caption": 0.80,
480
- "text_slide_or_bad_crop": 0.80,
481
- }
482
- for reason in reasons:
483
- score -= penalties.get(reason, 0.10)
484
- if image_quality is not None:
485
- if image_quality.brightness < 8 or image_quality.brightness > 247:
486
- score -= 0.10
487
- if image_quality.contrast > 8 and image_quality.entropy > 2:
488
- score += 0.05
489
- if canonical in DEFAULT_HYPERNYMS:
490
- score += 0.05
491
- return max(0.0, min(1.0, score))
492
-
493
-
494
- def merge_vlm_decision(
495
- cheap: ParentCleanDecision,
496
- vlm_payload: dict[str, Any] | None,
497
- vlm_can_rescue: bool = False,
498
- ) -> ParentCleanDecision:
499
- if not vlm_payload:
500
- return cheap
501
- reasons = list(cheap.reasons)
502
- canonical = normalize_text(vlm_payload.get("canonical_parent") or cheap.canonical_text)
503
- if canonical:
504
- canonical = canonicalize_parent_text(canonical)
505
- hypernyms = tuple(
506
- normalize_text(value)
507
- for value in vlm_payload.get("hypernyms", cheap.hypernyms)
508
- if normalize_text(value)
509
- )
510
- quality_score = float(vlm_payload.get("quality_score", cheap.quality_score))
511
- keep_payload = vlm_payload.get("keep")
512
- if keep_payload is False:
513
- reasons.append("vlm_reject")
514
- keep = False
515
- elif keep_payload is True and vlm_can_rescue:
516
- keep = quality_score >= 0.45 and bool(canonical)
517
- else:
518
- keep = cheap.keep and keep_payload is not False
519
- reject_reason = normalize_text(vlm_payload.get("reject_reason") or "")
520
- if reject_reason:
521
- reasons.append(f"vlm:{reject_reason}")
522
- return ParentCleanDecision(
523
- original_text=cheap.original_text,
524
- canonical_text=canonical,
525
- keep=keep,
526
- quality_score=max(0.0, min(1.0, quality_score)),
527
- reasons=tuple(dict.fromkeys(reasons)),
528
- hypernyms=hypernyms,
529
- image_quality=cheap.image_quality,
530
- )
531
-
532
-
533
- def expand_parent_texts(decision: ParentCleanDecision, add_hypernyms: bool) -> tuple[str, ...]:
534
- if not decision.keep:
535
- return ()
536
- values = [decision.canonical_text]
537
- if add_hypernyms:
538
- values.extend(decision.hypernyms)
539
- return tuple(dict.fromkeys(value for value in values if value))
540
-
541
-
542
- def _border_pixels(gray: np.ndarray) -> np.ndarray:
543
- if gray.ndim != 2 or min(gray.shape) < 4:
544
- return np.asarray([], dtype=np.float32)
545
- width = max(1, min(gray.shape) // 16)
546
- top = gray[:width, :].reshape(-1)
547
- bottom = gray[-width:, :].reshape(-1)
548
- left = gray[:, :width].reshape(-1)
549
- right = gray[:, -width:].reshape(-1)
550
- return np.concatenate([top, bottom, left, right])
551
-
552
-
553
- def finite_float(value: float) -> float:
554
- return value if math.isfinite(value) else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/data/grit_webdataset.py DELETED
@@ -1,133 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import copy
4
- import glob
5
- import hashlib
6
- import random
7
- from collections.abc import Iterator, Sequence
8
- from pathlib import Path
9
- from typing import Any
10
-
11
- import torch
12
- import webdataset as wds
13
- from PIL import Image
14
- from torch.utils.data import IterableDataset, get_worker_info
15
-
16
- from hyper3_clip.data.transforms import build_train_transform
17
- from hyper3_clip.training.distributed import get_rank, get_world_size
18
-
19
-
20
- PART_SAMPLING_MODES = {"random_one", "all"}
21
-
22
-
23
- class ProcessedGritDataset(IterableDataset):
24
- """Reader for official HyCoCLIP processed GRIT shards."""
25
-
26
- def __init__(
27
- self,
28
- tarfiles: Sequence[str],
29
- image_size: int,
30
- seed: int,
31
- shuffle_buffer: int = 4000,
32
- part_sampling: str = "random_one",
33
- max_parts: int | None = None,
34
- train_transform: str = "wide_random_crop",
35
- image_normalization: str = "imagenet",
36
- deterministic_transforms: bool = False,
37
- ) -> None:
38
- self.tarfiles = _expand_tarfiles(tarfiles)
39
- if not self.tarfiles:
40
- raise FileNotFoundError(f"No GRIT processed shards matched {tarfiles!r}")
41
- rank = get_rank()
42
- world_size = get_world_size()
43
- self.tarfiles = self.tarfiles[rank::world_size]
44
- self.shuffle_buffer = shuffle_buffer
45
- self.seed = seed
46
- if part_sampling not in PART_SAMPLING_MODES:
47
- raise ValueError(f"part_sampling must be one of {sorted(PART_SAMPLING_MODES)}, got {part_sampling!r}")
48
- if max_parts is not None and max_parts <= 0:
49
- raise ValueError("max_parts must be positive when set")
50
- self.part_sampling = part_sampling
51
- self.max_parts = max_parts
52
- self.deterministic_transforms = deterministic_transforms
53
- self.transform = build_train_transform(image_size, preset=train_transform, normalization=image_normalization)
54
-
55
- def __iter__(self) -> Iterator[dict[str, Any]]:
56
- worker = get_worker_info()
57
- worker_id = worker.id if worker is not None else 0
58
- shuffle_rng = random.Random(self.seed + get_rank() * 1_000_003 + worker_id)
59
- part_rng = random.Random(self.seed + 31_415_926 + get_rank() * 1_000_003 + worker_id)
60
- pipeline: Any = wds.DataPipeline(
61
- wds.SimpleShardList(self.tarfiles, seed=self.seed),
62
- wds.split_by_worker,
63
- wds.tarfile_to_samples(),
64
- wds.shuffle(self.shuffle_buffer, initial=self.shuffle_buffer, rng=shuffle_rng),
65
- wds.decode("pil", handler=wds.warn_and_continue),
66
- )
67
- while True:
68
- pipeline_copy = copy.deepcopy(pipeline)
69
- for sample in pipeline_copy:
70
- yield self._decode_sample(sample, part_rng)
71
-
72
- def _decode_sample(self, sample: dict[str, Any], rng: random.Random) -> dict[str, Any]:
73
- num_parents = int(_as_text(sample["numparents.txt"]))
74
- parent_indices = self._select_parent_indices(num_parents, rng)
75
- parent_keys = [f"parent{parent_index:03d}" for parent_index in parent_indices]
76
- sample_key = _as_text(sample.get("__key__", ""))
77
- return {
78
- "image": self._transform_image(sample["child.jpg"], sample_key, "child"),
79
- "caption": _as_text(sample["child.txt"]),
80
- "part_images": [
81
- self._transform_image(sample[f"{parent_key}.jpg"], sample_key, parent_key) for parent_key in parent_keys
82
- ],
83
- "part_texts": [_as_text(sample[f"{parent_key}.txt"]) for parent_key in parent_keys],
84
- }
85
-
86
- def _select_parent_indices(self, num_parents: int, rng: random.Random) -> list[int]:
87
- if self.part_sampling == "random_one":
88
- return [rng.randrange(num_parents)]
89
- parent_indices = list(range(num_parents))
90
- if self.max_parts is not None and len(parent_indices) > self.max_parts:
91
- parent_indices = sorted(rng.sample(parent_indices, k=self.max_parts))
92
- return parent_indices
93
-
94
- def _transform_image(self, value: Any, sample_key: str, role: str) -> torch.Tensor:
95
- image = _as_image(value)
96
- if not self.deterministic_transforms:
97
- return self.transform(image)
98
- transform_seed = _stable_seed(self.seed, sample_key, role)
99
- python_random_state = random.getstate()
100
- try:
101
- random.seed(transform_seed)
102
- with torch.random.fork_rng(devices=[]):
103
- torch.manual_seed(transform_seed)
104
- return self.transform(image)
105
- finally:
106
- random.setstate(python_random_state)
107
-
108
-
109
- def _expand_tarfiles(tarfiles: Sequence[str]) -> list[str]:
110
- expanded: list[str] = []
111
- for pattern in tarfiles:
112
- matches = sorted(glob.glob(pattern))
113
- expanded.extend(matches if matches else [pattern])
114
- return [str(Path(path)) for path in expanded]
115
-
116
-
117
- def _as_text(value: Any) -> str:
118
- return value.decode("utf-8") if isinstance(value, bytes) else str(value)
119
-
120
-
121
- def _as_image(value: Any) -> Image.Image:
122
- if not isinstance(value, Image.Image):
123
- raise TypeError(f"Expected PIL image from WebDataset decode, got {type(value)!r}")
124
- return value.convert("RGB")
125
-
126
-
127
- def _stable_seed(seed: int, *parts: str) -> int:
128
- digest = hashlib.blake2b(digest_size=8)
129
- digest.update(str(seed).encode("utf-8"))
130
- for part in parts:
131
- digest.update(b"\0")
132
- digest.update(part.encode("utf-8"))
133
- return int.from_bytes(digest.digest(), byteorder="big", signed=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/data/manifest_dataset.py DELETED
@@ -1,120 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import random
5
- from pathlib import Path
6
- from typing import Any
7
-
8
- import torch
9
- from PIL import Image
10
- from torch.utils.data import Dataset, get_worker_info
11
-
12
- from hyper3_clip.data.collators import collate_grounded as collate_grounded
13
- from hyper3_clip.data.transforms import build_train_transform
14
- from hyper3_clip.data.types import GroundedParent, GroundedRecord
15
-
16
-
17
- __all__ = ["GroundedManifestDataset", "collate_grounded"]
18
- PART_SAMPLING_MODES = {"random_one", "all"}
19
-
20
-
21
- class GroundedManifestDataset(Dataset):
22
- """Manifest dataset with one full image/caption and one or more grounded parents per row."""
23
-
24
- def __init__(
25
- self,
26
- manifests: list[str] | str | Path,
27
- image_size: int,
28
- seed: int,
29
- manifest_weights: list[float] | None = None,
30
- part_sampling: str = "random_one",
31
- max_parts: int | None = None,
32
- train_transform: str = "wide_random_crop",
33
- image_normalization: str = "imagenet",
34
- ) -> None:
35
- manifest_paths = [str(manifests)] if isinstance(manifests, str | Path) else manifests
36
- self.records: list[GroundedRecord] = []
37
- source_records: list[list[GroundedRecord]] = []
38
- for manifest_path in manifest_paths:
39
- rows: list[GroundedRecord] = []
40
- with Path(manifest_path).open("r", encoding="utf-8") as handle:
41
- for line in handle:
42
- if line.strip():
43
- rows.append(GroundedRecord.from_json(json.loads(line)))
44
- source_records.append(rows)
45
-
46
- if manifest_weights is None:
47
- for rows in source_records:
48
- self.records.extend(rows)
49
- else:
50
- if len(manifest_weights) != len(source_records):
51
- raise ValueError("manifest_weights must match manifests length")
52
- max_len = max(len(rows) for rows in source_records if rows)
53
- for rows, weight in zip(source_records, manifest_weights):
54
- if not rows or weight <= 0.0:
55
- continue
56
- target_len = max(1, int(round(max_len * weight)))
57
- for idx in range(target_len):
58
- self.records.append(rows[idx % len(rows)])
59
-
60
- self.seed = seed
61
- if part_sampling not in PART_SAMPLING_MODES:
62
- raise ValueError(f"part_sampling must be one of {sorted(PART_SAMPLING_MODES)}, got {part_sampling!r}")
63
- if max_parts is not None and max_parts <= 0:
64
- raise ValueError("max_parts must be positive when set")
65
- self.part_sampling = part_sampling
66
- self.max_parts = max_parts
67
- self.transform = build_train_transform(image_size, preset=train_transform, normalization=image_normalization)
68
-
69
- def __len__(self) -> int:
70
- return len(self.records)
71
-
72
- def __getitem__(self, index: int) -> dict[str, Any]:
73
- record = self.records[index]
74
- parents = self._select_parents(index, record.parents)
75
- return {
76
- "image": self._load_image(record.image_path),
77
- "part_images": [self._load_parent_image(record.image_path, parent) for parent in parents],
78
- "caption": record.caption,
79
- "part_texts": [parent.text for parent in parents],
80
- }
81
-
82
- def _select_parents(self, index: int, parents: tuple[GroundedParent, ...]) -> tuple[GroundedParent, ...]:
83
- if self.part_sampling == "all":
84
- if self.max_parts is None or len(parents) <= self.max_parts:
85
- return parents
86
- worker = get_worker_info()
87
- worker_id = worker.id if worker is not None else 0
88
- rng = random.Random(self.seed + index + 1_000_003 * worker_id)
89
- parent_indices = sorted(rng.sample(range(len(parents)), k=self.max_parts))
90
- return tuple(parents[parent_index] for parent_index in parent_indices)
91
- worker = get_worker_info()
92
- worker_id = worker.id if worker is not None else 0
93
- rng = random.Random(self.seed + index + 1_000_003 * worker_id)
94
- return (parents[rng.randrange(len(parents))],)
95
-
96
- def _load_image(self, path: Path) -> torch.Tensor:
97
- with Image.open(path) as image:
98
- return self.transform(image.convert("RGB"))
99
-
100
- def _load_parent_image(self, image_path: Path, parent: GroundedParent) -> torch.Tensor:
101
- source_path = parent.image_path or image_path
102
- with Image.open(source_path) as image:
103
- rgb = image.convert("RGB")
104
- if parent.bbox is not None:
105
- rgb = _crop_bbox(rgb, parent.bbox)
106
- return self.transform(rgb)
107
-
108
-
109
- def _crop_bbox(image: Image.Image, bbox: tuple[float, float, float, float]) -> Image.Image:
110
- width, height = image.size
111
- left, top, right, bottom = bbox
112
- crop_box = (
113
- max(0, min(width, int(round(left)))),
114
- max(0, min(height, int(round(top)))),
115
- max(0, min(width, int(round(right)))),
116
- max(0, min(height, int(round(bottom)))),
117
- )
118
- if crop_box[2] <= crop_box[0] or crop_box[3] <= crop_box[1]:
119
- return image
120
- return image.crop(crop_box)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/data/mixed_dataset.py DELETED
@@ -1,68 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import random
4
- from collections.abc import Iterator
5
- from typing import Any
6
-
7
- from torch.utils.data import Dataset, IterableDataset, get_worker_info
8
-
9
- from hyper3_clip.training.distributed import get_rank, get_world_size
10
-
11
-
12
- class MixedGroundedIterableDataset(IterableDataset):
13
- """Infinite stream that mixes a primary stream with a finite grounded dataset.
14
-
15
- This is intended for cleaned processed-GRIT plus explicit taxonomy hierarchy
16
- manifests. The primary stream remains the pacing dataset, while auxiliary
17
- examples are sampled with a fixed probability.
18
- """
19
-
20
- def __init__(
21
- self,
22
- primary: IterableDataset,
23
- auxiliary: Dataset,
24
- auxiliary_probability: float,
25
- seed: int,
26
- ) -> None:
27
- if not 0.0 <= auxiliary_probability <= 1.0:
28
- raise ValueError("auxiliary_probability must be in [0, 1]")
29
- if len(auxiliary) == 0:
30
- raise ValueError("auxiliary dataset must not be empty")
31
- self.primary = primary
32
- self.auxiliary = auxiliary
33
- self.auxiliary_probability = auxiliary_probability
34
- self.seed = seed
35
-
36
- def __iter__(self) -> Iterator[dict[str, Any]]:
37
- worker = get_worker_info()
38
- worker_id = worker.id if worker is not None else 0
39
- num_workers = worker.num_workers if worker is not None else 1
40
- rank = get_rank()
41
- world_size = get_world_size()
42
- rng = random.Random(self.seed + 1_000_003 * rank + 9_176 * worker_id)
43
- primary_iter = iter(self.primary)
44
- auxiliary_iter = self._iter_auxiliary_indices(rng, rank, world_size, worker_id, num_workers)
45
-
46
- while True:
47
- if rng.random() < self.auxiliary_probability:
48
- yield self.auxiliary[next(auxiliary_iter)]
49
- else:
50
- yield next(primary_iter)
51
-
52
- def _iter_auxiliary_indices(
53
- self,
54
- rng: random.Random,
55
- rank: int,
56
- world_size: int,
57
- worker_id: int,
58
- num_workers: int,
59
- ) -> Iterator[int]:
60
- indices = list(range(len(self.auxiliary)))
61
- indices = indices[rank::world_size]
62
- indices = indices[worker_id::num_workers]
63
- if not indices:
64
- indices = list(range(len(self.auxiliary)))
65
- while True:
66
- shuffled = list(indices)
67
- rng.shuffle(shuffled)
68
- yield from shuffled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/data/transforms.py DELETED
@@ -1,125 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from torchvision import transforms
4
-
5
-
6
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
7
- IMAGENET_STD = (0.229, 0.224, 0.225)
8
- CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
9
- CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
10
- SIGLIP_MEAN = (0.5, 0.5, 0.5)
11
- SIGLIP_STD = (0.5, 0.5, 0.5)
12
-
13
-
14
- def normalization_stats(normalization: str) -> tuple[tuple[float, float, float], tuple[float, float, float]]:
15
- if normalization == "imagenet":
16
- return IMAGENET_MEAN, IMAGENET_STD
17
- if normalization == "clip":
18
- return CLIP_MEAN, CLIP_STD
19
- if normalization == "siglip":
20
- return SIGLIP_MEAN, SIGLIP_STD
21
- raise ValueError("normalization must be one of 'imagenet', 'clip', or 'siglip'")
22
-
23
-
24
- def build_train_transform(
25
- image_size: int,
26
- preset: str = "wide_random_crop",
27
- normalization: str = "imagenet",
28
- ) -> transforms.Compose:
29
- if preset == "wide_random_crop":
30
- steps = [
31
- transforms.RandomResizedCrop(
32
- size=image_size,
33
- scale=(0.5, 1.0),
34
- interpolation=transforms.InterpolationMode.BICUBIC,
35
- ),
36
- transforms.ToTensor(),
37
- ]
38
- elif preset == "wide_random_crop_light_color":
39
- steps = [
40
- transforms.RandomResizedCrop(
41
- size=image_size,
42
- scale=(0.5, 1.0),
43
- interpolation=transforms.InterpolationMode.BICUBIC,
44
- ),
45
- transforms.RandomApply(
46
- [
47
- transforms.ColorJitter(
48
- brightness=0.2,
49
- contrast=0.2,
50
- saturation=0.2,
51
- hue=0.05,
52
- )
53
- ],
54
- p=0.4,
55
- ),
56
- transforms.ToTensor(),
57
- ]
58
- elif preset == "medium_random_crop":
59
- steps = [
60
- transforms.RandomResizedCrop(
61
- size=image_size,
62
- scale=(0.6, 1.0),
63
- interpolation=transforms.InterpolationMode.BICUBIC,
64
- ),
65
- transforms.ToTensor(),
66
- ]
67
- elif preset == "center_crop":
68
- steps = [
69
- transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
70
- transforms.CenterCrop(image_size),
71
- transforms.ToTensor(),
72
- ]
73
- elif preset == "tight_crop_color_jitter_gray":
74
- steps = [
75
- transforms.RandomResizedCrop(
76
- size=image_size,
77
- scale=(0.8, 1.0),
78
- interpolation=transforms.InterpolationMode.BICUBIC,
79
- ),
80
- transforms.RandomApply(
81
- [
82
- transforms.ColorJitter(
83
- brightness=0.4,
84
- contrast=0.4,
85
- saturation=0.4,
86
- hue=0.1,
87
- )
88
- ],
89
- p=0.8,
90
- ),
91
- transforms.RandomGrayscale(p=0.2),
92
- transforms.ToTensor(),
93
- ]
94
- else:
95
- raise ValueError(
96
- f"Unsupported train transform preset {preset!r}; "
97
- "expected 'wide_random_crop', 'wide_random_crop_light_color', "
98
- "'medium_random_crop', 'tight_crop_color_jitter_gray', or 'center_crop'"
99
- )
100
-
101
- mean, std = normalization_stats(normalization)
102
- return transforms.Compose([*steps, transforms.Normalize(mean=mean, std=std)])
103
-
104
-
105
- def build_eval_transform(image_size: int, normalization: str = "imagenet") -> transforms.Compose:
106
- mean, std = normalization_stats(normalization)
107
- return transforms.Compose(
108
- [
109
- transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
110
- transforms.CenterCrop(image_size),
111
- transforms.ToTensor(),
112
- transforms.Normalize(mean=mean, std=std),
113
- ]
114
- )
115
-
116
-
117
- def build_retrieval_transform(image_size: int, normalization: str = "imagenet") -> transforms.Compose:
118
- mean, std = normalization_stats(normalization)
119
- return transforms.Compose(
120
- [
121
- transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
122
- transforms.ToTensor(),
123
- transforms.Normalize(mean=mean, std=std),
124
- ]
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/data/types.py DELETED
@@ -1,48 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
- from typing import Any
6
-
7
-
8
- @dataclass(frozen=True)
9
- class GroundedParent:
10
- text: str
11
- image_path: Path | None = None
12
- bbox: tuple[float, float, float, float] | None = None
13
-
14
-
15
- @dataclass(frozen=True)
16
- class GroundedRecord:
17
- image_path: Path
18
- caption: str
19
- parents: tuple[GroundedParent, ...]
20
-
21
- @classmethod
22
- def from_json(cls, payload: dict[str, Any]) -> "GroundedRecord":
23
- parents_payload = payload.get("parents")
24
- if parents_payload is None:
25
- parents_payload = [
26
- {
27
- "text": payload.get("box_text", ""),
28
- "image_path": payload.get("box_image_path"),
29
- "bbox": payload.get("bbox"),
30
- }
31
- ]
32
-
33
- parents: list[GroundedParent] = []
34
- for parent_payload in parents_payload:
35
- text = str(parent_payload.get("text") or parent_payload.get("box_text") or "").strip()
36
- image_path = parent_payload.get("image_path") or parent_payload.get("box_image_path")
37
- bbox_payload = parent_payload.get("bbox")
38
- bbox = None
39
- if bbox_payload is not None:
40
- if len(bbox_payload) != 4:
41
- raise ValueError(f"Expected four bbox values, got {bbox_payload!r}")
42
- bbox = tuple(float(value) for value in bbox_payload)
43
- parents.append(GroundedParent(text=text, image_path=Path(image_path) if image_path else None, bbox=bbox))
44
-
45
- if not parents:
46
- raise ValueError("Grounded records must include at least one parent")
47
-
48
- return cls(image_path=Path(payload["image_path"]), caption=str(payload["caption"]), parents=tuple(parents))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/evaluation/__init__.py DELETED
@@ -1,20 +0,0 @@
1
- from hyper3_clip.evaluation.classification import evaluate_imagenet_zero_shot
2
- from hyper3_clip.evaluation.hierarchical import evaluate_imagenet_hierarchical
3
- from hyper3_clip.evaluation.pep import PEPEntailmentDataset, evaluate_pep_entailment
4
- from hyper3_clip.evaluation.retrieval import (
5
- CocoCaptionRetrieval,
6
- CocoKarpathyCaptionRetrieval,
7
- Flickr30kCaptionRetrieval,
8
- evaluate_caption_retrieval,
9
- )
10
-
11
- __all__ = [
12
- "CocoCaptionRetrieval",
13
- "CocoKarpathyCaptionRetrieval",
14
- "Flickr30kCaptionRetrieval",
15
- "PEPEntailmentDataset",
16
- "evaluate_caption_retrieval",
17
- "evaluate_imagenet_hierarchical",
18
- "evaluate_imagenet_zero_shot",
19
- "evaluate_pep_entailment",
20
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/evaluation/classification.py DELETED
@@ -1,105 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
-
5
- import torch
6
- from torch.utils.data import DataLoader
7
- from torch.utils.data import Subset
8
- from torchvision import datasets
9
-
10
- from hyper3_clip.data.transforms import build_eval_transform
11
-
12
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
13
-
14
-
15
- IMAGENET_PROMPTS = (
16
- "i took a picture : itap of a {}.",
17
- "pics : a bad photo of the {}.",
18
- "pics : a origami {}.",
19
- "pics : a photo of the large {}.",
20
- "pics : a {} in a video game.",
21
- "pics : art of the {}.",
22
- "pics : a photo of the small {}.",
23
- )
24
-
25
-
26
- @torch.inference_mode()
27
- def evaluate_imagenet_zero_shot(
28
- model: Hyper3CLIP,
29
- imagenet_val_root: str | Path,
30
- device: torch.device,
31
- batch_size: int = 128,
32
- image_size: int = 224,
33
- max_text_length: int = 77,
34
- max_items: int | None = None,
35
- prompts: tuple[str, ...] = IMAGENET_PROMPTS,
36
- ) -> dict[str, float]:
37
- model.eval()
38
- dataset = datasets.ImageFolder(str(imagenet_val_root), transform=build_eval_transform(image_size))
39
- class_names = _imagenet_prompt_names(dataset.classes, Path(imagenet_val_root))
40
- classifier = _build_text_classifier(model, class_names, prompts, device, max_text_length)
41
- eval_dataset = Subset(dataset, range(min(max_items, len(dataset)))) if max_items is not None else dataset
42
- loader = DataLoader(eval_dataset, batch_size=batch_size, num_workers=4, pin_memory=device.type == "cuda")
43
- correct = 0
44
- total = 0
45
- per_class_correct = torch.zeros(len(dataset.classes), dtype=torch.float64)
46
- per_class_total = torch.zeros(len(dataset.classes), dtype=torch.float64)
47
- for images, targets in loader:
48
- images = images.to(device, non_blocking=True)
49
- targets = targets.to(device, non_blocking=True)
50
- predictions = model.similarity_scores(model.encode_image(images), classifier).argmax(dim=1)
51
- matches = predictions == targets
52
- correct += int(matches.sum().item())
53
- total += targets.numel()
54
- per_class_correct.scatter_add_(0, targets.cpu(), matches.cpu().double())
55
- per_class_total.scatter_add_(0, targets.cpu(), torch.ones_like(targets.cpu(), dtype=torch.float64))
56
-
57
- observed_classes = per_class_total > 0
58
- mean_per_class = (per_class_correct[observed_classes] / per_class_total[observed_classes]).mean().item()
59
- top1 = correct / max(total, 1)
60
- return {"top1": top1, "top1_pct": 100.0 * top1, "mean_per_class_acc_pct": 100.0 * mean_per_class}
61
-
62
-
63
- def _build_text_classifier(
64
- model: Hyper3CLIP,
65
- class_names: list[str],
66
- prompts: tuple[str, ...],
67
- device: torch.device,
68
- max_text_length: int,
69
- ) -> torch.Tensor:
70
- class_embeddings: list[torch.Tensor] = []
71
- for class_name in class_names:
72
- readable_name = class_name.replace("_", " ")
73
- texts = [prompt.format(readable_name) for prompt in prompts]
74
- tokenized = model.tokenizer(
75
- texts,
76
- padding=True,
77
- truncation=True,
78
- max_length=max_text_length,
79
- return_tensors="pt",
80
- ).to(device)
81
- attention_mask = (
82
- tokenized.attention_mask if "attention_mask" in tokenized else torch.ones_like(tokenized.input_ids)
83
- )
84
- tangent = model.encode_text(tokenized.input_ids, attention_mask, project=False).float().mean(dim=0, keepdim=True)
85
- class_embeddings.append(model.project_text_features(tangent).squeeze(0))
86
- return torch.stack(class_embeddings, dim=0)
87
-
88
-
89
- def _imagenet_prompt_names(class_names: list[str], imagenet_val_root: Path) -> list[str]:
90
- if len(class_names) == 1000 and all(_looks_like_wnid(class_name) for class_name in class_names):
91
- from torchvision.models._meta import _IMAGENET_CATEGORIES
92
-
93
- label_index = imagenet_val_root / "imagenet_label_to_wnid.tsv"
94
- if label_index.exists():
95
- wnid_to_label: dict[str, int] = {}
96
- for line in label_index.read_text(encoding="utf-8").splitlines():
97
- label, wnid = line.split("\t", maxsplit=1)
98
- wnid_to_label[wnid] = int(label)
99
- return [_IMAGENET_CATEGORIES[wnid_to_label[class_name]] for class_name in class_names]
100
- return list(_IMAGENET_CATEGORIES)
101
- return class_names
102
-
103
-
104
- def _looks_like_wnid(class_name: str) -> bool:
105
- return len(class_name) == 9 and class_name.startswith("n") and class_name[1:].isdigit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/evaluation/hierarchical.py DELETED
@@ -1,118 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import csv
4
- import pickle
5
- from pathlib import Path
6
-
7
- import networkx as nx
8
- import torch
9
- from torch.utils.data import DataLoader, Subset
10
- from torchvision import datasets
11
-
12
- from hyper3_clip.data.transforms import build_eval_transform
13
- from hyper3_clip.evaluation.classification import IMAGENET_PROMPTS, _build_text_classifier, _imagenet_prompt_names, _looks_like_wnid
14
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
15
-
16
-
17
- @torch.inference_mode()
18
- def evaluate_imagenet_hierarchical(
19
- model: Hyper3CLIP,
20
- imagenet_val_root: str | Path,
21
- assets_root: str | Path,
22
- device: torch.device,
23
- batch_size: int = 128,
24
- image_size: int = 224,
25
- max_text_length: int = 77,
26
- max_items: int | None = None,
27
- prompts: tuple[str, ...] = IMAGENET_PROMPTS,
28
- ) -> dict[str, float]:
29
- model.eval()
30
- imagenet_root = Path(imagenet_val_root)
31
- dataset = datasets.ImageFolder(str(imagenet_root), transform=build_eval_transform(image_size))
32
- class_names = _imagenet_prompt_names(dataset.classes, imagenet_root)
33
- classifier = _build_text_classifier(model, class_names, prompts, device, max_text_length)
34
- eval_dataset = Subset(dataset, range(min(max_items, len(dataset)))) if max_items is not None else dataset
35
- loader = DataLoader(eval_dataset, batch_size=batch_size, num_workers=4, pin_memory=device.type == "cuda")
36
-
37
- assets_path = Path(assets_root)
38
- synsets_ordering = pickle.load((assets_path / "all_synsets.pkl").open("rb"))
39
- ancestor_indices = pickle.load((assets_path / "all_ancestors_indices.pkl").open("rb"))
40
- graph = _create_graph_from_edges(assets_path / "imagenet_isa.txt")
41
- dataset_to_official = _dataset_to_official_indices(dataset.classes, imagenet_root, synsets_ordering).to(device)
42
-
43
- totals = torch.zeros(5, dtype=torch.float64)
44
- total_count = 0
45
- for images, targets in loader:
46
- images = images.to(device, non_blocking=True)
47
- official_targets = dataset_to_official[targets.to(device, non_blocking=True)]
48
- dataset_predictions = model.similarity_scores(model.encode_image(images), classifier).argmax(dim=1)
49
- official_predictions = dataset_to_official[dataset_predictions]
50
- batch_totals = _hierarchical_totals(
51
- official_predictions.cpu().tolist(),
52
- official_targets.cpu().tolist(),
53
- ancestor_indices,
54
- graph,
55
- synsets_ordering,
56
- )
57
- totals += torch.tensor(batch_totals, dtype=torch.float64)
58
- total_count += int(official_targets.numel())
59
-
60
- averages = totals / max(total_count, 1)
61
- return {
62
- "tie": float(averages[0].item()),
63
- "lca": float(averages[1].item()),
64
- "jaccard": float(averages[2].item()),
65
- "hierarchical_precision": float(averages[3].item()),
66
- "hierarchical_recall": float(averages[4].item()),
67
- }
68
-
69
-
70
- def _create_graph_from_edges(edge_file: Path) -> nx.DiGraph:
71
- graph = nx.DiGraph()
72
- with edge_file.open("r", encoding="utf-8") as handle:
73
- reader = csv.reader(handle, delimiter=" ")
74
- for parent, child in reader:
75
- graph.add_edge(parent, child)
76
- return graph
77
-
78
-
79
- def _dataset_to_official_indices(class_names: list[str], imagenet_val_root: Path, synsets_ordering: list[str]) -> torch.Tensor:
80
- label_index = imagenet_val_root / "imagenet_label_to_wnid.tsv"
81
- if label_index.exists():
82
- wnid_to_label = {}
83
- for line in label_index.read_text(encoding="utf-8").splitlines():
84
- label, wnid = line.split("\t", maxsplit=1)
85
- wnid_to_label[wnid] = int(label)
86
- return torch.tensor([wnid_to_label[class_name] for class_name in class_names], dtype=torch.long)
87
- if all(_looks_like_wnid(class_name) for class_name in class_names):
88
- synset_to_label = {synset: label for label, synset in enumerate(synsets_ordering)}
89
- return torch.tensor([synset_to_label[class_name] for class_name in class_names], dtype=torch.long)
90
- return torch.arange(len(class_names), dtype=torch.long)
91
-
92
-
93
- def _hierarchical_totals(
94
- predicted_labels: list[int],
95
- true_labels: list[int],
96
- ancestor_indices: list[list[int]],
97
- graph: nx.DiGraph,
98
- synsets_ordering: list[str],
99
- ) -> tuple[float, float, float, float, float]:
100
- undirected_graph = graph.to_undirected()
101
- tree_induced_error = 0.0
102
- least_common_ancestor = 0.0
103
- jaccard = 0.0
104
- hierarchical_precision = 0.0
105
- hierarchical_recall = 0.0
106
- for pred_label, true_label in zip(predicted_labels, true_labels):
107
- pred_synset = synsets_ordering[pred_label]
108
- true_synset = synsets_ordering[true_label]
109
- pred_ancestors = set(ancestor_indices[pred_label])
110
- true_ancestors = set(ancestor_indices[true_label])
111
- intersection = pred_ancestors.intersection(true_ancestors)
112
- union = pred_ancestors.union(true_ancestors)
113
- tree_induced_error += nx.shortest_path_length(undirected_graph, source=pred_synset, target=true_synset)
114
- least_common_ancestor += len(pred_ancestors) - len(intersection) + 1
115
- jaccard += len(intersection) / len(union)
116
- hierarchical_precision += len(intersection) / len(pred_ancestors)
117
- hierarchical_recall += len(intersection) / len(true_ancestors)
118
- return tree_induced_error, least_common_ancestor, jaccard, hierarchical_precision, hierarchical_recall
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/evaluation/pep.py DELETED
@@ -1,462 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import csv
4
- import hashlib
5
- import json
6
- import math
7
- from urllib.parse import urlparse
8
- import urllib.request
9
- from dataclasses import dataclass
10
- from io import BytesIO
11
- from pathlib import Path
12
- from typing import Any
13
-
14
- import torch
15
- from PIL import Image
16
- from torch.utils.data import Dataset
17
-
18
- from hyper3_clip.data.transforms import build_eval_transform
19
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
20
- from hyper3_clip.models.losses import factor_oxy_angle
21
-
22
-
23
- @dataclass(frozen=True)
24
- class PEPSample:
25
- image_id: str
26
- image_path: Path | None
27
- image_url: str | None
28
- positive_captions: tuple[str, ...]
29
- negative_captions: tuple[str, ...] = ()
30
-
31
-
32
- class PEPEntailmentDataset(Dataset):
33
- def __init__(
34
- self,
35
- annotations_path: str | Path,
36
- image_root: str | Path | None = None,
37
- image_size: int = 224,
38
- max_items: int | None = None,
39
- image_cache_dir: str | Path | None = None,
40
- allow_image_download: bool = False,
41
- ) -> None:
42
- self.samples, self.global_negative_captions = load_pep_samples(
43
- annotations_path,
44
- image_root=image_root,
45
- max_items=max_items,
46
- )
47
- self.transform = build_eval_transform(image_size)
48
- self.image_cache_dir = Path(image_cache_dir) if image_cache_dir is not None else None
49
- self.allow_image_download = allow_image_download
50
-
51
- def __len__(self) -> int:
52
- return len(self.samples)
53
-
54
- def __getitem__(self, index: int) -> dict[str, Any]:
55
- sample = self.samples[index]
56
- image = _load_sample_image(sample, self.image_cache_dir, self.allow_image_download)
57
- return {
58
- "image": self.transform(image.convert("RGB")),
59
- "image_id": sample.image_id,
60
- "positive_captions": sample.positive_captions,
61
- "negative_captions": sample.negative_captions,
62
- }
63
-
64
-
65
- @torch.inference_mode()
66
- def evaluate_pep_entailment(
67
- model: Hyper3CLIP,
68
- annotations_path: str | Path,
69
- device: torch.device,
70
- image_root: str | Path | None = None,
71
- image_size: int = 224,
72
- max_text_length: int = 77,
73
- batch_size: int = 128,
74
- max_items: int | None = None,
75
- image_cache_dir: str | Path | None = None,
76
- allow_image_download: bool = False,
77
- negative_pool_strategy: str = "annotation",
78
- max_negatives_per_image: int | None = None,
79
- pair_batch_size: int = 8192,
80
- ) -> dict[str, float]:
81
- """Evaluate ARGENT-style PEP entailment AUC/AP.
82
-
83
- PEP treats image-caption hierarchy evaluation as binary entailment
84
- classification. Positives are the hierarchical captions attached to the same
85
- image. Negatives come either from explicit annotation/global pools, or from
86
- other samples' finest captions when ``negative_pool_strategy`` is
87
- ``"all_fine_captions"``.
88
- """
89
- if negative_pool_strategy not in {"annotation", "all_fine_captions"}:
90
- raise ValueError("negative_pool_strategy must be 'annotation' or 'all_fine_captions'")
91
-
92
- model.eval()
93
- dataset = PEPEntailmentDataset(
94
- annotations_path,
95
- image_root=image_root,
96
- image_size=image_size,
97
- max_items=max_items,
98
- image_cache_dir=image_cache_dir,
99
- allow_image_download=allow_image_download,
100
- )
101
- if len(dataset) == 0:
102
- raise ValueError("PEP evaluation requires at least one sample")
103
-
104
- image_feats = _encode_images(model, dataset, device, batch_size)
105
- pair_image_indices, pair_captions, labels = _build_pep_pairs(
106
- dataset.samples,
107
- dataset.global_negative_captions,
108
- negative_pool_strategy=negative_pool_strategy,
109
- max_negatives_per_image=max_negatives_per_image,
110
- )
111
- if not any(labels) or all(labels):
112
- raise ValueError("PEP evaluation requires both positive and negative pairs")
113
-
114
- captions = sorted(set(pair_captions))
115
- caption_to_index = {caption: index for index, caption in enumerate(captions)}
116
- text_feats = _encode_texts(model, captions, device, max_text_length, batch_size)
117
- pair_text_indices = [caption_to_index[caption] for caption in pair_captions]
118
- scores = _score_pep_pairs(
119
- model,
120
- image_feats,
121
- text_feats,
122
- pair_image_indices,
123
- pair_text_indices,
124
- device,
125
- pair_batch_size=pair_batch_size,
126
- )
127
-
128
- auc = _roc_auc_score(labels, scores)
129
- ap = _average_precision_score(labels, scores)
130
- positive_scores = [score for score, label in zip(scores, labels) if label == 1]
131
- negative_scores = [score for score, label in zip(scores, labels) if label == 0]
132
- return {
133
- "auc_roc": auc,
134
- "average_precision": ap,
135
- "auc_roc_pct": 100.0 * auc,
136
- "average_precision_pct": 100.0 * ap,
137
- "num_samples": float(len(dataset.samples)),
138
- "num_pairs": float(len(labels)),
139
- "num_positive_pairs": float(sum(labels)),
140
- "num_negative_pairs": float(len(labels) - sum(labels)),
141
- "mean_positive_score": float(sum(positive_scores) / len(positive_scores)),
142
- "mean_negative_score": float(sum(negative_scores) / len(negative_scores)),
143
- }
144
-
145
-
146
- def probabilistic_entailment_score(
147
- specific: torch.Tensor,
148
- general: torch.Tensor,
149
- kappa: torch.Tensor,
150
- ) -> torch.Tensor:
151
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
152
- if angles.dim() == 2:
153
- angles = angles.mean(dim=-1)
154
- return torch.clamp(1.0 - (2.0 * angles / math.pi), min=0.0, max=1.0)
155
-
156
-
157
- def load_pep_samples(
158
- annotations_path: str | Path,
159
- image_root: str | Path | None = None,
160
- max_items: int | None = None,
161
- ) -> tuple[list[PEPSample], tuple[str, ...]]:
162
- path = Path(annotations_path)
163
- if path.suffix.lower() == ".jsonl":
164
- samples = [_sample_from_mapping(json.loads(line), image_root) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
165
- return _limit_samples(samples, max_items), ()
166
- if path.suffix.lower() == ".json":
167
- payload = json.loads(path.read_text(encoding="utf-8"))
168
- samples_payload = payload.get("samples", payload.get("data", [])) if isinstance(payload, dict) else payload
169
- global_negatives = _caption_tuple(
170
- payload.get("negative_captions") or payload.get("global_negative_captions") or payload.get("negative_pool")
171
- ) if isinstance(payload, dict) else ()
172
- samples = [_sample_from_mapping(item, image_root) for item in samples_payload]
173
- return _limit_samples(samples, max_items), global_negatives
174
- if path.suffix.lower() in {".csv", ".tsv"}:
175
- return _load_csv_samples(path, image_root, max_items)
176
- raise ValueError(f"Unsupported PEP annotations format {path.suffix!r}; expected .json, .jsonl, .csv, or .tsv")
177
-
178
-
179
- def _load_csv_samples(
180
- path: Path,
181
- image_root: str | Path | None,
182
- max_items: int | None,
183
- ) -> tuple[list[PEPSample], tuple[str, ...]]:
184
- delimiter = "\t" if path.suffix.lower() == ".tsv" else ","
185
- with path.open("r", encoding="utf-8", newline="") as handle:
186
- rows = list(csv.DictReader(handle, delimiter=delimiter))
187
- if not rows:
188
- return [], ()
189
-
190
- has_pair_labels = "caption" in rows[0] and "label" in rows[0]
191
- if has_pair_labels:
192
- grouped: dict[str, dict[str, Any]] = {}
193
- for row in rows:
194
- key = _image_key(row)
195
- item = grouped.setdefault(key, {**row, "positive_captions": [], "negative_captions": []})
196
- if _truthy_label(row["label"]):
197
- item["positive_captions"].append(row["caption"])
198
- else:
199
- item["negative_captions"].append(row["caption"])
200
- samples = [_sample_from_mapping(item, image_root) for item in grouped.values()]
201
- return _limit_samples(samples, max_items), ()
202
-
203
- samples = [_sample_from_mapping(row, image_root) for row in rows]
204
- return _limit_samples(samples, max_items), ()
205
-
206
-
207
- def _sample_from_mapping(item: dict[str, Any], image_root: str | Path | None) -> PEPSample:
208
- positives = _extract_positive_captions(item)
209
- if not positives:
210
- raise ValueError(f"PEP sample {item.get('id', item.get('image_id', '<unknown>'))!r} has no positive captions")
211
- image_path = _extract_image_path(item, image_root)
212
- image_url = _first_present(item, ("image_url", "url"))
213
- if image_path is None and not image_url:
214
- raise ValueError(f"PEP sample {item.get('id', item.get('image_id', '<unknown>'))!r} has no image path or URL")
215
- image_id = str(_first_present(item, ("image_id", "id", "uid")) or image_path or image_url)
216
- negatives = _caption_tuple(_first_present(item, ("negative_captions", "negatives", "negative_pool")))
217
- return PEPSample(
218
- image_id=image_id,
219
- image_path=image_path,
220
- image_url=str(image_url) if image_url else None,
221
- positive_captions=positives,
222
- negative_captions=negatives,
223
- )
224
-
225
-
226
- def _extract_positive_captions(item: dict[str, Any]) -> tuple[str, ...]:
227
- raw = _first_present(item, ("positive_captions", "hierarchical_captions", "caption_hierarchy", "captions"))
228
- captions = _caption_tuple(raw)
229
- if captions:
230
- return captions
231
- caption = item.get("caption")
232
- return (str(caption).strip(),) if caption else ()
233
-
234
-
235
- def _caption_tuple(raw: Any) -> tuple[str, ...]:
236
- if raw is None:
237
- return ()
238
- if isinstance(raw, str):
239
- stripped = raw.strip()
240
- if not stripped:
241
- return ()
242
- if stripped.startswith("["):
243
- try:
244
- parsed = json.loads(stripped)
245
- return _caption_tuple(parsed)
246
- except json.JSONDecodeError:
247
- pass
248
- separator = "=>" if "=>" in stripped else "||" if "||" in stripped else None
249
- values = stripped.split(separator) if separator else [stripped]
250
- return tuple(value.strip() for value in values if value.strip())
251
- if isinstance(raw, (list, tuple)):
252
- return tuple(str(value).strip() for value in raw if str(value).strip())
253
- return (str(raw).strip(),)
254
-
255
-
256
- def _extract_image_path(item: dict[str, Any], image_root: str | Path | None) -> Path | None:
257
- raw = _first_present(item, ("image_path", "path", "file_name", "filename"))
258
- if raw is None:
259
- image_url = _first_present(item, ("image_url", "url"))
260
- if image_url is None or image_root is None:
261
- return None
262
- return _url_to_local_image_path(str(image_url), Path(image_root))
263
- path = Path(str(raw))
264
- if not path.is_absolute() and image_root is not None:
265
- path = Path(image_root) / path
266
- return path
267
-
268
-
269
- def _url_to_local_image_path(url: str, image_root: Path) -> Path:
270
- url_path = Path(urlparse(url).path)
271
- filename = url_path.name
272
- if not filename:
273
- raise ValueError(f"Cannot infer image filename from URL {url!r}")
274
- candidates = [image_root / filename]
275
- if url_path.parent.name:
276
- candidates.append(image_root / url_path.parent.name / filename)
277
- for candidate in candidates:
278
- if candidate.exists():
279
- return candidate
280
- return candidates[0]
281
-
282
-
283
- def _first_present(item: dict[str, Any], keys: tuple[str, ...]) -> Any:
284
- for key in keys:
285
- value = item.get(key)
286
- if value not in (None, ""):
287
- return value
288
- return None
289
-
290
-
291
- def _image_key(row: dict[str, Any]) -> str:
292
- return str(_first_present(row, ("image_id", "id", "image_path", "path", "file_name", "filename", "image_url", "url")))
293
-
294
-
295
- def _truthy_label(raw: Any) -> bool:
296
- return str(raw).strip().lower() in {"1", "true", "yes", "positive", "pos"}
297
-
298
-
299
- def _limit_samples(samples: list[PEPSample], max_items: int | None) -> list[PEPSample]:
300
- return samples[:max_items] if max_items is not None else samples
301
-
302
-
303
- def _load_sample_image(sample: PEPSample, image_cache_dir: Path | None, allow_image_download: bool) -> Image.Image:
304
- if sample.image_path is not None:
305
- with Image.open(sample.image_path) as image:
306
- return image.convert("RGB")
307
- if not sample.image_url:
308
- raise ValueError(f"PEP sample {sample.image_id!r} has no image path or URL")
309
- if not allow_image_download:
310
- raise ValueError("PEP sample uses image_url; set allow_image_download=true and image_cache_dir to evaluate it")
311
- if image_cache_dir is None:
312
- with urllib.request.urlopen(sample.image_url, timeout=30) as response:
313
- return Image.open(BytesIO(response.read())).convert("RGB")
314
- image_cache_dir.mkdir(parents=True, exist_ok=True)
315
- cache_path = image_cache_dir / _url_cache_name(sample.image_url)
316
- if not cache_path.exists():
317
- urllib.request.urlretrieve(sample.image_url, cache_path)
318
- with Image.open(cache_path) as image:
319
- return image.convert("RGB")
320
-
321
-
322
- def _url_cache_name(url: str) -> str:
323
- suffix = Path(url.split("?", maxsplit=1)[0]).suffix or ".jpg"
324
- digest = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
325
- return f"{digest}{suffix}"
326
-
327
-
328
- def _encode_images(
329
- model: Hyper3CLIP,
330
- dataset: PEPEntailmentDataset,
331
- device: torch.device,
332
- batch_size: int,
333
- ) -> torch.Tensor:
334
- feats: list[torch.Tensor] = []
335
- batch: list[torch.Tensor] = []
336
- for index in range(len(dataset)):
337
- batch.append(dataset[index]["image"])
338
- if len(batch) == batch_size or index == len(dataset) - 1:
339
- images = torch.stack(batch).to(device)
340
- feats.append(model.encode_image(images).cpu())
341
- batch = []
342
- return torch.cat(feats)
343
-
344
-
345
- def _encode_texts(
346
- model: Hyper3CLIP,
347
- captions: list[str],
348
- device: torch.device,
349
- max_text_length: int,
350
- batch_size: int,
351
- ) -> torch.Tensor:
352
- feats: list[torch.Tensor] = []
353
- for start in range(0, len(captions), batch_size):
354
- batch = captions[start : start + batch_size]
355
- tokenized = model.tokenizer(
356
- batch,
357
- padding=True,
358
- truncation=True,
359
- max_length=max_text_length,
360
- return_tensors="pt",
361
- ).to(device)
362
- attention_mask = (
363
- tokenized.attention_mask if "attention_mask" in tokenized else torch.ones_like(tokenized.input_ids)
364
- )
365
- feats.append(model.encode_text(tokenized.input_ids, attention_mask).cpu())
366
- return torch.cat(feats)
367
-
368
-
369
- def _build_pep_pairs(
370
- samples: list[PEPSample],
371
- global_negative_captions: tuple[str, ...],
372
- negative_pool_strategy: str,
373
- max_negatives_per_image: int | None,
374
- ) -> tuple[list[int], list[str], list[int]]:
375
- fine_caption_pool = tuple(sample.positive_captions[-1] for sample in samples)
376
- pair_image_indices: list[int] = []
377
- pair_captions: list[str] = []
378
- labels: list[int] = []
379
- for image_index, sample in enumerate(samples):
380
- positives = set(sample.positive_captions)
381
- for caption in sample.positive_captions:
382
- pair_image_indices.append(image_index)
383
- pair_captions.append(caption)
384
- labels.append(1)
385
-
386
- negatives = sample.negative_captions or global_negative_captions
387
- if not negatives and negative_pool_strategy == "all_fine_captions":
388
- negatives = tuple(caption for idx, caption in enumerate(fine_caption_pool) if idx != image_index)
389
- negatives = tuple(caption for caption in negatives if caption not in positives)
390
- if max_negatives_per_image is not None:
391
- negatives = negatives[:max_negatives_per_image]
392
- for caption in negatives:
393
- pair_image_indices.append(image_index)
394
- pair_captions.append(caption)
395
- labels.append(0)
396
- return pair_image_indices, pair_captions, labels
397
-
398
-
399
- def _score_pep_pairs(
400
- model: Hyper3CLIP,
401
- image_feats: torch.Tensor,
402
- text_feats: torch.Tensor,
403
- pair_image_indices: list[int],
404
- pair_text_indices: list[int],
405
- device: torch.device,
406
- pair_batch_size: int,
407
- ) -> list[float]:
408
- kappa = model._kappa().detach().to(device)
409
- scores: list[torch.Tensor] = []
410
- for start in range(0, len(pair_image_indices), pair_batch_size):
411
- image_index = torch.tensor(pair_image_indices[start : start + pair_batch_size], dtype=torch.long)
412
- text_index = torch.tensor(pair_text_indices[start : start + pair_batch_size], dtype=torch.long)
413
- batch_images = image_feats.index_select(0, image_index).to(device)
414
- batch_texts = text_feats.index_select(0, text_index).to(device)
415
- scores.append(probabilistic_entailment_score(batch_images, batch_texts, kappa).cpu())
416
- return torch.cat(scores).tolist()
417
-
418
-
419
- def _roc_auc_score(labels: list[int], scores: list[float]) -> float:
420
- positives = sum(labels)
421
- negatives = len(labels) - positives
422
- if positives == 0 or negatives == 0:
423
- raise ValueError("ROC AUC requires both positive and negative labels")
424
-
425
- sorted_pairs = sorted(zip(scores, labels), key=lambda pair: pair[0])
426
- rank_sum_pos = 0.0
427
- rank = 1
428
- index = 0
429
- while index < len(sorted_pairs):
430
- end = index + 1
431
- while end < len(sorted_pairs) and sorted_pairs[end][0] == sorted_pairs[index][0]:
432
- end += 1
433
- avg_rank = (rank + rank + (end - index) - 1) / 2.0
434
- rank_sum_pos += avg_rank * sum(label for _, label in sorted_pairs[index:end])
435
- rank += end - index
436
- index = end
437
- return (rank_sum_pos - positives * (positives + 1) / 2.0) / (positives * negatives)
438
-
439
-
440
- def _average_precision_score(labels: list[int], scores: list[float]) -> float:
441
- positives = sum(labels)
442
- if positives == 0:
443
- raise ValueError("Average precision requires at least one positive label")
444
-
445
- sorted_pairs = sorted(zip(scores, labels), key=lambda pair: pair[0], reverse=True)
446
- true_positives = 0
447
- false_positives = 0
448
- previous_recall = 0.0
449
- ap = 0.0
450
- index = 0
451
- while index < len(sorted_pairs):
452
- end = index + 1
453
- while end < len(sorted_pairs) and sorted_pairs[end][0] == sorted_pairs[index][0]:
454
- end += 1
455
- true_positives += sum(label for _, label in sorted_pairs[index:end])
456
- false_positives += (end - index) - sum(label for _, label in sorted_pairs[index:end])
457
- recall = true_positives / positives
458
- precision = true_positives / (true_positives + false_positives)
459
- ap += (recall - previous_recall) * precision
460
- previous_recall = recall
461
- index = end
462
- return ap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/evaluation/retrieval.py DELETED
@@ -1,215 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- from collections import defaultdict
5
- from pathlib import Path
6
- from typing import Any
7
-
8
- import torch
9
- from PIL import Image
10
- from torch.utils.data import Dataset
11
-
12
- from hyper3_clip.data.transforms import build_retrieval_transform
13
-
14
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
15
-
16
-
17
- class CocoCaptionRetrieval(Dataset):
18
- def __init__(
19
- self,
20
- root: str | Path,
21
- image_size: int = 224,
22
- max_items: int | None = None,
23
- image_normalization: str = "imagenet",
24
- ) -> None:
25
- self.root = Path(root)
26
- with (self.root / "annotations" / "captions_val2017.json").open("r", encoding="utf-8") as handle:
27
- payload = json.load(handle)
28
- images = {item["id"]: item["file_name"] for item in payload["images"]}
29
- captions: dict[int, list[str]] = defaultdict(list)
30
- for annotation in payload["annotations"]:
31
- captions[int(annotation["image_id"])].append(str(annotation["caption"]))
32
- self.items = [
33
- {"image_id": image_id, "image_path": self.root / "val2017" / images[image_id], "captions": captions[image_id]}
34
- for image_id in sorted(captions)
35
- ]
36
- if max_items is not None:
37
- self.items = self.items[:max_items]
38
- self.transform = build_retrieval_transform(image_size, normalization=image_normalization)
39
-
40
- def __len__(self) -> int:
41
- return len(self.items)
42
-
43
- def __getitem__(self, index: int) -> dict[str, Any]:
44
- item = self.items[index]
45
- with Image.open(item["image_path"]) as image:
46
- tensor = self.transform(image.convert("RGB"))
47
- return {"image": tensor, "captions": item["captions"], "image_id": item["image_id"]}
48
-
49
-
50
- class CocoKarpathyCaptionRetrieval(Dataset):
51
- def __init__(
52
- self,
53
- root: str | Path,
54
- split: str = "test",
55
- image_size: int = 224,
56
- max_items: int | None = None,
57
- image_normalization: str = "imagenet",
58
- ) -> None:
59
- self.root = Path(root)
60
- with (self.root / "karpathy" / "dataset_coco.json").open("r", encoding="utf-8") as handle:
61
- payload = json.load(handle)
62
- images = [item for item in payload["images"] if item["split"] == split]
63
- if max_items is not None:
64
- images = images[:max_items]
65
- self.items = [
66
- {
67
- "image_id": item["imgid"],
68
- "image_path": self.root / item["filepath"] / item["filename"],
69
- "captions": [sentence["raw"].strip() for sentence in item["sentences"]],
70
- }
71
- for item in images
72
- ]
73
- self.transform = build_retrieval_transform(image_size, normalization=image_normalization)
74
-
75
- def __len__(self) -> int:
76
- return len(self.items)
77
-
78
- def __getitem__(self, index: int) -> dict[str, Any]:
79
- item = self.items[index]
80
- with Image.open(item["image_path"]) as image:
81
- tensor = self.transform(image.convert("RGB"))
82
- return {"image": tensor, "captions": item["captions"], "image_id": item["image_id"]}
83
-
84
-
85
- class Flickr30kCaptionRetrieval(Dataset):
86
- def __init__(
87
- self,
88
- root: str | Path,
89
- split: str = "test",
90
- image_size: int = 224,
91
- max_items: int | None = None,
92
- image_normalization: str = "imagenet",
93
- ) -> None:
94
- self.root = Path(root)
95
- with (self.root / "dataset_flickr30k.json").open("r", encoding="utf-8") as handle:
96
- payload = json.load(handle)
97
- self.items = []
98
- for index, image_payload in enumerate(payload["images"]):
99
- if image_payload.get("split") != split:
100
- continue
101
- captions = [str(sentence.get("raw") or " ".join(sentence.get("tokens", []))) for sentence in image_payload["sentences"]]
102
- self.items.append(
103
- {
104
- "image_id": index,
105
- "image_path": self.root / "flickr30k_images" / image_payload["filename"],
106
- "captions": captions,
107
- }
108
- )
109
- if max_items is not None:
110
- self.items = self.items[:max_items]
111
- self.transform = build_retrieval_transform(image_size, normalization=image_normalization)
112
-
113
- def __len__(self) -> int:
114
- return len(self.items)
115
-
116
- def __getitem__(self, index: int) -> dict[str, Any]:
117
- item = self.items[index]
118
- with Image.open(item["image_path"]) as image:
119
- tensor = self.transform(image.convert("RGB"))
120
- return {"image": tensor, "captions": item["captions"], "image_id": item["image_id"]}
121
-
122
-
123
- @torch.inference_mode()
124
- def evaluate_caption_retrieval(
125
- model: Hyper3CLIP,
126
- dataset: Dataset,
127
- device: torch.device,
128
- max_text_length: int = 77,
129
- batch_size: int = 128,
130
- ) -> dict[str, float]:
131
- model.eval()
132
- image_feats: list[torch.Tensor] = []
133
- captions: list[str] = []
134
- text_feats: list[torch.Tensor] = []
135
- text_to_image: list[int] = []
136
-
137
- image_batch: list[torch.Tensor] = []
138
- for item_index in range(len(dataset)):
139
- item = dataset[item_index]
140
- image_batch.append(item["image"])
141
- if len(image_batch) == batch_size or item_index == len(dataset) - 1:
142
- images = torch.stack(image_batch).to(device)
143
- image_feats.append(model.encode_retrieval_image(images).cpu())
144
- image_batch = []
145
- captions.extend(item["captions"])
146
- text_to_image.extend([item_index] * len(item["captions"]))
147
-
148
- for start in range(0, len(captions), batch_size):
149
- caption_batch = captions[start : start + batch_size]
150
- tokenized = model.tokenizer(
151
- caption_batch,
152
- padding=True,
153
- truncation=True,
154
- max_length=max_text_length,
155
- return_tensors="pt",
156
- ).to(device)
157
- attention_mask = (
158
- tokenized.attention_mask if "attention_mask" in tokenized else torch.ones_like(tokenized.input_ids)
159
- )
160
- text_feats.append(model.encode_retrieval_text(tokenized.input_ids, attention_mask).cpu())
161
-
162
- images = torch.cat(image_feats).to(device)
163
- texts = torch.cat(text_feats).to(device)
164
- scores_i2t = _retrieval_similarity_scores(model, images, texts, chunk_size=max(1, min(batch_size, 64)))
165
- scores_t2i = scores_i2t.transpose(0, 1)
166
- target_device = scores_i2t.device
167
- text_targets = torch.tensor(text_to_image, device=target_device)
168
- fractions = {
169
- "image_to_text_r1": _recall_at_k(scores_i2t, _image_to_text_targets(text_to_image, len(dataset), target_device), 1),
170
- "image_to_text_r5": _recall_at_k(scores_i2t, _image_to_text_targets(text_to_image, len(dataset), target_device), 5),
171
- "image_to_text_r10": _recall_at_k(scores_i2t, _image_to_text_targets(text_to_image, len(dataset), target_device), 10),
172
- "text_to_image_r1": _single_target_recall_at_k(scores_t2i, text_targets, 1),
173
- "text_to_image_r5": _single_target_recall_at_k(scores_t2i, text_targets, 5),
174
- "text_to_image_r10": _single_target_recall_at_k(scores_t2i, text_targets, 10),
175
- }
176
- return {
177
- **fractions,
178
- "i2t_r1": 100.0 * fractions["image_to_text_r1"],
179
- "i2t_r5": 100.0 * fractions["image_to_text_r5"],
180
- "i2t_r10": 100.0 * fractions["image_to_text_r10"],
181
- "t2i_r1": 100.0 * fractions["text_to_image_r1"],
182
- "t2i_r5": 100.0 * fractions["text_to_image_r5"],
183
- "t2i_r10": 100.0 * fractions["text_to_image_r10"],
184
- }
185
-
186
-
187
- def _retrieval_similarity_scores(
188
- model: Hyper3CLIP, images: torch.Tensor, texts: torch.Tensor, chunk_size: int
189
- ) -> torch.Tensor:
190
- if not getattr(model, "retrieval_requires_chunking", False):
191
- return model.retrieval_similarity_scores(images, texts)
192
-
193
- chunks: list[torch.Tensor] = []
194
- for start in range(0, images.shape[0], chunk_size):
195
- chunk_scores = model.retrieval_similarity_scores(images[start : start + chunk_size], texts)
196
- chunks.append(chunk_scores.cpu())
197
- return torch.cat(chunks, dim=0)
198
-
199
-
200
- def _image_to_text_targets(text_to_image: list[int], num_images: int, device: torch.device) -> list[torch.Tensor]:
201
- targets: list[list[int]] = [[] for _ in range(num_images)]
202
- for text_index, image_index in enumerate(text_to_image):
203
- targets[image_index].append(text_index)
204
- return [torch.tensor(indices, device=device) for indices in targets]
205
-
206
-
207
- def _recall_at_k(scores: torch.Tensor, targets: list[torch.Tensor], k: int) -> float:
208
- topk = scores.topk(k=min(k, scores.shape[1]), dim=1).indices
209
- hits = [bool(torch.isin(targets[row], topk[row]).any().item()) for row in range(scores.shape[0])]
210
- return float(sum(hits) / len(hits))
211
-
212
-
213
- def _single_target_recall_at_k(scores: torch.Tensor, targets: torch.Tensor, k: int) -> float:
214
- topk = scores.topk(k=min(k, scores.shape[1]), dim=1).indices
215
- return float((topk == targets[:, None]).any(dim=1).float().mean().item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
2
-
3
- __all__ = ["Hyper3CLIP"]
 
 
 
 
hyper3_clip/models/encoders.py DELETED
@@ -1,173 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import timm
4
- import torch
5
- from torch import nn
6
- from transformers import (
7
- AutoConfig,
8
- AutoModel,
9
- AutoTokenizer,
10
- CLIPTextConfig,
11
- CLIPTextModel,
12
- CLIPTextModelWithProjection,
13
- CLIPVisionConfig,
14
- CLIPVisionModel,
15
- CLIPVisionModelWithProjection,
16
- SiglipTextConfig,
17
- SiglipTextModel,
18
- SiglipVisionConfig,
19
- SiglipVisionModel,
20
- )
21
-
22
-
23
- class VisionEncoder(nn.Module):
24
- def __init__(self, backbone_name: str, pretrained: bool = True) -> None:
25
- super().__init__()
26
- self.kind = "timm"
27
- if backbone_name.startswith("hf_clip_projected:"):
28
- self.kind = "hf_clip_projected"
29
- model_name = backbone_name.removeprefix("hf_clip_projected:")
30
- self.backbone = (
31
- CLIPVisionModelWithProjection.from_pretrained(model_name)
32
- if pretrained
33
- else CLIPVisionModelWithProjection(CLIPVisionConfig.from_pretrained(model_name))
34
- )
35
- self.output_dim = self.backbone.config.projection_dim
36
- elif backbone_name.startswith("hf_clip:"):
37
- self.kind = "hf_vision"
38
- model_name = backbone_name.removeprefix("hf_clip:")
39
- self.backbone = (
40
- CLIPVisionModel.from_pretrained(model_name)
41
- if pretrained
42
- else CLIPVisionModel(CLIPVisionConfig.from_pretrained(model_name))
43
- )
44
- self.output_dim = self.backbone.config.hidden_size
45
- elif backbone_name.startswith("hf_siglip:"):
46
- self.kind = "hf_vision"
47
- model_name = backbone_name.removeprefix("hf_siglip:")
48
- self.backbone = (
49
- SiglipVisionModel.from_pretrained(model_name)
50
- if pretrained
51
- else SiglipVisionModel(SiglipVisionConfig.from_pretrained(model_name))
52
- )
53
- self.output_dim = self.backbone.config.hidden_size
54
- else:
55
- self.backbone = timm.create_model(
56
- backbone_name,
57
- pretrained=pretrained,
58
- num_classes=0,
59
- global_pool="avg",
60
- )
61
- self.output_dim = self.backbone.num_features
62
-
63
- def forward(self, image: torch.Tensor) -> torch.Tensor:
64
- if self.kind == "hf_clip_projected":
65
- return self.backbone(pixel_values=image).image_embeds
66
- if self.kind == "hf_vision":
67
- out = self.backbone(pixel_values=image)
68
- if hasattr(out, "pooler_output") and out.pooler_output is not None:
69
- return out.pooler_output
70
- return out.last_hidden_state[:, 0]
71
- return self.backbone(image)
72
-
73
- def forward_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
74
- if self.kind == "hf_clip_projected":
75
- out = self.backbone(pixel_values=image)
76
- tokens = getattr(out, "last_hidden_state", None)
77
- if tokens is None and hasattr(out, "vision_model_output"):
78
- tokens = out.vision_model_output.last_hidden_state
79
- if tokens is None:
80
- raise RuntimeError("Projected CLIP vision output did not include patch tokens")
81
- return out.image_embeds, tokens
82
- if self.kind == "hf_vision":
83
- out = self.backbone(pixel_values=image)
84
- if hasattr(out, "pooler_output") and out.pooler_output is not None:
85
- pooled = out.pooler_output
86
- else:
87
- pooled = out.last_hidden_state[:, 0]
88
- return pooled, out.last_hidden_state
89
-
90
- if not hasattr(self.backbone, "forward_features"):
91
- pooled = self.backbone(image)
92
- return pooled, pooled[:, None, :]
93
- features = self.backbone.forward_features(image)
94
- if hasattr(self.backbone, "forward_head"):
95
- pooled = self.backbone.forward_head(features, pre_logits=False)
96
- else:
97
- pooled = self.backbone(image)
98
- return pooled, _tokens_from_features(features)
99
-
100
-
101
- class TextEncoder(nn.Module):
102
- def __init__(self, model_name: str, pretrained: bool = True, pooling: str = "auto") -> None:
103
- super().__init__()
104
- if pooling not in {"auto", "pooler", "cls", "mean"}:
105
- raise ValueError(f"Unsupported text pooling {pooling!r}; expected auto, pooler, cls, or mean")
106
- self.kind = "hf_text"
107
- self.pooling = pooling
108
- tokenizer_name = model_name.removeprefix("hf_clip_projected:")
109
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
110
- model_name_lower = model_name.lower()
111
- if model_name.startswith("hf_clip_projected:"):
112
- self.kind = "hf_clip_projected"
113
- projected_model_name = model_name.removeprefix("hf_clip_projected:")
114
- if pretrained:
115
- self.backbone = CLIPTextModelWithProjection.from_pretrained(projected_model_name)
116
- else:
117
- self.backbone = CLIPTextModelWithProjection(CLIPTextConfig.from_pretrained(projected_model_name))
118
- self.output_dim = self.backbone.config.projection_dim
119
- elif "siglip" in model_name_lower:
120
- if pretrained:
121
- self.backbone = SiglipTextModel.from_pretrained(model_name)
122
- else:
123
- self.backbone = SiglipTextModel(SiglipTextConfig.from_pretrained(model_name))
124
- self.output_dim = self.backbone.config.hidden_size
125
- elif "clip" in model_name_lower:
126
- if pretrained:
127
- self.backbone = CLIPTextModel.from_pretrained(model_name)
128
- else:
129
- self.backbone = CLIPTextModel(CLIPTextConfig.from_pretrained(model_name))
130
- self.output_dim = self.backbone.config.hidden_size
131
- else:
132
- if pretrained:
133
- self.backbone = AutoModel.from_pretrained(model_name)
134
- else:
135
- self.backbone = AutoModel.from_config(AutoConfig.from_pretrained(model_name))
136
- hidden_size = getattr(self.backbone.config, "hidden_size", None)
137
- if hidden_size is None:
138
- raise ValueError(f"Unsupported text model config for {model_name}")
139
- self.output_dim = hidden_size
140
-
141
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
142
- out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
143
- if self.kind == "hf_clip_projected":
144
- return out.text_embeds
145
- if self.pooling == "mean":
146
- mask = attention_mask.to(dtype=out.last_hidden_state.dtype).unsqueeze(-1)
147
- summed = (out.last_hidden_state * mask).sum(dim=1)
148
- denom = mask.sum(dim=1).clamp_min(1.0)
149
- return summed / denom
150
- if self.pooling in {"auto", "pooler"} and hasattr(out, "pooler_output") and out.pooler_output is not None:
151
- return out.pooler_output
152
- return out.last_hidden_state[:, 0]
153
-
154
-
155
- def _tokens_from_features(features: torch.Tensor | dict | tuple | list) -> torch.Tensor:
156
- if isinstance(features, dict):
157
- for key in ("x", "last_hidden_state", "features"):
158
- if key in features:
159
- features = features[key]
160
- break
161
- else:
162
- features = next(iter(features.values()))
163
- if isinstance(features, tuple | list):
164
- features = features[0]
165
- if not torch.is_tensor(features):
166
- raise TypeError(f"Expected tensor features, got {type(features)!r}")
167
- if features.ndim == 4:
168
- return features.flatten(2).transpose(1, 2)
169
- if features.ndim == 3:
170
- return features
171
- if features.ndim == 2:
172
- return features[:, None, :]
173
- raise ValueError(f"Unsupported feature tensor shape {tuple(features.shape)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/experimental.py DELETED
@@ -1,587 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections.abc import Callable
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import Tensor, nn
8
-
9
- from hyper3_clip.models.lorentz import exp_map0, metric_pairwise_dist
10
- from hyper3_clip.models.losses import beta_cal_loss
11
- from hyper3_clip.models.tren import TRENRegionEncoder
12
- from hyper3_clip.training.distributed import gather_variable_with_grad, gather_with_grad, get_rank
13
-
14
-
15
- ProjectionHeadFactory = Callable[[int, int, int | None], nn.Module]
16
-
17
-
18
- class ExperimentalObjectiveMixin:
19
- @staticmethod
20
- def _validate_experimental_options(
21
- *,
22
- proclip_geometry: str,
23
- proclip_projection_hidden_dim: int | None,
24
- proclip_component_dim: int | None,
25
- beta_clip_weight: float,
26
- beta_clip_global_weight: float,
27
- beta_clip_beta: float,
28
- beta_clip_variant: str,
29
- beta_clip_similarity: str,
30
- beta_clip_num_heads: int,
31
- beta_clip_mlp_ratio: float,
32
- tren_weight: float,
33
- tren_visual_distill_weight: float,
34
- tren_text_distill_weight: float,
35
- tren_region_text_weight: float,
36
- tren_num_region_tokens: int,
37
- tren_num_decoder_layers: int,
38
- tren_num_attention_heads: int,
39
- tren_prompt_grid_size: int,
40
- tren_dropout: float,
41
- ) -> None:
42
- if proclip_geometry not in {"product", "hyperbolic", "euclidean", "spherical", "clip"}:
43
- raise ValueError("proclip_geometry must be 'product', 'hyperbolic', 'euclidean', 'spherical', or 'clip'")
44
- if proclip_projection_hidden_dim is not None and proclip_projection_hidden_dim <= 0:
45
- raise ValueError("proclip_projection_hidden_dim must be positive when set")
46
- if proclip_component_dim is not None and proclip_component_dim <= 0:
47
- raise ValueError("proclip_component_dim must be positive when set")
48
- if beta_clip_variant not in {"ce", "bce"}:
49
- raise ValueError("beta_clip_variant must be 'ce' or 'bce'")
50
- if beta_clip_similarity not in {"metric", "dot"}:
51
- raise ValueError("beta_clip_similarity must be 'metric' or 'dot'")
52
- if beta_clip_weight < 0.0:
53
- raise ValueError("beta_clip_weight must be non-negative")
54
- if beta_clip_global_weight < 0.0:
55
- raise ValueError("beta_clip_global_weight must be non-negative")
56
- if beta_clip_beta < 0.0:
57
- raise ValueError("beta_clip_beta must be non-negative")
58
- if beta_clip_num_heads <= 0:
59
- raise ValueError("beta_clip_num_heads must be positive")
60
- if beta_clip_mlp_ratio <= 0.0:
61
- raise ValueError("beta_clip_mlp_ratio must be positive")
62
- if tren_weight < 0.0:
63
- raise ValueError("tren_weight must be non-negative")
64
- if tren_visual_distill_weight < 0.0 or tren_text_distill_weight < 0.0 or tren_region_text_weight < 0.0:
65
- raise ValueError("T-REN loss weights must be non-negative")
66
- if tren_num_region_tokens <= 0:
67
- raise ValueError("tren_num_region_tokens must be positive")
68
- if tren_num_decoder_layers <= 0:
69
- raise ValueError("tren_num_decoder_layers must be positive")
70
- if tren_num_attention_heads <= 0:
71
- raise ValueError("tren_num_attention_heads must be positive")
72
- if tren_prompt_grid_size <= 0:
73
- raise ValueError("tren_prompt_grid_size must be positive")
74
- if tren_dropout < 0.0:
75
- raise ValueError("tren_dropout must be non-negative")
76
-
77
- def _init_experimental_modules(
78
- self,
79
- *,
80
- beta_clip_num_heads: int,
81
- beta_clip_mlp_ratio: float,
82
- tren_num_region_tokens: int,
83
- tren_num_decoder_layers: int,
84
- tren_num_attention_heads: int,
85
- tren_prompt_grid_size: int,
86
- tren_dropout: float,
87
- projection_hidden_dim: int | None,
88
- proclip_projection_hidden_dim: int | None,
89
- projection_head: ProjectionHeadFactory,
90
- ) -> None:
91
- if self.beta_query_pooling_enabled:
92
- if self.vision_encoder.output_dim % beta_clip_num_heads != 0:
93
- raise ValueError("vision encoder output_dim must be divisible by beta_clip_num_heads")
94
- beta_clip_hidden_dim = max(1, int(round(self.vision_encoder.output_dim * beta_clip_mlp_ratio)))
95
- self.beta_clip_text_query_proj = nn.Linear(self.text_encoder.output_dim, self.vision_encoder.output_dim)
96
- self.beta_clip_cross_attention = nn.MultiheadAttention(
97
- self.vision_encoder.output_dim,
98
- beta_clip_num_heads,
99
- batch_first=True,
100
- )
101
- self.beta_clip_mlp_norm = nn.LayerNorm(self.vision_encoder.output_dim)
102
- self.beta_clip_pool_mlp = nn.Sequential(
103
- nn.Linear(self.vision_encoder.output_dim, beta_clip_hidden_dim),
104
- nn.GELU(),
105
- nn.Linear(beta_clip_hidden_dim, self.vision_encoder.output_dim),
106
- )
107
- if self.beta_clip_enabled:
108
- self.beta_clip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
109
- if self.tren_enabled:
110
- self.tren_region_encoder = TRENRegionEncoder(
111
- vision_dim=self.vision_encoder.output_dim,
112
- text_dim=self.text_encoder.output_dim,
113
- num_region_tokens=tren_num_region_tokens,
114
- num_decoder_layers=tren_num_decoder_layers,
115
- num_attention_heads=tren_num_attention_heads,
116
- prompt_grid_size=tren_prompt_grid_size,
117
- dropout=tren_dropout,
118
- )
119
- self.tren_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
120
- if self.proclip_enabled:
121
- component_dim = self._proclip_component_dim
122
- spherical_dim = self._proclip_spherical_ambient_dim
123
- proclip_hidden_dim = proclip_projection_hidden_dim
124
- if proclip_hidden_dim is None:
125
- proclip_hidden_dim = projection_hidden_dim
126
- if self.proclip_dedicated_hyperbolic:
127
- self.proclip_image_hyperbolic_proj = projection_head(
128
- self.vision_encoder.output_dim, self.embed_dim, proclip_hidden_dim
129
- )
130
- self.proclip_text_hyperbolic_proj = projection_head(
131
- self.text_encoder.output_dim, self.embed_dim, proclip_hidden_dim
132
- )
133
- self.proclip_image_euclidean_proj = projection_head(
134
- self.vision_encoder.output_dim, component_dim, proclip_hidden_dim
135
- )
136
- self.proclip_text_euclidean_proj = projection_head(
137
- self.text_encoder.output_dim, component_dim, proclip_hidden_dim
138
- )
139
- self.proclip_image_spherical_proj = projection_head(
140
- self.vision_encoder.output_dim, spherical_dim, proclip_hidden_dim
141
- )
142
- self.proclip_text_spherical_proj = projection_head(
143
- self.text_encoder.output_dim, spherical_dim, proclip_hidden_dim
144
- )
145
- self.proclip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
146
- self.proclip_log_weights = nn.Parameter(torch.zeros(3))
147
-
148
- @property
149
- def proclip_enabled(self) -> bool:
150
- return (
151
- self.objective_name == "proclip"
152
- or self.proclip_component_dim is not None
153
- or self.proclip_weight > 0.0
154
- or self.proclip_retrieval
155
- )
156
-
157
- @property
158
- def beta_clip_enabled(self) -> bool:
159
- return self.beta_clip_weight > 0.0
160
-
161
- @property
162
- def beta_query_pooling_enabled(self) -> bool:
163
- return self.beta_clip_enabled or (
164
- self.objective_name == "uncha"
165
- and self.uncha_entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}
166
- )
167
-
168
- @property
169
- def tren_enabled(self) -> bool:
170
- return self.tren_weight > 0.0
171
-
172
- @property
173
- def _proclip_component_dim(self) -> int:
174
- return int(self.proclip_component_dim or self.embed_dim)
175
-
176
- @property
177
- def _proclip_spherical_ambient_dim(self) -> int:
178
- return self._proclip_component_dim + 1
179
-
180
- def _clamp_experimental_logit_scales(self) -> None:
181
- if self.proclip_enabled:
182
- self.proclip_logit_scale.clamp_(max=4.6052)
183
- if self.beta_clip_enabled:
184
- self.beta_clip_logit_scale.clamp_(max=4.6052)
185
- if self.tren_enabled:
186
- self.tren_logit_scale.clamp_(max=4.6052)
187
-
188
- def _detached_experimental_logit_scales(self) -> dict[str, torch.Tensor]:
189
- logs = {}
190
- if self.proclip_enabled:
191
- logs.update(self._detached_proclip_logs())
192
- if self.beta_clip_enabled:
193
- logs["beta_clip_logit_scale"] = self.beta_clip_logit_scale.exp().detach()
194
- if self.tren_enabled:
195
- logs["tren_logit_scale"] = self.tren_logit_scale.exp().detach()
196
- return logs
197
-
198
- def _beta_clip_global_contrastive_loss(
199
- self,
200
- *,
201
- image_euc: torch.Tensor,
202
- text_euc: torch.Tensor,
203
- targets: torch.Tensor,
204
- ) -> torch.Tensor:
205
- image_feats = F.normalize(image_euc.float(), dim=-1)
206
- text_feats = F.normalize(text_euc.float(), dim=-1)
207
- all_image_feats = gather_with_grad(image_feats)
208
- all_text_feats = gather_with_grad(text_feats)
209
- if self.objective_name == "hycoclip":
210
- scale = self.logit_scale.exp().clamp(max=100.0)
211
- elif self.objective_name == "proclip":
212
- scale = self.proclip_logit_scale.exp().clamp(max=100.0)
213
- else:
214
- scale = self.global_logit_scale.exp().clamp(max=100.0)
215
- logits_i_t = image_feats @ all_text_feats.T * scale
216
- logits_t_i = text_feats @ all_image_feats.T * scale
217
- return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
218
-
219
- def _beta_query_entailment_embeddings(
220
- self,
221
- *,
222
- image_tokens: torch.Tensor,
223
- beta_query_input_ids: torch.Tensor | None,
224
- beta_query_attention_mask: torch.Tensor | None,
225
- beta_query_owner: torch.Tensor | None,
226
- beta_query_parent: torch.Tensor | None,
227
- beta_query_weight: torch.Tensor | None,
228
- beta_query_source_part: torch.Tensor | None,
229
- kappa: torch.Tensor,
230
- query_base: torch.Tensor | None = None,
231
- ) -> dict[str, torch.Tensor]:
232
- if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
233
- raise ValueError(f"{self.uncha_entailment_loss} requires beta query tensors from the collator")
234
- if beta_query_parent is None or beta_query_weight is None:
235
- raise ValueError(f"{self.uncha_entailment_loss} requires beta query hierarchy metadata from the collator")
236
- if self.uncha_entailment_loss == "hier_beta_sourcepart_argent" and beta_query_source_part is None:
237
- raise ValueError("hier_beta_sourcepart_argent requires beta_query_source_part from the collator")
238
- if beta_query_input_ids.shape[0] == 0:
239
- source_part = (
240
- beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)
241
- if beta_query_source_part is not None
242
- else beta_query_owner.new_zeros((0,), device=image_tokens.device, dtype=torch.long)
243
- )
244
- return {
245
- "beta_query_image_feats": image_tokens.new_zeros((0, self.embed_dim)),
246
- "beta_query_text_feats": image_tokens.new_zeros((0, self.embed_dim)),
247
- "beta_query_owner": beta_query_owner.to(device=image_tokens.device, dtype=torch.long),
248
- "beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
249
- "beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
250
- "beta_query_source_part": source_part,
251
- }
252
-
253
- query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
254
- if query_base is None:
255
- query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
256
- conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, query_owner)
257
- query_image_euc = self.image_proj(conditioned_image_base)
258
- query_text_euc = self.text_proj(query_base)
259
- return {
260
- "beta_query_image_feats": self.project_image_features(query_image_euc),
261
- "beta_query_text_feats": self.project_text_features(query_text_euc),
262
- "beta_query_owner": query_owner,
263
- "beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
264
- "beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
265
- **(
266
- {"beta_query_source_part": beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)}
267
- if beta_query_source_part is not None
268
- else {}
269
- ),
270
- }
271
-
272
- def _beta_clip_auxiliary_loss(
273
- self,
274
- *,
275
- image_tokens: torch.Tensor,
276
- beta_query_input_ids: torch.Tensor | None,
277
- beta_query_attention_mask: torch.Tensor | None,
278
- beta_query_owner: torch.Tensor | None,
279
- global_targets: torch.Tensor,
280
- kappa: torch.Tensor,
281
- ) -> torch.Tensor:
282
- if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
283
- raise ValueError("beta-CLIP auxiliary requires beta query tensors from the collator")
284
- if beta_query_input_ids.shape[0] == 0:
285
- return image_tokens.new_zeros(())
286
-
287
- beta_query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
288
- query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
289
- conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, beta_query_owner)
290
- query_image_euc = self.image_proj(conditioned_image_base)
291
- query_text_euc = self.text_proj(query_base)
292
-
293
- if self.beta_clip_similarity == "dot":
294
- query_image_feats = F.normalize(query_image_euc.float(), dim=-1)
295
- query_text_feats = F.normalize(query_text_euc.float(), dim=-1)
296
- else:
297
- query_image_feats = self.project_image_features(query_image_euc)
298
- query_text_feats = self.project_text_features(query_text_euc)
299
-
300
- all_query_image_feats, query_counts = gather_variable_with_grad(query_image_feats)
301
- all_query_text_feats, _ = gather_variable_with_grad(query_text_feats)
302
- query_offset = query_counts[: get_rank()].sum() if query_counts.numel() > 1 else query_counts.new_zeros(())
303
- query_targets = torch.arange(query_image_feats.size(0), device=query_image_feats.device) + query_offset
304
- query_group_ids = global_targets.index_select(0, beta_query_owner)
305
- all_query_group_ids, _ = gather_variable_with_grad(query_group_ids)
306
-
307
- scale = self.beta_clip_logit_scale.exp().clamp(max=100.0)
308
- if self.beta_clip_similarity == "dot":
309
- logits_i_t = query_image_feats @ all_query_text_feats.T * scale
310
- logits_t_i = query_text_feats @ all_query_image_feats.T * scale
311
- else:
312
- logits_i_t = -metric_pairwise_dist(
313
- query_image_feats,
314
- all_query_text_feats,
315
- kappa,
316
- product_metric=self.phyclip_product_metric,
317
- ) * scale
318
- logits_t_i = -metric_pairwise_dist(
319
- query_text_feats,
320
- all_query_image_feats,
321
- kappa,
322
- product_metric=self.phyclip_product_metric,
323
- ) * scale
324
- return 0.5 * (
325
- beta_cal_loss(
326
- logits_i_t,
327
- targets=query_targets,
328
- group_ids=query_group_ids,
329
- all_group_ids=all_query_group_ids,
330
- beta=self.beta_clip_beta,
331
- variant=self.beta_clip_variant,
332
- )
333
- + beta_cal_loss(
334
- logits_t_i,
335
- targets=query_targets,
336
- group_ids=query_group_ids,
337
- all_group_ids=all_query_group_ids,
338
- beta=self.beta_clip_beta,
339
- variant=self.beta_clip_variant,
340
- )
341
- )
342
-
343
- def _beta_clip_text_conditioned_pool(
344
- self,
345
- image_tokens: torch.Tensor,
346
- query_base: torch.Tensor,
347
- query_owner: torch.Tensor,
348
- ) -> torch.Tensor:
349
- if image_tokens.ndim != 3:
350
- raise ValueError("beta-CLIP image tokens must have shape [batch, tokens, dim]")
351
- if getattr(self, "group_beta_query_pooling", False):
352
- return self._beta_clip_text_conditioned_pool_grouped(image_tokens, query_base, query_owner)
353
- if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1:
354
- image_tokens = image_tokens[:, 1:, :]
355
- selected_tokens = image_tokens.index_select(0, query_owner).to(dtype=query_base.dtype)
356
- query = self.beta_clip_text_query_proj(query_base).unsqueeze(1)
357
- attended, _ = self.beta_clip_cross_attention(query, selected_tokens, selected_tokens, need_weights=False)
358
- pooled = attended.squeeze(1)
359
- return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
360
-
361
- def _beta_clip_text_conditioned_pool_grouped(
362
- self,
363
- image_tokens: torch.Tensor,
364
- query_base: torch.Tensor,
365
- query_owner: torch.Tensor,
366
- ) -> torch.Tensor:
367
- if query_owner.numel() == 0:
368
- return query_base.new_zeros((0, self.vision_encoder.output_dim))
369
- if query_owner.min().item() < 0 or query_owner.max().item() >= image_tokens.size(0):
370
- raise IndexError("beta_query_owner contains an out-of-range image index")
371
-
372
- tokens = image_tokens[:, 1:, :] if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1 else image_tokens
373
- tokens = tokens.to(dtype=query_base.dtype)
374
- query_projected = self.beta_clip_text_query_proj(query_base)
375
- counts = torch.bincount(query_owner, minlength=image_tokens.size(0))
376
- max_queries = int(counts.max().item())
377
-
378
- order = torch.argsort(query_owner)
379
- sorted_owner = query_owner.index_select(0, order)
380
- owner_offsets = torch.zeros_like(counts)
381
- owner_offsets[1:] = counts.cumsum(0)[:-1]
382
- sorted_positions = torch.arange(query_owner.numel(), device=query_owner.device) - owner_offsets.index_select(
383
- 0, sorted_owner
384
- )
385
- positions = torch.empty_like(sorted_positions)
386
- positions[order] = sorted_positions
387
-
388
- packed_query = query_projected.new_zeros((image_tokens.size(0), max_queries, query_projected.size(-1)))
389
- packed_query[query_owner, positions] = query_projected
390
- attended, _ = self.beta_clip_cross_attention(packed_query, tokens, tokens, need_weights=False)
391
- pooled = attended[query_owner, positions]
392
- return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
393
-
394
- def _tren_auxiliary_losses(
395
- self,
396
- *,
397
- image_tokens: torch.Tensor,
398
- part_owner: torch.Tensor,
399
- part_image_base: torch.Tensor,
400
- part_text_base: torch.Tensor,
401
- ) -> dict[str, torch.Tensor]:
402
- zero = image_tokens.new_zeros(())
403
- if part_owner.numel() == 0:
404
- return {
405
- "tren_loss": zero,
406
- "tren_visual_distill_loss": zero,
407
- "tren_text_distill_loss": zero,
408
- "tren_region_text_contrastive_loss": zero,
409
- "tren_assignment_count": part_owner.new_tensor(0),
410
- }
411
-
412
- tren_outputs = self.tren_region_encoder(image_tokens)
413
- visual_tokens = tren_outputs["visual_tokens"].flatten(1, 2)
414
- text_tokens = tren_outputs["text_aligned_tokens"].flatten(1, 2)
415
-
416
- matched_visual: list[torch.Tensor] = []
417
- matched_text: list[torch.Tensor] = []
418
- target_visual: list[torch.Tensor] = []
419
- target_text: list[torch.Tensor] = []
420
- for owner in range(image_tokens.size(0)):
421
- region_mask = part_owner == owner
422
- if not bool(region_mask.any()):
423
- continue
424
- owner_target_visual = part_image_base[region_mask].detach()
425
- owner_target_text = part_text_base[region_mask].detach()
426
- owner_visual_tokens = visual_tokens[owner]
427
- owner_text_tokens = text_tokens[owner]
428
- pred_indices, target_indices = _greedy_region_assignment(owner_visual_tokens, owner_target_visual)
429
- if pred_indices.numel() == 0:
430
- continue
431
- matched_visual.append(owner_visual_tokens.index_select(0, pred_indices))
432
- matched_text.append(owner_text_tokens.index_select(0, pred_indices))
433
- target_visual.append(owner_target_visual.index_select(0, target_indices))
434
- target_text.append(owner_target_text.index_select(0, target_indices))
435
-
436
- if not matched_visual:
437
- return {
438
- "tren_loss": zero,
439
- "tren_visual_distill_loss": zero,
440
- "tren_text_distill_loss": zero,
441
- "tren_region_text_contrastive_loss": zero,
442
- "tren_assignment_count": part_owner.new_tensor(0),
443
- }
444
-
445
- matched_visual_tensor = torch.cat(matched_visual, dim=0)
446
- matched_text_tensor = torch.cat(matched_text, dim=0)
447
- target_visual_tensor = torch.cat(target_visual, dim=0)
448
- target_text_tensor = torch.cat(target_text, dim=0)
449
- visual_distill = 1.0 - F.cosine_similarity(matched_visual_tensor, target_visual_tensor, dim=-1).mean()
450
- text_distill = 1.0 - F.cosine_similarity(matched_text_tensor, target_text_tensor, dim=-1).mean()
451
- region_text = _symmetric_dot_contrastive(
452
- matched_text_tensor,
453
- target_text_tensor,
454
- scale=self.tren_logit_scale.exp().clamp(max=100.0),
455
- )
456
- total = (
457
- self.tren_visual_distill_weight * visual_distill
458
- + self.tren_text_distill_weight * text_distill
459
- + self.tren_region_text_weight * region_text
460
- )
461
- return {
462
- "tren_loss": total,
463
- "tren_visual_distill_loss": visual_distill,
464
- "tren_text_distill_loss": text_distill,
465
- "tren_region_text_contrastive_loss": region_text,
466
- "tren_assignment_count": part_owner.new_tensor(matched_visual_tensor.size(0)),
467
- }
468
-
469
- def _project_proclip_image_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
470
- if self.proclip_geometry == "clip":
471
- return F.normalize(base_feats.float(), dim=-1)
472
- if self.proclip_dedicated_hyperbolic:
473
- hyperbolic = exp_map0(self.proclip_image_hyperbolic_proj(base_feats.float()), self._kappa().float())
474
- return self._pack_proclip_features(
475
- hyperbolic=hyperbolic,
476
- euclidean=self.proclip_image_euclidean_proj(base_feats.float()),
477
- spherical=self.proclip_image_spherical_proj(base_feats.float()),
478
- )
479
-
480
- def _project_proclip_text_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
481
- if self.proclip_geometry == "clip":
482
- return F.normalize(base_feats.float(), dim=-1)
483
- if self.proclip_dedicated_hyperbolic:
484
- hyperbolic = exp_map0(self.proclip_text_hyperbolic_proj(base_feats.float()), self._kappa().float())
485
- return self._pack_proclip_features(
486
- hyperbolic=hyperbolic,
487
- euclidean=self.proclip_text_euclidean_proj(base_feats.float()),
488
- spherical=self.proclip_text_spherical_proj(base_feats.float()),
489
- )
490
-
491
- def _pack_proclip_features(self, hyperbolic: torch.Tensor, euclidean: torch.Tensor, spherical: torch.Tensor) -> torch.Tensor:
492
- spherical = F.normalize(spherical.float(), dim=-1)
493
- if self.proclip_geometry == "hyperbolic":
494
- return hyperbolic.float()
495
- if self.proclip_geometry == "euclidean":
496
- return euclidean.float()
497
- if self.proclip_geometry == "spherical":
498
- return spherical
499
- return torch.cat([hyperbolic.float(), euclidean.float(), spherical], dim=-1)
500
-
501
- def _split_proclip_features(self, feats: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
502
- hyperbolic_dim = self.embed_dim + 1
503
- component_dim = self._proclip_component_dim
504
- spherical_dim = self._proclip_spherical_ambient_dim
505
- hyperbolic = feats[:, :hyperbolic_dim]
506
- euclidean = feats[:, hyperbolic_dim : hyperbolic_dim + component_dim]
507
- spherical = feats[:, hyperbolic_dim + component_dim : hyperbolic_dim + component_dim + spherical_dim]
508
- return hyperbolic, euclidean, spherical
509
-
510
- def _proclip_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
511
- if self.proclip_geometry == "clip":
512
- return image_feats.float() @ text_feats.float().T
513
- if self.proclip_geometry == "hyperbolic":
514
- return -metric_pairwise_dist(image_feats, text_feats, self._kappa()).square()
515
- if self.proclip_geometry == "euclidean":
516
- return -torch.cdist(image_feats.float(), text_feats.float(), p=2).square()
517
- if self.proclip_geometry == "spherical":
518
- dot = (image_feats.float() @ text_feats.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
519
- return -torch.acos(dot).square()
520
- image_hyp, image_euc, image_sph = self._split_proclip_features(image_feats)
521
- text_hyp, text_euc, text_sph = self._split_proclip_features(text_feats)
522
- weights = self.proclip_log_weights.exp().to(device=image_feats.device, dtype=torch.float32)
523
- hyperbolic_dist2 = metric_pairwise_dist(image_hyp, text_hyp, self._kappa()).square()
524
- euclidean_dist2 = torch.cdist(image_euc.float(), text_euc.float(), p=2).square()
525
- spherical_dot = (image_sph.float() @ text_sph.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
526
- spherical_dist2 = torch.acos(spherical_dot).square()
527
- return -(weights[0] * hyperbolic_dist2 + weights[1] * euclidean_dist2 + weights[2] * spherical_dist2)
528
-
529
- def _proclip_contrastive_loss(
530
- self,
531
- image_feats: torch.Tensor,
532
- text_feats: torch.Tensor,
533
- all_image_feats: torch.Tensor,
534
- all_text_feats: torch.Tensor,
535
- targets: torch.Tensor,
536
- ) -> torch.Tensor:
537
- scale = self.proclip_logit_scale.exp().clamp(max=100.0)
538
- logits_i_t = self._proclip_similarity_scores(image_feats, all_text_feats) * scale
539
- logits_t_i = self._proclip_similarity_scores(text_feats, all_image_feats) * scale
540
- return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
541
-
542
- def _detached_proclip_logs(self) -> dict[str, torch.Tensor]:
543
- weights = self.proclip_log_weights.exp().detach()
544
- return {
545
- "proclip_logit_scale": self.proclip_logit_scale.exp().detach(),
546
- "proclip_hyperbolic_weight": weights[0],
547
- "proclip_euclidean_weight": weights[1],
548
- "proclip_spherical_weight": weights[2],
549
- }
550
-
551
-
552
- def _greedy_region_assignment(pred_tokens: torch.Tensor, target_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
553
- if pred_tokens.numel() == 0 or target_tokens.numel() == 0:
554
- empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
555
- return empty, empty
556
- similarities = F.normalize(pred_tokens.float(), dim=-1) @ F.normalize(target_tokens.float(), dim=-1).T
557
- pair_scores = similarities.flatten()
558
- order = torch.argsort(pair_scores, descending=True)
559
- used_pred = torch.zeros(pred_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
560
- used_target = torch.zeros(target_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
561
- pred_indices: list[torch.Tensor] = []
562
- target_indices: list[torch.Tensor] = []
563
- for flat_index in order:
564
- pred_index = torch.div(flat_index, target_tokens.size(0), rounding_mode="floor")
565
- target_index = flat_index % target_tokens.size(0)
566
- if used_pred[pred_index] or used_target[target_index]:
567
- continue
568
- used_pred[pred_index] = True
569
- used_target[target_index] = True
570
- pred_indices.append(pred_index)
571
- target_indices.append(target_index)
572
- if len(target_indices) == target_tokens.size(0):
573
- break
574
- if not pred_indices:
575
- empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
576
- return empty, empty
577
- return torch.stack(pred_indices), torch.stack(target_indices)
578
-
579
-
580
- def _symmetric_dot_contrastive(region_tokens: torch.Tensor, text_tokens: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
581
- if region_tokens.size(0) == 1:
582
- return region_tokens.new_zeros(())
583
- region_tokens = F.normalize(region_tokens.float(), dim=-1)
584
- text_tokens = F.normalize(text_tokens.float(), dim=-1)
585
- logits = region_tokens @ text_tokens.T * scale
586
- targets = torch.arange(logits.size(0), device=logits.device)
587
- return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/himo.py DELETED
@@ -1,55 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
-
7
- def hide_reconstruct_embeddings(
8
- embeddings: Tensor,
9
- *,
10
- variance_threshold: float = 0.9,
11
- detach_pca: bool = True,
12
- eps: float = 1e-8,
13
- ) -> Tensor:
14
- """HiMo-CLIP HiDe: PCA-reconstruct embeddings using top principal components.
15
-
16
- Given a batch of embeddings ``U ∈ R^{B×D}``, compute mean-centered embeddings,
17
- perform SVD/PCA, choose the smallest number of components whose cumulative
18
- explained variance exceeds ``variance_threshold``, and reconstruct each
19
- embedding from this principal subspace:
20
-
21
- u'_i = P^T (P (u_i - ū)) + ū
22
-
23
- where P stacks the selected principal components as rows.
24
- """
25
- if embeddings.ndim != 2:
26
- raise ValueError("hide_reconstruct_embeddings expects a [batch, dim] tensor")
27
- if not (0.0 < variance_threshold <= 1.0):
28
- raise ValueError("variance_threshold must be in (0, 1]")
29
- if embeddings.size(0) < 2:
30
- return embeddings
31
-
32
- u = embeddings.to(dtype=torch.float32)
33
- mean = u.mean(dim=0, keepdim=True)
34
- centered = u - mean
35
- if detach_pca:
36
- centered_for_pca = centered.detach()
37
- else:
38
- centered_for_pca = centered
39
-
40
- # SVD: centered = U S Vh, principal components are rows of Vh.
41
- _, s, vh = torch.linalg.svd(centered_for_pca, full_matrices=False)
42
- if s.numel() == 0 or float((s.square().sum()).item()) <= eps:
43
- return embeddings
44
-
45
- explained = s.square()
46
- cumulative = explained.cumsum(dim=0) / explained.sum().clamp_min(eps)
47
- m = int((cumulative >= variance_threshold).to(dtype=torch.int64).argmax().item()) + 1
48
- m = max(1, min(m, vh.size(0)))
49
- p = vh[:m]
50
- if detach_pca:
51
- p = p.detach()
52
-
53
- recon = (centered @ p.T) @ p + mean
54
- return recon.to(dtype=embeddings.dtype)
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/hyper3_clip.py DELETED
@@ -1,958 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
- from hyper3_clip.models.encoders import TextEncoder, VisionEncoder
8
- from hyper3_clip.models.experimental import ExperimentalObjectiveMixin
9
- from hyper3_clip.models.himo import hide_reconstruct_embeddings
10
- from hyper3_clip.models.lorentz import exp_map0, metric_similarity
11
- from hyper3_clip.models.objectives import build_objective
12
- from hyper3_clip.training.distributed import (
13
- gather_with_grad,
14
- get_rank,
15
- get_world_size,
16
- local_target_indices,
17
- )
18
-
19
-
20
- class Hyper3CLIP(ExperimentalObjectiveMixin, nn.Module):
21
- def __init__(
22
- self,
23
- vision_backbone: str,
24
- text_model_name: str,
25
- embed_dim: int,
26
- curv_init: float,
27
- learn_curv: bool,
28
- entail_weight: float,
29
- inter_aperture_scale: float,
30
- intra_aperture_scale: float,
31
- objective: str = "hycoclip",
32
- uncha_piecewise_factor: float = 0.1,
33
- uncha_calibration_alpha: float = 10.0,
34
- uncha_stop_grad_calibration: bool = True,
35
- vision_pretrained: bool = True,
36
- text_pretrained: bool = True,
37
- text_pooling: str = "auto",
38
- freeze_vision_encoder: bool = False,
39
- freeze_text_encoder: bool = False,
40
- normalize_encoder_features: bool = False,
41
- projection_hidden_dim: int | None = None,
42
- uncha_entailment_geometry: str = "lorentz",
43
- uncha_aggregate_weight: float = 0.0,
44
- uncha_entailment_loss: str = "piecewise",
45
- uncha_argent_beta: float = 1.0,
46
- uncha_argent_norm_weight: float = 0.0,
47
- uncha_argent_aux_weight: float = 0.5,
48
- uncha_argent_aggregation: str = "uncha",
49
- uncha_part_weight_power: float = 0.0,
50
- uncha_contrastive_loss: str = "ce",
51
- uncha_sigmoid_bias_init: float = -10.0,
52
- uncha_sigmoid_negative_weight: float = 1.0,
53
- uncha_part_quality_mode: str = "none",
54
- uncha_part_quality_topk: int = 5,
55
- uncha_part_quality_temperature: float = 4.0,
56
- uncha_entailment_warmup_steps: int = 0,
57
- uncha_contrastive_global_weight: float = 1.0,
58
- uncha_contrastive_local_weight: float = 1.0,
59
- uncha_contrastive_global_local_weight: float = 1.0,
60
- uncha_global_local_mode: str = "repeat",
61
- uncha_global_local_metric: str = "distance",
62
- uncha_global_local_angle_aux_weight: float = 0.0,
63
- uncha_global_local_angle_aux_mode: str = "contrastive",
64
- uncha_global_local_angle_aux_scale: float = 5.5,
65
- uncha_global_local_angle_aux_aperture_scale: float = 1.0,
66
- uncha_beta_cal_beta: float = 0.0,
67
- uncha_beta_cal_variant: str = "ce",
68
- uncha_beta_cal_weight: float = 0.0,
69
- uncha_himo_component_weight: float = 0.0,
70
- uncha_himo_variance_threshold: float = 0.9,
71
- uncha_himo_detach_pca: bool = True,
72
- uncha_radius_order_weight: float = 0.0,
73
- uncha_radius_order_margin: float = 0.0,
74
- uncha_gramian_align_weight: float = 0.0,
75
- phyclip_subspace_dim: int | None = None,
76
- phyclip_product_metric: str = "l1",
77
- proclip_weight: float = 0.0,
78
- proclip_component_dim: int | None = None,
79
- proclip_retrieval: bool = False,
80
- proclip_geometry: str = "product",
81
- proclip_dedicated_hyperbolic: bool = False,
82
- proclip_projection_hidden_dim: int | None = None,
83
- beta_clip_weight: float = 0.0,
84
- beta_clip_global_weight: float = 0.0,
85
- beta_clip_beta: float = 0.5,
86
- beta_clip_variant: str = "ce",
87
- beta_clip_similarity: str = "metric",
88
- beta_clip_num_heads: int = 8,
89
- beta_clip_mlp_ratio: float = 4.0,
90
- beta_clip_drop_cls_token: bool = True,
91
- tren_weight: float = 0.0,
92
- tren_visual_distill_weight: float = 1.0,
93
- tren_text_distill_weight: float = 1.0,
94
- tren_region_text_weight: float = 1.0,
95
- tren_num_region_tokens: int = 3,
96
- tren_num_decoder_layers: int = 2,
97
- tren_num_attention_heads: int = 8,
98
- tren_prompt_grid_size: int = 7,
99
- tren_dropout: float = 0.1,
100
- fuse_whole_part_encoder_forwards: bool = False,
101
- fuse_beta_query_encoder_forwards: bool = False,
102
- group_beta_query_pooling: bool = False,
103
- objective_autocast_dtype: str = "float32",
104
- ) -> None:
105
- super().__init__()
106
- if objective not in {"hycoclip", "uncha", "proclip"}:
107
- raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip', 'uncha', or 'proclip'")
108
- if phyclip_product_metric not in {"l1", "l2"}:
109
- raise ValueError("phyclip_product_metric must be 'l1' or 'l2'")
110
- self._validate_experimental_options(
111
- proclip_geometry=proclip_geometry,
112
- proclip_projection_hidden_dim=proclip_projection_hidden_dim,
113
- proclip_component_dim=proclip_component_dim,
114
- beta_clip_weight=beta_clip_weight,
115
- beta_clip_global_weight=beta_clip_global_weight,
116
- beta_clip_beta=beta_clip_beta,
117
- beta_clip_variant=beta_clip_variant,
118
- beta_clip_similarity=beta_clip_similarity,
119
- beta_clip_num_heads=beta_clip_num_heads,
120
- beta_clip_mlp_ratio=beta_clip_mlp_ratio,
121
- tren_weight=tren_weight,
122
- tren_visual_distill_weight=tren_visual_distill_weight,
123
- tren_text_distill_weight=tren_text_distill_weight,
124
- tren_region_text_weight=tren_region_text_weight,
125
- tren_num_region_tokens=tren_num_region_tokens,
126
- tren_num_decoder_layers=tren_num_decoder_layers,
127
- tren_num_attention_heads=tren_num_attention_heads,
128
- tren_prompt_grid_size=tren_prompt_grid_size,
129
- tren_dropout=tren_dropout,
130
- )
131
- if objective_autocast_dtype not in {"float32", "fp32", "float16", "fp16", "bfloat16", "bf16"}:
132
- raise ValueError("objective_autocast_dtype must be one of 'float32', 'float16', or 'bfloat16'")
133
- if uncha_contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
134
- raise ValueError("uncha_contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
135
- if uncha_global_local_metric not in {"distance", "angle"}:
136
- raise ValueError("uncha_global_local_metric must be 'distance' or 'angle'")
137
- if uncha_global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
138
- raise ValueError("uncha_global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
139
- if uncha_global_local_angle_aux_weight < 0.0:
140
- raise ValueError("uncha_global_local_angle_aux_weight must be non-negative")
141
- if uncha_global_local_angle_aux_scale <= 0.0:
142
- raise ValueError("uncha_global_local_angle_aux_scale must be positive")
143
- if uncha_global_local_angle_aux_aperture_scale <= 0.0:
144
- raise ValueError("uncha_global_local_angle_aux_aperture_scale must be positive")
145
- if uncha_entailment_warmup_steps < 0:
146
- raise ValueError("uncha_entailment_warmup_steps must be non-negative")
147
- self.objective_name = objective
148
- self.uncha_contrastive_loss = uncha_contrastive_loss
149
- self.uncha_entailment_loss = uncha_entailment_loss
150
- self.uncha_entailment_warmup_steps = uncha_entailment_warmup_steps
151
- self.uncha_himo_component_weight = float(uncha_himo_component_weight)
152
- self.uncha_himo_variance_threshold = float(uncha_himo_variance_threshold)
153
- self.uncha_himo_detach_pca = bool(uncha_himo_detach_pca)
154
- self.proclip_weight = float(proclip_weight)
155
- self.proclip_retrieval = bool(proclip_retrieval)
156
- self.proclip_geometry = proclip_geometry
157
- self.proclip_dedicated_hyperbolic = bool(proclip_dedicated_hyperbolic)
158
- self.beta_clip_weight = float(beta_clip_weight)
159
- self.beta_clip_global_weight = float(beta_clip_global_weight)
160
- self.beta_clip_beta = float(beta_clip_beta)
161
- self.beta_clip_variant = beta_clip_variant
162
- self.beta_clip_similarity = beta_clip_similarity
163
- self.beta_clip_drop_cls_token = bool(beta_clip_drop_cls_token)
164
- self.tren_weight = float(tren_weight)
165
- self.tren_visual_distill_weight = float(tren_visual_distill_weight)
166
- self.tren_text_distill_weight = float(tren_text_distill_weight)
167
- self.tren_region_text_weight = float(tren_region_text_weight)
168
- self.fuse_whole_part_encoder_forwards = bool(fuse_whole_part_encoder_forwards)
169
- self.fuse_beta_query_encoder_forwards = bool(fuse_beta_query_encoder_forwards)
170
- self.group_beta_query_pooling = bool(group_beta_query_pooling)
171
- self.objective_autocast_dtype = objective_autocast_dtype
172
- self.freeze_vision_encoder = bool(freeze_vision_encoder)
173
- self.freeze_text_encoder = bool(freeze_text_encoder)
174
- self.normalize_encoder_features = bool(normalize_encoder_features)
175
- self.phyclip_subspace_dim = phyclip_subspace_dim
176
- self.phyclip_product_metric = phyclip_product_metric
177
- self.proclip_component_dim = proclip_component_dim
178
- if projection_hidden_dim is not None and projection_hidden_dim <= 0:
179
- raise ValueError("projection_hidden_dim must be positive when set")
180
- if self.proclip_enabled and phyclip_subspace_dim is not None:
181
- raise ValueError("ProCLIP mixed-curvature proxy cannot be combined with PHyCLIP Lorentz factors")
182
- if phyclip_subspace_dim is not None:
183
- if phyclip_subspace_dim <= 0:
184
- raise ValueError("phyclip_subspace_dim must be positive when set")
185
- if embed_dim % phyclip_subspace_dim != 0:
186
- raise ValueError("embed_dim must be divisible by phyclip_subspace_dim")
187
- self.phyclip_num_factors = embed_dim // phyclip_subspace_dim
188
- else:
189
- self.phyclip_num_factors = 0
190
- self.vision_encoder = VisionEncoder(vision_backbone, pretrained=vision_pretrained)
191
- self.text_encoder = TextEncoder(text_model_name, pretrained=text_pretrained, pooling=text_pooling)
192
- self.tokenizer = self.text_encoder.tokenizer
193
- self.embed_dim = embed_dim
194
- if self.freeze_vision_encoder:
195
- self.vision_encoder.requires_grad_(False)
196
- self.vision_encoder.eval()
197
- if self.freeze_text_encoder:
198
- self.text_encoder.requires_grad_(False)
199
- self.text_encoder.eval()
200
-
201
- self.image_proj = _projection_head(self.vision_encoder.output_dim, embed_dim, projection_hidden_dim)
202
- self.text_proj = _projection_head(self.text_encoder.output_dim, embed_dim, projection_hidden_dim)
203
- self._init_experimental_modules(
204
- beta_clip_num_heads=beta_clip_num_heads,
205
- beta_clip_mlp_ratio=beta_clip_mlp_ratio,
206
- tren_num_region_tokens=tren_num_region_tokens,
207
- tren_num_decoder_layers=tren_num_decoder_layers,
208
- tren_num_attention_heads=tren_num_attention_heads,
209
- tren_prompt_grid_size=tren_prompt_grid_size,
210
- tren_dropout=tren_dropout,
211
- projection_hidden_dim=projection_hidden_dim,
212
- proclip_projection_hidden_dim=proclip_projection_hidden_dim,
213
- projection_head=_projection_head,
214
- )
215
-
216
- if objective == "hycoclip":
217
- self.logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
218
- elif objective == "uncha":
219
- self.global_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
220
- self.local_logit_scale = nn.Parameter(torch.tensor(1 / 0.05).log())
221
- self.global_local_logit_scale = nn.Parameter(torch.tensor(1 / 0.06).log())
222
- if uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
223
- self.global_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
224
- self.local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
225
- self.global_local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
226
- alpha_dim = phyclip_subspace_dim or embed_dim
227
- alpha_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
228
- self.visual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
229
- self.textual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
230
-
231
- curv_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
232
- log_curv = torch.full(curv_shape, curv_init).log()
233
- self.log_curv = nn.Parameter(log_curv, requires_grad=learn_curv)
234
- self.curv_min = curv_init / 10.0
235
- self.curv_max = curv_init * 10.0
236
- self.objective = None
237
- if objective != "proclip":
238
- self.objective = build_objective(
239
- objective=objective,
240
- entail_weight=entail_weight,
241
- inter_aperture_scale=inter_aperture_scale,
242
- intra_aperture_scale=intra_aperture_scale,
243
- uncha_piecewise_factor=uncha_piecewise_factor,
244
- uncha_calibration_alpha=uncha_calibration_alpha,
245
- uncha_stop_grad_calibration=uncha_stop_grad_calibration,
246
- uncha_entailment_geometry=uncha_entailment_geometry,
247
- uncha_aggregate_weight=uncha_aggregate_weight,
248
- uncha_entailment_loss=uncha_entailment_loss,
249
- uncha_argent_beta=uncha_argent_beta,
250
- uncha_argent_norm_weight=uncha_argent_norm_weight,
251
- uncha_argent_aux_weight=uncha_argent_aux_weight,
252
- uncha_argent_aggregation=uncha_argent_aggregation,
253
- uncha_part_weight_power=uncha_part_weight_power,
254
- uncha_contrastive_loss=uncha_contrastive_loss,
255
- uncha_sigmoid_negative_weight=uncha_sigmoid_negative_weight,
256
- uncha_part_quality_mode=uncha_part_quality_mode,
257
- uncha_part_quality_topk=uncha_part_quality_topk,
258
- uncha_part_quality_temperature=uncha_part_quality_temperature,
259
- uncha_contrastive_global_weight=uncha_contrastive_global_weight,
260
- uncha_contrastive_local_weight=uncha_contrastive_local_weight,
261
- uncha_contrastive_global_local_weight=uncha_contrastive_global_local_weight,
262
- uncha_global_local_mode=uncha_global_local_mode,
263
- uncha_global_local_metric=uncha_global_local_metric,
264
- uncha_global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
265
- uncha_global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
266
- uncha_global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
267
- uncha_global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
268
- uncha_beta_cal_beta=uncha_beta_cal_beta,
269
- uncha_beta_cal_variant=uncha_beta_cal_variant,
270
- uncha_beta_cal_weight=uncha_beta_cal_weight,
271
- uncha_himo_component_weight=uncha_himo_component_weight,
272
- uncha_radius_order_weight=uncha_radius_order_weight,
273
- uncha_radius_order_margin=uncha_radius_order_margin,
274
- uncha_gramian_align_weight=uncha_gramian_align_weight,
275
- product_metric=phyclip_product_metric,
276
- )
277
-
278
- def train(self, mode: bool = True) -> Hyper3CLIP:
279
- super().train(mode)
280
- if self.freeze_vision_encoder:
281
- self.vision_encoder.eval()
282
- if self.freeze_text_encoder:
283
- self.text_encoder.eval()
284
- return self
285
-
286
- @property
287
- def phyclip_enabled(self) -> bool:
288
- return self.phyclip_subspace_dim is not None
289
-
290
- def _kappa(self) -> torch.Tensor:
291
- return self.log_curv.exp().clamp(min=self.curv_min, max=self.curv_max)
292
-
293
- def encode_image(self, image: torch.Tensor, project: bool = True) -> torch.Tensor:
294
- feats = self.image_proj(self.encode_image_base(image))
295
- if not project:
296
- return feats
297
- return self.project_image_features(feats)
298
-
299
- def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, project: bool = True) -> torch.Tensor:
300
- feats = self.text_proj(self.encode_text_base(input_ids, attention_mask))
301
- if not project:
302
- return feats
303
- return self.project_text_features(feats)
304
-
305
- def encode_image_base(self, image: torch.Tensor) -> torch.Tensor:
306
- with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
307
- feats = self.vision_encoder(image)
308
- feats = feats.detach() if self.freeze_vision_encoder else feats
309
- return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
310
-
311
- def encode_image_base_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
312
- with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
313
- feats, tokens = self.vision_encoder.forward_with_tokens(image)
314
- if self.freeze_vision_encoder:
315
- feats = feats.detach()
316
- tokens = tokens.detach()
317
- if self.normalize_encoder_features:
318
- feats = F.normalize(feats.float(), dim=-1)
319
- return feats, tokens
320
-
321
- def encode_text_base(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
322
- with torch.set_grad_enabled(self.training and not self.freeze_text_encoder):
323
- feats = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
324
- feats = feats.detach() if self.freeze_text_encoder else feats
325
- return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
326
-
327
- def project_image_features(self, feats: torch.Tensor) -> torch.Tensor:
328
- if self.phyclip_enabled:
329
- return self._project_product_features(feats, self.visual_alpha)
330
- return exp_map0(feats.float() * self.visual_alpha.exp().float(), self._kappa().float())
331
-
332
- def project_text_features(self, feats: torch.Tensor) -> torch.Tensor:
333
- if self.phyclip_enabled:
334
- return self._project_product_features(feats, self.textual_alpha)
335
- return exp_map0(feats.float() * self.textual_alpha.exp().float(), self._kappa().float())
336
-
337
- def similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
338
- return metric_similarity(image_feats, text_feats, self._kappa(), product_metric=self.phyclip_product_metric)
339
-
340
- def encode_retrieval_image(self, image: torch.Tensor) -> torch.Tensor:
341
- base = self.encode_image_base(image)
342
- tangent = self.image_proj(base)
343
- if self.proclip_retrieval:
344
- return self._project_proclip_image_base(base, self.project_image_features(tangent))
345
- return self.project_image_features(tangent)
346
-
347
- def encode_retrieval_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
348
- base = self.encode_text_base(input_ids, attention_mask)
349
- tangent = self.text_proj(base)
350
- if self.proclip_retrieval:
351
- return self._project_proclip_text_base(base, self.project_text_features(tangent))
352
- return self.project_text_features(tangent)
353
-
354
- def retrieval_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
355
- if self.proclip_retrieval:
356
- return self._proclip_similarity_scores(image_feats, text_feats)
357
- return self.similarity_scores(image_feats, text_feats)
358
-
359
- @property
360
- def retrieval_requires_chunking(self) -> bool:
361
- return self.phyclip_enabled or self.proclip_retrieval
362
-
363
- def _objective_autocast(self, device_type: str):
364
- dtype = {
365
- "float32": torch.float32,
366
- "fp32": torch.float32,
367
- "float16": torch.float16,
368
- "fp16": torch.float16,
369
- "bfloat16": torch.bfloat16,
370
- "bf16": torch.bfloat16,
371
- }[self.objective_autocast_dtype]
372
- enabled = device_type != "cpu" and dtype is not torch.float32
373
- return torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled)
374
-
375
- def forward(
376
- self,
377
- image: torch.Tensor,
378
- part_images: torch.Tensor,
379
- text_input_ids: torch.Tensor,
380
- text_attention_mask: torch.Tensor,
381
- part_text_input_ids: torch.Tensor,
382
- part_text_attention_mask: torch.Tensor,
383
- part_owner: torch.Tensor,
384
- step: int | None = None,
385
- beta_query_input_ids: torch.Tensor | None = None,
386
- beta_query_attention_mask: torch.Tensor | None = None,
387
- beta_query_owner: torch.Tensor | None = None,
388
- beta_query_type: torch.Tensor | None = None,
389
- beta_query_parent: torch.Tensor | None = None,
390
- beta_query_weight: torch.Tensor | None = None,
391
- beta_query_source_part: torch.Tensor | None = None,
392
- ) -> dict[str, torch.Tensor]:
393
- with torch.no_grad():
394
- self._clamp_logit_scales()
395
- self.visual_alpha.clamp_(max=0.0)
396
- self.textual_alpha.clamp_(max=0.0)
397
- kappa = self._kappa()
398
-
399
- feature_dim = self.embed_dim
400
- beta_image_tokens = None
401
- beta_query_base = None
402
- part_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
403
- part_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
404
- hier_beta_enabled = self.objective_name == "uncha" and self.uncha_entailment_loss in {
405
- "hier_beta_argent",
406
- "hier_beta_sourcepart_argent",
407
- }
408
- if (
409
- hier_beta_enabled
410
- and self.fuse_beta_query_encoder_forwards
411
- and not self.tren_enabled
412
- and beta_query_input_ids is not None
413
- and beta_query_attention_mask is not None
414
- and part_images.shape[0] > 0
415
- ):
416
- (
417
- image_base,
418
- text_base,
419
- image_euc,
420
- text_euc,
421
- image_feats,
422
- text_feats,
423
- part_image_feats,
424
- part_text_feats,
425
- part_image_euc,
426
- part_text_euc,
427
- part_image_base,
428
- part_text_base,
429
- beta_image_tokens,
430
- beta_query_base,
431
- ) = self._encode_hier_beta_whole_parts_and_queries(
432
- image=image,
433
- part_images=part_images,
434
- text_input_ids=text_input_ids,
435
- text_attention_mask=text_attention_mask,
436
- part_text_input_ids=part_text_input_ids,
437
- part_text_attention_mask=part_text_attention_mask,
438
- beta_query_input_ids=beta_query_input_ids,
439
- beta_query_attention_mask=beta_query_attention_mask,
440
- )
441
- elif self.beta_query_pooling_enabled or self.tren_enabled:
442
- image_base, beta_image_tokens = self.encode_image_base_with_tokens(image)
443
- text_base = self.encode_text_base(text_input_ids, text_attention_mask)
444
- image_euc = self.image_proj(image_base)
445
- text_euc = self.text_proj(text_base)
446
- image_feats = self.project_image_features(image_euc)
447
- text_feats = self.project_text_features(text_euc)
448
- (
449
- part_image_feats,
450
- part_text_feats,
451
- part_image_euc,
452
- part_text_euc,
453
- part_image_base,
454
- part_text_base,
455
- ) = self._encode_parts_with_base(
456
- part_images=part_images,
457
- part_text_input_ids=part_text_input_ids,
458
- part_text_attention_mask=part_text_attention_mask,
459
- feature_dim=feature_dim,
460
- )
461
- elif self.fuse_whole_part_encoder_forwards and self.objective_name != "proclip" and part_images.shape[0] > 0:
462
- (
463
- image_base,
464
- text_base,
465
- image_euc,
466
- text_euc,
467
- image_feats,
468
- text_feats,
469
- part_image_feats,
470
- part_text_feats,
471
- part_image_euc,
472
- part_text_euc,
473
- part_image_base,
474
- part_text_base,
475
- ) = self._encode_whole_and_parts(
476
- image=image,
477
- part_images=part_images,
478
- text_input_ids=text_input_ids,
479
- text_attention_mask=text_attention_mask,
480
- part_text_input_ids=part_text_input_ids,
481
- part_text_attention_mask=part_text_attention_mask,
482
- )
483
- else:
484
- image_base = self.encode_image_base(image)
485
- text_base = self.encode_text_base(text_input_ids, text_attention_mask)
486
- image_euc = self.image_proj(image_base)
487
- text_euc = self.text_proj(text_base)
488
- image_feats = self.project_image_features(image_euc)
489
- text_feats = self.project_text_features(text_euc)
490
- (
491
- part_image_feats,
492
- part_text_feats,
493
- part_image_euc,
494
- part_text_euc,
495
- part_image_base,
496
- part_text_base,
497
- ) = self._encode_parts_with_base(
498
- part_images=part_images,
499
- part_text_input_ids=part_text_input_ids,
500
- part_text_attention_mask=part_text_attention_mask,
501
- feature_dim=feature_dim,
502
- )
503
- targets = local_target_indices(image_feats.size(0), image_feats.device)
504
-
505
- if self.objective_name == "proclip":
506
- proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
507
- proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
508
- proclip_loss = self._proclip_contrastive_loss(
509
- image_feats=proclip_image_feats,
510
- text_feats=proclip_text_feats,
511
- all_image_feats=gather_with_grad(proclip_image_feats),
512
- all_text_feats=gather_with_grad(proclip_text_feats),
513
- targets=targets,
514
- )
515
- zero = proclip_loss.new_zeros(())
516
- return {
517
- "loss": proclip_loss,
518
- "contrastive_loss": proclip_loss,
519
- "entailment_loss": zero,
520
- "part_count": part_owner.new_tensor(0),
521
- "proclip_contrastive_loss": proclip_loss,
522
- **self._detached_kappa_logs(kappa),
523
- **self._detached_logit_scales(),
524
- }
525
-
526
- himo_text_feats = None
527
- all_himo_text_feats = None
528
- if self.objective_name == "uncha" and self.uncha_himo_component_weight > 0.0:
529
- all_text_euc = gather_with_grad(text_euc)
530
- all_component_euc = hide_reconstruct_embeddings(
531
- all_text_euc,
532
- variance_threshold=self.uncha_himo_variance_threshold,
533
- detach_pca=self.uncha_himo_detach_pca,
534
- )
535
- if get_world_size() > 1:
536
- start = text_euc.size(0) * get_rank()
537
- end = start + text_euc.size(0)
538
- component_euc = all_component_euc[start:end]
539
- else:
540
- component_euc = all_component_euc
541
- himo_text_feats = self.project_text_features(component_euc)
542
- all_himo_text_feats = gather_with_grad(himo_text_feats)
543
- all_image_feats = gather_with_grad(image_feats)
544
- all_text_feats = gather_with_grad(text_feats)
545
- all_image_euc = None
546
- all_text_euc = None
547
- if self.objective_name == "uncha" and self.uncha_contrastive_loss == "siglip":
548
- all_image_euc = gather_with_grad(image_euc)
549
- all_text_euc = gather_with_grad(text_euc)
550
- part_owner = part_owner.to(device=image_feats.device, dtype=torch.long)
551
- beta_query_embeddings = {}
552
- if self.objective_name == "uncha" and self.uncha_entailment_loss in {
553
- "hier_beta_argent",
554
- "hier_beta_sourcepart_argent",
555
- }:
556
- if beta_image_tokens is None:
557
- raise RuntimeError(f"{self.uncha_entailment_loss} requires image patch tokens")
558
- with torch.autocast(device_type=image.device.type, enabled=False):
559
- beta_query_embeddings = self._beta_query_entailment_embeddings(
560
- image_tokens=beta_image_tokens.float(),
561
- beta_query_input_ids=beta_query_input_ids,
562
- beta_query_attention_mask=beta_query_attention_mask,
563
- beta_query_owner=beta_query_owner,
564
- beta_query_parent=beta_query_parent,
565
- beta_query_weight=beta_query_weight,
566
- beta_query_source_part=beta_query_source_part,
567
- kappa=kappa.float(),
568
- query_base=beta_query_base,
569
- )
570
-
571
- with self._objective_autocast(image.device.type):
572
- if self.objective is None:
573
- raise RuntimeError("Non-ProCLIP forward requires an objective module")
574
- losses = self.objective(
575
- {
576
- "image_feats": image_feats,
577
- "text_feats": text_feats,
578
- "part_image_feats": part_image_feats,
579
- "part_text_feats": part_text_feats,
580
- "part_owner": part_owner,
581
- "all_image_feats": all_image_feats,
582
- "all_text_feats": all_text_feats,
583
- **(
584
- {
585
- "image_euc_feats": image_euc,
586
- "text_euc_feats": text_euc,
587
- "part_image_euc_feats": part_image_euc,
588
- "part_text_euc_feats": part_text_euc,
589
- "all_image_euc_feats": all_image_euc,
590
- "all_text_euc_feats": all_text_euc,
591
- }
592
- if all_image_euc is not None and all_text_euc is not None
593
- else {}
594
- ),
595
- "targets": targets,
596
- "kappa": kappa,
597
- "entail_weight_scale": self._entail_weight_scale(step, image_feats.device),
598
- **beta_query_embeddings,
599
- **(
600
- {
601
- "himo_text_feats": himo_text_feats,
602
- "all_himo_text_feats": all_himo_text_feats,
603
- }
604
- if himo_text_feats is not None
605
- else {}
606
- ),
607
- },
608
- self._objective_logit_scales(),
609
- )
610
-
611
- if self.beta_clip_global_weight > 0.0:
612
- with torch.autocast(device_type=image.device.type, enabled=False):
613
- beta_clip_global_loss = self._beta_clip_global_contrastive_loss(
614
- image_euc=image_euc,
615
- text_euc=text_euc,
616
- targets=targets,
617
- )
618
- losses = {
619
- **losses,
620
- "loss": losses["loss"] + self.beta_clip_global_weight * beta_clip_global_loss,
621
- "beta_clip_global_loss": beta_clip_global_loss,
622
- }
623
-
624
- if self.beta_clip_enabled:
625
- if beta_image_tokens is None:
626
- raise RuntimeError("beta-CLIP auxiliary requires image patch tokens")
627
- with torch.autocast(device_type=image.device.type, enabled=False):
628
- beta_clip_loss = self._beta_clip_auxiliary_loss(
629
- image_tokens=beta_image_tokens.float(),
630
- beta_query_input_ids=beta_query_input_ids,
631
- beta_query_attention_mask=beta_query_attention_mask,
632
- beta_query_owner=beta_query_owner,
633
- global_targets=targets,
634
- kappa=kappa.float(),
635
- )
636
- losses = {
637
- **losses,
638
- "loss": losses["loss"] + self.beta_clip_weight * beta_clip_loss,
639
- "beta_clip_loss": beta_clip_loss,
640
- }
641
-
642
- if self.tren_enabled:
643
- if beta_image_tokens is None:
644
- raise RuntimeError("T-REN auxiliary requires image patch tokens")
645
- with torch.autocast(device_type=image.device.type, enabled=False):
646
- tren_losses = self._tren_auxiliary_losses(
647
- image_tokens=beta_image_tokens.float(),
648
- part_owner=part_owner,
649
- part_image_base=part_image_base.float(),
650
- part_text_base=part_text_base.float(),
651
- )
652
- losses = {
653
- **losses,
654
- "loss": losses["loss"] + self.tren_weight * tren_losses["tren_loss"],
655
- **tren_losses,
656
- }
657
-
658
- if self.proclip_enabled and self.proclip_weight > 0.0:
659
- proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
660
- proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
661
- proclip_loss = self._proclip_contrastive_loss(
662
- image_feats=proclip_image_feats,
663
- text_feats=proclip_text_feats,
664
- all_image_feats=gather_with_grad(proclip_image_feats),
665
- all_text_feats=gather_with_grad(proclip_text_feats),
666
- targets=targets,
667
- )
668
- losses = {
669
- **losses,
670
- "loss": losses["loss"] + self.proclip_weight * proclip_loss,
671
- "proclip_contrastive_loss": proclip_loss,
672
- }
673
-
674
- return {**losses, **self._detached_kappa_logs(kappa), **self._detached_logit_scales()}
675
-
676
- def _encode_parts(
677
- self,
678
- part_images: torch.Tensor,
679
- part_text_input_ids: torch.Tensor,
680
- part_text_attention_mask: torch.Tensor,
681
- feature_dim: int,
682
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
683
- if part_images.shape[0] == 0:
684
- empty = part_images.new_zeros((0, feature_dim))
685
- return empty, empty, empty, empty
686
-
687
- part_image_euc = self.image_proj(self.encode_image_base(part_images))
688
- part_text_euc = self.text_proj(self.encode_text_base(part_text_input_ids, part_text_attention_mask))
689
- part_image_feats = self.project_image_features(part_image_euc)
690
- part_text_feats = self.project_text_features(part_text_euc)
691
- return part_image_feats, part_text_feats, part_image_euc, part_text_euc
692
-
693
- def _encode_parts_with_base(
694
- self,
695
- part_images: torch.Tensor,
696
- part_text_input_ids: torch.Tensor,
697
- part_text_attention_mask: torch.Tensor,
698
- feature_dim: int,
699
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
700
- if part_images.shape[0] == 0:
701
- empty = part_images.new_zeros((0, feature_dim))
702
- empty_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
703
- empty_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
704
- return empty, empty, empty, empty, empty_image_base, empty_text_base
705
-
706
- part_image_base = self.encode_image_base(part_images)
707
- part_text_base = self.encode_text_base(part_text_input_ids, part_text_attention_mask)
708
- part_image_euc = self.image_proj(part_image_base)
709
- part_text_euc = self.text_proj(part_text_base)
710
- part_image_feats = self.project_image_features(part_image_euc)
711
- part_text_feats = self.project_text_features(part_text_euc)
712
- return part_image_feats, part_text_feats, part_image_euc, part_text_euc, part_image_base, part_text_base
713
-
714
- def _encode_whole_and_parts(
715
- self,
716
- image: torch.Tensor,
717
- part_images: torch.Tensor,
718
- text_input_ids: torch.Tensor,
719
- text_attention_mask: torch.Tensor,
720
- part_text_input_ids: torch.Tensor,
721
- part_text_attention_mask: torch.Tensor,
722
- ) -> tuple[
723
- torch.Tensor,
724
- torch.Tensor,
725
- torch.Tensor,
726
- torch.Tensor,
727
- torch.Tensor,
728
- torch.Tensor,
729
- torch.Tensor,
730
- torch.Tensor,
731
- torch.Tensor,
732
- torch.Tensor,
733
- torch.Tensor,
734
- torch.Tensor,
735
- ]:
736
- batch_size = image.shape[0]
737
- part_count = part_images.shape[0]
738
- image_base_all = self.encode_image_base(torch.cat([image, part_images], dim=0))
739
- image_euc_all = self.image_proj(image_base_all)
740
- image_feats_all = self.project_image_features(image_euc_all)
741
-
742
- text_ids, text_mask = self._concat_text_batches(
743
- text_input_ids,
744
- text_attention_mask,
745
- part_text_input_ids,
746
- part_text_attention_mask,
747
- )
748
- text_base_all = self.encode_text_base(text_ids, text_mask)
749
- text_euc_all = self.text_proj(text_base_all)
750
- text_feats_all = self.project_text_features(text_euc_all)
751
-
752
- image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
753
- text_base, part_text_base = text_base_all.split([batch_size, part_count], dim=0)
754
- image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
755
- text_euc, part_text_euc = text_euc_all.split([batch_size, part_count], dim=0)
756
- image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
757
- text_feats, part_text_feats = text_feats_all.split([batch_size, part_count], dim=0)
758
- return (
759
- image_base,
760
- text_base,
761
- image_euc,
762
- text_euc,
763
- image_feats,
764
- text_feats,
765
- part_image_feats,
766
- part_text_feats,
767
- part_image_euc,
768
- part_text_euc,
769
- part_image_base,
770
- part_text_base,
771
- )
772
-
773
- def _encode_hier_beta_whole_parts_and_queries(
774
- self,
775
- image: torch.Tensor,
776
- part_images: torch.Tensor,
777
- text_input_ids: torch.Tensor,
778
- text_attention_mask: torch.Tensor,
779
- part_text_input_ids: torch.Tensor,
780
- part_text_attention_mask: torch.Tensor,
781
- beta_query_input_ids: torch.Tensor,
782
- beta_query_attention_mask: torch.Tensor,
783
- ) -> tuple[
784
- torch.Tensor,
785
- torch.Tensor,
786
- torch.Tensor,
787
- torch.Tensor,
788
- torch.Tensor,
789
- torch.Tensor,
790
- torch.Tensor,
791
- torch.Tensor,
792
- torch.Tensor,
793
- torch.Tensor,
794
- torch.Tensor,
795
- torch.Tensor,
796
- torch.Tensor,
797
- torch.Tensor,
798
- ]:
799
- batch_size = image.shape[0]
800
- part_count = part_images.shape[0]
801
- query_count = beta_query_input_ids.shape[0]
802
-
803
- image_base_all, image_tokens_all = self.encode_image_base_with_tokens(torch.cat([image, part_images], dim=0))
804
- image_euc_all = self.image_proj(image_base_all)
805
- image_feats_all = self.project_image_features(image_euc_all)
806
- image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
807
- image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
808
- image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
809
- beta_image_tokens = image_tokens_all[:batch_size]
810
-
811
- text_ids, text_mask = self._concat_text_batch_list(
812
- (text_input_ids, text_attention_mask),
813
- (part_text_input_ids, part_text_attention_mask),
814
- (beta_query_input_ids, beta_query_attention_mask),
815
- )
816
- text_base_all = self.encode_text_base(text_ids, text_mask)
817
- text_euc_all = self.text_proj(text_base_all)
818
- text_feats_all = self.project_text_features(text_euc_all)
819
- text_base, part_text_base, beta_query_base = text_base_all.split([batch_size, part_count, query_count], dim=0)
820
- text_euc, part_text_euc, _ = text_euc_all.split([batch_size, part_count, query_count], dim=0)
821
- text_feats, part_text_feats, _ = text_feats_all.split([batch_size, part_count, query_count], dim=0)
822
-
823
- return (
824
- image_base,
825
- text_base,
826
- image_euc,
827
- text_euc,
828
- image_feats,
829
- text_feats,
830
- part_image_feats,
831
- part_text_feats,
832
- part_image_euc,
833
- part_text_euc,
834
- part_image_base,
835
- part_text_base,
836
- beta_image_tokens,
837
- beta_query_base,
838
- )
839
-
840
- def _concat_text_batches(
841
- self,
842
- text_input_ids: torch.Tensor,
843
- text_attention_mask: torch.Tensor,
844
- part_text_input_ids: torch.Tensor,
845
- part_text_attention_mask: torch.Tensor,
846
- ) -> tuple[torch.Tensor, torch.Tensor]:
847
- return self._concat_text_batch_list(
848
- (text_input_ids, text_attention_mask),
849
- (part_text_input_ids, part_text_attention_mask),
850
- )
851
-
852
- def _concat_text_batch_list(
853
- self,
854
- *batches: tuple[torch.Tensor, torch.Tensor],
855
- ) -> tuple[torch.Tensor, torch.Tensor]:
856
- target_length = max(input_ids.shape[1] for input_ids, _ in batches)
857
- pad_token_id = self.text_encoder.tokenizer.pad_token_id
858
- if pad_token_id is None:
859
- pad_token_id = 0
860
- return (
861
- torch.cat([_pad_sequence_dim(input_ids, target_length, pad_token_id) for input_ids, _ in batches], dim=0),
862
- torch.cat([_pad_sequence_dim(attention_mask, target_length, 0) for _, attention_mask in batches], dim=0),
863
- )
864
-
865
- def _clamp_logit_scales(self) -> None:
866
- if self.objective_name == "proclip":
867
- self.proclip_logit_scale.clamp_(max=4.6052)
868
- self._clamp_experimental_logit_scales()
869
- return
870
- if self.objective_name == "hycoclip":
871
- self.logit_scale.clamp_(max=4.6052)
872
- self._clamp_experimental_logit_scales()
873
- return
874
- self.global_logit_scale.clamp_(max=4.6052)
875
- self.local_logit_scale.clamp_(max=4.6052)
876
- self.global_local_logit_scale.clamp_(max=4.6052)
877
- self._clamp_experimental_logit_scales()
878
-
879
- def _objective_logit_scales(self) -> torch.Tensor | dict[str, torch.Tensor]:
880
- if self.objective_name == "hycoclip":
881
- return self.logit_scale
882
- if self.objective_name == "proclip":
883
- return self.proclip_logit_scale
884
- return {
885
- "global": self.global_logit_scale,
886
- "local": self.local_logit_scale,
887
- "global_local": self.global_local_logit_scale,
888
- **(
889
- {
890
- "global_bias": self.global_logit_bias,
891
- "local_bias": self.local_logit_bias,
892
- "global_local_bias": self.global_local_logit_bias,
893
- }
894
- if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}
895
- else {}
896
- ),
897
- }
898
-
899
- def _detached_logit_scales(self) -> dict[str, torch.Tensor]:
900
- if self.objective_name == "proclip":
901
- return self._detached_experimental_logit_scales()
902
- if self.objective_name == "hycoclip":
903
- logs = {"logit_scale": self.logit_scale.exp().detach()}
904
- logs.update(self._detached_experimental_logit_scales())
905
- return logs
906
- logs = {
907
- "global_logit_scale": self.global_logit_scale.exp().detach(),
908
- "local_logit_scale": self.local_logit_scale.exp().detach(),
909
- "global_local_logit_scale": self.global_local_logit_scale.exp().detach(),
910
- }
911
- if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
912
- logs.update(
913
- {
914
- "global_logit_bias": self.global_logit_bias.detach(),
915
- "local_logit_bias": self.local_logit_bias.detach(),
916
- "global_local_logit_bias": self.global_local_logit_bias.detach(),
917
- }
918
- )
919
- logs.update(self._detached_experimental_logit_scales())
920
- return logs
921
-
922
- def _project_product_features(self, feats: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
923
- product_feats = feats.float().reshape(feats.size(0), self.phyclip_num_factors, self.phyclip_subspace_dim)
924
- product_feats = product_feats * alpha.exp().float().view(1, -1, 1)
925
- return exp_map0(product_feats, self._kappa().float().view(1, -1, 1))
926
-
927
- def _detached_kappa_logs(self, kappa: torch.Tensor) -> dict[str, torch.Tensor]:
928
- detached = kappa.detach()
929
- if detached.numel() == 1:
930
- return {"kappa": detached.reshape(())}
931
- return {
932
- "kappa": detached.mean(),
933
- "kappa_min": detached.min(),
934
- "kappa_max": detached.max(),
935
- }
936
-
937
- def _entail_weight_scale(self, step: int | None, device: torch.device) -> torch.Tensor:
938
- if self.uncha_entailment_warmup_steps <= 0 or step is None:
939
- return torch.ones((), device=device)
940
- scale = min(1.0, float(step + 1) / float(self.uncha_entailment_warmup_steps))
941
- return torch.tensor(scale, device=device)
942
-
943
-
944
- def _projection_head(input_dim: int, output_dim: int, hidden_dim: int | None) -> nn.Module:
945
- if hidden_dim is None:
946
- return nn.Linear(input_dim, output_dim)
947
- return nn.Sequential(
948
- nn.Linear(input_dim, hidden_dim),
949
- nn.ReLU(),
950
- nn.Linear(hidden_dim, output_dim),
951
- )
952
-
953
-
954
- def _pad_sequence_dim(tensor: torch.Tensor, target_length: int, value: int) -> torch.Tensor:
955
- pad = target_length - tensor.shape[1]
956
- if pad <= 0:
957
- return tensor
958
- return F.pad(tensor, (0, pad), value=value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/lorentz.py DELETED
@@ -1,265 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
-
5
- import torch
6
- from torch import Tensor
7
-
8
-
9
- def lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
10
- """Compute batched Lorentzian inner product for matching rows."""
11
- x = x.float()
12
- y = y.float()
13
- return -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
14
-
15
-
16
- def pairwise_lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
17
- """Compute all-pairs Lorentzian inner products."""
18
- x = x.float()
19
- y = y.float()
20
- time = -x[:, :1] @ y[:, :1].T
21
- space = x[:, 1:] @ y[:, 1:].T
22
- return time + space
23
-
24
-
25
- def exp_map0(u: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
26
- """Exponential map at the origin from tangent space to hyperboloid."""
27
- u = u.float()
28
- kappa = kappa.float()
29
- sqrt_k = torch.sqrt(kappa)
30
- norm_u = torch.linalg.norm(u, dim=-1, keepdim=True).clamp_min(eps)
31
- scaled = sqrt_k * norm_u
32
- clipped_scaled = scaled.clamp_max(math.asinh(2**15))
33
- time = torch.cosh(clipped_scaled) / sqrt_k
34
- space = torch.sinh(clipped_scaled) * u / scaled.clamp_min(eps)
35
- return torch.cat([time, space], dim=-1)
36
-
37
-
38
- def log_map0(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
39
- """Logarithmic map at the origin from hyperboloid to tangent space.
40
-
41
- Inverts ``exp_map0`` for points on the Lorentz model hyperboloid. Returns
42
- vectors in the Euclidean tangent space at the origin (no time coordinate).
43
- """
44
- x = x.float()
45
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
46
- kappa = kappa.to(dtype=torch.float32).flatten()
47
-
48
- if x.dim() == 2:
49
- if kappa.numel() != 1:
50
- raise ValueError("log_map0 expects scalar kappa for non-product embeddings")
51
- sqrt_k = torch.sqrt(kappa.reshape(()))
52
- alpha = torch.acosh((sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps))
53
- coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
54
- return x[:, 1:] * coef.unsqueeze(-1)
55
-
56
- if x.dim() == 3:
57
- if kappa.numel() == 1:
58
- kappa = kappa.expand(x.shape[1])
59
- if kappa.numel() != x.shape[1]:
60
- raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
61
- sqrt_k = torch.sqrt(kappa).view(1, -1)
62
- alpha = torch.acosh((sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps))
63
- coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
64
- return x[..., 1:] * coef.unsqueeze(-1)
65
-
66
- raise ValueError("log_map0 expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
67
-
68
-
69
- def pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
70
- """Pairwise geodesic distance on the Lorentz model."""
71
- kappa = kappa.float()
72
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
73
- prod = (-kappa) * pairwise_lorentz_inner(x, y)
74
- prod = prod.clamp_min(1.0 + dist_eps)
75
- return torch.acosh(prod) / torch.sqrt(kappa)
76
-
77
-
78
- def product_pairwise_dist(
79
- x: Tensor,
80
- y: Tensor,
81
- kappa: Tensor,
82
- metric: str = "l1",
83
- eps: float = 1e-8,
84
- ) -> Tensor:
85
- """Pairwise distance in an l1/l2 product of Lorentz factors.
86
-
87
- Inputs have shape ``[batch, factors, dim + 1]``. For ``metric="l1"``, this
88
- matches the official PHyCLIP implementation's mean distance over factors.
89
- """
90
- if x.dim() != 3 or y.dim() != 3:
91
- raise ValueError("product_pairwise_dist expects [batch, factors, dim + 1] tensors")
92
- if x.shape[1] != y.shape[1] or x.shape[2] != y.shape[2]:
93
- raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
94
- kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
95
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
96
- x = x.float()
97
- y = y.float()
98
- inner = -x[:, None, :, 0] * y[None, :, :, 0] + torch.einsum("bkd,nkd->bnk", x[..., 1:], y[..., 1:])
99
- prod = (-kappa.view(1, 1, -1)) * inner
100
- dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, 1, -1)
101
- if metric == "l1":
102
- return dist.mean(dim=-1)
103
- if metric == "l2":
104
- return dist.square().mean(dim=-1).sqrt()
105
- raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
106
-
107
-
108
- def metric_pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
109
- """Pairwise distance for either a single Lorentz space or a product space."""
110
- if x.dim() == 3 or y.dim() == 3:
111
- return product_pairwise_dist(x, y, kappa, metric=product_metric)
112
- return pairwise_dist(x, y, kappa)
113
-
114
-
115
- def paired_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1", eps: float = 1e-8) -> Tensor:
116
- """Row-wise distance for either a single Lorentz space or a product space."""
117
- if x.dim() == 3 or y.dim() == 3:
118
- if x.shape != y.shape:
119
- raise ValueError("Product paired_dist expects matching tensor shapes")
120
- kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
121
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
122
- x = x.float()
123
- y = y.float()
124
- inner = -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
125
- prod = (-kappa.view(1, -1)) * inner
126
- dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, -1)
127
- if product_metric == "l1":
128
- return dist.mean(dim=-1)
129
- if product_metric == "l2":
130
- return dist.square().mean(dim=-1).sqrt()
131
- raise ValueError(f"Unsupported product metric {product_metric!r}; expected 'l1' or 'l2'")
132
- kappa = kappa.float()
133
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
134
- prod = (-kappa) * lorentz_inner(x, y)
135
- prod = prod.clamp_min(1.0 + dist_eps)
136
- return torch.acosh(prod) / torch.sqrt(kappa)
137
-
138
-
139
- def radial_distance(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
140
- """Geodesic distance from the origin.
141
-
142
- For points on the hyperboloid, the time coordinate satisfies
143
- ``x0 = cosh(sqrt(kappa) * r) / sqrt(kappa)``, so we can recover the radial
144
- distance via ``r = arcosh(sqrt(kappa) * x0) / sqrt(kappa)``.
145
- """
146
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
147
- x = x.float()
148
- kappa = kappa.to(dtype=torch.float32).flatten()
149
- if x.dim() == 2:
150
- if kappa.numel() != 1:
151
- raise ValueError("radial_distance expects scalar kappa for non-product embeddings")
152
- sqrt_k = torch.sqrt(kappa.reshape(()))
153
- arg = (sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps)
154
- return torch.acosh(arg) / sqrt_k
155
- if x.dim() == 3:
156
- if kappa.numel() == 1:
157
- kappa = kappa.expand(x.shape[1])
158
- if kappa.numel() != x.shape[1]:
159
- raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
160
- sqrt_k = torch.sqrt(kappa).view(1, -1)
161
- arg = (sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps)
162
- dist = torch.acosh(arg) / sqrt_k
163
- return dist.mean(dim=-1)
164
- raise ValueError("radial_distance expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
165
-
166
-
167
- def metric_similarity(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
168
- """Retrieval/classification similarity for single-space and PHyCLIP-style models."""
169
- if x.dim() == 3 or y.dim() == 3:
170
- return -product_pairwise_dist(x, y, kappa, metric=product_metric)
171
- return pairwise_lorentz_inner(x, y)
172
-
173
-
174
- def half_aperture(general: Tensor, kappa: Tensor, min_radius: float = 0.1, eps: float = 1e-8) -> Tensor:
175
- """Cone half-aperture for entailment cone centered at general concept."""
176
- general = general.float()
177
- kappa = kappa.float()
178
- aperture_eps = max(eps, 16.0 * torch.finfo(general.dtype).eps)
179
- general_norm = torch.linalg.norm(general[:, 1:], dim=-1)
180
- ratio = (2.0 * min_radius) / (general_norm * torch.sqrt(kappa) + aperture_eps)
181
- ratio = ratio.clamp(max=1.0 - aperture_eps)
182
- return torch.asin(ratio)
183
-
184
-
185
- def oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
186
- """Exterior angle between specific point and entailment cone at general point."""
187
- specific = specific.float()
188
- general = general.float()
189
- kappa = kappa.float()
190
- angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
191
- inner = lorentz_inner(specific, general)
192
- numerator = specific[:, 0] + kappa * inner * general[:, 0]
193
- general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
194
- denom_term = (kappa * inner).pow(2) - 1.0
195
- denom = general_norm * torch.sqrt(denom_term.clamp_min(angle_eps))
196
- cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
197
- return torch.acos(cosine)
198
-
199
-
200
- def pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
201
- """All-pairs exterior angle between specific points and entailment cones at general points."""
202
- specific = specific.float()
203
- general = general.float()
204
- kappa = kappa.to(dtype=torch.float32).flatten()
205
- if kappa.numel() != 1:
206
- raise ValueError("pairwise_oxy_angle expects scalar kappa for non-product embeddings")
207
- kappa_scalar = kappa.reshape(())
208
- angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
209
- inner = -specific[:, None, 0] * general[None, :, 0] + torch.einsum("nd,md->nm", specific[:, 1:], general[:, 1:])
210
- numerator = specific[:, None, 0] + kappa_scalar * inner * general[None, :, 0]
211
- general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
212
- denom_term = (kappa_scalar * inner).pow(2) - 1.0
213
- denom = general_norm[None, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
214
- cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
215
- return torch.acos(cosine)
216
-
217
-
218
- def product_pairwise_oxy_angle(
219
- specific: Tensor,
220
- general: Tensor,
221
- kappa: Tensor,
222
- metric: str = "l1",
223
- eps: float = 1e-8,
224
- ) -> Tensor:
225
- """All-pairs exterior angle in an l1/l2 product of Lorentz factors."""
226
- if specific.dim() != 3 or general.dim() != 3:
227
- raise ValueError("product_pairwise_oxy_angle expects [batch, factors, dim + 1] tensors")
228
- if specific.shape[1] != general.shape[1] or specific.shape[2] != general.shape[2]:
229
- raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
230
- kappa = _product_kappa(kappa, specific.shape[1], specific.device).to(dtype=torch.float32)
231
- angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
232
- specific = specific.float()
233
- general = general.float()
234
- inner = -specific[:, None, :, 0] * general[None, :, :, 0] + torch.einsum(
235
- "nkd,mkd->nmk",
236
- specific[..., 1:],
237
- general[..., 1:],
238
- )
239
- numerator = specific[:, None, :, 0] + (kappa.view(1, 1, -1) * inner) * general[None, :, :, 0]
240
- general_norm = torch.linalg.norm(general[..., 1:], dim=-1).clamp_min(angle_eps)
241
- denom_term = (kappa.view(1, 1, -1) * inner).pow(2) - 1.0
242
- denom = general_norm[None, :, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
243
- cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
244
- angles = torch.acos(cosine)
245
- if metric == "l1":
246
- return angles.mean(dim=-1)
247
- if metric == "l2":
248
- return angles.square().mean(dim=-1).sqrt()
249
- raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
250
-
251
-
252
- def metric_pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
253
- """All-pairs oxy-angle for either a single Lorentz space or a product space."""
254
- if specific.dim() == 3 or general.dim() == 3:
255
- return product_pairwise_oxy_angle(specific, general, kappa, metric=product_metric)
256
- return pairwise_oxy_angle(specific, general, kappa)
257
-
258
-
259
- def _product_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
260
- kappa = kappa.to(device=device, dtype=torch.float32).flatten()
261
- if kappa.numel() == 1:
262
- return kappa.expand(num_factors)
263
- if kappa.numel() != num_factors:
264
- raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
265
- return kappa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/losses.py DELETED
@@ -1,1400 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
-
5
- import torch
6
- from torch import Tensor
7
- import torch.nn.functional as F
8
-
9
- from hyper3_clip.models.lorentz import (
10
- half_aperture,
11
- metric_pairwise_dist,
12
- metric_pairwise_oxy_angle,
13
- oxy_angle,
14
- paired_dist,
15
- radial_distance,
16
- )
17
-
18
-
19
- def contrastive_ce(logits: Tensor, targets: Tensor | None = None, weights: Tensor | None = None) -> Tensor:
20
- if targets is None:
21
- targets = torch.arange(logits.size(0), device=logits.device)
22
- losses = F.cross_entropy(logits, targets, reduction="none")
23
- return weighted_mean(losses, weights)
24
-
25
-
26
- def contrastive_sigmoid(
27
- logits: Tensor,
28
- targets: Tensor | None = None,
29
- weights: Tensor | None = None,
30
- negative_weight: float = 1.0,
31
- ) -> Tensor:
32
- if targets is None:
33
- targets = torch.arange(logits.size(0), device=logits.device)
34
- labels = torch.zeros_like(logits)
35
- labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
36
- losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
37
- if negative_weight != 1.0:
38
- element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
39
- losses = losses * element_weights
40
- losses = losses.mean(dim=1)
41
- return weighted_mean(losses, weights)
42
-
43
-
44
- def contrastive_siglip(
45
- logits: Tensor,
46
- targets: Tensor | None = None,
47
- weights: Tensor | None = None,
48
- negative_weight: float = 1.0,
49
- ) -> Tensor:
50
- """SigLIP pairwise sigmoid loss (Zhai et al., ICCV 2023).
51
-
52
- Uses labels in {+1, -1} with a per-row sum (not mean) over pairs:
53
- L_i = sum_j softplus(- y_ij * logit_ij)
54
- """
55
- if logits.ndim != 2:
56
- raise ValueError("contrastive_siglip expects a [batch, classes] logit matrix")
57
- if targets is None:
58
- targets = torch.arange(logits.size(0), device=logits.device)
59
- labels = logits.new_full(logits.shape, -1.0)
60
- labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
61
- losses = F.softplus(-(labels * logits))
62
- if negative_weight != 1.0:
63
- element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
64
- losses = losses * element_weights
65
- row_losses = losses.sum(dim=1)
66
- return weighted_mean(row_losses, weights)
67
-
68
-
69
- def weighted_mean(values: Tensor, weights: Tensor | None = None) -> Tensor:
70
- if weights is None:
71
- return values.mean()
72
- weights = weights.to(device=values.device, dtype=values.dtype)
73
- while weights.dim() < values.dim():
74
- weights = weights.unsqueeze(-1)
75
- return (values * weights).sum() / weights.sum().clamp_min(torch.finfo(values.dtype).eps)
76
-
77
-
78
- def gramian_volume_loss(vectors: Tensor, weights: Tensor | None = None, eps: float = 1e-4) -> Tensor:
79
- """GRAM-style volume loss for sets of vectors.
80
-
81
- ``vectors`` is expected to have shape ``[batch, k, dim]``. Each set of k
82
- vectors is L2-normalized along ``dim``, then we compute the Gramian
83
- ``G = V V^T`` and return ``sqrt(det(G + eps I))`` averaged over the batch.
84
- """
85
- if vectors.ndim != 3:
86
- raise ValueError("gramian_volume_loss expects a [batch, k, dim] tensor")
87
- if eps <= 0.0:
88
- raise ValueError("gramian_volume_loss eps must be positive")
89
-
90
- vectors = F.normalize(vectors.float(), dim=-1, eps=1e-8)
91
- gram = vectors @ vectors.transpose(-1, -2)
92
- k = gram.size(-1)
93
- gram = gram + eps * torch.eye(k, device=gram.device, dtype=gram.dtype)
94
- sign, logabsdet = torch.linalg.slogdet(gram)
95
- volume = torch.exp(0.5 * logabsdet)
96
- volume = torch.where(sign > 0, volume, volume.new_ones(volume.shape))
97
- return weighted_mean(volume, weights)
98
-
99
-
100
- def radius_order_hinge(
101
- specific: Tensor,
102
- general: Tensor,
103
- kappa: Tensor,
104
- margin: float,
105
- weights: Tensor | None = None,
106
- ) -> Tensor:
107
- if specific.shape[0] != general.shape[0]:
108
- raise ValueError("radius_order_hinge expects matching batch dimensions")
109
- if margin < 0.0:
110
- raise ValueError("radius_order_hinge margin must be non-negative")
111
- specific_radius = radial_distance(specific, kappa)
112
- general_radius = radial_distance(general, kappa)
113
- losses = F.relu(float(margin) + general_radius - specific_radius)
114
- return weighted_mean(losses, weights)
115
-
116
-
117
- def soft_contrastive_ce(logits: Tensor, target_weights: Tensor, weights: Tensor | None = None) -> Tensor:
118
- if logits.ndim != 2 or target_weights.ndim != 2:
119
- raise ValueError("soft_contrastive_ce expects [batch, classes] tensors")
120
- if logits.shape != target_weights.shape:
121
- raise ValueError("soft_contrastive_ce requires logits and target_weights to have matching shapes")
122
- log_probs = F.log_softmax(logits, dim=1)
123
- losses = -(target_weights.to(dtype=log_probs.dtype) * log_probs).sum(dim=1)
124
- return weighted_mean(losses, weights)
125
-
126
-
127
- def beta_cal_loss(
128
- logits: Tensor,
129
- *,
130
- targets: Tensor,
131
- group_ids: Tensor,
132
- all_group_ids: Tensor,
133
- beta: float,
134
- variant: str,
135
- weights: Tensor | None = None,
136
- ) -> Tensor:
137
- if beta < 0.0:
138
- raise ValueError("beta_cal_loss beta must be non-negative")
139
- if variant not in {"ce", "bce"}:
140
- raise ValueError("beta_cal_loss variant must be 'ce' or 'bce'")
141
- if logits.ndim != 2:
142
- raise ValueError("beta_cal_loss expects a [batch, classes] logit matrix")
143
- if targets.shape != (logits.size(0),):
144
- raise ValueError("beta_cal_loss targets must have shape [batch]")
145
- if group_ids.shape != (logits.size(0),):
146
- raise ValueError("beta_cal_loss group_ids must have shape [batch]")
147
- if all_group_ids.shape != (logits.size(1),):
148
- raise ValueError("beta_cal_loss all_group_ids must have shape [classes]")
149
-
150
- same_group = group_ids[:, None] == all_group_ids[None, :]
151
- same_pair = targets[:, None] == torch.arange(logits.size(1), device=logits.device)[None, :]
152
-
153
- if variant == "ce":
154
- target_weights = logits.new_zeros(logits.shape)
155
- target_weights = torch.where(same_pair, logits.new_ones(()), target_weights)
156
- target_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), target_weights)
157
- target_weights = target_weights / target_weights.sum(dim=1, keepdim=True).clamp_min(
158
- torch.finfo(target_weights.dtype).eps
159
- )
160
- return soft_contrastive_ce(logits, target_weights, weights)
161
-
162
- labels = same_group.to(dtype=logits.dtype)
163
- element_weights = logits.new_ones(logits.shape)
164
- element_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), element_weights)
165
- element_losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") * element_weights
166
- row_losses = element_losses.mean(dim=1)
167
- return weighted_mean(row_losses, weights)
168
-
169
- def compositional_contrastive_loss(
170
- image_feats: Tensor,
171
- text_feats: Tensor,
172
- box_image_feats: Tensor,
173
- box_text_feats: Tensor,
174
- kappa: Tensor,
175
- logit_scale: Tensor,
176
- all_image_feats: Tensor | None = None,
177
- all_text_feats: Tensor | None = None,
178
- targets: Tensor | None = None,
179
- ) -> Tensor:
180
- scale = logit_scale.exp().clamp(max=100.0)
181
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
182
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
183
-
184
- logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
185
- logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
186
- logits_bi_t = -metric_pairwise_dist(box_image_feats, all_text_feats, kappa) * scale
187
- logits_bt_i = -metric_pairwise_dist(box_text_feats, all_image_feats, kappa) * scale
188
-
189
- return 0.25 * (
190
- contrastive_ce(logits_i_t, targets)
191
- + contrastive_ce(logits_t_i, targets)
192
- + contrastive_ce(logits_bi_t, targets)
193
- + contrastive_ce(logits_bt_i, targets)
194
- )
195
-
196
-
197
- def multi_part_contrastive_loss(
198
- image_feats: Tensor,
199
- text_feats: Tensor,
200
- part_image_feats: Tensor,
201
- part_text_feats: Tensor,
202
- part_mask: Tensor,
203
- kappa: Tensor,
204
- logit_scale: Tensor,
205
- all_image_feats: Tensor | None = None,
206
- all_text_feats: Tensor | None = None,
207
- targets: Tensor | None = None,
208
- ) -> Tensor:
209
- scale = logit_scale.exp().clamp(max=100.0)
210
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
211
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
212
- if targets is None:
213
- targets = torch.arange(image_feats.size(0), device=image_feats.device)
214
-
215
- part_image_flat, part_text_flat, part_targets = _flatten_valid_parts(part_image_feats, part_text_feats, part_mask, targets)
216
-
217
- logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
218
- logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
219
- logits_pi_t = -metric_pairwise_dist(part_image_flat, all_text_feats, kappa) * scale
220
- logits_pt_i = -metric_pairwise_dist(part_text_flat, all_image_feats, kappa) * scale
221
-
222
- return 0.25 * (
223
- contrastive_ce(logits_i_t, targets)
224
- + contrastive_ce(logits_t_i, targets)
225
- + contrastive_ce(logits_pi_t, part_targets)
226
- + contrastive_ce(logits_pt_i, part_targets)
227
- )
228
-
229
-
230
- def packed_part_contrastive_loss(
231
- image_feats: Tensor,
232
- text_feats: Tensor,
233
- part_image_feats: Tensor,
234
- part_text_feats: Tensor,
235
- part_owner: Tensor,
236
- kappa: Tensor,
237
- logit_scale: Tensor,
238
- all_image_feats: Tensor | None = None,
239
- all_text_feats: Tensor | None = None,
240
- targets: Tensor | None = None,
241
- ) -> Tensor:
242
- scale = logit_scale.exp().clamp(max=100.0)
243
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
244
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
245
- if targets is None:
246
- targets = torch.arange(image_feats.size(0), device=image_feats.device)
247
-
248
- logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
249
- logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
250
- global_loss = 0.5 * (contrastive_ce(logits_i_t, targets) + contrastive_ce(logits_t_i, targets))
251
-
252
- if part_image_feats.numel() == 0:
253
- return global_loss
254
-
255
- part_targets = targets[part_owner]
256
- logits_pi_t = -metric_pairwise_dist(part_image_feats, all_text_feats, kappa) * scale
257
- logits_pt_i = -metric_pairwise_dist(part_text_feats, all_image_feats, kappa) * scale
258
- part_loss = 0.5 * (contrastive_ce(logits_pi_t, part_targets) + contrastive_ce(logits_pt_i, part_targets))
259
- return 0.5 * (global_loss + part_loss)
260
-
261
-
262
- def factor_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor) -> Tensor:
263
- if specific.dim() != 3:
264
- return oxy_angle(specific=specific, general=general, kappa=kappa)
265
- batch_size, num_factors, feature_dim = specific.shape
266
- kappa = _factor_kappa(kappa, num_factors, specific.device)
267
- factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
268
- return oxy_angle(
269
- specific=specific.reshape(batch_size * num_factors, feature_dim),
270
- general=general.reshape(batch_size * num_factors, feature_dim),
271
- kappa=factor_kappa,
272
- ).reshape(batch_size, num_factors)
273
-
274
-
275
- def factor_half_aperture(general: Tensor, kappa: Tensor) -> Tensor:
276
- if general.dim() != 3:
277
- return half_aperture(general=general, kappa=kappa)
278
- batch_size, num_factors, feature_dim = general.shape
279
- kappa = _factor_kappa(kappa, num_factors, general.device)
280
- factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
281
- return half_aperture(
282
- general=general.reshape(batch_size * num_factors, feature_dim),
283
- kappa=factor_kappa,
284
- ).reshape(batch_size, num_factors)
285
-
286
-
287
- def _factor_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
288
- kappa = kappa.to(device=device, dtype=torch.float32).flatten()
289
- if kappa.numel() == 1:
290
- return kappa.expand(num_factors)
291
- if kappa.numel() != num_factors:
292
- raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
293
- return kappa
294
-
295
-
296
- def entailment_residual(
297
- specific: Tensor,
298
- general: Tensor,
299
- kappa: Tensor,
300
- aperture_scale: float,
301
- ) -> Tensor:
302
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
303
- apertures = factor_half_aperture(general=general, kappa=kappa)
304
- return torch.clamp(angles - (aperture_scale * apertures), min=0.0).mean()
305
-
306
-
307
- def weighted_entailment_residual(
308
- specific: Tensor,
309
- general: Tensor,
310
- kappa: Tensor,
311
- aperture_scale: float,
312
- weights: Tensor | None = None,
313
- ) -> Tensor:
314
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
315
- apertures = factor_half_aperture(general=general, kappa=kappa)
316
- residuals = torch.clamp(angles - (aperture_scale * apertures), min=0.0)
317
- if residuals.dim() == 2:
318
- residuals = residuals.mean(dim=-1)
319
- return weighted_mean(residuals, weights)
320
-
321
-
322
- def compositional_entailment_loss(
323
- image_feats: Tensor,
324
- text_feats: Tensor,
325
- box_image_feats: Tensor,
326
- box_text_feats: Tensor,
327
- kappa: Tensor,
328
- inter_aperture_scale: float,
329
- intra_aperture_scale: float,
330
- ) -> Tensor:
331
- text_to_image = entailment_residual(
332
- specific=image_feats,
333
- general=text_feats,
334
- kappa=kappa,
335
- aperture_scale=inter_aperture_scale,
336
- )
337
- box_text_to_box_image = entailment_residual(
338
- specific=box_image_feats,
339
- general=box_text_feats,
340
- kappa=kappa,
341
- aperture_scale=inter_aperture_scale,
342
- )
343
- box_image_to_image = entailment_residual(
344
- specific=image_feats,
345
- general=box_image_feats,
346
- kappa=kappa,
347
- aperture_scale=intra_aperture_scale,
348
- )
349
- box_text_to_text = entailment_residual(
350
- specific=text_feats,
351
- general=box_text_feats,
352
- kappa=kappa,
353
- aperture_scale=intra_aperture_scale,
354
- )
355
-
356
- return 0.5 * (text_to_image + box_text_to_box_image + box_image_to_image + box_text_to_text)
357
-
358
-
359
- def multi_part_entailment_loss(
360
- image_feats: Tensor,
361
- text_feats: Tensor,
362
- part_image_feats: Tensor,
363
- part_text_feats: Tensor,
364
- part_mask: Tensor,
365
- kappa: Tensor,
366
- inter_aperture_scale: float,
367
- intra_aperture_scale: float,
368
- ) -> Tensor:
369
- part_image_flat = part_image_feats[part_mask]
370
- part_text_flat = part_text_feats[part_mask]
371
- image_for_parts = image_feats[:, None, :].expand_as(part_image_feats)[part_mask]
372
- text_for_parts = text_feats[:, None, :].expand_as(part_text_feats)[part_mask]
373
-
374
- text_to_image = entailment_residual(
375
- specific=image_feats,
376
- general=text_feats,
377
- kappa=kappa,
378
- aperture_scale=inter_aperture_scale,
379
- )
380
- part_text_to_part_image = entailment_residual(
381
- specific=part_image_flat,
382
- general=part_text_flat,
383
- kappa=kappa,
384
- aperture_scale=inter_aperture_scale,
385
- )
386
- part_image_to_image = entailment_residual(
387
- specific=image_for_parts,
388
- general=part_image_flat,
389
- kappa=kappa,
390
- aperture_scale=intra_aperture_scale,
391
- )
392
- part_text_to_text = entailment_residual(
393
- specific=text_for_parts,
394
- general=part_text_flat,
395
- kappa=kappa,
396
- aperture_scale=intra_aperture_scale,
397
- )
398
-
399
- return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
400
-
401
-
402
- def packed_part_entailment_loss(
403
- image_feats: Tensor,
404
- text_feats: Tensor,
405
- part_image_feats: Tensor,
406
- part_text_feats: Tensor,
407
- part_owner: Tensor,
408
- kappa: Tensor,
409
- inter_aperture_scale: float,
410
- intra_aperture_scale: float,
411
- ) -> Tensor:
412
- text_to_image = entailment_residual(
413
- specific=image_feats,
414
- general=text_feats,
415
- kappa=kappa,
416
- aperture_scale=inter_aperture_scale,
417
- )
418
- if part_image_feats.numel() == 0:
419
- return text_to_image
420
-
421
- image_for_parts = image_feats[part_owner]
422
- text_for_parts = text_feats[part_owner]
423
- part_text_to_part_image = entailment_residual(
424
- specific=part_image_feats,
425
- general=part_text_feats,
426
- kappa=kappa,
427
- aperture_scale=inter_aperture_scale,
428
- )
429
- part_image_to_image = entailment_residual(
430
- specific=image_for_parts,
431
- general=part_image_feats,
432
- kappa=kappa,
433
- aperture_scale=intra_aperture_scale,
434
- )
435
- part_text_to_text = entailment_residual(
436
- specific=text_for_parts,
437
- general=part_text_feats,
438
- kappa=kappa,
439
- aperture_scale=intra_aperture_scale,
440
- )
441
-
442
- return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
443
-
444
-
445
- def uncha_contrastive_losses(
446
- image_feats: Tensor,
447
- text_feats: Tensor,
448
- part_image_flat: Tensor,
449
- part_text_flat: Tensor,
450
- image_for_parts: Tensor,
451
- text_for_parts: Tensor,
452
- kappa: Tensor,
453
- global_logit_scale: Tensor,
454
- local_logit_scale: Tensor,
455
- global_local_logit_scale: Tensor,
456
- image_euc_feats: Tensor | None = None,
457
- text_euc_feats: Tensor | None = None,
458
- part_image_euc_flat: Tensor | None = None,
459
- part_text_euc_flat: Tensor | None = None,
460
- image_for_parts_euc: Tensor | None = None,
461
- text_for_parts_euc: Tensor | None = None,
462
- all_image_feats: Tensor | None = None,
463
- all_text_feats: Tensor | None = None,
464
- all_part_image_feats: Tensor | None = None,
465
- all_part_text_feats: Tensor | None = None,
466
- all_image_for_parts: Tensor | None = None,
467
- all_text_for_parts: Tensor | None = None,
468
- all_image_euc_feats: Tensor | None = None,
469
- all_text_euc_feats: Tensor | None = None,
470
- all_part_image_euc_feats: Tensor | None = None,
471
- all_part_text_euc_feats: Tensor | None = None,
472
- all_image_for_parts_euc: Tensor | None = None,
473
- all_text_for_parts_euc: Tensor | None = None,
474
- global_targets: Tensor | None = None,
475
- part_targets: Tensor | None = None,
476
- part_weights: Tensor | None = None,
477
- product_metric: str = "l1",
478
- loss_type: str = "ce",
479
- contrastive_global_weight: float = 1.0,
480
- contrastive_local_weight: float = 1.0,
481
- contrastive_global_local_weight: float = 1.0,
482
- beta_cal_beta: float = 0.0,
483
- beta_cal_variant: str = "ce",
484
- beta_cal_weight: float = 0.0,
485
- part_group_ids: Tensor | None = None,
486
- all_part_group_ids: Tensor | None = None,
487
- global_logit_bias: Tensor | None = None,
488
- local_logit_bias: Tensor | None = None,
489
- global_local_logit_bias: Tensor | None = None,
490
- sigmoid_negative_weight: float = 1.0,
491
- global_local_mode: str = "repeat",
492
- global_local_metric: str = "distance",
493
- global_local_angle_aux_weight: float = 0.0,
494
- global_local_angle_aux_mode: str = "contrastive",
495
- global_local_angle_aux_scale: float = 5.5,
496
- global_local_angle_aux_aperture_scale: float = 1.0,
497
- ) -> dict[str, Tensor]:
498
- if loss_type not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
499
- raise ValueError(
500
- f"Unsupported contrastive loss {loss_type!r}; expected 'ce', 'sigmoid', 'siglip', or 'siglip_metric'"
501
- )
502
- if global_local_mode not in {"repeat", "inbatch"}:
503
- raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
504
- if global_local_metric not in {"distance", "angle"}:
505
- raise ValueError("global_local_metric must be 'distance' or 'angle'")
506
- if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
507
- raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
508
- if global_local_angle_aux_weight < 0.0:
509
- raise ValueError("global_local_angle_aux_weight must be non-negative")
510
- if global_local_angle_aux_scale <= 0.0:
511
- raise ValueError("global_local_angle_aux_scale must be positive")
512
- if global_local_angle_aux_aperture_scale <= 0.0:
513
- raise ValueError("global_local_angle_aux_aperture_scale must be positive")
514
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
515
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
516
- all_part_image_feats = part_image_flat if all_part_image_feats is None else all_part_image_feats
517
- all_part_text_feats = part_text_flat if all_part_text_feats is None else all_part_text_feats
518
- all_image_for_parts = image_for_parts if all_image_for_parts is None else all_image_for_parts
519
- all_text_for_parts = text_for_parts if all_text_for_parts is None else all_text_for_parts
520
- if global_targets is None:
521
- global_targets = torch.arange(image_feats.size(0), device=image_feats.device)
522
- if part_targets is None:
523
- part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device)
524
-
525
- global_scale = global_logit_scale.exp().clamp(max=100.0)
526
- local_scale = local_logit_scale.exp().clamp(max=100.0)
527
- global_local_scale = global_local_logit_scale.exp().clamp(max=100.0)
528
-
529
- if loss_type == "siglip":
530
- if image_euc_feats is None or text_euc_feats is None:
531
- raise ValueError("siglip contrastive requires image_euc_feats and text_euc_feats")
532
- if image_feats.dim() != 2 or text_feats.dim() != 2:
533
- raise ValueError("siglip contrastive is only supported for non-product features")
534
- all_image_euc_feats = image_euc_feats if all_image_euc_feats is None else all_image_euc_feats
535
- all_text_euc_feats = text_euc_feats if all_text_euc_feats is None else all_text_euc_feats
536
- zimg = F.normalize(image_euc_feats.float(), dim=-1)
537
- ztxt = F.normalize(text_euc_feats.float(), dim=-1)
538
- zimg_all = F.normalize(all_image_euc_feats.float(), dim=-1)
539
- ztxt_all = F.normalize(all_text_euc_feats.float(), dim=-1)
540
- image_logits = (zimg @ ztxt_all.T) * global_scale
541
- text_logits = (ztxt @ zimg_all.T) * global_scale
542
- else:
543
- image_logits = -metric_pairwise_dist(image_feats, all_text_feats, kappa, product_metric=product_metric) * global_scale
544
- text_logits = -metric_pairwise_dist(text_feats, all_image_feats, kappa, product_metric=product_metric) * global_scale
545
-
546
- if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
547
- bias = image_logits.new_zeros(()) if global_logit_bias is None else global_logit_bias.to(image_logits.device)
548
- image_logits = image_logits + bias
549
- text_logits = text_logits + bias
550
- global_contrastive = 0.5 * (
551
- _contrastive_loss(image_logits, global_targets, None, loss_type, sigmoid_negative_weight)
552
- + _contrastive_loss(text_logits, global_targets, None, loss_type, sigmoid_negative_weight)
553
- )
554
-
555
- if part_image_flat.numel() == 0:
556
- zero = image_feats.new_zeros(())
557
- contrastive = contrastive_global_weight * global_contrastive
558
- return {
559
- "contrastive_loss": contrastive,
560
- "global_contrastive_loss": global_contrastive,
561
- "local_contrastive_loss": zero,
562
- "global_local_contrastive_loss": zero,
563
- "global_local_angle_aux_loss": zero,
564
- "beta_cal_loss": zero,
565
- }
566
-
567
- if loss_type == "siglip":
568
- if part_image_euc_flat is None or part_text_euc_flat is None:
569
- raise ValueError("siglip contrastive requires part_image_euc_flat and part_text_euc_flat when parts exist")
570
- all_part_image_euc_feats = part_image_euc_flat if all_part_image_euc_feats is None else all_part_image_euc_feats
571
- all_part_text_euc_feats = part_text_euc_flat if all_part_text_euc_feats is None else all_part_text_euc_feats
572
- zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
573
- zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
574
- zpi_all = F.normalize(all_part_image_euc_feats.float(), dim=-1)
575
- zpt_all = F.normalize(all_part_text_euc_feats.float(), dim=-1)
576
- part_image_logits = (zpi @ zpt_all.T) * local_scale
577
- part_text_logits = (zpt @ zpi_all.T) * local_scale
578
- else:
579
- part_image_logits = -metric_pairwise_dist(part_image_flat, all_part_text_feats, kappa, product_metric=product_metric) * local_scale
580
- part_text_logits = -metric_pairwise_dist(part_text_flat, all_part_image_feats, kappa, product_metric=product_metric) * local_scale
581
-
582
- if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
583
- bias = part_image_logits.new_zeros(()) if local_logit_bias is None else local_logit_bias.to(part_image_logits.device)
584
- part_image_logits = part_image_logits + bias
585
- part_text_logits = part_text_logits + bias
586
- local_contrastive = 0.5 * (
587
- _contrastive_loss(part_image_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
588
- + _contrastive_loss(part_text_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
589
- )
590
-
591
- global_local_contrastive = image_feats.new_zeros(())
592
- global_local_angle_aux = image_feats.new_zeros(())
593
- if contrastive_global_local_weight != 0.0:
594
- if global_local_mode == "inbatch":
595
- if part_group_ids is None:
596
- raise ValueError("inbatch global-local contrastive requires part_group_ids to be provided")
597
- global_local_targets = part_group_ids
598
- all_text_for_global_local = all_text_feats
599
- all_image_for_global_local = all_image_feats
600
- all_text_for_global_local_euc = all_text_euc_feats
601
- all_image_for_global_local_euc = all_image_euc_feats
602
- else:
603
- global_local_targets = part_targets
604
- all_text_for_global_local = all_text_for_parts
605
- all_image_for_global_local = all_image_for_parts
606
- all_text_for_global_local_euc = all_text_for_parts_euc
607
- all_image_for_global_local_euc = all_image_for_parts_euc
608
-
609
- image_uncertainty = embedding_uncertainty(part_image_flat).detach()
610
- text_uncertainty = embedding_uncertainty(part_text_flat).detach()
611
- image_temp = torch.exp(-0.5 * image_uncertainty).clamp(min=0.1, max=10.0)
612
- text_temp = torch.exp(-0.5 * text_uncertainty).clamp(min=0.1, max=10.0)
613
-
614
- if loss_type == "siglip":
615
- if part_image_euc_flat is None or part_text_euc_flat is None:
616
- raise ValueError("siglip global-local contrastive requires part_image_euc_flat/part_text_euc_flat")
617
- if all_text_for_global_local_euc is None or all_image_for_global_local_euc is None:
618
- raise ValueError("siglip global-local contrastive requires all_image_euc_feats/all_text_euc_feats")
619
- zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
620
- zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
621
- zimg_all = F.normalize(all_image_for_global_local_euc.float(), dim=-1)
622
- ztxt_all = F.normalize(all_text_for_global_local_euc.float(), dim=-1)
623
- part_image_to_whole_text = (zpi @ ztxt_all.T) * image_temp[:, None] * global_local_scale
624
- part_text_to_whole_image = (zpt @ zimg_all.T) * text_temp[:, None] * global_local_scale
625
- else:
626
- if global_local_metric == "angle":
627
- part_image_to_whole_text = -metric_pairwise_oxy_angle(
628
- part_image_flat,
629
- all_text_for_global_local,
630
- kappa,
631
- product_metric=product_metric,
632
- )
633
- part_text_to_whole_image = -metric_pairwise_oxy_angle(
634
- part_text_flat,
635
- all_image_for_global_local,
636
- kappa,
637
- product_metric=product_metric,
638
- )
639
- else:
640
- part_image_to_whole_text = -metric_pairwise_dist(
641
- part_image_flat, all_text_for_global_local, kappa, product_metric=product_metric
642
- )
643
- part_text_to_whole_image = -metric_pairwise_dist(
644
- part_text_flat, all_image_for_global_local, kappa, product_metric=product_metric
645
- )
646
- part_image_to_whole_text = part_image_to_whole_text * image_temp[:, None] * global_local_scale
647
- part_text_to_whole_image = part_text_to_whole_image * text_temp[:, None] * global_local_scale
648
-
649
- if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
650
- bias = (
651
- part_image_to_whole_text.new_zeros(())
652
- if global_local_logit_bias is None
653
- else global_local_logit_bias.to(part_image_to_whole_text.device)
654
- )
655
- part_image_to_whole_text = part_image_to_whole_text + bias
656
- part_text_to_whole_image = part_text_to_whole_image + bias
657
-
658
- global_local_contrastive = 0.5 * (
659
- _contrastive_loss(part_image_to_whole_text, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
660
- + _contrastive_loss(part_text_to_whole_image, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
661
- )
662
-
663
- if global_local_angle_aux_weight > 0.0:
664
- if global_local_angle_aux_mode == "positive_hinge":
665
- positive_text = all_text_for_global_local.index_select(0, global_local_targets)
666
- positive_image = all_image_for_global_local.index_select(0, global_local_targets)
667
- global_local_angle_aux = 0.5 * (
668
- weighted_entailment_residual(
669
- specific=part_image_flat,
670
- general=positive_text,
671
- kappa=kappa,
672
- aperture_scale=global_local_angle_aux_aperture_scale,
673
- weights=part_weights,
674
- )
675
- + weighted_entailment_residual(
676
- specific=part_text_flat,
677
- general=positive_image,
678
- kappa=kappa,
679
- aperture_scale=global_local_angle_aux_aperture_scale,
680
- weights=part_weights,
681
- )
682
- )
683
- elif loss_type != "siglip":
684
- angle_scale = part_image_flat.new_tensor(float(global_local_angle_aux_scale))
685
- part_image_to_whole_text_angle = -metric_pairwise_oxy_angle(
686
- part_image_flat,
687
- all_text_for_global_local,
688
- kappa,
689
- product_metric=product_metric,
690
- ) * image_temp[:, None] * angle_scale
691
- part_text_to_whole_image_angle = -metric_pairwise_oxy_angle(
692
- part_text_flat,
693
- all_image_for_global_local,
694
- kappa,
695
- product_metric=product_metric,
696
- ) * text_temp[:, None] * angle_scale
697
- if loss_type in {"sigmoid", "siglip_metric"}:
698
- bias = (
699
- part_image_to_whole_text_angle.new_zeros(())
700
- if global_local_logit_bias is None
701
- else global_local_logit_bias.to(part_image_to_whole_text_angle.device)
702
- )
703
- part_image_to_whole_text_angle = part_image_to_whole_text_angle + bias
704
- part_text_to_whole_image_angle = part_text_to_whole_image_angle + bias
705
- global_local_angle_aux = 0.5 * (
706
- _contrastive_loss(
707
- part_image_to_whole_text_angle,
708
- global_local_targets,
709
- part_weights,
710
- loss_type,
711
- sigmoid_negative_weight,
712
- )
713
- + _contrastive_loss(
714
- part_text_to_whole_image_angle,
715
- global_local_targets,
716
- part_weights,
717
- loss_type,
718
- sigmoid_negative_weight,
719
- )
720
- )
721
-
722
- beta_cal = image_feats.new_zeros(())
723
- if beta_cal_weight > 0.0 and beta_cal_beta > 0.0:
724
- if part_group_ids is None or all_part_group_ids is None:
725
- raise ValueError("beta_cal requires part_group_ids and all_part_group_ids to be provided")
726
- beta_cal = 0.5 * (
727
- beta_cal_loss(
728
- part_image_logits,
729
- targets=part_targets,
730
- group_ids=part_group_ids,
731
- all_group_ids=all_part_group_ids,
732
- beta=beta_cal_beta,
733
- variant=beta_cal_variant,
734
- weights=part_weights,
735
- )
736
- + beta_cal_loss(
737
- part_text_logits,
738
- targets=part_targets,
739
- group_ids=part_group_ids,
740
- all_group_ids=all_part_group_ids,
741
- beta=beta_cal_beta,
742
- variant=beta_cal_variant,
743
- weights=part_weights,
744
- )
745
- )
746
-
747
- contrastive = (
748
- contrastive_global_weight * global_contrastive
749
- + contrastive_local_weight * local_contrastive
750
- + contrastive_global_local_weight * global_local_contrastive
751
- + global_local_angle_aux_weight * global_local_angle_aux
752
- + beta_cal_weight * beta_cal
753
- )
754
- return {
755
- "contrastive_loss": contrastive,
756
- "global_contrastive_loss": global_contrastive,
757
- "local_contrastive_loss": local_contrastive,
758
- "global_local_contrastive_loss": global_local_contrastive,
759
- "global_local_angle_aux_loss": global_local_angle_aux,
760
- "beta_cal_loss": beta_cal,
761
- }
762
-
763
-
764
- def _contrastive_loss(
765
- logits: Tensor,
766
- targets: Tensor,
767
- weights: Tensor | None,
768
- loss_type: str,
769
- sigmoid_negative_weight: float,
770
- ) -> Tensor:
771
- if loss_type == "ce":
772
- return contrastive_ce(logits, targets, weights)
773
- if loss_type == "sigmoid":
774
- return contrastive_sigmoid(logits, targets, weights, negative_weight=sigmoid_negative_weight)
775
- if loss_type in {"siglip", "siglip_metric"}:
776
- return contrastive_siglip(logits, targets, weights, negative_weight=sigmoid_negative_weight)
777
- raise ValueError(f"Unsupported contrastive loss {loss_type!r}")
778
-
779
-
780
- def uncha_entailment_losses(
781
- image_feats: Tensor,
782
- text_feats: Tensor,
783
- part_image_flat: Tensor,
784
- part_text_flat: Tensor,
785
- image_for_parts: Tensor,
786
- text_for_parts: Tensor,
787
- kappa: Tensor,
788
- inter_aperture_scale: float,
789
- intra_aperture_scale: float,
790
- piecewise_factor: float = 0.1,
791
- calibration_alpha: float = 10.0,
792
- stop_grad_calibration: bool = True,
793
- geometry: str = "lorentz",
794
- part_weights: Tensor | None = None,
795
- ) -> dict[str, Tensor]:
796
- text_image = piecewise_entailment_residual(
797
- specific=image_feats,
798
- general=text_feats,
799
- kappa=kappa,
800
- aperture_scale=inter_aperture_scale,
801
- factor=piecewise_factor,
802
- geometry=geometry,
803
- )
804
- text_image_entailment = 0.5 * text_image.mean()
805
-
806
- if part_image_flat.numel() == 0:
807
- zero = image_feats.new_zeros(())
808
- return {
809
- "entailment_loss": text_image_entailment,
810
- "text_image_entailment_loss": text_image_entailment,
811
- "part_text_image_entailment_loss": zero,
812
- "cross_image_entailment_loss": zero,
813
- "cross_text_entailment_loss": zero,
814
- "cross_image_calibration_loss": zero,
815
- "cross_text_calibration_loss": zero,
816
- }
817
-
818
- part_text_image = piecewise_entailment_residual(
819
- specific=part_image_flat,
820
- general=part_text_flat,
821
- kappa=kappa,
822
- aperture_scale=inter_aperture_scale,
823
- factor=piecewise_factor,
824
- geometry=geometry,
825
- )
826
- cross_image = piecewise_entailment_residual(
827
- specific=image_for_parts,
828
- general=part_image_flat,
829
- kappa=kappa,
830
- aperture_scale=intra_aperture_scale,
831
- factor=piecewise_factor,
832
- geometry=geometry,
833
- )
834
- cross_text = piecewise_entailment_residual(
835
- specific=text_for_parts,
836
- general=part_text_flat,
837
- kappa=kappa,
838
- aperture_scale=intra_aperture_scale,
839
- factor=piecewise_factor,
840
- geometry=geometry,
841
- )
842
-
843
- part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
844
- cross_image_entailment, cross_image_calibration = uncertainty_calibrated_entailment_loss(
845
- cross_image,
846
- embedding_uncertainty(part_image_flat),
847
- alpha=calibration_alpha,
848
- stop_grad=stop_grad_calibration,
849
- weights=part_weights,
850
- )
851
- cross_text_entailment, cross_text_calibration = uncertainty_calibrated_entailment_loss(
852
- cross_text,
853
- embedding_uncertainty(part_text_flat),
854
- alpha=calibration_alpha,
855
- stop_grad=stop_grad_calibration,
856
- weights=part_weights,
857
- )
858
-
859
- entailment = (
860
- text_image_entailment
861
- + part_text_image_entailment
862
- + 0.5 * (cross_image_entailment + cross_text_entailment)
863
- + cross_image_calibration
864
- + cross_text_calibration
865
- )
866
- return {
867
- "entailment_loss": entailment,
868
- "text_image_entailment_loss": text_image_entailment,
869
- "part_text_image_entailment_loss": part_text_image_entailment,
870
- "cross_image_entailment_loss": cross_image_entailment,
871
- "cross_text_entailment_loss": cross_text_entailment,
872
- "cross_image_calibration_loss": cross_image_calibration,
873
- "cross_text_calibration_loss": cross_text_calibration,
874
- }
875
-
876
-
877
- def uncha_argent_entailment_losses(
878
- image_feats: Tensor,
879
- text_feats: Tensor,
880
- part_image_flat: Tensor,
881
- part_text_flat: Tensor,
882
- image_for_parts: Tensor,
883
- text_for_parts: Tensor,
884
- kappa: Tensor,
885
- beta: float = 1.0,
886
- part_weights: Tensor | None = None,
887
- product_metric: str = "l1",
888
- aggregation: str = "uncha",
889
- ) -> dict[str, Tensor]:
890
- if aggregation not in {"uncha", "equal"}:
891
- raise ValueError("aggregation must be 'uncha' or 'equal'")
892
- text_image = argent_adaptive_entailment_residual(
893
- specific=image_feats,
894
- general=text_feats,
895
- kappa=kappa,
896
- adaptive_weight=False,
897
- beta=beta,
898
- product_metric=product_metric,
899
- )
900
- text_image_entailment = 0.5 * text_image.mean()
901
-
902
- if part_image_flat.numel() == 0:
903
- zero = image_feats.new_zeros(())
904
- norm_regularization = argent_norm_regularization_loss(image_feats, text_feats)
905
- return {
906
- "entailment_loss": text_image_entailment,
907
- "text_image_entailment_loss": text_image_entailment,
908
- "part_text_image_entailment_loss": zero,
909
- "cross_image_entailment_loss": zero,
910
- "cross_text_entailment_loss": zero,
911
- "cross_image_calibration_loss": zero,
912
- "cross_text_calibration_loss": zero,
913
- "norm_regularization_loss": norm_regularization,
914
- }
915
-
916
- part_text_image = argent_adaptive_entailment_residual(
917
- specific=part_image_flat,
918
- general=part_text_flat,
919
- kappa=kappa,
920
- adaptive_weight=False,
921
- beta=beta,
922
- product_metric=product_metric,
923
- )
924
- cross_image = argent_adaptive_entailment_residual(
925
- specific=image_for_parts,
926
- general=part_image_flat,
927
- kappa=kappa,
928
- adaptive_weight=True,
929
- beta=beta,
930
- product_metric=product_metric,
931
- )
932
- cross_text = argent_adaptive_entailment_residual(
933
- specific=text_for_parts,
934
- general=part_text_flat,
935
- kappa=kappa,
936
- adaptive_weight=True,
937
- beta=beta,
938
- product_metric=product_metric,
939
- )
940
-
941
- part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
942
- cross_image_entailment = 0.5 * weighted_mean(cross_image, part_weights)
943
- cross_text_entailment = 0.5 * weighted_mean(cross_text, part_weights)
944
- norm_regularization = argent_norm_regularization_loss(image_feats, text_feats, part_image_flat, part_text_flat)
945
- if aggregation == "equal":
946
- entailment = text_image_entailment + part_text_image_entailment + cross_image_entailment + cross_text_entailment
947
- else:
948
- entailment = text_image_entailment + part_text_image_entailment + 0.5 * (
949
- cross_image_entailment + cross_text_entailment
950
- )
951
- diagnostics = argent_entailment_diagnostics(
952
- image_feats=image_feats,
953
- text_feats=text_feats,
954
- part_image_flat=part_image_flat,
955
- part_text_flat=part_text_flat,
956
- image_for_parts=image_for_parts,
957
- text_for_parts=text_for_parts,
958
- kappa=kappa,
959
- product_metric=product_metric,
960
- )
961
-
962
- return {
963
- "entailment_loss": entailment,
964
- "text_image_entailment_loss": text_image_entailment,
965
- "part_text_image_entailment_loss": part_text_image_entailment,
966
- "cross_image_entailment_loss": cross_image_entailment,
967
- "cross_text_entailment_loss": cross_text_entailment,
968
- "cross_image_calibration_loss": image_feats.new_zeros(()),
969
- "cross_text_calibration_loss": image_feats.new_zeros(()),
970
- "norm_regularization_loss": norm_regularization,
971
- **diagnostics,
972
- }
973
-
974
-
975
- def hierarchical_beta_argent_entailment_losses(
976
- image_feats: Tensor,
977
- text_feats: Tensor,
978
- part_image_flat: Tensor,
979
- part_text_flat: Tensor,
980
- image_for_parts: Tensor,
981
- text_for_parts: Tensor,
982
- beta_query_image_feats: Tensor,
983
- beta_query_text_feats: Tensor,
984
- beta_query_owner: Tensor,
985
- beta_query_parent: Tensor,
986
- beta_query_weight: Tensor,
987
- kappa: Tensor,
988
- beta_query_source_part: Tensor | None = None,
989
- beta: float = 1.0,
990
- part_weights: Tensor | None = None,
991
- product_metric: str = "l1",
992
- aggregation: str = "uncha",
993
- ) -> dict[str, Tensor]:
994
- base = uncha_argent_entailment_losses(
995
- image_feats=image_feats,
996
- text_feats=text_feats,
997
- part_image_flat=part_image_flat,
998
- part_text_flat=part_text_flat,
999
- image_for_parts=image_for_parts,
1000
- text_for_parts=text_for_parts,
1001
- kappa=kappa,
1002
- beta=beta,
1003
- part_weights=part_weights,
1004
- product_metric=product_metric,
1005
- aggregation=aggregation,
1006
- )
1007
- if beta_query_image_feats.numel() == 0:
1008
- return {
1009
- **base,
1010
- "hier_beta_query_text_entailment_loss": image_feats.new_zeros(()),
1011
- "hier_beta_visual_entailment_loss": image_feats.new_zeros(()),
1012
- "hier_beta_text_entailment_loss": image_feats.new_zeros(()),
1013
- "hier_beta_sourcepart_visual_entailment_loss": image_feats.new_zeros(()),
1014
- "hier_beta_sourcepart_text_entailment_loss": image_feats.new_zeros(()),
1015
- "hier_beta_query_count": beta_query_owner.new_tensor(0),
1016
- "hier_beta_sourcepart_query_count": beta_query_owner.new_tensor(0),
1017
- }
1018
-
1019
- query_owner = beta_query_owner.to(device=image_feats.device, dtype=torch.long)
1020
- query_weights = beta_query_weight.to(device=image_feats.device, dtype=torch.float32).clamp_min(0.0)
1021
- if query_weights.numel() != beta_query_image_feats.size(0):
1022
- raise ValueError("beta_query_weight must have one value per beta query")
1023
- query_weights = query_weights / query_weights.mean().clamp_min(torch.finfo(query_weights.dtype).eps)
1024
-
1025
- query_text = argent_adaptive_entailment_residual(
1026
- specific=beta_query_image_feats,
1027
- general=beta_query_text_feats,
1028
- kappa=kappa,
1029
- adaptive_weight=False,
1030
- beta=beta,
1031
- product_metric=product_metric,
1032
- )
1033
- visual_hierarchy = argent_adaptive_entailment_residual(
1034
- specific=image_feats.index_select(0, query_owner),
1035
- general=beta_query_image_feats,
1036
- kappa=kappa,
1037
- adaptive_weight=True,
1038
- beta=beta,
1039
- product_metric=product_metric,
1040
- )
1041
- query_text_entailment = 0.5 * weighted_mean(query_text, query_weights)
1042
- visual_entailment = 0.5 * weighted_mean(visual_hierarchy, query_weights)
1043
-
1044
- parent = beta_query_parent.to(device=image_feats.device, dtype=torch.long)
1045
- parent_mask = (parent >= 0) & (parent < beta_query_text_feats.size(0)) & (query_weights > 0.0)
1046
- if bool(parent_mask.any()):
1047
- child_text = beta_query_text_feats[parent_mask]
1048
- parent_text = beta_query_text_feats[parent[parent_mask]]
1049
- text_hierarchy = argent_adaptive_entailment_residual(
1050
- specific=parent_text,
1051
- general=child_text,
1052
- kappa=kappa,
1053
- adaptive_weight=True,
1054
- beta=beta,
1055
- product_metric=product_metric,
1056
- )
1057
- text_entailment = 0.5 * weighted_mean(text_hierarchy, query_weights[parent_mask])
1058
- else:
1059
- text_entailment = image_feats.new_zeros(())
1060
-
1061
- sourcepart_visual_entailment = image_feats.new_zeros(())
1062
- sourcepart_text_entailment = image_feats.new_zeros(())
1063
- sourcepart_query_count = beta_query_owner.new_tensor(0)
1064
- if beta_query_source_part is not None and part_image_flat.numel() > 0:
1065
- source_part = beta_query_source_part.to(device=image_feats.device, dtype=torch.long)
1066
- if source_part.numel() != beta_query_image_feats.size(0):
1067
- raise ValueError("beta_query_source_part must have one value per beta query")
1068
- source_mask = (
1069
- (source_part >= 0)
1070
- & (source_part < part_image_flat.size(0))
1071
- & (query_weights > 0.0)
1072
- )
1073
- if bool(source_mask.any()):
1074
- source_indices = source_part[source_mask]
1075
- sourcepart_visual = argent_adaptive_entailment_residual(
1076
- specific=part_image_flat.index_select(0, source_indices),
1077
- general=beta_query_image_feats[source_mask],
1078
- kappa=kappa,
1079
- adaptive_weight=True,
1080
- beta=beta,
1081
- product_metric=product_metric,
1082
- )
1083
- sourcepart_text = argent_adaptive_entailment_residual(
1084
- specific=part_text_flat.index_select(0, source_indices),
1085
- general=beta_query_text_feats[source_mask],
1086
- kappa=kappa,
1087
- adaptive_weight=True,
1088
- beta=beta,
1089
- product_metric=product_metric,
1090
- )
1091
- source_weights = query_weights[source_mask]
1092
- sourcepart_visual_entailment = 0.5 * weighted_mean(sourcepart_visual, source_weights)
1093
- sourcepart_text_entailment = 0.5 * weighted_mean(sourcepart_text, source_weights)
1094
- sourcepart_query_count = beta_query_owner.new_tensor(int(source_mask.sum().item()))
1095
-
1096
- norm_regularization = argent_norm_regularization_loss(
1097
- image_feats,
1098
- text_feats,
1099
- part_image_flat,
1100
- part_text_flat,
1101
- beta_query_image_feats,
1102
- beta_query_text_feats,
1103
- )
1104
- sourcepart_entailment = 0.5 * (sourcepart_visual_entailment + sourcepart_text_entailment)
1105
- query_entailment = query_text_entailment + 0.5 * (visual_entailment + text_entailment) + sourcepart_entailment
1106
- return {
1107
- **base,
1108
- "entailment_loss": base["entailment_loss"] + query_entailment,
1109
- "norm_regularization_loss": norm_regularization,
1110
- "hier_beta_query_text_entailment_loss": query_text_entailment,
1111
- "hier_beta_visual_entailment_loss": visual_entailment,
1112
- "hier_beta_text_entailment_loss": text_entailment,
1113
- "hier_beta_sourcepart_visual_entailment_loss": sourcepart_visual_entailment,
1114
- "hier_beta_sourcepart_text_entailment_loss": sourcepart_text_entailment,
1115
- "hier_beta_query_count": beta_query_owner.new_tensor(beta_query_owner.numel()),
1116
- "hier_beta_sourcepart_query_count": sourcepart_query_count,
1117
- }
1118
-
1119
-
1120
- def argent_entailment_diagnostics(
1121
- image_feats: Tensor,
1122
- text_feats: Tensor,
1123
- part_image_flat: Tensor,
1124
- part_text_flat: Tensor,
1125
- image_for_parts: Tensor,
1126
- text_for_parts: Tensor,
1127
- kappa: Tensor,
1128
- product_metric: str = "l1",
1129
- ) -> dict[str, Tensor]:
1130
- zero = image_feats.new_zeros(())
1131
-
1132
- def angle_mean(specific: Tensor, general: Tensor) -> Tensor:
1133
- if specific.numel() == 0:
1134
- return zero
1135
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1136
- if angles.dim() == 2:
1137
- angles = angles.mean(dim=-1)
1138
- return angles.detach().mean()
1139
-
1140
- def pent_mean(specific: Tensor, general: Tensor) -> Tensor:
1141
- if specific.numel() == 0:
1142
- return zero
1143
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1144
- if angles.dim() == 2:
1145
- angles = angles.mean(dim=-1)
1146
- scores = torch.clamp(1.0 - (2.0 * angles / math.pi), min=0.0, max=1.0)
1147
- return scores.detach().mean()
1148
-
1149
- def distance_mean(specific: Tensor, general: Tensor) -> Tensor:
1150
- if specific.numel() == 0:
1151
- return zero
1152
- return lorentz_dist(specific, general, kappa, product_metric=product_metric).detach().mean()
1153
-
1154
- def adaptive_weight_mean(specific: Tensor, general: Tensor) -> Tensor:
1155
- if specific.numel() == 0:
1156
- return zero
1157
- weights = 1.0 - torch.exp(-lorentz_dist(specific, general, kappa, product_metric=product_metric))
1158
- return weights.detach().mean()
1159
-
1160
- def space_norm_mean(embedding: Tensor) -> Tensor:
1161
- if embedding.numel() == 0:
1162
- return zero
1163
- return torch.linalg.norm(_space_components(embedding).float(), dim=-1).detach().mean()
1164
-
1165
- return {
1166
- "argent_text_image_angle_mean": angle_mean(image_feats, text_feats),
1167
- "argent_text_image_pent_mean": pent_mean(image_feats, text_feats),
1168
- "argent_part_text_image_angle_mean": angle_mean(part_image_flat, part_text_flat),
1169
- "argent_part_text_image_pent_mean": pent_mean(part_image_flat, part_text_flat),
1170
- "argent_cross_image_angle_mean": angle_mean(image_for_parts, part_image_flat),
1171
- "argent_cross_image_pent_mean": pent_mean(image_for_parts, part_image_flat),
1172
- "argent_cross_image_distance_mean": distance_mean(image_for_parts, part_image_flat),
1173
- "argent_cross_image_adaptive_weight_mean": adaptive_weight_mean(image_for_parts, part_image_flat),
1174
- "argent_cross_text_angle_mean": angle_mean(text_for_parts, part_text_flat),
1175
- "argent_cross_text_pent_mean": pent_mean(text_for_parts, part_text_flat),
1176
- "argent_cross_text_distance_mean": distance_mean(text_for_parts, part_text_flat),
1177
- "argent_cross_text_adaptive_weight_mean": adaptive_weight_mean(text_for_parts, part_text_flat),
1178
- "argent_image_space_norm_mean": space_norm_mean(image_feats),
1179
- "argent_text_space_norm_mean": space_norm_mean(text_feats),
1180
- "argent_part_image_space_norm_mean": space_norm_mean(part_image_flat),
1181
- "argent_part_text_space_norm_mean": space_norm_mean(part_text_flat),
1182
- }
1183
-
1184
-
1185
- def part_quality_weights(
1186
- image_for_parts: Tensor,
1187
- text_for_parts: Tensor,
1188
- part_image_flat: Tensor,
1189
- part_text_flat: Tensor,
1190
- part_owner: Tensor,
1191
- batch_size: int,
1192
- kappa: Tensor,
1193
- mode: str,
1194
- topk: int = 5,
1195
- temperature: float = 4.0,
1196
- product_metric: str = "l1",
1197
- ) -> tuple[Tensor | None, Tensor, Tensor]:
1198
- if mode not in {"none", "soft", "topk"}:
1199
- raise ValueError(f"Unsupported part quality mode {mode!r}; expected 'none', 'soft', or 'topk'")
1200
- if mode == "none" or part_image_flat.numel() == 0:
1201
- empty = part_image_flat.new_zeros((part_image_flat.size(0),))
1202
- return None, empty, empty
1203
-
1204
- with torch.no_grad():
1205
- image_parent = torch.exp(-lorentz_dist(part_image_flat, image_for_parts, kappa, product_metric=product_metric))
1206
- text_parent = torch.exp(-lorentz_dist(part_text_flat, text_for_parts, kappa, product_metric=product_metric))
1207
- image_text = torch.exp(-lorentz_dist(part_image_flat, part_text_flat, kappa, product_metric=product_metric))
1208
- scores = torch.stack([image_parent, text_parent, image_text]).mean(dim=0).clamp_min(0.0)
1209
-
1210
- if mode == "soft":
1211
- weights = _owner_softmax_weights(scores, part_owner, batch_size, temperature)
1212
- else:
1213
- weights = _owner_topk_weights(scores, part_owner, batch_size, topk)
1214
- weights = weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
1215
- return weights, scores, (weights > 0.0).to(dtype=scores.dtype)
1216
-
1217
-
1218
- def _owner_softmax_weights(scores: Tensor, part_owner: Tensor, batch_size: int, temperature: float) -> Tensor:
1219
- weights = torch.zeros_like(scores)
1220
- for owner in range(batch_size):
1221
- mask = part_owner == owner
1222
- if not bool(mask.any()):
1223
- continue
1224
- owner_scores = scores[mask]
1225
- owner_weights = torch.softmax(owner_scores * temperature, dim=0) * owner_scores.numel()
1226
- weights[mask] = owner_weights
1227
- return weights
1228
-
1229
-
1230
- def _owner_topk_weights(scores: Tensor, part_owner: Tensor, batch_size: int, topk: int) -> Tensor:
1231
- if topk <= 0:
1232
- raise ValueError("topk must be positive for top-k part quality weighting")
1233
- weights = torch.zeros_like(scores)
1234
- for owner in range(batch_size):
1235
- indices = torch.nonzero(part_owner == owner, as_tuple=False).flatten()
1236
- if indices.numel() == 0:
1237
- continue
1238
- keep = min(topk, indices.numel())
1239
- selected = indices[scores[indices].topk(k=keep).indices]
1240
- weights[selected] = 1.0
1241
- return weights
1242
-
1243
-
1244
- def argent_adaptive_entailment_residual(
1245
- specific: Tensor,
1246
- general: Tensor,
1247
- kappa: Tensor,
1248
- adaptive_weight: bool,
1249
- beta: float = 1.0,
1250
- product_metric: str = "l1",
1251
- ) -> Tensor:
1252
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1253
- if angles.dim() == 2:
1254
- angles = angles.mean(dim=-1)
1255
- if adaptive_weight:
1256
- weights = 1.0 - torch.exp(
1257
- -lorentz_dist(specific=specific, general=general, kappa=kappa, product_metric=product_metric)
1258
- )
1259
- angles = angles * weights
1260
- return F.huber_loss(angles, torch.zeros_like(angles), delta=beta, reduction="none")
1261
-
1262
-
1263
- def lorentz_dist(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
1264
- return paired_dist(specific, general, kappa, product_metric=product_metric)
1265
-
1266
-
1267
- def argent_norm_regularization_loss(*embeddings: Tensor, eps: float = 1e-6) -> Tensor:
1268
- losses = []
1269
- for embedding in embeddings:
1270
- if embedding.numel() == 0:
1271
- continue
1272
- space = _space_components(embedding)
1273
- space_norm = torch.linalg.norm(space.float(), dim=-1).clamp_min(eps)
1274
- losses.append((space_norm.square() - torch.log(space_norm)).mean())
1275
- if not losses:
1276
- raise ValueError("argent_norm_regularization_loss requires at least one non-empty embedding tensor")
1277
- return torch.stack(losses).mean()
1278
-
1279
-
1280
- def piecewise_entailment_residual(
1281
- specific: Tensor,
1282
- general: Tensor,
1283
- kappa: Tensor,
1284
- aperture_scale: float,
1285
- factor: float = 0.1,
1286
- geometry: str = "lorentz",
1287
- ) -> Tensor:
1288
- if geometry == "lorentz":
1289
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1290
- apertures = factor_half_aperture(general=general, kappa=kappa)
1291
- elif geometry == "euclidean":
1292
- angles = euclidean_angle(specific=specific, general=general)
1293
- apertures = euclidean_half_aperture(general=general, aperture_scale=aperture_scale)
1294
- aperture_scale = 1.0
1295
- else:
1296
- raise ValueError(f"Unsupported entailment geometry {geometry!r}; expected 'lorentz' or 'euclidean'")
1297
- residual = angles - aperture_scale * apertures
1298
- loss = torch.where(residual > 0.0, residual + factor * angles, factor * angles)
1299
- return loss.mean(dim=-1) if loss.dim() == 2 else loss
1300
-
1301
-
1302
- def euclidean_angle(specific: Tensor, general: Tensor, eps: float = 1e-6) -> Tensor:
1303
- specific_space = _space_components(specific).float()
1304
- general_space = _space_components(general).float()
1305
- numerator = (specific_space * general_space).sum(dim=-1)
1306
- denominator = torch.linalg.norm(specific_space, dim=-1) * torch.linalg.norm(general_space, dim=-1)
1307
- dtype_eps = torch.finfo(specific_space.dtype).eps
1308
- angle_eps = max(eps, 16.0 * dtype_eps)
1309
- cosine = (numerator / denominator.clamp_min(angle_eps)).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
1310
- return torch.acos(cosine)
1311
-
1312
-
1313
- def euclidean_half_aperture(general: Tensor, aperture_scale: float, eps: float = 1e-8) -> Tensor:
1314
- general_norm = torch.linalg.norm(_space_components(general).float(), dim=-1).clamp_min(eps)
1315
- return torch.atan(torch.as_tensor(aperture_scale, device=general.device, dtype=general.dtype) / general_norm)
1316
-
1317
-
1318
- def aggregate_part_consistency_loss(
1319
- image_feats: Tensor,
1320
- text_feats: Tensor,
1321
- part_image_flat: Tensor,
1322
- part_text_flat: Tensor,
1323
- part_owner: Tensor,
1324
- part_weights: Tensor | None = None,
1325
- ) -> Tensor:
1326
- if part_image_flat.numel() == 0:
1327
- return image_feats.new_zeros(())
1328
-
1329
- batch_size = image_feats.size(0)
1330
- image_space = _space_components(image_feats).reshape(batch_size, -1).float()
1331
- text_space = _space_components(text_feats).reshape(batch_size, -1).float()
1332
- part_image_space = _space_components(part_image_flat).reshape(part_image_flat.size(0), -1).float()
1333
- part_text_space = _space_components(part_text_flat).reshape(part_text_flat.size(0), -1).float()
1334
- if part_weights is None:
1335
- counts = torch.bincount(part_owner, minlength=batch_size).to(device=image_feats.device, dtype=image_space.dtype)
1336
- denom = counts
1337
- valid = counts > 0
1338
- weights = part_image_space.new_ones((part_image_space.size(0),))
1339
- else:
1340
- weights = part_weights.to(device=image_feats.device, dtype=image_space.dtype).flatten()
1341
- if weights.numel() != part_owner.numel():
1342
- raise ValueError("part_weights must have the same number of elements as part_owner when provided")
1343
- denom = torch.zeros(batch_size, device=image_feats.device, dtype=image_space.dtype)
1344
- denom.index_add_(0, part_owner, weights)
1345
- valid = denom > 0
1346
-
1347
- image_agg = image_space.new_zeros(image_space.shape)
1348
- text_agg = text_space.new_zeros(text_space.shape)
1349
- image_agg.index_add_(0, part_owner, part_image_space * weights[:, None])
1350
- text_agg.index_add_(0, part_owner, part_text_space * weights[:, None])
1351
- image_agg = image_agg[valid] / denom[valid, None].clamp_min(1.0)
1352
- text_agg = text_agg[valid] / denom[valid, None].clamp_min(1.0)
1353
-
1354
- image_space = image_space[valid]
1355
- text_space = text_space[valid]
1356
- return 0.25 * (
1357
- cosine_residual(image_agg, image_space)
1358
- + cosine_residual(text_agg, text_space)
1359
- + cosine_residual(image_agg, text_space)
1360
- + cosine_residual(text_agg, image_space)
1361
- )
1362
-
1363
-
1364
- def cosine_residual(x: Tensor, y: Tensor) -> Tensor:
1365
- return (1.0 - F.cosine_similarity(x, y, dim=-1)).mean()
1366
-
1367
-
1368
- def uncertainty_calibrated_entailment_loss(
1369
- entail_residual: Tensor,
1370
- log_uncertainty: Tensor,
1371
- alpha: float = 10.0,
1372
- stop_grad: bool = True,
1373
- weights: Tensor | None = None,
1374
- ) -> tuple[Tensor, Tensor]:
1375
- mean_loss = 0.5 * entail_residual
1376
- uncertainty = torch.exp(log_uncertainty).clamp(min=1e-6, max=1e6)
1377
- residual = entail_residual.detach() if stop_grad else entail_residual
1378
- scaled_entail = residual / (uncertainty + 1e-6)
1379
- calibration_term = 0.5 * scaled_entail + 0.5 * log_uncertainty
1380
- prob = torch.softmax(log_uncertainty.flatten(), dim=0)
1381
- entropy = -(prob * torch.log(prob + 1e-8)).sum()
1382
- calibration_loss = alpha * (calibration_term + entropy)
1383
- return weighted_mean(mean_loss, weights), weighted_mean(calibration_loss, weights)
1384
-
1385
-
1386
- def embedding_uncertainty(x: Tensor) -> Tensor:
1387
- space = _space_components(x)
1388
- norm = torch.linalg.norm(space.float(), dim=-1)
1389
- if norm.dim() > 1:
1390
- norm = norm.mean(dim=-1)
1391
- return F.softplus(-norm)
1392
-
1393
-
1394
- def _space_components(x: Tensor) -> Tensor:
1395
- return x[..., 1:] if x.shape[-1] > 1 else x
1396
-
1397
-
1398
- def _flatten_valid_parts(part_image_feats: Tensor, part_text_feats: Tensor, part_mask: Tensor, targets: Tensor) -> tuple[Tensor, Tensor, Tensor]:
1399
- part_targets = targets[:, None].expand_as(part_mask)[part_mask]
1400
- return part_image_feats[part_mask], part_text_feats[part_mask], part_targets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/objectives.py DELETED
@@ -1,580 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections.abc import Mapping
4
-
5
- import torch
6
- from torch import Tensor, nn
7
-
8
- from hyper3_clip.models.lorentz import log_map0, metric_pairwise_dist
9
- from hyper3_clip.models.losses import (
10
- aggregate_part_consistency_loss,
11
- contrastive_ce,
12
- gramian_volume_loss,
13
- hierarchical_beta_argent_entailment_losses,
14
- packed_part_contrastive_loss,
15
- packed_part_entailment_loss,
16
- part_quality_weights,
17
- radius_order_hinge,
18
- uncha_argent_entailment_losses,
19
- uncha_contrastive_losses,
20
- uncha_entailment_losses,
21
- )
22
- from hyper3_clip.training.distributed import gather_variable_many_with_grad, gather_variable_no_grad, get_rank
23
-
24
-
25
- class HyCoCLIPObjective(nn.Module):
26
- def __init__(
27
- self,
28
- entail_weight: float,
29
- inter_aperture_scale: float,
30
- intra_aperture_scale: float,
31
- product_metric: str = "l1",
32
- ) -> None:
33
- super().__init__()
34
- self.entail_weight = entail_weight
35
- self.inter_aperture_scale = inter_aperture_scale
36
- self.intra_aperture_scale = intra_aperture_scale
37
- self.product_metric = product_metric
38
-
39
- def forward(self, embeddings: Mapping[str, Tensor], logit_scale: Tensor) -> dict[str, Tensor]:
40
- part_owner = embeddings["part_owner"].long()
41
- part_count = part_owner.new_tensor(part_owner.numel())
42
- contrastive = packed_part_contrastive_loss(
43
- image_feats=embeddings["image_feats"],
44
- text_feats=embeddings["text_feats"],
45
- part_image_feats=embeddings["part_image_feats"],
46
- part_text_feats=embeddings["part_text_feats"],
47
- part_owner=part_owner,
48
- kappa=embeddings["kappa"],
49
- logit_scale=logit_scale,
50
- all_image_feats=embeddings.get("all_image_feats"),
51
- all_text_feats=embeddings.get("all_text_feats"),
52
- targets=embeddings.get("targets"),
53
- )
54
- entailment = packed_part_entailment_loss(
55
- image_feats=embeddings["image_feats"],
56
- text_feats=embeddings["text_feats"],
57
- part_image_feats=embeddings["part_image_feats"],
58
- part_text_feats=embeddings["part_text_feats"],
59
- part_owner=part_owner,
60
- kappa=embeddings["kappa"],
61
- inter_aperture_scale=self.inter_aperture_scale,
62
- intra_aperture_scale=self.intra_aperture_scale,
63
- )
64
- total = contrastive + self.entail_weight * entailment
65
- return {
66
- "loss": total,
67
- "contrastive_loss": contrastive,
68
- "entailment_loss": entailment,
69
- "part_count": part_count,
70
- }
71
-
72
-
73
- class UNCHAObjective(nn.Module):
74
- def __init__(
75
- self,
76
- entail_weight: float,
77
- inter_aperture_scale: float,
78
- intra_aperture_scale: float,
79
- piecewise_factor: float = 0.1,
80
- calibration_alpha: float = 10.0,
81
- stop_grad_calibration: bool = True,
82
- entailment_geometry: str = "lorentz",
83
- aggregate_weight: float = 0.0,
84
- entailment_loss: str = "piecewise",
85
- argent_beta: float = 1.0,
86
- argent_norm_weight: float = 0.0,
87
- argent_aux_weight: float = 0.5,
88
- argent_aggregation: str = "uncha",
89
- part_weight_power: float = 0.0,
90
- product_metric: str = "l1",
91
- contrastive_loss: str = "ce",
92
- sigmoid_negative_weight: float = 1.0,
93
- part_quality_mode: str = "none",
94
- part_quality_topk: int = 5,
95
- part_quality_temperature: float = 4.0,
96
- contrastive_global_weight: float = 1.0,
97
- contrastive_local_weight: float = 1.0,
98
- contrastive_global_local_weight: float = 1.0,
99
- beta_cal_beta: float = 0.0,
100
- beta_cal_variant: str = "ce",
101
- beta_cal_weight: float = 0.0,
102
- himo_component_weight: float = 0.0,
103
- global_local_mode: str = "repeat",
104
- global_local_metric: str = "distance",
105
- global_local_angle_aux_weight: float = 0.0,
106
- global_local_angle_aux_mode: str = "contrastive",
107
- global_local_angle_aux_scale: float = 5.5,
108
- global_local_angle_aux_aperture_scale: float = 1.0,
109
- radius_order_weight: float = 0.0,
110
- radius_order_margin: float = 0.0,
111
- gramian_align_weight: float = 0.0,
112
- ) -> None:
113
- super().__init__()
114
- if entailment_loss not in {
115
- "piecewise",
116
- "argent",
117
- "piecewise_argent",
118
- "hier_beta_argent",
119
- "hier_beta_sourcepart_argent",
120
- }:
121
- raise ValueError(
122
- f"Unsupported UNCHA entailment loss {entailment_loss!r}; "
123
- "expected 'piecewise', 'argent', 'piecewise_argent', 'hier_beta_argent', "
124
- "or 'hier_beta_sourcepart_argent'"
125
- )
126
- if contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
127
- raise ValueError("contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
128
- if beta_cal_variant not in {"ce", "bce"}:
129
- raise ValueError("beta_cal_variant must be 'ce' or 'bce'")
130
- if argent_aggregation not in {"uncha", "equal"}:
131
- raise ValueError("argent_aggregation must be 'uncha' or 'equal'")
132
- if part_quality_mode not in {"none", "soft", "topk"}:
133
- raise ValueError("part_quality_mode must be 'none', 'soft', or 'topk'")
134
- if global_local_mode not in {"repeat", "inbatch"}:
135
- raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
136
- if global_local_metric not in {"distance", "angle"}:
137
- raise ValueError("global_local_metric must be 'distance' or 'angle'")
138
- if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
139
- raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
140
- if global_local_angle_aux_weight < 0.0:
141
- raise ValueError("global_local_angle_aux_weight must be non-negative")
142
- if global_local_angle_aux_scale <= 0.0:
143
- raise ValueError("global_local_angle_aux_scale must be positive")
144
- if global_local_angle_aux_aperture_scale <= 0.0:
145
- raise ValueError("global_local_angle_aux_aperture_scale must be positive")
146
- if part_quality_topk <= 0:
147
- raise ValueError("part_quality_topk must be positive")
148
- self.entail_weight = entail_weight
149
- self.inter_aperture_scale = inter_aperture_scale
150
- self.intra_aperture_scale = intra_aperture_scale
151
- self.piecewise_factor = piecewise_factor
152
- self.calibration_alpha = calibration_alpha
153
- self.stop_grad_calibration = stop_grad_calibration
154
- self.entailment_geometry = entailment_geometry
155
- self.aggregate_weight = aggregate_weight
156
- self.entailment_loss = entailment_loss
157
- self.argent_beta = argent_beta
158
- self.argent_norm_weight = argent_norm_weight
159
- self.argent_aux_weight = argent_aux_weight
160
- self.argent_aggregation = argent_aggregation
161
- self.part_weight_power = part_weight_power
162
- self.product_metric = product_metric
163
- self.contrastive_loss = contrastive_loss
164
- self.sigmoid_negative_weight = sigmoid_negative_weight
165
- self.part_quality_mode = part_quality_mode
166
- self.part_quality_topk = part_quality_topk
167
- self.part_quality_temperature = part_quality_temperature
168
- self.contrastive_global_weight = float(contrastive_global_weight)
169
- self.contrastive_local_weight = float(contrastive_local_weight)
170
- self.contrastive_global_local_weight = float(contrastive_global_local_weight)
171
- self.beta_cal_beta = float(beta_cal_beta)
172
- self.beta_cal_variant = beta_cal_variant
173
- self.beta_cal_weight = float(beta_cal_weight)
174
- self.himo_component_weight = float(himo_component_weight)
175
- self.global_local_mode = global_local_mode
176
- self.global_local_metric = global_local_metric
177
- self.global_local_angle_aux_weight = float(global_local_angle_aux_weight)
178
- self.global_local_angle_aux_mode = global_local_angle_aux_mode
179
- self.global_local_angle_aux_scale = float(global_local_angle_aux_scale)
180
- self.global_local_angle_aux_aperture_scale = float(global_local_angle_aux_aperture_scale)
181
- self.radius_order_weight = float(radius_order_weight)
182
- self.radius_order_margin = float(radius_order_margin)
183
- self.gramian_align_weight = float(gramian_align_weight)
184
-
185
- def forward(self, embeddings: Mapping[str, Tensor], logit_scales: Mapping[str, Tensor]) -> dict[str, Tensor]:
186
- part_owner = embeddings["part_owner"].long()
187
- part_count = part_owner.new_tensor(part_owner.numel())
188
- part_image_flat = embeddings["part_image_feats"]
189
- part_text_flat = embeddings["part_text_feats"]
190
- image_feats = embeddings["image_feats"]
191
- text_feats = embeddings["text_feats"]
192
-
193
- if part_owner.numel() == 0:
194
- image_for_parts = image_feats.new_zeros((0, image_feats.size(-1)))
195
- text_for_parts = text_feats.new_zeros((0, text_feats.size(-1)))
196
- else:
197
- image_for_parts = image_feats[part_owner]
198
- text_for_parts = text_feats[part_owner]
199
- count_part_weights = _part_weights(part_owner, image_feats.size(0), self.part_weight_power)
200
- quality_part_weights, quality_scores, quality_keep = part_quality_weights(
201
- image_for_parts=image_for_parts,
202
- text_for_parts=text_for_parts,
203
- part_image_flat=part_image_flat,
204
- part_text_flat=part_text_flat,
205
- part_owner=part_owner,
206
- batch_size=image_feats.size(0),
207
- kappa=embeddings["kappa"],
208
- mode=self.part_quality_mode,
209
- topk=self.part_quality_topk,
210
- temperature=self.part_quality_temperature,
211
- product_metric=self.product_metric,
212
- )
213
- part_weights = _combine_part_weights(count_part_weights, quality_part_weights)
214
-
215
- needs_repeated_global_local = self.global_local_mode == "repeat" and self.contrastive_global_local_weight != 0.0
216
- part_feature_tensors = [part_image_flat, part_text_flat]
217
- if needs_repeated_global_local:
218
- part_feature_tensors.extend([image_for_parts, text_for_parts])
219
- gathered_part_features, part_counts = gather_variable_many_with_grad(part_feature_tensors)
220
- all_part_image_feats = gathered_part_features[0]
221
- all_part_text_feats = gathered_part_features[1]
222
- all_image_for_parts = gathered_part_features[2] if needs_repeated_global_local else None
223
- all_text_for_parts = gathered_part_features[3] if needs_repeated_global_local else None
224
- image_euc_feats = embeddings.get("image_euc_feats")
225
- text_euc_feats = embeddings.get("text_euc_feats")
226
- part_image_euc_flat = embeddings.get("part_image_euc_feats")
227
- part_text_euc_flat = embeddings.get("part_text_euc_feats")
228
- image_for_parts_euc = None
229
- text_for_parts_euc = None
230
- all_part_image_euc_feats = None
231
- all_part_text_euc_feats = None
232
- all_image_for_parts_euc = None
233
- all_text_for_parts_euc = None
234
- if (
235
- image_euc_feats is not None
236
- and text_euc_feats is not None
237
- and part_owner.numel() > 0
238
- and needs_repeated_global_local
239
- ):
240
- image_for_parts_euc = image_euc_feats[part_owner]
241
- text_for_parts_euc = text_euc_feats[part_owner]
242
- if part_image_euc_flat is not None and part_text_euc_flat is not None:
243
- euc_feature_tensors = [part_image_euc_flat, part_text_euc_flat]
244
- if image_for_parts_euc is not None and text_for_parts_euc is not None:
245
- euc_feature_tensors.extend([image_for_parts_euc, text_for_parts_euc])
246
- gathered_euc_features, _ = gather_variable_many_with_grad(euc_feature_tensors)
247
- all_part_image_euc_feats = gathered_euc_features[0]
248
- all_part_text_euc_feats = gathered_euc_features[1]
249
- if image_for_parts_euc is not None and text_for_parts_euc is not None:
250
- all_image_for_parts_euc = gathered_euc_features[2]
251
- all_text_for_parts_euc = gathered_euc_features[3]
252
- if "targets" not in embeddings:
253
- raise ValueError("UNCHAObjective requires 'targets' to compute group-aware losses")
254
- global_targets = embeddings["targets"]
255
- part_group_ids = global_targets[part_owner] if part_owner.numel() > 0 else part_owner.new_zeros((0,))
256
- all_part_group_ids = None
257
- if self.beta_cal_weight > 0.0 and self.beta_cal_beta > 0.0:
258
- all_part_group_ids, _ = gather_variable_no_grad(part_group_ids)
259
- part_offset = part_counts[: get_rank()].sum() if part_counts.numel() > 1 else part_counts.new_zeros(())
260
- part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device) + part_offset
261
-
262
- contrastive = uncha_contrastive_losses(
263
- image_feats=image_feats,
264
- text_feats=text_feats,
265
- part_image_flat=part_image_flat,
266
- part_text_flat=part_text_flat,
267
- image_for_parts=image_for_parts,
268
- text_for_parts=text_for_parts,
269
- image_euc_feats=image_euc_feats,
270
- text_euc_feats=text_euc_feats,
271
- part_image_euc_flat=part_image_euc_flat,
272
- part_text_euc_flat=part_text_euc_flat,
273
- image_for_parts_euc=image_for_parts_euc,
274
- text_for_parts_euc=text_for_parts_euc,
275
- kappa=embeddings["kappa"],
276
- global_logit_scale=logit_scales["global"],
277
- local_logit_scale=logit_scales["local"],
278
- global_local_logit_scale=logit_scales["global_local"],
279
- all_image_feats=embeddings.get("all_image_feats"),
280
- all_text_feats=embeddings.get("all_text_feats"),
281
- all_part_image_feats=all_part_image_feats,
282
- all_part_text_feats=all_part_text_feats,
283
- all_image_for_parts=all_image_for_parts,
284
- all_text_for_parts=all_text_for_parts,
285
- all_image_euc_feats=embeddings.get("all_image_euc_feats"),
286
- all_text_euc_feats=embeddings.get("all_text_euc_feats"),
287
- all_part_image_euc_feats=all_part_image_euc_feats,
288
- all_part_text_euc_feats=all_part_text_euc_feats,
289
- all_image_for_parts_euc=all_image_for_parts_euc,
290
- all_text_for_parts_euc=all_text_for_parts_euc,
291
- global_targets=global_targets,
292
- part_targets=part_targets,
293
- part_weights=part_weights,
294
- product_metric=self.product_metric,
295
- loss_type=self.contrastive_loss,
296
- contrastive_global_weight=self.contrastive_global_weight,
297
- contrastive_local_weight=self.contrastive_local_weight,
298
- contrastive_global_local_weight=self.contrastive_global_local_weight,
299
- beta_cal_beta=self.beta_cal_beta,
300
- beta_cal_variant=self.beta_cal_variant,
301
- beta_cal_weight=self.beta_cal_weight,
302
- part_group_ids=part_group_ids,
303
- all_part_group_ids=all_part_group_ids,
304
- global_logit_bias=logit_scales.get("global_bias"),
305
- local_logit_bias=logit_scales.get("local_bias"),
306
- global_local_logit_bias=logit_scales.get("global_local_bias"),
307
- sigmoid_negative_weight=self.sigmoid_negative_weight,
308
- global_local_mode=self.global_local_mode,
309
- global_local_metric=self.global_local_metric,
310
- global_local_angle_aux_weight=self.global_local_angle_aux_weight,
311
- global_local_angle_aux_mode=self.global_local_angle_aux_mode,
312
- global_local_angle_aux_scale=self.global_local_angle_aux_scale,
313
- global_local_angle_aux_aperture_scale=self.global_local_angle_aux_aperture_scale,
314
- )
315
- himo_component_loss = image_feats.new_zeros(())
316
- if self.himo_component_weight > 0.0 and embeddings.get("himo_text_feats") is not None:
317
- himo_text_feats = embeddings["himo_text_feats"]
318
- all_himo_text_feats = embeddings.get("all_himo_text_feats")
319
- if all_himo_text_feats is None:
320
- raise ValueError("himo_text_feats requires all_himo_text_feats for distributed contrastive loss")
321
- scale = logit_scales["global"].exp().clamp(max=100.0)
322
- logits_i_t = -metric_pairwise_dist(image_feats, all_himo_text_feats, embeddings["kappa"], product_metric=self.product_metric) * scale
323
- logits_t_i = -metric_pairwise_dist(himo_text_feats, embeddings["all_image_feats"], embeddings["kappa"], product_metric=self.product_metric) * scale
324
- himo_component_loss = 0.5 * (contrastive_ce(logits_i_t, global_targets) + contrastive_ce(logits_t_i, global_targets))
325
- if self.entailment_loss == "argent":
326
- entailment = uncha_argent_entailment_losses(
327
- image_feats=image_feats,
328
- text_feats=text_feats,
329
- part_image_flat=part_image_flat,
330
- part_text_flat=part_text_flat,
331
- image_for_parts=image_for_parts,
332
- text_for_parts=text_for_parts,
333
- kappa=embeddings["kappa"],
334
- beta=self.argent_beta,
335
- part_weights=part_weights,
336
- product_metric=self.product_metric,
337
- aggregation=self.argent_aggregation,
338
- )
339
- elif self.entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}:
340
- required = (
341
- "beta_query_image_feats",
342
- "beta_query_text_feats",
343
- "beta_query_owner",
344
- "beta_query_parent",
345
- "beta_query_weight",
346
- )
347
- if self.entailment_loss == "hier_beta_sourcepart_argent":
348
- required = (*required, "beta_query_source_part")
349
- missing = [key for key in required if embeddings.get(key) is None]
350
- if missing:
351
- raise ValueError(f"{self.entailment_loss} requires beta query embeddings: missing {missing}")
352
- entailment = hierarchical_beta_argent_entailment_losses(
353
- image_feats=image_feats,
354
- text_feats=text_feats,
355
- part_image_flat=part_image_flat,
356
- part_text_flat=part_text_flat,
357
- image_for_parts=image_for_parts,
358
- text_for_parts=text_for_parts,
359
- beta_query_image_feats=embeddings["beta_query_image_feats"],
360
- beta_query_text_feats=embeddings["beta_query_text_feats"],
361
- beta_query_owner=embeddings["beta_query_owner"],
362
- beta_query_parent=embeddings["beta_query_parent"],
363
- beta_query_weight=embeddings["beta_query_weight"],
364
- beta_query_source_part=embeddings.get("beta_query_source_part")
365
- if self.entailment_loss == "hier_beta_sourcepart_argent"
366
- else None,
367
- kappa=embeddings["kappa"],
368
- beta=self.argent_beta,
369
- part_weights=part_weights,
370
- product_metric=self.product_metric,
371
- aggregation=self.argent_aggregation,
372
- )
373
- else:
374
- piecewise_entailment = uncha_entailment_losses(
375
- image_feats=image_feats,
376
- text_feats=text_feats,
377
- part_image_flat=part_image_flat,
378
- part_text_flat=part_text_flat,
379
- image_for_parts=image_for_parts,
380
- text_for_parts=text_for_parts,
381
- kappa=embeddings["kappa"],
382
- inter_aperture_scale=self.inter_aperture_scale,
383
- intra_aperture_scale=self.intra_aperture_scale,
384
- piecewise_factor=self.piecewise_factor,
385
- calibration_alpha=self.calibration_alpha,
386
- stop_grad_calibration=self.stop_grad_calibration,
387
- geometry=self.entailment_geometry,
388
- part_weights=part_weights,
389
- )
390
- if self.entailment_loss == "piecewise_argent":
391
- argent_entailment = uncha_argent_entailment_losses(
392
- image_feats=image_feats,
393
- text_feats=text_feats,
394
- part_image_flat=part_image_flat,
395
- part_text_flat=part_text_flat,
396
- image_for_parts=image_for_parts,
397
- text_for_parts=text_for_parts,
398
- kappa=embeddings["kappa"],
399
- beta=self.argent_beta,
400
- part_weights=part_weights,
401
- product_metric=self.product_metric,
402
- aggregation=self.argent_aggregation,
403
- )
404
- entailment = {
405
- **piecewise_entailment,
406
- "entailment_loss": piecewise_entailment["entailment_loss"]
407
- + self.argent_aux_weight * argent_entailment["entailment_loss"],
408
- "piecewise_entailment_loss": piecewise_entailment["entailment_loss"],
409
- "argent_entailment_loss": argent_entailment["entailment_loss"],
410
- "norm_regularization_loss": argent_entailment["norm_regularization_loss"],
411
- }
412
- else:
413
- entailment = piecewise_entailment
414
- aggregate = aggregate_part_consistency_loss(
415
- image_feats=image_feats,
416
- text_feats=text_feats,
417
- part_image_flat=part_image_flat,
418
- part_text_flat=part_text_flat,
419
- part_owner=part_owner,
420
- part_weights=part_weights,
421
- )
422
- radius_order = image_feats.new_zeros(())
423
- if self.radius_order_weight > 0.0:
424
- radius_order = (
425
- radius_order_hinge(image_feats, text_feats, embeddings["kappa"], self.radius_order_margin)
426
- + radius_order_hinge(part_image_flat, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
427
- + radius_order_hinge(image_for_parts, part_image_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
428
- + radius_order_hinge(text_for_parts, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
429
- )
430
- gramian_align = image_feats.new_zeros(())
431
- if self.gramian_align_weight > 0.0 and part_owner.numel() > 0:
432
- def _tangent_flat(x: Tensor) -> Tensor:
433
- tangent = log_map0(x, embeddings["kappa"])
434
- return tangent.reshape(tangent.size(0), -1) if tangent.dim() == 3 else tangent
435
-
436
- gramian_vectors = torch.stack(
437
- [
438
- _tangent_flat(image_for_parts),
439
- _tangent_flat(text_for_parts),
440
- _tangent_flat(part_image_flat),
441
- _tangent_flat(part_text_flat),
442
- ],
443
- dim=1,
444
- )
445
- gramian_align = gramian_volume_loss(gramian_vectors, part_weights)
446
- entail_weight_scale = embeddings.get("entail_weight_scale", image_feats.new_ones(()))
447
- total = (
448
- contrastive["contrastive_loss"]
449
- + self.himo_component_weight * himo_component_loss
450
- + self.entail_weight * entail_weight_scale * entailment["entailment_loss"]
451
- + self.aggregate_weight * aggregate
452
- + self.radius_order_weight * radius_order
453
- + self.gramian_align_weight * gramian_align
454
- + self.argent_norm_weight * entailment.get(
455
- "norm_regularization_loss",
456
- image_feats.new_zeros(()),
457
- )
458
- )
459
- return {
460
- "loss": total,
461
- **contrastive,
462
- "himo_component_contrastive_loss": himo_component_loss,
463
- **entailment,
464
- "aggregate_consistency_loss": aggregate,
465
- "radius_order_loss": radius_order,
466
- "gramian_align_loss": gramian_align,
467
- "part_count": part_count,
468
- "entail_weight_scale": entail_weight_scale.detach(),
469
- "part_quality_mean": (
470
- image_feats.new_zeros(()) if quality_scores.numel() == 0 else quality_scores.mean().detach()
471
- ),
472
- "part_quality_keep_fraction": (
473
- image_feats.new_zeros(()) if quality_keep.numel() == 0 else quality_keep.mean().detach()
474
- ),
475
- }
476
-
477
-
478
- def build_objective(
479
- objective: str,
480
- entail_weight: float,
481
- inter_aperture_scale: float,
482
- intra_aperture_scale: float,
483
- uncha_piecewise_factor: float = 0.1,
484
- uncha_calibration_alpha: float = 10.0,
485
- uncha_stop_grad_calibration: bool = True,
486
- uncha_entailment_geometry: str = "lorentz",
487
- uncha_aggregate_weight: float = 0.0,
488
- uncha_entailment_loss: str = "piecewise",
489
- uncha_argent_beta: float = 1.0,
490
- uncha_argent_norm_weight: float = 0.0,
491
- uncha_argent_aux_weight: float = 0.5,
492
- uncha_argent_aggregation: str = "uncha",
493
- uncha_part_weight_power: float = 0.0,
494
- uncha_contrastive_loss: str = "ce",
495
- uncha_sigmoid_negative_weight: float = 1.0,
496
- uncha_part_quality_mode: str = "none",
497
- uncha_part_quality_topk: int = 5,
498
- uncha_part_quality_temperature: float = 4.0,
499
- uncha_contrastive_global_weight: float = 1.0,
500
- uncha_contrastive_local_weight: float = 1.0,
501
- uncha_contrastive_global_local_weight: float = 1.0,
502
- uncha_beta_cal_beta: float = 0.0,
503
- uncha_beta_cal_variant: str = "ce",
504
- uncha_beta_cal_weight: float = 0.0,
505
- uncha_himo_component_weight: float = 0.0,
506
- uncha_global_local_mode: str = "repeat",
507
- uncha_global_local_metric: str = "distance",
508
- uncha_global_local_angle_aux_weight: float = 0.0,
509
- uncha_global_local_angle_aux_mode: str = "contrastive",
510
- uncha_global_local_angle_aux_scale: float = 5.5,
511
- uncha_global_local_angle_aux_aperture_scale: float = 1.0,
512
- uncha_radius_order_weight: float = 0.0,
513
- uncha_radius_order_margin: float = 0.0,
514
- uncha_gramian_align_weight: float = 0.0,
515
- product_metric: str = "l1",
516
- ) -> nn.Module:
517
- if objective == "hycoclip":
518
- return HyCoCLIPObjective(
519
- entail_weight=entail_weight,
520
- inter_aperture_scale=inter_aperture_scale,
521
- intra_aperture_scale=intra_aperture_scale,
522
- product_metric=product_metric,
523
- )
524
- if objective == "uncha":
525
- return UNCHAObjective(
526
- entail_weight=entail_weight,
527
- inter_aperture_scale=inter_aperture_scale,
528
- intra_aperture_scale=intra_aperture_scale,
529
- piecewise_factor=uncha_piecewise_factor,
530
- calibration_alpha=uncha_calibration_alpha,
531
- stop_grad_calibration=uncha_stop_grad_calibration,
532
- entailment_geometry=uncha_entailment_geometry,
533
- aggregate_weight=uncha_aggregate_weight,
534
- entailment_loss=uncha_entailment_loss,
535
- argent_beta=uncha_argent_beta,
536
- argent_norm_weight=uncha_argent_norm_weight,
537
- argent_aux_weight=uncha_argent_aux_weight,
538
- argent_aggregation=uncha_argent_aggregation,
539
- part_weight_power=uncha_part_weight_power,
540
- product_metric=product_metric,
541
- contrastive_loss=uncha_contrastive_loss,
542
- sigmoid_negative_weight=uncha_sigmoid_negative_weight,
543
- part_quality_mode=uncha_part_quality_mode,
544
- part_quality_topk=uncha_part_quality_topk,
545
- part_quality_temperature=uncha_part_quality_temperature,
546
- contrastive_global_weight=uncha_contrastive_global_weight,
547
- contrastive_local_weight=uncha_contrastive_local_weight,
548
- contrastive_global_local_weight=uncha_contrastive_global_local_weight,
549
- beta_cal_beta=uncha_beta_cal_beta,
550
- beta_cal_variant=uncha_beta_cal_variant,
551
- beta_cal_weight=uncha_beta_cal_weight,
552
- himo_component_weight=uncha_himo_component_weight,
553
- global_local_mode=uncha_global_local_mode,
554
- global_local_metric=uncha_global_local_metric,
555
- global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
556
- global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
557
- global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
558
- global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
559
- radius_order_weight=uncha_radius_order_weight,
560
- radius_order_margin=uncha_radius_order_margin,
561
- gramian_align_weight=uncha_gramian_align_weight,
562
- )
563
- raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip' or 'uncha'")
564
-
565
-
566
- def _part_weights(part_owner: Tensor, batch_size: int, power: float) -> Tensor | None:
567
- if power <= 0.0 or part_owner.numel() == 0:
568
- return None
569
- counts = torch.bincount(part_owner, minlength=batch_size).to(dtype=torch.float32, device=part_owner.device)
570
- weights = counts[part_owner].clamp_min(1.0).pow(-power)
571
- return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
572
-
573
-
574
- def _combine_part_weights(count_weights: Tensor | None, quality_weights: Tensor | None) -> Tensor | None:
575
- if count_weights is None:
576
- return quality_weights
577
- if quality_weights is None:
578
- return count_weights
579
- weights = count_weights * quality_weights
580
- return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/tren.py DELETED
@@ -1,255 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import Tensor, nn
8
-
9
-
10
- class FourierPositionEncoding2D(nn.Module):
11
- def __init__(self, dim: int, scale: float = 1.0) -> None:
12
- super().__init__()
13
- if dim <= 0 or dim % 2 != 0:
14
- raise ValueError("FourierPositionEncoding2D dim must be a positive even integer")
15
- if scale <= 0.0:
16
- raise ValueError("FourierPositionEncoding2D scale must be positive")
17
- generator = torch.Generator()
18
- generator.manual_seed(42)
19
- self.register_buffer("gaussian_matrix", scale * torch.randn((2, dim // 2), generator=generator))
20
-
21
- def forward(self, coords: Tensor) -> Tensor:
22
- projected = (2.0 * coords.float() - 1.0) @ self.gaussian_matrix
23
- projected = 2.0 * math.pi * projected
24
- return torch.cat([torch.sin(projected), torch.cos(projected)], dim=-1)
25
-
26
-
27
- class _MLPBlock(nn.Module):
28
- def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
29
- super().__init__()
30
- self.net = nn.Sequential(
31
- nn.Linear(dim, hidden_dim),
32
- nn.GELU(),
33
- nn.Dropout(dropout),
34
- nn.Linear(hidden_dim, dim),
35
- )
36
-
37
- def forward(self, x: Tensor) -> Tensor:
38
- return self.net(x)
39
-
40
-
41
- class _AttentionLayer(nn.Module):
42
- def __init__(
43
- self,
44
- q_dim: int,
45
- kv_dim: int,
46
- hidden_dim: int,
47
- *,
48
- num_heads: int,
49
- dropout: float,
50
- use_bias: bool = False,
51
- use_v_proj: bool = True,
52
- use_out_proj: bool = True,
53
- ) -> None:
54
- super().__init__()
55
- if hidden_dim % num_heads != 0:
56
- raise ValueError("hidden_dim must be divisible by num_heads")
57
- if not use_v_proj and kv_dim != hidden_dim:
58
- raise ValueError("kv_dim must equal hidden_dim when value projection is disabled")
59
- self.hidden_dim = hidden_dim
60
- self.num_heads = num_heads
61
- self.head_dim = hidden_dim // num_heads
62
- self.q_proj = nn.Linear(q_dim, hidden_dim, bias=use_bias)
63
- self.k_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias)
64
- self.v_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias) if use_v_proj else nn.Identity()
65
- self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias) if use_out_proj else nn.Identity()
66
- self.q_norm = nn.LayerNorm(self.head_dim)
67
- self.k_norm = nn.LayerNorm(self.head_dim)
68
- self.dropout = nn.Dropout(dropout)
69
- self.scale = self.head_dim**-0.5
70
-
71
- nn.init.kaiming_normal_(self.q_proj.weight, mode="fan_in", nonlinearity="linear")
72
- nn.init.kaiming_normal_(self.k_proj.weight, mode="fan_in", nonlinearity="linear")
73
- if isinstance(self.v_proj, nn.Linear):
74
- nn.init.kaiming_normal_(self.v_proj.weight, mode="fan_in", nonlinearity="linear")
75
- if isinstance(self.out_proj, nn.Linear):
76
- nn.init.kaiming_normal_(self.out_proj.weight, mode="fan_in", nonlinearity="linear")
77
-
78
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
79
- batch_size, q_len, _ = q.shape
80
- _, kv_len, _ = k.shape
81
- query = self.q_proj(q).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
82
- key = self.k_proj(k).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
83
- value = self.v_proj(v).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
84
-
85
- query = self.q_norm(query)
86
- key = self.k_norm(key)
87
- attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
88
- attn_weights = self.dropout(F.softmax(attn_scores, dim=-1))
89
- out = torch.matmul(attn_weights, value)
90
- out = out.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_dim)
91
- return self.out_proj(out), attn_weights
92
-
93
-
94
- class _CrossAttentionBlock(nn.Module):
95
- def __init__(self, dim: int, *, num_heads: int, dropout: float) -> None:
96
- super().__init__()
97
- self.query_norm = nn.LayerNorm(dim)
98
- self.cross_attn = _AttentionLayer(dim, dim, dim, num_heads=num_heads, dropout=dropout)
99
- self.dropout = nn.Dropout(dropout)
100
- self.mlp_norm = nn.LayerNorm(dim)
101
- self.mlp = _MLPBlock(dim, 2 * dim, dropout)
102
- self.out_norm = nn.LayerNorm(dim)
103
-
104
- def forward(self, query: Tensor, context: Tensor) -> Tensor:
105
- x, _ = self.cross_attn(self.query_norm(query), context, context)
106
- x = query + self.dropout(x)
107
- return self.out_norm(x + self.mlp(self.mlp_norm(x)))
108
-
109
-
110
- class TRENRegionEncoder(nn.Module):
111
- """T-REN-style point-prompted region token encoder.
112
-
113
- The module follows the public T-REN architecture: learned k-per-prompt
114
- query tokens, Fourier 2D prompt/patch position encodings, alternating
115
- cross-attention and per-prompt self-attention, then final single-head
116
- attention that pools unprojected patch tokens into region tokens.
117
- """
118
-
119
- def __init__(
120
- self,
121
- vision_dim: int,
122
- text_dim: int,
123
- *,
124
- hidden_dim: int | None = None,
125
- num_region_tokens: int = 3,
126
- num_decoder_layers: int = 2,
127
- num_attention_heads: int = 8,
128
- prompt_grid_size: int = 7,
129
- dropout: float = 0.1,
130
- ) -> None:
131
- super().__init__()
132
- if num_region_tokens <= 0:
133
- raise ValueError("num_region_tokens must be positive")
134
- if num_decoder_layers <= 0:
135
- raise ValueError("num_decoder_layers must be positive")
136
- if prompt_grid_size <= 0:
137
- raise ValueError("prompt_grid_size must be positive")
138
- hidden_dim = int(hidden_dim or vision_dim)
139
- if hidden_dim != vision_dim:
140
- raise ValueError("TRENRegionEncoder currently requires hidden_dim == vision_dim")
141
- if hidden_dim % 2 != 0:
142
- raise ValueError("TRENRegionEncoder hidden_dim must be even for Fourier features")
143
- if hidden_dim % num_attention_heads != 0:
144
- raise ValueError("TRENRegionEncoder hidden_dim must be divisible by num_attention_heads")
145
-
146
- self.vision_dim = vision_dim
147
- self.text_dim = text_dim
148
- self.hidden_dim = hidden_dim
149
- self.num_region_tokens = num_region_tokens
150
- self.prompt_grid_size = prompt_grid_size
151
- self.position_encoder = FourierPositionEncoding2D(hidden_dim)
152
- self.region_token_embeddings = nn.Embedding(num_region_tokens, hidden_dim)
153
- nn.init.normal_(self.region_token_embeddings.weight, std=0.02)
154
- self.region_attention_layers = nn.ModuleList(
155
- [_CrossAttentionBlock(hidden_dim, num_heads=num_attention_heads, dropout=dropout) for _ in range(num_decoder_layers)]
156
- )
157
- self.region_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
158
- self.prompt_attention_layers = nn.ModuleList(
159
- [
160
- _AttentionLayer(
161
- hidden_dim,
162
- hidden_dim,
163
- hidden_dim,
164
- num_heads=num_attention_heads,
165
- dropout=dropout,
166
- )
167
- for _ in range(num_decoder_layers)
168
- ]
169
- )
170
- self.prompt_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
171
- self.token_prediction_head = _AttentionLayer(
172
- hidden_dim,
173
- hidden_dim,
174
- hidden_dim,
175
- num_heads=1,
176
- dropout=0.0,
177
- use_v_proj=False,
178
- use_out_proj=False,
179
- )
180
- self.text_alignment_block = nn.Sequential(
181
- nn.Linear(hidden_dim, 2 * hidden_dim),
182
- nn.GELU(),
183
- nn.Dropout(dropout),
184
- nn.Linear(2 * hidden_dim, text_dim),
185
- )
186
-
187
- def forward(self, image_tokens: Tensor) -> dict[str, Tensor]:
188
- patch_tokens, patch_grid = _patch_tokens_and_grid(image_tokens)
189
- batch_size, patch_count, _ = patch_tokens.shape
190
- patch_coords = _grid_coords(patch_grid, patch_grid, patch_tokens.device)
191
- prompt_coords = _grid_coords(self.prompt_grid_size, self.prompt_grid_size, patch_tokens.device)
192
- prompt_count = prompt_coords.size(0)
193
-
194
- feature_pos = self.position_encoder(patch_coords).to(dtype=patch_tokens.dtype)
195
- prompt_pos = self.position_encoder(prompt_coords).to(dtype=patch_tokens.dtype)
196
- kv = patch_tokens + feature_pos.unsqueeze(0)
197
- prompt_pos = prompt_pos.view(1, prompt_count, 1, self.hidden_dim)
198
-
199
- q = self.region_token_embeddings.weight.to(dtype=patch_tokens.dtype)
200
- q = q.view(1, 1, self.num_region_tokens, self.hidden_dim).expand(
201
- batch_size,
202
- prompt_count,
203
- self.num_region_tokens,
204
- self.hidden_dim,
205
- )
206
- for region_layer, region_norm, prompt_layer, prompt_norm in zip(
207
- self.region_attention_layers,
208
- self.region_attention_norms,
209
- self.prompt_attention_layers,
210
- self.prompt_attention_norms,
211
- strict=True,
212
- ):
213
- q = q + prompt_pos
214
- q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
215
- q = region_layer(q, kv)
216
- q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
217
- q = region_norm(q)
218
- q = q.reshape(batch_size * prompt_count, self.num_region_tokens, self.hidden_dim)
219
- q, _ = prompt_layer(q, q, q)
220
- q = prompt_norm(q)
221
- q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
222
-
223
- flat_q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
224
- visual_tokens, attn_weights = self.token_prediction_head(flat_q, kv, patch_tokens)
225
- visual_tokens = visual_tokens.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
226
- attn_weights = attn_weights.squeeze(1).reshape(batch_size, prompt_count, self.num_region_tokens, patch_count)
227
- region_masks = attn_weights / attn_weights.amax(dim=-1, keepdim=True).clamp_min(torch.finfo(attn_weights.dtype).eps)
228
- region_masks = region_masks.reshape(batch_size, prompt_count, self.num_region_tokens, patch_grid, patch_grid)
229
- text_aligned_tokens = self.text_alignment_block(visual_tokens)
230
- return {
231
- "visual_tokens": visual_tokens,
232
- "text_aligned_tokens": text_aligned_tokens,
233
- "region_masks": region_masks,
234
- "prompt_coords": prompt_coords,
235
- }
236
-
237
-
238
- def _patch_tokens_and_grid(tokens: Tensor) -> tuple[Tensor, int]:
239
- if tokens.ndim != 3:
240
- raise ValueError("TRENRegionEncoder expects image tokens with shape [batch, tokens, dim]")
241
- token_count = tokens.size(1)
242
- grid = int(math.isqrt(token_count))
243
- if grid * grid == token_count:
244
- return tokens, grid
245
- grid = int(math.isqrt(token_count - 1))
246
- if grid * grid == token_count - 1:
247
- return tokens[:, 1:, :], grid
248
- raise ValueError(f"Cannot infer a square patch grid from {token_count} image tokens")
249
-
250
-
251
- def _grid_coords(height: int, width: int, device: torch.device) -> Tensor:
252
- y = torch.linspace(0.5 / height, 1.0 - 0.5 / height, height, device=device)
253
- x = torch.linspace(0.5 / width, 1.0 - 0.5 / width, width, device=device)
254
- yy, xx = torch.meshgrid(y, x, indexing="ij")
255
- return torch.stack([xx, yy], dim=-1).reshape(-1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/training/__init__.py DELETED
@@ -1 +0,0 @@
1
- __all__: list[str] = []
 
 
hyper3_clip/training/checkpointing.py DELETED
@@ -1,91 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- import random
5
- from typing import Any
6
-
7
- import numpy as np
8
- import torch
9
- from torch import nn
10
-
11
-
12
- def save_checkpoint(
13
- path: str | Path,
14
- step: int,
15
- model: nn.Module,
16
- optimizer: torch.optim.Optimizer,
17
- scheduler: Any,
18
- scaler: Any,
19
- config: dict,
20
- ) -> None:
21
- checkpoint_path = Path(path)
22
- tmp_path = checkpoint_path.with_name(f"{checkpoint_path.name}.tmp")
23
- checkpoint = {
24
- "step": step,
25
- "model": model.state_dict(),
26
- "optimizer": optimizer.state_dict(),
27
- "scheduler": scheduler.state_dict(),
28
- "scaler": scaler.state_dict(),
29
- "config": config,
30
- "rng": _rng_state(),
31
- }
32
- torch.save(checkpoint, tmp_path)
33
- tmp_path.replace(checkpoint_path)
34
-
35
-
36
- def load_checkpoint(
37
- path: str | Path,
38
- model: nn.Module,
39
- optimizer: torch.optim.Optimizer,
40
- scheduler: Any,
41
- scaler: Any,
42
- device: torch.device,
43
- *,
44
- model_only: bool = False,
45
- strict_model: bool = True,
46
- ) -> int:
47
- checkpoint = torch.load(path, map_location=device, weights_only=False)
48
- model.load_state_dict(checkpoint["model"], strict=strict_model)
49
- if model_only:
50
- return int(checkpoint["step"])
51
- optimizer.load_state_dict(checkpoint["optimizer"])
52
- scheduler.load_state_dict(checkpoint["scheduler"])
53
- scaler.load_state_dict(checkpoint["scaler"])
54
- _set_rng_state(checkpoint["rng"])
55
- return int(checkpoint["step"])
56
-
57
-
58
- def latest_checkpoint(output_dir: str | Path) -> Path | None:
59
- paths = sorted(Path(output_dir).glob("checkpoint_step_*.pt"))
60
- if not paths:
61
- return None
62
- return max(paths, key=_checkpoint_step)
63
-
64
-
65
- def _checkpoint_step(path: Path) -> int:
66
- return int(path.stem.rsplit("_", 1)[1])
67
-
68
-
69
- def _rng_state() -> dict[str, Any]:
70
- state: dict[str, Any] = {
71
- "python": random.getstate(),
72
- "numpy": np.random.get_state(),
73
- "torch": torch.get_rng_state(),
74
- }
75
- if torch.cuda.is_available():
76
- state["cuda"] = torch.cuda.get_rng_state_all()
77
- return state
78
-
79
-
80
- def _set_rng_state(state: dict[str, Any]) -> None:
81
- random.setstate(state["python"])
82
- np.random.set_state(state["numpy"])
83
- torch.set_rng_state(_cpu_byte_tensor(state["torch"]))
84
- if torch.cuda.is_available() and "cuda" in state:
85
- torch.cuda.set_rng_state_all([_cpu_byte_tensor(cuda_state) for cuda_state in state["cuda"]])
86
-
87
-
88
- def _cpu_byte_tensor(value: Any) -> torch.ByteTensor:
89
- if isinstance(value, torch.Tensor):
90
- return value.detach().to(device="cpu", dtype=torch.uint8)
91
- return torch.as_tensor(value, dtype=torch.uint8, device="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/training/distributed.py DELETED
@@ -1,149 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections.abc import Sequence
4
- import os
5
-
6
- import torch
7
- import torch.distributed as dist
8
- from torch.distributed.nn import all_gather as differentiable_all_gather
9
- from torch import Tensor
10
-
11
-
12
- def init_distributed() -> None:
13
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not dist.is_initialized():
14
- backend = "nccl" if torch.cuda.is_available() else "gloo"
15
- if torch.cuda.is_available():
16
- torch.cuda.set_device(get_local_rank())
17
- dist.init_process_group(backend=backend)
18
-
19
-
20
- def is_distributed() -> bool:
21
- return dist.is_available() and dist.is_initialized()
22
-
23
-
24
- def barrier() -> None:
25
- if is_distributed():
26
- dist.barrier()
27
-
28
-
29
- def destroy_distributed() -> None:
30
- if is_distributed():
31
- dist.destroy_process_group()
32
-
33
-
34
- def get_rank() -> int:
35
- return dist.get_rank() if is_distributed() else 0
36
-
37
-
38
- def get_world_size() -> int:
39
- return dist.get_world_size() if is_distributed() else 1
40
-
41
-
42
- def get_local_rank() -> int:
43
- return int(os.environ.get("LOCAL_RANK", "0"))
44
-
45
-
46
- def is_main_process() -> bool:
47
- return get_rank() == 0
48
-
49
-
50
- def gather_with_grad(tensor: Tensor) -> Tensor:
51
- world_size = get_world_size()
52
- if world_size == 1:
53
- return tensor
54
- return torch.cat(list(differentiable_all_gather(tensor.contiguous())), dim=0)
55
-
56
-
57
- def gather_variable_with_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
58
- """Gather tensors with variable first-dimension lengths across ranks."""
59
- count_tensor, max_count, keep = _variable_gather_metadata(tensor)
60
- if get_world_size() == 1:
61
- return tensor, count_tensor
62
- return _gather_variable_from_metadata(tensor, max_count, keep), count_tensor
63
-
64
-
65
- def gather_variable_many_with_grad(tensors: Sequence[Tensor]) -> tuple[list[Tensor], Tensor]:
66
- """Gather same-length variable tensors while sharing count metadata.
67
-
68
- Tensors with matching dtype/rank/trailing shape are packed along the last
69
- dimension so a single differentiable all-gather can serve several feature
70
- tensors with the same variable first dimension.
71
- """
72
- if not tensors:
73
- raise ValueError("gather_variable_many_with_grad requires at least one tensor")
74
- first = tensors[0]
75
- for tensor in tensors:
76
- if tensor.device != first.device:
77
- raise ValueError("all tensors must be on the same device")
78
- if tensor.shape[0] != first.shape[0]:
79
- raise ValueError("all tensors must have the same first dimension")
80
- count_tensor, max_count, keep = _variable_gather_metadata(first)
81
- if get_world_size() == 1:
82
- return list(tensors), count_tensor
83
-
84
- gathered: list[Tensor | None] = [None] * len(tensors)
85
- groups: dict[tuple[torch.dtype, torch.Size, int], list[int]] = {}
86
- for index, tensor in enumerate(tensors):
87
- if tensor.dim() == 0:
88
- raise ValueError("variable gather tensors must have at least one dimension")
89
- key = (tensor.dtype, tensor.shape[1:-1], tensor.dim()) if tensor.dim() > 1 else (tensor.dtype, torch.Size(), 1)
90
- groups.setdefault(key, []).append(index)
91
-
92
- for indices in groups.values():
93
- group_tensors = [tensors[index] for index in indices]
94
- if len(group_tensors) == 1 or group_tensors[0].dim() == 1:
95
- for index, tensor in zip(indices, group_tensors, strict=True):
96
- gathered[index] = _gather_variable_from_metadata(tensor, max_count, keep)
97
- continue
98
- widths = [tensor.shape[-1] for tensor in group_tensors]
99
- packed = torch.cat(group_tensors, dim=-1)
100
- gathered_packed = _gather_variable_from_metadata(packed, max_count, keep)
101
- for index, chunk in zip(indices, gathered_packed.split(widths, dim=-1), strict=True):
102
- gathered[index] = chunk
103
-
104
- if any(tensor is None for tensor in gathered):
105
- raise RuntimeError("internal error while gathering variable tensors")
106
- return [tensor for tensor in gathered if tensor is not None], count_tensor
107
-
108
-
109
- def gather_variable_no_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
110
- """Gather variable-length tensors that do not require autograd."""
111
- count_tensor, max_count, keep = _variable_gather_metadata(tensor)
112
- if get_world_size() == 1:
113
- return tensor, count_tensor
114
- padded = tensor.new_zeros((max_count, *tensor.shape[1:]))
115
- padded[: tensor.shape[0]] = tensor
116
- gathered = [torch.zeros_like(padded) for _ in range(get_world_size())]
117
- dist.all_gather(gathered, padded.contiguous())
118
- return torch.cat(gathered, dim=0)[keep], count_tensor
119
-
120
-
121
- def _variable_gather_metadata(tensor: Tensor) -> tuple[Tensor, int, Tensor]:
122
- world_size = get_world_size()
123
- local_count = torch.tensor([tensor.shape[0]], device=tensor.device, dtype=torch.long)
124
- if world_size == 1:
125
- keep = torch.ones(tensor.shape[0], device=tensor.device, dtype=torch.bool)
126
- return local_count, tensor.shape[0], keep
127
-
128
- counts = [torch.zeros_like(local_count) for _ in range(world_size)]
129
- dist.all_gather(counts, local_count)
130
- count_tensor = torch.cat(counts)
131
- max_count = int(count_tensor.max().item())
132
- keep = torch.zeros(world_size * max_count, device=tensor.device, dtype=torch.bool)
133
- for rank, count in enumerate(count_tensor.tolist()):
134
- start = rank * max_count
135
- keep[start : start + count] = True
136
- return count_tensor, max_count, keep
137
-
138
-
139
- def _gather_variable_from_metadata(tensor: Tensor, max_count: int, keep: Tensor) -> Tensor:
140
- padded_shape = (max_count, *tensor.shape[1:])
141
- padded = tensor.new_zeros(padded_shape)
142
- padded[: tensor.shape[0]] = tensor
143
-
144
- gathered = torch.cat(list(differentiable_all_gather(padded.contiguous())), dim=0)
145
- return gathered[keep]
146
-
147
-
148
- def local_target_indices(batch_size: int, device: torch.device) -> Tensor:
149
- return torch.arange(batch_size, device=device) + batch_size * get_rank()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/training/engine.py DELETED
@@ -1,442 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from datetime import datetime, timezone
4
- import json
5
- import os
6
- from pathlib import Path
7
- import time
8
-
9
- import torch
10
- from torch import nn
11
- from torch.optim import AdamW, Optimizer
12
- from torch.nn.parallel import DistributedDataParallel
13
- from torch.utils.data import DataLoader, DistributedSampler, IterableDataset
14
- from torch.amp import GradScaler
15
-
16
- from hyper3_clip.data import (
17
- GroundedManifestDataset,
18
- MixedGroundedIterableDataset,
19
- ProcessedGritDataset,
20
- collate_grounded,
21
- )
22
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
23
- from hyper3_clip.training.checkpointing import latest_checkpoint, load_checkpoint, save_checkpoint
24
- from hyper3_clip.training.distributed import (
25
- barrier,
26
- destroy_distributed,
27
- get_local_rank,
28
- get_rank,
29
- get_world_size,
30
- init_distributed,
31
- is_main_process,
32
- )
33
- from hyper3_clip.training.logging import JsonlLogger
34
- from hyper3_clip.utils.io import ensure_dir, save_yaml, set_seed
35
-
36
- try:
37
- from hypercluster.hooks import RunControl
38
- except ImportError: # pragma: no cover - hypercluster is only present in cluster allocations.
39
- RunControl = None
40
-
41
-
42
- class CosineWithWarmup:
43
- def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, base_lr: float) -> None:
44
- self.optimizer = optimizer
45
- self.warmup_steps = warmup_steps
46
- self.total_steps = total_steps
47
- self.base_lr = base_lr
48
-
49
- def step(self, step_idx: int) -> None:
50
- if step_idx < self.warmup_steps:
51
- lr = self.base_lr * float(step_idx + 1) / float(max(1, self.warmup_steps))
52
- else:
53
- progress = float(step_idx - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
54
- lr = self.base_lr * 0.5 * (1.0 + torch.cos(torch.tensor(progress * torch.pi)).item())
55
- for group in self.optimizer.param_groups:
56
- group["lr"] = lr
57
-
58
- def state_dict(self) -> dict[str, int | float]:
59
- return {"warmup_steps": self.warmup_steps, "total_steps": self.total_steps, "base_lr": self.base_lr}
60
-
61
- def load_state_dict(self, state: dict[str, int | float]) -> None:
62
- self.warmup_steps = int(state["warmup_steps"])
63
- self.total_steps = int(state["total_steps"])
64
- self.base_lr = float(state["base_lr"])
65
-
66
-
67
- def _build_optimizer(model: nn.Module, cfg: dict) -> AdamW:
68
- no_decay_names = set(cfg.get("optimizer", {}).get("no_decay_params", []))
69
- decay_params = []
70
- no_decay_params = []
71
-
72
- for name, param in model.named_parameters():
73
- if not param.requires_grad:
74
- continue
75
- leaf_name = name.split(".")[-1]
76
- if param.ndim < 2 or leaf_name in no_decay_names or leaf_name == "bias" or "norm" in name.lower():
77
- no_decay_params.append(param)
78
- else:
79
- decay_params.append(param)
80
-
81
- return AdamW(
82
- [
83
- {"params": decay_params, "weight_decay": cfg["training"]["weight_decay"]},
84
- {"params": no_decay_params, "weight_decay": 0.0},
85
- ],
86
- lr=cfg["training"]["lr"],
87
- betas=tuple(cfg["training"]["betas"]),
88
- )
89
-
90
-
91
- def run_training(config: dict) -> None:
92
- init_distributed()
93
- set_seed(config["seed"] + get_rank())
94
- ensure_dir(config["output_dir"])
95
- started_at = utc_timestamp()
96
- if is_main_process():
97
- save_yaml(Path(config["output_dir"]) / "config.yaml", config)
98
- write_metadata(config, status="running", started_at=started_at)
99
-
100
- if torch.cuda.is_available():
101
- if "LOCAL_RANK" in os.environ:
102
- device = torch.device(f"cuda:{get_local_rank()}")
103
- torch.cuda.set_device(device)
104
- else:
105
- device = torch.device("cuda")
106
- else:
107
- device = torch.device("cpu")
108
- if device.type == "cuda":
109
- torch.cuda.reset_peak_memory_stats()
110
- torch.backends.cudnn.benchmark = bool(config["training"].get("cudnn_benchmark", False))
111
-
112
- raw_model = Hyper3CLIP(**config["model"]).to(device)
113
- channels_last = str(config["training"].get("memory_format", "")).lower() == "channels_last"
114
- if channels_last:
115
- raw_model = raw_model.to(memory_format=torch.channels_last)
116
- model: nn.Module = raw_model
117
- if get_world_size() > 1:
118
- device_ids = [get_local_rank()] if device.type == "cuda" else None
119
- model = DistributedDataParallel(
120
- raw_model,
121
- device_ids=device_ids,
122
- broadcast_buffers=False,
123
- find_unused_parameters=bool(config["training"].get("find_unused_parameters", False)),
124
- )
125
- dataset = _build_dataset(config["data"], config["seed"])
126
- sampler = _build_sampler(dataset)
127
- local_batch_size = _local_batch_size(config["training"])
128
- num_workers = config["data"].get("num_workers", config["training"].get("num_workers", 4))
129
- dataloader_kwargs = {}
130
- if num_workers > 0:
131
- dataloader_kwargs["persistent_workers"] = bool(
132
- config["data"].get("persistent_workers", config["training"].get("persistent_workers", False))
133
- )
134
- prefetch_factor = config["data"].get("prefetch_factor", config["training"].get("prefetch_factor"))
135
- if prefetch_factor is not None:
136
- dataloader_kwargs["prefetch_factor"] = int(prefetch_factor)
137
- beta_clip_data_config = config["data"].get("beta_clip", {})
138
- dataloader = DataLoader(
139
- dataset,
140
- batch_size=local_batch_size,
141
- sampler=sampler,
142
- shuffle=sampler is None and not isinstance(dataset, IterableDataset),
143
- num_workers=num_workers,
144
- pin_memory=bool(config["data"].get("pin_memory", True)),
145
- drop_last=True,
146
- collate_fn=lambda x: collate_grounded(
147
- x,
148
- tokenizer=raw_model.text_encoder.tokenizer,
149
- max_text_length=config["data"]["max_text_length"],
150
- beta_clip_queries=bool(beta_clip_data_config.get("enabled", False)),
151
- beta_clip_max_sentences=int(beta_clip_data_config.get("max_sentences", 5)),
152
- beta_clip_max_phrases=int(beta_clip_data_config.get("max_phrases", 30)),
153
- beta_clip_max_queries_per_image=beta_clip_data_config.get("max_queries_per_image"),
154
- beta_clip_use_part_texts=bool(beta_clip_data_config.get("use_part_texts", True)),
155
- ),
156
- **dataloader_kwargs,
157
- )
158
-
159
- optimizer = _build_optimizer(model=raw_model, cfg=config)
160
- scheduler = CosineWithWarmup(
161
- optimizer=optimizer,
162
- warmup_steps=config["training"]["warmup_steps"],
163
- total_steps=config["training"]["total_steps"],
164
- base_lr=config["training"]["lr"],
165
- )
166
- scaler = GradScaler(device.type, enabled=config["training"]["amp"])
167
- start_step = _resume_step(config, raw_model, optimizer, scheduler, scaler, device)
168
- run_control = RunControl.from_env() if RunControl is not None else None
169
-
170
- logger = JsonlLogger(Path(config["output_dir"]) / "train_log.jsonl")
171
-
172
- model.train()
173
- step = start_step
174
- micro_step = 0
175
- grad_accum_steps = max(1, int(config["training"].get("grad_accum_steps", 1)))
176
- non_blocking_transfer = bool(config["training"].get("non_blocking_transfer", True))
177
- micro_batch_global_size = local_batch_size * get_world_size()
178
- effective_global_batch_size = micro_batch_global_size * grad_accum_steps
179
- last_step_time = time.perf_counter()
180
- optimizer.zero_grad(set_to_none=True)
181
- while step < config["training"]["total_steps"]:
182
- if sampler is not None:
183
- sampler.set_epoch(step)
184
- for batch in dataloader:
185
- if step >= config["training"]["total_steps"]:
186
- break
187
-
188
- if micro_step % grad_accum_steps == 0:
189
- optimizer.zero_grad(set_to_none=True)
190
- scheduler.step(step)
191
-
192
- batch = {k: v.to(device, non_blocking=non_blocking_transfer) for k, v in batch.items()}
193
- if channels_last:
194
- batch["image"] = batch["image"].contiguous(memory_format=torch.channels_last)
195
- batch["part_images"] = batch["part_images"].contiguous(memory_format=torch.channels_last)
196
-
197
- with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=config["training"]["amp"]):
198
- out = model(**batch, step=step)
199
- loss = out["loss"] / grad_accum_steps
200
-
201
- scaler.scale(loss).backward()
202
- micro_step += 1
203
- if micro_step % grad_accum_steps != 0:
204
- continue
205
-
206
- if config["training"]["max_grad_norm"] > 0:
207
- scaler.unscale_(optimizer)
208
- grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["training"]["max_grad_norm"])
209
- else:
210
- grad_norm = None
211
- scaler.step(optimizer)
212
- scaler.update()
213
-
214
- completed_steps = step + 1
215
- now = time.perf_counter()
216
- step_time_seconds = now - last_step_time
217
- last_step_time = now
218
-
219
- if completed_steps == 1 or completed_steps % config["training"]["log_interval"] == 0:
220
- remaining_steps = config["training"]["total_steps"] - completed_steps
221
- row = {
222
- "timestamp": datetime.now(timezone.utc).replace(microsecond=0).isoformat(),
223
- "step": completed_steps,
224
- "loss": float(out["loss"].detach().cpu().item()),
225
- "contrastive_loss": float(out["contrastive_loss"].detach().cpu().item()),
226
- "entailment_loss": float(out["entailment_loss"].detach().cpu().item()),
227
- "part_count": int(out["part_count"].detach().cpu().item()),
228
- "kappa": float(out["kappa"].detach().cpu().item()),
229
- "lr": optimizer.param_groups[0]["lr"],
230
- "grad_norm": None if grad_norm is None else float(grad_norm.detach().cpu().item()),
231
- "step_time_seconds": step_time_seconds,
232
- "steps_per_second": 1.0 / max(step_time_seconds, 1e-12),
233
- "samples_per_second": effective_global_batch_size / max(step_time_seconds, 1e-12),
234
- "samples_seen": completed_steps * effective_global_batch_size,
235
- "progress": completed_steps / config["training"]["total_steps"],
236
- "eta_seconds": remaining_steps * step_time_seconds,
237
- "rank": get_rank(),
238
- "world_size": get_world_size(),
239
- "local_batch_size": local_batch_size,
240
- "micro_batch_global_size": micro_batch_global_size,
241
- "global_batch_size": effective_global_batch_size,
242
- "grad_accum_steps": grad_accum_steps,
243
- }
244
- if device.type == "cuda":
245
- row["cuda_max_memory_allocated_mb"] = torch.cuda.max_memory_allocated() / (1024**2)
246
- for key, value in out.items():
247
- if key in row or key == "loss":
248
- continue
249
- if torch.is_tensor(value) and value.numel() == 1:
250
- row[key] = _scalar_log_value(value)
251
- if is_main_process():
252
- logger.write(row)
253
- print(_format_log_row(row), flush=True)
254
-
255
- if is_main_process() and completed_steps > 0 and completed_steps % config["training"]["ckpt_interval"] == 0:
256
- ckpt_path = str(Path(config["output_dir"]) / f"checkpoint_step_{completed_steps}.pt")
257
- save_checkpoint(ckpt_path, completed_steps, raw_model, optimizer, scheduler, scaler, config)
258
-
259
- step = completed_steps
260
- if run_control is not None and run_control.should_pause():
261
- ckpt_path = str(Path(config["output_dir"]) / f"checkpoint_step_{completed_steps}.pt")
262
- if is_main_process():
263
- save_checkpoint(ckpt_path, completed_steps, raw_model, optimizer, scheduler, scaler, config)
264
- (Path(config["output_dir"]) / "latest_checkpoint.txt").write_text(f"{ckpt_path}\n", encoding="utf-8")
265
- run_control.report_checkpoint(ckpt_path)
266
- write_metadata(config, status="paused", started_at=started_at, ended_at=utc_timestamp(), final_step=completed_steps)
267
- barrier()
268
- destroy_distributed()
269
- raise SystemExit(run_control.PAUSED_EXIT_CODE)
270
-
271
- barrier()
272
- if is_main_process():
273
- final_ckpt = str(Path(config["output_dir"]) / "checkpoint_final.pt")
274
- save_checkpoint(final_ckpt, step, raw_model, optimizer, scheduler, scaler, config)
275
- write_metadata(config, status="completed", started_at=started_at, ended_at=utc_timestamp(), final_step=step)
276
- barrier()
277
- destroy_distributed()
278
-
279
-
280
- def utc_timestamp() -> str:
281
- return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
282
-
283
-
284
- def write_metadata(
285
- config: dict,
286
- *,
287
- status: str,
288
- started_at: str,
289
- ended_at: str | None = None,
290
- final_step: int | None = None,
291
- ) -> None:
292
- metadata = {
293
- "run_id": config["project"]["experiment"],
294
- "experiment_name": config["project"]["name"],
295
- "status": status,
296
- "start_time": started_at,
297
- "end_time": ended_at,
298
- "final_step": final_step,
299
- "tags": {
300
- "data": config.get("data", {}).get("type", "unknown"),
301
- "model": config.get("model", {}).get("vision_backbone", "unknown"),
302
- "objective": config.get("model", {}).get("objective", "hycoclip"),
303
- },
304
- "job": {
305
- "job_id": os.environ.get("JOB_ID") or os.environ.get("SCHEDULER_JOB_ID") or os.environ.get("SLURM_JOB_ID"),
306
- "partition": os.environ.get("JOB_PARTITION")
307
- or os.environ.get("SCHEDULER_PARTITION")
308
- or os.environ.get("SLURM_JOB_PARTITION"),
309
- "num_nodes": os.environ.get("NUM_NODES") or os.environ.get("SLURM_JOB_NUM_NODES"),
310
- "node_list": os.environ.get("NODE_LIST") or os.environ.get("SLURM_JOB_NODELIST"),
311
- "gpus": os.environ.get("GPU_DEVICES") or os.environ.get("SLURM_JOB_GPUS") or os.environ.get("SLURM_GPUS"),
312
- },
313
- "env": {
314
- "hostname": os.environ.get("HOSTNAME"),
315
- "world_size": str(get_world_size()),
316
- "rank": str(get_rank()),
317
- },
318
- }
319
- path = Path(config["output_dir"]) / "metadata.json"
320
- path.write_text(json.dumps(metadata, indent=2, sort_keys=True) + "\n", encoding="utf-8")
321
-
322
-
323
- def _build_dataset(data_config: dict, seed: int) -> GroundedManifestDataset | ProcessedGritDataset | MixedGroundedIterableDataset:
324
- data_type = data_config.get("type")
325
- if data_type is None:
326
- data_type = "processed_grit" if data_config.get("tarfiles") else "manifest"
327
- if data_type == "manifest":
328
- manifests = data_config.get("manifests") or data_config.get("manifest")
329
- if manifests is None:
330
- raise ValueError("Manifest training requires data.manifests or data.manifest")
331
- return GroundedManifestDataset(
332
- manifests=manifests,
333
- image_size=data_config["image_size"],
334
- seed=seed,
335
- manifest_weights=data_config.get("manifest_weights"),
336
- part_sampling=data_config.get("part_sampling", "random_one"),
337
- max_parts=data_config.get("max_parts"),
338
- train_transform=data_config.get("train_transform", "wide_random_crop"),
339
- image_normalization=data_config.get("image_normalization", "imagenet"),
340
- )
341
- if data_type == "processed_grit":
342
- return ProcessedGritDataset(
343
- tarfiles=data_config["tarfiles"],
344
- image_size=data_config["image_size"],
345
- seed=seed,
346
- shuffle_buffer=data_config.get("shuffle_buffer", 4000),
347
- part_sampling=data_config.get("part_sampling", "random_one"),
348
- max_parts=data_config.get("max_parts"),
349
- train_transform=data_config.get("train_transform", "wide_random_crop"),
350
- image_normalization=data_config.get("image_normalization", "imagenet"),
351
- deterministic_transforms=data_config.get("deterministic_transforms", False),
352
- )
353
- if data_type == "mixed_processed_grit_manifest":
354
- manifest_config = data_config.get("manifest_data", {})
355
- manifests = manifest_config.get("manifests") or manifest_config.get("manifest") or data_config.get("manifests")
356
- if manifests is None:
357
- raise ValueError("Mixed GRIT+manifest training requires data.manifest_data.manifests")
358
- primary = ProcessedGritDataset(
359
- tarfiles=data_config["tarfiles"],
360
- image_size=data_config["image_size"],
361
- seed=seed,
362
- shuffle_buffer=data_config.get("shuffle_buffer", 4000),
363
- part_sampling=data_config.get("part_sampling", "random_one"),
364
- max_parts=data_config.get("max_parts"),
365
- train_transform=data_config.get("train_transform", "wide_random_crop"),
366
- image_normalization=data_config.get("image_normalization", "imagenet"),
367
- deterministic_transforms=data_config.get("deterministic_transforms", False),
368
- )
369
- auxiliary = GroundedManifestDataset(
370
- manifests=manifests,
371
- image_size=manifest_config.get("image_size", data_config["image_size"]),
372
- seed=seed + 47,
373
- manifest_weights=manifest_config.get("manifest_weights"),
374
- part_sampling=manifest_config.get("part_sampling", data_config.get("manifest_part_sampling", "all")),
375
- max_parts=manifest_config.get("max_parts", data_config.get("manifest_max_parts")),
376
- train_transform=manifest_config.get("train_transform", data_config.get("train_transform", "wide_random_crop")),
377
- image_normalization=manifest_config.get("image_normalization", data_config.get("image_normalization", "imagenet")),
378
- )
379
- return MixedGroundedIterableDataset(
380
- primary=primary,
381
- auxiliary=auxiliary,
382
- auxiliary_probability=float(data_config.get("manifest_probability", 0.15)),
383
- seed=seed,
384
- )
385
- raise ValueError(f"Unsupported data.type {data_type!r}")
386
-
387
-
388
- def _build_sampler(dataset: GroundedManifestDataset | ProcessedGritDataset | MixedGroundedIterableDataset) -> DistributedSampler | None:
389
- if get_world_size() == 1 or isinstance(dataset, IterableDataset):
390
- return None
391
- return DistributedSampler(dataset, num_replicas=get_world_size(), rank=get_rank(), shuffle=True, drop_last=True)
392
-
393
-
394
- def _local_batch_size(training_config: dict) -> int:
395
- if "batch_size" in training_config:
396
- return int(training_config["batch_size"])
397
- global_batch_size = int(training_config["global_batch_size"])
398
- if global_batch_size % get_world_size() != 0:
399
- raise ValueError("training.global_batch_size must be divisible by world size")
400
- return global_batch_size // get_world_size()
401
-
402
-
403
- def _resume_step(
404
- config: dict,
405
- model: nn.Module,
406
- optimizer: Optimizer,
407
- scheduler: CosineWithWarmup,
408
- scaler: GradScaler,
409
- device: torch.device,
410
- ) -> int:
411
- training_config = config["training"]
412
- resume_env = training_config.get("resume_from_env", "RESUME_FROM_CHECKPOINT")
413
- resume_path = os.environ.get(str(resume_env)) if resume_env else None
414
- if resume_path is None:
415
- resume_path = training_config.get("resume_from")
416
- if resume_path is None and training_config.get("resume", False):
417
- resume_path = latest_checkpoint(config["output_dir"])
418
- if resume_path is None:
419
- return 0
420
- return load_checkpoint(
421
- resume_path,
422
- model,
423
- optimizer,
424
- scheduler,
425
- scaler,
426
- device,
427
- model_only=bool(training_config.get("resume_model_only", False)),
428
- strict_model=bool(training_config.get("resume_strict_model", True)),
429
- )
430
-
431
-
432
- def _format_log_row(row: dict) -> str:
433
- return " ".join(f"{key}={value}" for key, value in row.items())
434
-
435
-
436
- def _scalar_log_value(value: torch.Tensor) -> float | int:
437
- detached = value.detach().cpu()
438
- if detached.dtype == torch.bool:
439
- return int(detached.item())
440
- if detached.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
441
- return int(detached.item())
442
- return float(detached.item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/training/logging.py DELETED
@@ -1,15 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- from pathlib import Path
5
- from typing import Any
6
-
7
-
8
- class JsonlLogger:
9
- def __init__(self, path: str | Path) -> None:
10
- self.path = Path(path)
11
- self.path.parent.mkdir(parents=True, exist_ok=True)
12
-
13
- def write(self, row: dict[str, Any]) -> None:
14
- with self.path.open("a", encoding="utf-8") as handle:
15
- handle.write(json.dumps(row) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/utils/io.py DELETED
@@ -1,29 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- import random
5
-
6
- import numpy as np
7
- import torch
8
- import yaml
9
-
10
-
11
- def load_yaml(path: str) -> dict:
12
- with Path(path).open("r", encoding="utf-8") as f:
13
- return yaml.safe_load(f)
14
-
15
-
16
- def save_yaml(path: str | Path, payload: dict) -> None:
17
- with Path(path).open("w", encoding="utf-8") as f:
18
- yaml.safe_dump(payload, f, sort_keys=False)
19
-
20
-
21
- def ensure_dir(path: str) -> None:
22
- Path(path).mkdir(parents=True, exist_ok=True)
23
-
24
-
25
- def set_seed(seed: int) -> None:
26
- random.seed(seed)
27
- np.random.seed(seed)
28
- torch.manual_seed(seed)
29
- torch.cuda.manual_seed_all(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip_provider.py DELETED
@@ -1,133 +0,0 @@
1
- """HyperView embedding provider for the Hyper3-CLIP v0.5 HF checkpoint."""
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- import numpy as np
10
- import torch
11
- import yaml
12
- from huggingface_hub import snapshot_download
13
- from lancedb.embeddings import EmbeddingFunction
14
- from pydantic import PrivateAttr
15
- from safetensors.torch import load_file
16
-
17
-
18
- class Hyper3ClipEmbeddings(EmbeddingFunction):
19
- """Image embeddings from Hyper3-CLIP v0.5 in Lorentz/hyperboloid space."""
20
-
21
- name: str = "hyper3labs/hyper3-clip-v0.5"
22
- batch_size: int = 8
23
- device: str = "cpu"
24
-
25
- _model: Any = PrivateAttr(default=None)
26
- _transform: Any = PrivateAttr(default=None)
27
-
28
- @property
29
- def geometry(self) -> str:
30
- return "hyperboloid"
31
-
32
- @property
33
- def curvature(self) -> float:
34
- self._ensure_model()
35
- return float(self._model._kappa().detach().cpu().reshape(-1)[0].item())
36
-
37
- def ndims(self) -> int:
38
- return 513
39
-
40
- def _ensure_model(self) -> None:
41
- if self._model is not None:
42
- return
43
-
44
- from hyper3_clip import Hyper3CLIP
45
- from torchvision import transforms
46
-
47
- token = os.environ.get("HF_TOKEN")
48
- local_dir = snapshot_download(
49
- self.name,
50
- allow_patterns=["config.yaml", "model.safetensors"],
51
- token=token,
52
- )
53
- root = Path(local_dir)
54
- config = yaml.safe_load((root / "config.yaml").read_text(encoding="utf-8"))
55
-
56
- model = Hyper3CLIP(**config["model"])
57
- state = load_file(root / "model.safetensors", device="cpu")
58
- state = _normalize_checkpoint_keys(state, model)
59
- model.load_state_dict(state)
60
- model.to(torch.device(self.device))
61
- model.eval()
62
-
63
- self._model = model
64
- image_size = int(config.get("data", {}).get("image_size", 224))
65
- self._transform = transforms.Compose(
66
- [
67
- transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
68
- transforms.CenterCrop(image_size),
69
- transforms.ToTensor(),
70
- transforms.Normalize(
71
- mean=(0.485, 0.456, 0.406),
72
- std=(0.229, 0.224, 0.225),
73
- ),
74
- ]
75
- )
76
-
77
- def compute_source_embeddings(
78
- self,
79
- inputs: Any,
80
- *args: Any,
81
- **kwargs: Any,
82
- ) -> list[np.ndarray | None]:
83
- from PIL import Image
84
- from hyperview.core.sample import Sample
85
-
86
- self._ensure_model()
87
- device = torch.device(self.device)
88
- images = []
89
- for item in self.sanitize_input(inputs):
90
- if isinstance(item, Sample):
91
- with item.load_image() as img:
92
- images.append(img.convert("RGB"))
93
- elif isinstance(item, str):
94
- with Image.open(item) as img:
95
- images.append(img.convert("RGB"))
96
- elif isinstance(item, Image.Image):
97
- images.append(item.convert("RGB"))
98
- else:
99
- raise TypeError(f"Unsupported input type: {type(item)}")
100
-
101
- outputs: list[np.ndarray | None] = []
102
- with torch.inference_mode():
103
- for start in range(0, len(images), self.batch_size):
104
- batch = images[start:start + self.batch_size]
105
- tensor = torch.stack([self._transform(image) for image in batch]).to(device)
106
- encoded = self._model.encode_image(tensor).detach().cpu().numpy().astype(np.float32)
107
- outputs.extend(encoded)
108
- return outputs
109
-
110
- def compute_query_embeddings(
111
- self,
112
- query: Any,
113
- *args: Any,
114
- **kwargs: Any,
115
- ) -> list[np.ndarray | None]:
116
- return self.compute_source_embeddings([query], *args, **kwargs)
117
-
118
-
119
- def _normalize_checkpoint_keys(state: dict[str, torch.Tensor], model: torch.nn.Module) -> dict[str, torch.Tensor]:
120
- """Handle CLIPTextModel wrapper key drift between training and Space runtime."""
121
- model_keys = set(model.state_dict())
122
- old_prefix = "text_encoder.backbone.text_model."
123
- new_prefix = "text_encoder.backbone."
124
- if not any(key.startswith(old_prefix) for key in state):
125
- return state
126
- if any(key.startswith(old_prefix) for key in model_keys):
127
- return state
128
-
129
- normalized: dict[str, torch.Tensor] = {}
130
- for key, value in state.items():
131
- candidate = new_prefix + key[len(old_prefix):] if key.startswith(old_prefix) else key
132
- normalized[candidate if candidate in model_keys else key] = value
133
- return normalized