Update manufacturing Space to latest HyperView public API
Browse files- .hyperview/extensions/manufacturing-readout/panel.js +86 -66
- Dockerfile +3 -1
- README.md +6 -6
- demo.py +28 -27
- hyper3_clip/__init__.py +0 -3
- hyper3_clip/data/__init__.py +0 -14
- hyper3_clip/data/collators.py +0 -209
- hyper3_clip/data/grit_cleaning.py +0 -554
- hyper3_clip/data/grit_webdataset.py +0 -133
- hyper3_clip/data/manifest_dataset.py +0 -120
- hyper3_clip/data/mixed_dataset.py +0 -68
- hyper3_clip/data/transforms.py +0 -125
- hyper3_clip/data/types.py +0 -48
- hyper3_clip/evaluation/__init__.py +0 -20
- hyper3_clip/evaluation/classification.py +0 -105
- hyper3_clip/evaluation/hierarchical.py +0 -118
- hyper3_clip/evaluation/pep.py +0 -462
- hyper3_clip/evaluation/retrieval.py +0 -215
- hyper3_clip/models/__init__.py +0 -3
- hyper3_clip/models/encoders.py +0 -173
- hyper3_clip/models/experimental.py +0 -587
- hyper3_clip/models/himo.py +0 -55
- hyper3_clip/models/hyper3_clip.py +0 -958
- hyper3_clip/models/lorentz.py +0 -265
- hyper3_clip/models/losses.py +0 -1400
- hyper3_clip/models/objectives.py +0 -580
- hyper3_clip/models/tren.py +0 -255
- hyper3_clip/training/__init__.py +0 -1
- hyper3_clip/training/checkpointing.py +0 -91
- hyper3_clip/training/distributed.py +0 -149
- hyper3_clip/training/engine.py +0 -442
- hyper3_clip/training/logging.py +0 -15
- hyper3_clip/utils/io.py +0 -29
- hyper3_clip_provider.py +0 -133
.hyperview/extensions/manufacturing-readout/panel.js
CHANGED
|
@@ -39,7 +39,6 @@ function normalizeModels(value) {
|
|
| 39 |
displayName: String(model.displayName || model.display_name || model.key || `Model ${index + 1}`),
|
| 40 |
buttonLabel: String(model.buttonLabel || model.button_label || `${model.key || "Model"} query`),
|
| 41 |
layoutKey: model.layoutKey || model.layout_key || null,
|
| 42 |
-
spaceKey: model.spaceKey || model.space_key || null,
|
| 43 |
}))
|
| 44 |
.filter((model) => model.layoutKey);
|
| 45 |
}
|
|
@@ -178,6 +177,77 @@ function CompactEvidence({ item, models }) {
|
|
| 178 |
);
|
| 179 |
}
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
function StepBlock({ number, title, children }) {
|
| 182 |
return React.createElement(
|
| 183 |
"div",
|
|
@@ -320,8 +390,7 @@ function choiceFromSimilarity(similarity, examples, models) {
|
|
| 320 |
const sourceKey = source.includes(":") ? source.split(":").pop() : null;
|
| 321 |
const model =
|
| 322 |
models.find((candidate) => candidate.key === sourceKey) ||
|
| 323 |
-
models.find((candidate) => candidate.layoutKey === similarity.layout_key)
|
| 324 |
-
models.find((candidate) => candidate.spaceKey === similarity.space_key);
|
| 325 |
if (!model) return null;
|
| 326 |
const metric = advantageMetric(model.key, item.id);
|
| 327 |
return {
|
|
@@ -332,18 +401,6 @@ function choiceFromSimilarity(similarity, examples, models) {
|
|
| 332 |
};
|
| 333 |
}
|
| 334 |
|
| 335 |
-
async function sendJson(path, payload, method = "POST") {
|
| 336 |
-
const response = await fetch(path, {
|
| 337 |
-
method,
|
| 338 |
-
headers: { "Content-Type": "application/json" },
|
| 339 |
-
body: JSON.stringify(payload),
|
| 340 |
-
});
|
| 341 |
-
if (!response.ok) {
|
| 342 |
-
throw new Error(await response.text());
|
| 343 |
-
}
|
| 344 |
-
return response.json();
|
| 345 |
-
}
|
| 346 |
-
|
| 347 |
function buttonText(model) {
|
| 348 |
if (model.key === "candidate") return "Hyper3";
|
| 349 |
if (model.key === "clip") return "CLIP";
|
|
@@ -416,6 +473,7 @@ function WalkthroughCard({ item, models, onSelectQuery, loadingKey, activeModelK
|
|
| 416 |
{ style: { color: colors.mutedText, fontSize: 10, lineHeight: 1.3 } },
|
| 417 |
"Wrong-line references send operators to the wrong golden sample.",
|
| 418 |
),
|
|
|
|
| 419 |
React.createElement(CompactEvidence, { item, models }),
|
| 420 |
);
|
| 421 |
}
|
|
@@ -424,7 +482,6 @@ export default function ManufacturingPanel() {
|
|
| 424 |
const props = usePanelProps() || {};
|
| 425 |
const commands = usePanelCommands();
|
| 426 |
const runtimeState = usePanelRuntimeState ? usePanelRuntimeState() : {};
|
| 427 |
-
const workspaceId = String(props.workspaceId || props.workspace_id || "manufacturing-visa-reference-clip-hyper3clip");
|
| 428 |
const models = normalizeModels(props.models);
|
| 429 |
const examples = Array.isArray(props.examples) ? props.examples : [];
|
| 430 |
const primaryExample = examples.find((item) => item.id === "fryum") || examples[0] || null;
|
|
@@ -449,6 +506,7 @@ export default function ManufacturingPanel() {
|
|
| 449 |
const item = examples.find((example) => example.queryId === sampleId);
|
| 450 |
const metric = advantageMetric(model.key, item?.id);
|
| 451 |
const nextChoice = {
|
|
|
|
| 452 |
modelName: model.displayName,
|
| 453 |
queryLabel: title(item?.queryLabel || "fryum"),
|
| 454 |
metricLine: metric?.line || null,
|
|
@@ -458,63 +516,25 @@ export default function ManufacturingPanel() {
|
|
| 458 |
setActiveChoice(nextChoice);
|
| 459 |
setLoadingKey(choiceKey);
|
| 460 |
try {
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
source: `manufacturing-demo:${model.key}`,
|
| 471 |
-
focus: "samples",
|
| 472 |
-
persist: "none",
|
| 473 |
-
});
|
| 474 |
-
}
|
| 475 |
-
await sendJson("/api/control/ui/state", {
|
| 476 |
-
workspace_id: workspaceId,
|
| 477 |
-
set_active_layout: true,
|
| 478 |
-
active_layout_key: model.layoutKey,
|
| 479 |
-
set_selection: true,
|
| 480 |
-
selected_ids: [sampleId],
|
| 481 |
-
set_similarity_query: true,
|
| 482 |
-
similarity_query: {
|
| 483 |
-
sample_id: sampleId,
|
| 484 |
-
layout_key: model.layoutKey,
|
| 485 |
-
space_key: model.spaceKey,
|
| 486 |
-
k: 10,
|
| 487 |
-
source: `manufacturing-demo:${model.key}`,
|
| 488 |
-
},
|
| 489 |
-
}, "PATCH");
|
| 490 |
setActiveChoice(nextChoice);
|
| 491 |
setActiveModelKey(model.key);
|
| 492 |
} catch (error) {
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
workspace_id: workspaceId,
|
| 496 |
-
layout_key: model.layoutKey,
|
| 497 |
-
});
|
| 498 |
-
await sendJson("/api/control/ui/similarity", {
|
| 499 |
-
workspace_id: workspaceId,
|
| 500 |
-
sample_id: sampleId,
|
| 501 |
-
layout_key: model.layoutKey,
|
| 502 |
-
space_key: model.spaceKey,
|
| 503 |
-
k: 10,
|
| 504 |
-
source: `manufacturing-demo:${model.key}`,
|
| 505 |
-
});
|
| 506 |
-
setActiveChoice(nextChoice);
|
| 507 |
-
setActiveModelKey(model.key);
|
| 508 |
-
return;
|
| 509 |
-
} catch (fallbackError) {
|
| 510 |
-
const message = fallbackError instanceof Error ? fallbackError.message : String(fallbackError);
|
| 511 |
-
setPanelError(`Could not show neighbors: ${message}`);
|
| 512 |
-
}
|
| 513 |
} finally {
|
| 514 |
setLoadingKey(null);
|
| 515 |
}
|
| 516 |
},
|
| 517 |
-
[commands, examples
|
| 518 |
);
|
| 519 |
|
| 520 |
return React.createElement(
|
|
|
|
| 39 |
displayName: String(model.displayName || model.display_name || model.key || `Model ${index + 1}`),
|
| 40 |
buttonLabel: String(model.buttonLabel || model.button_label || `${model.key || "Model"} query`),
|
| 41 |
layoutKey: model.layoutKey || model.layout_key || null,
|
|
|
|
| 42 |
}))
|
| 43 |
.filter((model) => model.layoutKey);
|
| 44 |
}
|
|
|
|
| 177 |
);
|
| 178 |
}
|
| 179 |
|
| 180 |
+
function ActiveNeighbors({ item, modelKey }) {
|
| 181 |
+
if (!item || !modelKey) return null;
|
| 182 |
+
const summary = item.summaries?.[modelKey] || {};
|
| 183 |
+
const neighbors = Array.isArray(summary.neighbors) ? summary.neighbors.slice(0, 5) : [];
|
| 184 |
+
if (!neighbors.length) return null;
|
| 185 |
+
const cell = {
|
| 186 |
+
padding: "4px 3px",
|
| 187 |
+
borderBottom: `1px solid ${colors.border}`,
|
| 188 |
+
fontSize: 10,
|
| 189 |
+
color: colors.bodyText,
|
| 190 |
+
};
|
| 191 |
+
const head = { ...cell, color: colors.mutedText, fontSize: 9, textTransform: "uppercase" };
|
| 192 |
+
return React.createElement(
|
| 193 |
+
"div",
|
| 194 |
+
{
|
| 195 |
+
style: {
|
| 196 |
+
borderTop: `1px solid ${colors.border}`,
|
| 197 |
+
paddingTop: 7,
|
| 198 |
+
display: "flex",
|
| 199 |
+
flexDirection: "column",
|
| 200 |
+
gap: 4,
|
| 201 |
+
},
|
| 202 |
+
},
|
| 203 |
+
React.createElement(
|
| 204 |
+
"div",
|
| 205 |
+
{ style: { color: colors.strongText, fontSize: 11.5, fontWeight: 900 } },
|
| 206 |
+
modelKey === "candidate" ? "Hyper3 Top Refs" : "CLIP Top Refs",
|
| 207 |
+
),
|
| 208 |
+
React.createElement(
|
| 209 |
+
"table",
|
| 210 |
+
{ style: { width: "100%", borderCollapse: "collapse" } },
|
| 211 |
+
React.createElement(
|
| 212 |
+
"thead",
|
| 213 |
+
null,
|
| 214 |
+
React.createElement(
|
| 215 |
+
"tr",
|
| 216 |
+
null,
|
| 217 |
+
React.createElement("th", { style: head, align: "left" }, "Rank"),
|
| 218 |
+
React.createElement("th", { style: head, align: "left" }, "SKU"),
|
| 219 |
+
React.createElement("th", { style: head, align: "right" }, "Signal"),
|
| 220 |
+
),
|
| 221 |
+
),
|
| 222 |
+
React.createElement(
|
| 223 |
+
"tbody",
|
| 224 |
+
null,
|
| 225 |
+
neighbors.map((neighbor) => {
|
| 226 |
+
const signal = neighbor.sameSkuNormal
|
| 227 |
+
? "correct normal"
|
| 228 |
+
: neighbor.pipeFryumConfusion
|
| 229 |
+
? "wrong line"
|
| 230 |
+
: neighbor.sameSku
|
| 231 |
+
? "same SKU"
|
| 232 |
+
: "other";
|
| 233 |
+
const signalColor = neighbor.sameSkuNormal
|
| 234 |
+
? colors.good
|
| 235 |
+
: neighbor.pipeFryumConfusion
|
| 236 |
+
? colors.error
|
| 237 |
+
: colors.bodyText;
|
| 238 |
+
return React.createElement(
|
| 239 |
+
"tr",
|
| 240 |
+
{ key: `${modelKey}-${neighbor.rank}-${neighbor.id}` },
|
| 241 |
+
React.createElement("td", { style: { ...cell, color: colors.strongText, fontWeight: 800 } }, `#${neighbor.rank}`),
|
| 242 |
+
React.createElement("td", { style: cell }, pretty(neighbor.sku)),
|
| 243 |
+
React.createElement("td", { style: { ...cell, color: signalColor, fontWeight: 800 }, align: "right" }, signal),
|
| 244 |
+
);
|
| 245 |
+
}),
|
| 246 |
+
),
|
| 247 |
+
),
|
| 248 |
+
);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
function StepBlock({ number, title, children }) {
|
| 252 |
return React.createElement(
|
| 253 |
"div",
|
|
|
|
| 390 |
const sourceKey = source.includes(":") ? source.split(":").pop() : null;
|
| 391 |
const model =
|
| 392 |
models.find((candidate) => candidate.key === sourceKey) ||
|
| 393 |
+
models.find((candidate) => candidate.layoutKey === similarity.layout_key);
|
|
|
|
| 394 |
if (!model) return null;
|
| 395 |
const metric = advantageMetric(model.key, item.id);
|
| 396 |
return {
|
|
|
|
| 401 |
};
|
| 402 |
}
|
| 403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
function buttonText(model) {
|
| 405 |
if (model.key === "candidate") return "Hyper3";
|
| 406 |
if (model.key === "clip") return "CLIP";
|
|
|
|
| 473 |
{ style: { color: colors.mutedText, fontSize: 10, lineHeight: 1.3 } },
|
| 474 |
"Wrong-line references send operators to the wrong golden sample.",
|
| 475 |
),
|
| 476 |
+
React.createElement(ActiveNeighbors, { item, modelKey: activeModelKey }),
|
| 477 |
React.createElement(CompactEvidence, { item, models }),
|
| 478 |
);
|
| 479 |
}
|
|
|
|
| 482 |
const props = usePanelProps() || {};
|
| 483 |
const commands = usePanelCommands();
|
| 484 |
const runtimeState = usePanelRuntimeState ? usePanelRuntimeState() : {};
|
|
|
|
| 485 |
const models = normalizeModels(props.models);
|
| 486 |
const examples = Array.isArray(props.examples) ? props.examples : [];
|
| 487 |
const primaryExample = examples.find((item) => item.id === "fryum") || examples[0] || null;
|
|
|
|
| 506 |
const item = examples.find((example) => example.queryId === sampleId);
|
| 507 |
const metric = advantageMetric(model.key, item?.id);
|
| 508 |
const nextChoice = {
|
| 509 |
+
modelKey: model.key,
|
| 510 |
modelName: model.displayName,
|
| 511 |
queryLabel: title(item?.queryLabel || "fryum"),
|
| 512 |
metricLine: metric?.line || null,
|
|
|
|
| 516 |
setActiveChoice(nextChoice);
|
| 517 |
setLoadingKey(choiceKey);
|
| 518 |
try {
|
| 519 |
+
await commands.setActiveLayout(model.layoutKey, { persist: true });
|
| 520 |
+
await commands.showSimilar({
|
| 521 |
+
sampleId,
|
| 522 |
+
layoutKey: model.layoutKey,
|
| 523 |
+
k: 10,
|
| 524 |
+
source: `manufacturing-demo:${model.key}`,
|
| 525 |
+
focus: "samples",
|
| 526 |
+
persist: true,
|
| 527 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
setActiveChoice(nextChoice);
|
| 529 |
setActiveModelKey(model.key);
|
| 530 |
} catch (error) {
|
| 531 |
+
const message = error instanceof Error ? error.message : String(error);
|
| 532 |
+
setPanelError(`Could not show neighbors: ${message}`);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
} finally {
|
| 534 |
setLoadingKey(null);
|
| 535 |
}
|
| 536 |
},
|
| 537 |
+
[commands, examples],
|
| 538 |
);
|
| 539 |
|
| 540 |
return React.createElement(
|
Dockerfile
CHANGED
|
@@ -20,7 +20,8 @@ WORKDIR $HOME/app
|
|
| 20 |
|
| 21 |
RUN pip install --upgrade pip
|
| 22 |
|
| 23 |
-
ARG HYPERVIEW_VERSION=0.6.
|
|
|
|
| 24 |
|
| 25 |
# Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
|
| 26 |
RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
|
@@ -33,6 +34,7 @@ import hyperview as hv
|
|
| 33 |
print("hyperview", hv.__version__, inspect.signature(hv.launch))
|
| 34 |
PY
|
| 35 |
RUN pip install \
|
|
|
|
| 36 |
"datasets>=4.5.0" \
|
| 37 |
"Pillow>=12.0.0" \
|
| 38 |
"timm>=1.0.0" \
|
|
|
|
| 20 |
|
| 21 |
RUN pip install --upgrade pip
|
| 22 |
|
| 23 |
+
ARG HYPERVIEW_VERSION=0.6.2
|
| 24 |
+
ARG HYPER_MODELS_VERSION=0.3.0
|
| 25 |
|
| 26 |
# Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
|
| 27 |
RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
|
|
|
| 34 |
print("hyperview", hv.__version__, inspect.signature(hv.launch))
|
| 35 |
PY
|
| 36 |
RUN pip install \
|
| 37 |
+
"hyper-models[ml]==${HYPER_MODELS_VERSION}" \
|
| 38 |
"datasets>=4.5.0" \
|
| 39 |
"Pillow>=12.0.0" \
|
| 40 |
"timm>=1.0.0" \
|
README.md
CHANGED
|
@@ -14,7 +14,7 @@ This Space builds a balanced subset of the VisA industrial visual anomaly
|
|
| 14 |
dataset and opens HyperView with two side-by-side embedding spaces:
|
| 15 |
|
| 16 |
- CLIP ViT-B/32 in a Euclidean 2D layout
|
| 17 |
-
- Hyper3-CLIP `
|
| 18 |
|
| 19 |
The workflow is inspection reference retrieval: given a production-line
|
| 20 |
inspection image, retrieve the right normal references for the same SKU or
|
|
@@ -59,8 +59,8 @@ VISA_SAMPLES_PER_CATEGORY=12 HYPERVIEW_PORT=6265 \
|
|
| 59 |
uv run python hyperview-spaces/spaces/manufacturing-visa-reference-clip-hyper3clip/demo.py
|
| 60 |
```
|
| 61 |
|
| 62 |
-
Hyper3-CLIP weights are loaded
|
| 63 |
-
`hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs
|
| 64 |
-
`HF_TOKEN` secret with access to that model. If unavailable, the Space can
|
| 65 |
-
with a clearly labeled CLIP fallback unless
|
| 66 |
-
is set.
|
|
|
|
| 14 |
dataset and opens HyperView with two side-by-side embedding spaces:
|
| 15 |
|
| 16 |
- CLIP ViT-B/32 in a Euclidean 2D layout
|
| 17 |
+
- Hyper3-CLIP `hyper3-clip-v0.5` from `hyper-models` in a Poincare 2D layout
|
| 18 |
|
| 19 |
The workflow is inspection reference retrieval: given a production-line
|
| 20 |
inspection image, retrieve the right normal references for the same SKU or
|
|
|
|
| 59 |
uv run python hyperview-spaces/spaces/manufacturing-visa-reference-clip-hyper3clip/demo.py
|
| 60 |
```
|
| 61 |
|
| 62 |
+
Hyper3-CLIP weights are loaded through the `hyper-models` catalog entry for the
|
| 63 |
+
gated `hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs
|
| 64 |
+
an `HF_TOKEN` secret with access to that model. If unavailable, the Space can
|
| 65 |
+
start with a clearly labeled CLIP fallback unless
|
| 66 |
+
`HYPERVIEW_ALLOW_CANDIDATE_FALLBACK=0` is set.
|
demo.py
CHANGED
|
@@ -18,7 +18,6 @@ from PIL import Image, ImageOps
|
|
| 18 |
|
| 19 |
import hyperview as hv
|
| 20 |
|
| 21 |
-
|
| 22 |
SPACE_DIR = Path(__file__).resolve().parent
|
| 23 |
SPACE_HOST = os.environ.get("HYPERVIEW_HOST", "127.0.0.1")
|
| 24 |
SPACE_PORT = int(os.environ.get("HYPERVIEW_PORT", "6265"))
|
|
@@ -29,6 +28,11 @@ EXTENSION_DIR = SPACE_DIR / ".hyperview" / "extensions" / "manufacturing-readout
|
|
| 29 |
SAMPLES_PER_CATEGORY = int(os.environ.get("VISA_SAMPLES_PER_CATEGORY", "4"))
|
| 30 |
TRAIN_FRACTION = float(os.environ.get("VISA_TRAIN_FRACTION", "0.5"))
|
| 31 |
IMAGE_MAX_SIZE = (640, 640)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
ALLOW_CANDIDATE_FALLBACK = os.environ.get("HYPERVIEW_ALLOW_CANDIDATE_FALLBACK", "1").lower() in {
|
| 33 |
"1",
|
| 34 |
"true",
|
|
@@ -90,8 +94,8 @@ MODEL_SPECS = [
|
|
| 90 |
"key": "candidate",
|
| 91 |
"display_name": os.environ.get("VISA_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
|
| 92 |
"button_label": os.environ.get("VISA_CANDIDATE_BUTTON_LABEL", "Show Hyper3 neighbors"),
|
| 93 |
-
"provider": os.environ.get("VISA_CANDIDATE_PROVIDER", "
|
| 94 |
-
"model": os.environ.get("VISA_CANDIDATE_MODEL", "
|
| 95 |
"layout": os.environ.get("VISA_CANDIDATE_LAYOUT", "poincare:2d"),
|
| 96 |
"geometry": os.environ.get("VISA_CANDIDATE_GEOMETRY", "poincare"),
|
| 97 |
"layout_dimension": int(os.environ.get("VISA_CANDIDATE_LAYOUT_DIMENSION", "2")),
|
|
@@ -219,6 +223,7 @@ def add_visa_samples(dataset: hv.Dataset) -> None:
|
|
| 219 |
media_dir = media_root()
|
| 220 |
added = 0
|
| 221 |
updated = 0
|
|
|
|
| 222 |
for record in select_visa_records():
|
| 223 |
sample_id = safe_sample_id(record["category"], record["split_name"], record["row_index"], record["defect_label"])
|
| 224 |
destination = Path(record["local_path"]) if record.get("local_path") else media_dir / f"{sample_id}.jpg"
|
|
@@ -234,12 +239,17 @@ def add_visa_samples(dataset: hv.Dataset) -> None:
|
|
| 234 |
"source_dataset": "BrachioLab/visa",
|
| 235 |
}
|
| 236 |
existed = sample_id in existing_ids
|
|
|
|
|
|
|
|
|
|
| 237 |
dataset.add_image(str(destination), label=record["category"], metadata=metadata, sample_id=sample_id)
|
| 238 |
if existed:
|
| 239 |
updated += 1
|
| 240 |
else:
|
| 241 |
added += 1
|
| 242 |
existing_ids.add(sample_id)
|
|
|
|
|
|
|
| 243 |
print(f"Prepared VisA samples ({added} added, {updated} updated).", flush=True)
|
| 244 |
|
| 245 |
|
|
@@ -262,6 +272,7 @@ def ensure_layouts(dataset: hv.Dataset) -> dict[str, str]:
|
|
| 262 |
)
|
| 263 |
print(warning, flush=True)
|
| 264 |
RUNTIME_WARNINGS.append(warning)
|
|
|
|
| 265 |
spec.update(
|
| 266 |
{
|
| 267 |
"display_name": "Hyper3-CLIP unavailable (CLIP fallback)",
|
|
@@ -270,21 +281,22 @@ def ensure_layouts(dataset: hv.Dataset) -> dict[str, str]:
|
|
| 270 |
"layout_dimension": MODEL_SPECS[0]["layout_dimension"],
|
| 271 |
"panel_title": "Hyper3-CLIP unavailable - showing CLIP fallback",
|
| 272 |
"fallback": True,
|
| 273 |
-
"
|
| 274 |
}
|
| 275 |
)
|
| 276 |
-
layouts[spec["key"]] =
|
| 277 |
continue
|
| 278 |
raise
|
| 279 |
-
spec["space_key"] = space_key
|
| 280 |
print(f"Ensuring {spec['display_name']} layout...", flush=True)
|
| 281 |
-
|
| 282 |
space_key=space_key,
|
| 283 |
layout=spec["layout"],
|
| 284 |
n_neighbors=20,
|
| 285 |
min_dist=0.08,
|
| 286 |
metric=spec["metric"],
|
| 287 |
)
|
|
|
|
|
|
|
| 288 |
return layouts
|
| 289 |
|
| 290 |
|
|
@@ -305,24 +317,19 @@ def model_panel_props(layouts: dict[str, str]) -> list[dict[str, Any]]:
|
|
| 305 |
"displayName": spec["display_name"],
|
| 306 |
"buttonLabel": spec["button_label"],
|
| 307 |
"layoutKey": layout_key,
|
| 308 |
-
"spaceKey": spec.get("space_key") or space_key_from_layout(layout_key),
|
| 309 |
}
|
| 310 |
)
|
| 311 |
return props
|
| 312 |
|
| 313 |
|
| 314 |
-
def space_key_from_layout(layout_key: str) -> str:
|
| 315 |
-
return layout_key.split("__euclidean_umap", 1)[0].split("__poincare_umap", 1)[0]
|
| 316 |
-
|
| 317 |
-
|
| 318 |
def reference_summary(dataset: hv.Dataset, sample_id: str, model_key: str) -> dict[str, Any]:
|
| 319 |
spec = next((item for item in MODEL_SPECS if item["key"] == model_key), None)
|
| 320 |
-
if spec is None or spec.get("
|
| 321 |
return {}
|
| 322 |
query = dataset[sample_id]
|
| 323 |
query_sku = query.metadata.get("sku")
|
| 324 |
query_family = query.metadata.get("product_family")
|
| 325 |
-
neighbors = dataset.find_similar(sample_id, k=10,
|
| 326 |
sku_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("sku") == query_sku)
|
| 327 |
family_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("product_family") == query_family)
|
| 328 |
normal_refs = sum(1 for sample, _distance in neighbors if sample.metadata.get("workflow_role") == "normal_reference")
|
|
@@ -431,17 +438,6 @@ def category_strength_rows(dataset: hv.Dataset) -> list[dict[str, str]]:
|
|
| 431 |
return sorted(rows, key=lambda row: float(row["delta"]), reverse=True)[:3]
|
| 432 |
|
| 433 |
|
| 434 |
-
def register_hyper3_clip_provider() -> None:
|
| 435 |
-
from hyperview.runtime import ProviderRegistry
|
| 436 |
-
|
| 437 |
-
ProviderRegistry().register_python(
|
| 438 |
-
"hyper3-clip",
|
| 439 |
-
"hyper3_clip_provider:Hyper3ClipEmbeddings",
|
| 440 |
-
description="Hyper3-CLIP v0.5 image embeddings from hyper3labs/hyper3-clip-v0.5",
|
| 441 |
-
overwrite=True,
|
| 442 |
-
)
|
| 443 |
-
|
| 444 |
-
|
| 445 |
def build_demo_view(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.ui.View:
|
| 446 |
scatter_panels = [
|
| 447 |
hv.ui.Scatter(
|
|
@@ -460,14 +456,20 @@ def build_demo_view(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.ui.View:
|
|
| 460 |
extension="manufacturing-readout",
|
| 461 |
panel="manufacturing-comparison",
|
| 462 |
position="right",
|
|
|
|
| 463 |
props={
|
| 464 |
-
"workspaceId": WORKSPACE_ID,
|
| 465 |
"models": model_panel_props(layouts),
|
| 466 |
"examples": build_examples(dataset),
|
| 467 |
"strengthRows": category_strength_rows(dataset),
|
| 468 |
"warnings": RUNTIME_WARNINGS,
|
| 469 |
},
|
| 470 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
)
|
| 472 |
|
| 473 |
|
|
@@ -491,7 +493,6 @@ def launch_demo(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.Session:
|
|
| 491 |
|
| 492 |
|
| 493 |
def main() -> None:
|
| 494 |
-
register_hyper3_clip_provider()
|
| 495 |
dataset, layouts = build_dataset()
|
| 496 |
print("Layouts:", flush=True)
|
| 497 |
for spec in MODEL_SPECS:
|
|
|
|
| 18 |
|
| 19 |
import hyperview as hv
|
| 20 |
|
|
|
|
| 21 |
SPACE_DIR = Path(__file__).resolve().parent
|
| 22 |
SPACE_HOST = os.environ.get("HYPERVIEW_HOST", "127.0.0.1")
|
| 23 |
SPACE_PORT = int(os.environ.get("HYPERVIEW_PORT", "6265"))
|
|
|
|
| 28 |
SAMPLES_PER_CATEGORY = int(os.environ.get("VISA_SAMPLES_PER_CATEGORY", "4"))
|
| 29 |
TRAIN_FRACTION = float(os.environ.get("VISA_TRAIN_FRACTION", "0.5"))
|
| 30 |
IMAGE_MAX_SIZE = (640, 640)
|
| 31 |
+
FORCE_SAMPLE_REFRESH = os.environ.get("HYPERVIEW_VISA_FORCE_REFRESH", "").lower() in {
|
| 32 |
+
"1",
|
| 33 |
+
"true",
|
| 34 |
+
"yes",
|
| 35 |
+
}
|
| 36 |
ALLOW_CANDIDATE_FALLBACK = os.environ.get("HYPERVIEW_ALLOW_CANDIDATE_FALLBACK", "1").lower() in {
|
| 37 |
"1",
|
| 38 |
"true",
|
|
|
|
| 94 |
"key": "candidate",
|
| 95 |
"display_name": os.environ.get("VISA_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
|
| 96 |
"button_label": os.environ.get("VISA_CANDIDATE_BUTTON_LABEL", "Show Hyper3 neighbors"),
|
| 97 |
+
"provider": os.environ.get("VISA_CANDIDATE_PROVIDER", "hyper-models"),
|
| 98 |
+
"model": os.environ.get("VISA_CANDIDATE_MODEL", "hyper3-clip-v0.5"),
|
| 99 |
"layout": os.environ.get("VISA_CANDIDATE_LAYOUT", "poincare:2d"),
|
| 100 |
"geometry": os.environ.get("VISA_CANDIDATE_GEOMETRY", "poincare"),
|
| 101 |
"layout_dimension": int(os.environ.get("VISA_CANDIDATE_LAYOUT_DIMENSION", "2")),
|
|
|
|
| 223 |
media_dir = media_root()
|
| 224 |
added = 0
|
| 225 |
updated = 0
|
| 226 |
+
skipped = 0
|
| 227 |
for record in select_visa_records():
|
| 228 |
sample_id = safe_sample_id(record["category"], record["split_name"], record["row_index"], record["defect_label"])
|
| 229 |
destination = Path(record["local_path"]) if record.get("local_path") else media_dir / f"{sample_id}.jpg"
|
|
|
|
| 239 |
"source_dataset": "BrachioLab/visa",
|
| 240 |
}
|
| 241 |
existed = sample_id in existing_ids
|
| 242 |
+
if existed and not FORCE_SAMPLE_REFRESH:
|
| 243 |
+
skipped += 1
|
| 244 |
+
continue
|
| 245 |
dataset.add_image(str(destination), label=record["category"], metadata=metadata, sample_id=sample_id)
|
| 246 |
if existed:
|
| 247 |
updated += 1
|
| 248 |
else:
|
| 249 |
added += 1
|
| 250 |
existing_ids.add(sample_id)
|
| 251 |
+
if skipped:
|
| 252 |
+
print(f"Skipped {skipped} existing VisA sample rows.", flush=True)
|
| 253 |
print(f"Prepared VisA samples ({added} added, {updated} updated).", flush=True)
|
| 254 |
|
| 255 |
|
|
|
|
| 272 |
)
|
| 273 |
print(warning, flush=True)
|
| 274 |
RUNTIME_WARNINGS.append(warning)
|
| 275 |
+
fallback_layout_key = layouts["clip"]
|
| 276 |
spec.update(
|
| 277 |
{
|
| 278 |
"display_name": "Hyper3-CLIP unavailable (CLIP fallback)",
|
|
|
|
| 281 |
"layout_dimension": MODEL_SPECS[0]["layout_dimension"],
|
| 282 |
"panel_title": "Hyper3-CLIP unavailable - showing CLIP fallback",
|
| 283 |
"fallback": True,
|
| 284 |
+
"layout_key": fallback_layout_key,
|
| 285 |
}
|
| 286 |
)
|
| 287 |
+
layouts[spec["key"]] = fallback_layout_key
|
| 288 |
continue
|
| 289 |
raise
|
|
|
|
| 290 |
print(f"Ensuring {spec['display_name']} layout...", flush=True)
|
| 291 |
+
layout_key = dataset.compute_visualization(
|
| 292 |
space_key=space_key,
|
| 293 |
layout=spec["layout"],
|
| 294 |
n_neighbors=20,
|
| 295 |
min_dist=0.08,
|
| 296 |
metric=spec["metric"],
|
| 297 |
)
|
| 298 |
+
spec["layout_key"] = layout_key
|
| 299 |
+
layouts[spec["key"]] = layout_key
|
| 300 |
return layouts
|
| 301 |
|
| 302 |
|
|
|
|
| 317 |
"displayName": spec["display_name"],
|
| 318 |
"buttonLabel": spec["button_label"],
|
| 319 |
"layoutKey": layout_key,
|
|
|
|
| 320 |
}
|
| 321 |
)
|
| 322 |
return props
|
| 323 |
|
| 324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
def reference_summary(dataset: hv.Dataset, sample_id: str, model_key: str) -> dict[str, Any]:
|
| 326 |
spec = next((item for item in MODEL_SPECS if item["key"] == model_key), None)
|
| 327 |
+
if spec is None or spec.get("layout_key") is None:
|
| 328 |
return {}
|
| 329 |
query = dataset[sample_id]
|
| 330 |
query_sku = query.metadata.get("sku")
|
| 331 |
query_family = query.metadata.get("product_family")
|
| 332 |
+
neighbors = dataset.find_similar(sample_id, k=10, layout_key=str(spec["layout_key"]))
|
| 333 |
sku_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("sku") == query_sku)
|
| 334 |
family_hits = sum(1 for sample, _distance in neighbors if sample.metadata.get("product_family") == query_family)
|
| 335 |
normal_refs = sum(1 for sample, _distance in neighbors if sample.metadata.get("workflow_role") == "normal_reference")
|
|
|
|
| 438 |
return sorted(rows, key=lambda row: float(row["delta"]), reverse=True)[:3]
|
| 439 |
|
| 440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
def build_demo_view(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.ui.View:
|
| 442 |
scatter_panels = [
|
| 443 |
hv.ui.Scatter(
|
|
|
|
| 456 |
extension="manufacturing-readout",
|
| 457 |
panel="manufacturing-comparison",
|
| 458 |
position="right",
|
| 459 |
+
layout=hv.ui.PanelLayout(width=340, min_width=300),
|
| 460 |
props={
|
|
|
|
| 461 |
"models": model_panel_props(layouts),
|
| 462 |
"examples": build_examples(dataset),
|
| 463 |
"strengthRows": category_strength_rows(dataset),
|
| 464 |
"warnings": RUNTIME_WARNINGS,
|
| 465 |
},
|
| 466 |
),
|
| 467 |
+
hv.ui.Samples(
|
| 468 |
+
id="manufacturing-neighbors",
|
| 469 |
+
title="Step 2 - Retrieved References",
|
| 470 |
+
position="bottom",
|
| 471 |
+
layout=hv.ui.PanelLayout(height=220, min_height=180),
|
| 472 |
+
),
|
| 473 |
)
|
| 474 |
|
| 475 |
|
|
|
|
| 493 |
|
| 494 |
|
| 495 |
def main() -> None:
|
|
|
|
| 496 |
dataset, layouts = build_dataset()
|
| 497 |
print("Layouts:", flush=True)
|
| 498 |
for spec in MODEL_SPECS:
|
hyper3_clip/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 2 |
-
|
| 3 |
-
__all__ = ["Hyper3CLIP"]
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
from hyper3_clip.data.collators import collate_grounded
|
| 2 |
-
from hyper3_clip.data.grit_webdataset import ProcessedGritDataset
|
| 3 |
-
from hyper3_clip.data.manifest_dataset import GroundedManifestDataset
|
| 4 |
-
from hyper3_clip.data.mixed_dataset import MixedGroundedIterableDataset
|
| 5 |
-
from hyper3_clip.data.types import GroundedParent, GroundedRecord
|
| 6 |
-
|
| 7 |
-
__all__ = [
|
| 8 |
-
"GroundedManifestDataset",
|
| 9 |
-
"GroundedParent",
|
| 10 |
-
"GroundedRecord",
|
| 11 |
-
"MixedGroundedIterableDataset",
|
| 12 |
-
"ProcessedGritDataset",
|
| 13 |
-
"collate_grounded",
|
| 14 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/collators.py
DELETED
|
@@ -1,209 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import re
|
| 4 |
-
from typing import Any
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def _attention_mask(tokens) -> torch.Tensor:
|
| 10 |
-
if "attention_mask" in tokens:
|
| 11 |
-
return tokens["attention_mask"]
|
| 12 |
-
return torch.ones_like(tokens["input_ids"])
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def collate_grounded(
|
| 16 |
-
batch: list[dict[str, Any]],
|
| 17 |
-
tokenizer,
|
| 18 |
-
max_text_length: int,
|
| 19 |
-
*,
|
| 20 |
-
beta_clip_queries: bool = False,
|
| 21 |
-
beta_clip_max_sentences: int = 5,
|
| 22 |
-
beta_clip_max_phrases: int = 30,
|
| 23 |
-
beta_clip_max_queries_per_image: int | None = None,
|
| 24 |
-
beta_clip_use_part_texts: bool = True,
|
| 25 |
-
) -> dict[str, torch.Tensor]:
|
| 26 |
-
images = torch.stack([item["image"] for item in batch])
|
| 27 |
-
captions = [item["caption"] for item in batch]
|
| 28 |
-
part_image_rows: list[torch.Tensor] = []
|
| 29 |
-
part_texts: list[str] = []
|
| 30 |
-
part_owner: list[int] = []
|
| 31 |
-
|
| 32 |
-
for batch_index, item in enumerate(batch):
|
| 33 |
-
for part_index, part_image in enumerate(item["part_images"]):
|
| 34 |
-
part_image_rows.append(part_image)
|
| 35 |
-
part_texts.append(item["part_texts"][part_index])
|
| 36 |
-
part_owner.append(batch_index)
|
| 37 |
-
|
| 38 |
-
text = tokenizer(captions, padding=True, truncation=True, max_length=max_text_length, return_tensors="pt")
|
| 39 |
-
text_attention_mask = _attention_mask(text)
|
| 40 |
-
if part_image_rows:
|
| 41 |
-
part_images = torch.stack(part_image_rows)
|
| 42 |
-
part_text = tokenizer(part_texts, padding=True, truncation=True, max_length=max_text_length, return_tensors="pt")
|
| 43 |
-
part_text_input_ids = part_text["input_ids"]
|
| 44 |
-
part_text_attention_mask = _attention_mask(part_text)
|
| 45 |
-
else:
|
| 46 |
-
part_images = images.new_zeros((0, *images.shape[1:]))
|
| 47 |
-
empty_text_shape = (0, text["input_ids"].shape[1])
|
| 48 |
-
part_text_input_ids = text["input_ids"].new_zeros(empty_text_shape)
|
| 49 |
-
part_text_attention_mask = text_attention_mask.new_zeros(empty_text_shape)
|
| 50 |
-
|
| 51 |
-
collated = {
|
| 52 |
-
"image": images,
|
| 53 |
-
"part_images": part_images,
|
| 54 |
-
"part_owner": torch.tensor(part_owner, dtype=torch.long),
|
| 55 |
-
"text_input_ids": text["input_ids"],
|
| 56 |
-
"text_attention_mask": text_attention_mask,
|
| 57 |
-
"part_text_input_ids": part_text_input_ids,
|
| 58 |
-
"part_text_attention_mask": part_text_attention_mask,
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
if beta_clip_queries:
|
| 62 |
-
query_texts: list[str] = []
|
| 63 |
-
query_owner: list[int] = []
|
| 64 |
-
query_type: list[int] = []
|
| 65 |
-
query_parent: list[int] = []
|
| 66 |
-
query_weight: list[float] = []
|
| 67 |
-
query_source_part: list[int] = []
|
| 68 |
-
part_offsets = []
|
| 69 |
-
cursor = 0
|
| 70 |
-
for item in batch:
|
| 71 |
-
part_offsets.append(cursor)
|
| 72 |
-
cursor += len(item["part_images"])
|
| 73 |
-
for batch_index, item in enumerate(batch):
|
| 74 |
-
image_queries = _beta_clip_query_items_for_item(
|
| 75 |
-
caption=item["caption"],
|
| 76 |
-
part_texts=item["part_texts"],
|
| 77 |
-
max_sentences=beta_clip_max_sentences,
|
| 78 |
-
max_phrases=beta_clip_max_phrases,
|
| 79 |
-
max_queries=beta_clip_max_queries_per_image,
|
| 80 |
-
use_part_texts=beta_clip_use_part_texts,
|
| 81 |
-
)
|
| 82 |
-
query_offset = len(query_texts)
|
| 83 |
-
for query in image_queries:
|
| 84 |
-
query_texts.append(query["text"])
|
| 85 |
-
query_owner.append(batch_index)
|
| 86 |
-
query_type.append(query["type"])
|
| 87 |
-
local_parent = query["parent"]
|
| 88 |
-
query_parent.append(-1 if local_parent < 0 else query_offset + local_parent)
|
| 89 |
-
query_weight.append(query["weight"])
|
| 90 |
-
local_part = query["source_part"]
|
| 91 |
-
query_source_part.append(-1 if local_part < 0 else part_offsets[batch_index] + local_part)
|
| 92 |
-
query_tokens = tokenizer(query_texts, padding=True, truncation=True, max_length=max_text_length, return_tensors="pt")
|
| 93 |
-
collated.update(
|
| 94 |
-
{
|
| 95 |
-
"beta_query_input_ids": query_tokens["input_ids"],
|
| 96 |
-
"beta_query_attention_mask": _attention_mask(query_tokens),
|
| 97 |
-
"beta_query_owner": torch.tensor(query_owner, dtype=torch.long),
|
| 98 |
-
"beta_query_type": torch.tensor(query_type, dtype=torch.long),
|
| 99 |
-
"beta_query_parent": torch.tensor(query_parent, dtype=torch.long),
|
| 100 |
-
"beta_query_weight": torch.tensor(query_weight, dtype=torch.float32),
|
| 101 |
-
"beta_query_source_part": torch.tensor(query_source_part, dtype=torch.long),
|
| 102 |
-
}
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
return collated
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def _beta_clip_queries_for_item(
|
| 109 |
-
*,
|
| 110 |
-
caption: str,
|
| 111 |
-
part_texts: list[str],
|
| 112 |
-
max_sentences: int,
|
| 113 |
-
max_phrases: int,
|
| 114 |
-
max_queries: int | None,
|
| 115 |
-
use_part_texts: bool,
|
| 116 |
-
) -> list[str]:
|
| 117 |
-
return [
|
| 118 |
-
query["text"]
|
| 119 |
-
for query in _beta_clip_query_items_for_item(
|
| 120 |
-
caption=caption,
|
| 121 |
-
part_texts=part_texts,
|
| 122 |
-
max_sentences=max_sentences,
|
| 123 |
-
max_phrases=max_phrases,
|
| 124 |
-
max_queries=max_queries,
|
| 125 |
-
use_part_texts=use_part_texts,
|
| 126 |
-
)
|
| 127 |
-
]
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def _beta_clip_query_items_for_item(
|
| 131 |
-
*,
|
| 132 |
-
caption: str,
|
| 133 |
-
part_texts: list[str],
|
| 134 |
-
max_sentences: int,
|
| 135 |
-
max_phrases: int,
|
| 136 |
-
max_queries: int | None,
|
| 137 |
-
use_part_texts: bool,
|
| 138 |
-
) -> list[dict[str, str | int | float]]:
|
| 139 |
-
queries: list[str] = []
|
| 140 |
-
query_items: list[dict[str, str | int | float]] = []
|
| 141 |
-
seen: set[str] = set()
|
| 142 |
-
|
| 143 |
-
def add_query(text: str, *, query_type: int, parent: int = 0, weight: float = 1.0, source_part: int = -1) -> int:
|
| 144 |
-
normalized = " ".join(str(text).strip().split())
|
| 145 |
-
key = normalized.casefold()
|
| 146 |
-
if len(normalized) >= 3 and key not in seen:
|
| 147 |
-
seen.add(key)
|
| 148 |
-
queries.append(normalized)
|
| 149 |
-
query_items.append(
|
| 150 |
-
{
|
| 151 |
-
"text": normalized,
|
| 152 |
-
"type": query_type,
|
| 153 |
-
"parent": parent,
|
| 154 |
-
"weight": weight,
|
| 155 |
-
"source_part": source_part,
|
| 156 |
-
}
|
| 157 |
-
)
|
| 158 |
-
return len(queries) - 1
|
| 159 |
-
return -1
|
| 160 |
-
|
| 161 |
-
caption_index = add_query(caption, query_type=0, parent=-1, weight=0.0)
|
| 162 |
-
if caption_index < 0:
|
| 163 |
-
caption_index = 0
|
| 164 |
-
sentence_indices: list[int] = []
|
| 165 |
-
for sentence in _split_sentences(caption)[: max(0, max_sentences)]:
|
| 166 |
-
sentence_index = add_query(sentence, query_type=1, parent=caption_index, weight=1.0)
|
| 167 |
-
if sentence_index >= 0:
|
| 168 |
-
sentence_indices.append(sentence_index)
|
| 169 |
-
phrase_parent = sentence_indices[0] if sentence_indices else caption_index
|
| 170 |
-
|
| 171 |
-
phrase_count = 0
|
| 172 |
-
if use_part_texts:
|
| 173 |
-
for part_index, part_text in enumerate(part_texts):
|
| 174 |
-
if phrase_count >= max_phrases:
|
| 175 |
-
break
|
| 176 |
-
before = len(queries)
|
| 177 |
-
add_query(part_text, query_type=2, parent=caption_index, weight=0.75, source_part=part_index)
|
| 178 |
-
phrase_count += int(len(queries) > before)
|
| 179 |
-
|
| 180 |
-
if phrase_count < max_phrases:
|
| 181 |
-
for phrase in _extract_lightweight_phrases(caption):
|
| 182 |
-
if phrase_count >= max_phrases:
|
| 183 |
-
break
|
| 184 |
-
before = len(queries)
|
| 185 |
-
add_query(phrase, query_type=3, parent=phrase_parent, weight=0.5)
|
| 186 |
-
phrase_count += int(len(queries) > before)
|
| 187 |
-
|
| 188 |
-
if max_queries is not None:
|
| 189 |
-
return query_items[: max(1, int(max_queries))]
|
| 190 |
-
return query_items
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
def _split_sentences(text: str) -> list[str]:
|
| 194 |
-
return [part.strip() for part in re.split(r"(?<=[.!?;])\s+|\n+", text) if part.strip()]
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def _extract_lightweight_phrases(text: str) -> list[str]:
|
| 198 |
-
chunks = re.split(r"[,;:()]|\s+(?:and|with|near|beside|behind|under|above|around|next to)\s+", text, flags=re.I)
|
| 199 |
-
phrases: list[str] = []
|
| 200 |
-
for chunk in chunks:
|
| 201 |
-
words = re.findall(r"[A-Za-z0-9]+(?:[-'][A-Za-z0-9]+)?", chunk)
|
| 202 |
-
if 2 <= len(words) <= 8:
|
| 203 |
-
phrases.append(" ".join(words))
|
| 204 |
-
elif len(words) > 8:
|
| 205 |
-
for start in range(0, len(words) - 1, 4):
|
| 206 |
-
phrase = " ".join(words[start : start + 6])
|
| 207 |
-
if len(phrase.split()) >= 2:
|
| 208 |
-
phrases.append(phrase)
|
| 209 |
-
return phrases
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/grit_cleaning.py
DELETED
|
@@ -1,554 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
import re
|
| 5 |
-
import unicodedata
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
from PIL import Image
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
SPACE_RE = re.compile(r"\s+")
|
| 14 |
-
URL_RE = re.compile(r"(https?://|www\.|\.com\b|\.net\b|\.org\b)", re.IGNORECASE)
|
| 15 |
-
EMAIL_RE = re.compile(r"\b[\w.+-]+@[\w.-]+\.\w+\b")
|
| 16 |
-
HTML_RE = re.compile(r"<[^>]+>")
|
| 17 |
-
TOKEN_RE = re.compile(r"[a-z0-9]+(?:[-'][a-z0-9]+)?")
|
| 18 |
-
|
| 19 |
-
LEADING_DETERMINERS = {
|
| 20 |
-
"a",
|
| 21 |
-
"an",
|
| 22 |
-
"the",
|
| 23 |
-
"this",
|
| 24 |
-
"that",
|
| 25 |
-
"these",
|
| 26 |
-
"those",
|
| 27 |
-
"his",
|
| 28 |
-
"her",
|
| 29 |
-
"its",
|
| 30 |
-
"their",
|
| 31 |
-
"my",
|
| 32 |
-
"our",
|
| 33 |
-
"your",
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
QUANTITY_WORDS = {
|
| 37 |
-
"one",
|
| 38 |
-
"two",
|
| 39 |
-
"three",
|
| 40 |
-
"four",
|
| 41 |
-
"five",
|
| 42 |
-
"six",
|
| 43 |
-
"many",
|
| 44 |
-
"several",
|
| 45 |
-
"some",
|
| 46 |
-
"few",
|
| 47 |
-
"group",
|
| 48 |
-
"pair",
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
VISUAL_MODIFIERS = {
|
| 52 |
-
"big",
|
| 53 |
-
"small",
|
| 54 |
-
"large",
|
| 55 |
-
"little",
|
| 56 |
-
"old",
|
| 57 |
-
"young",
|
| 58 |
-
"new",
|
| 59 |
-
"red",
|
| 60 |
-
"blue",
|
| 61 |
-
"green",
|
| 62 |
-
"yellow",
|
| 63 |
-
"white",
|
| 64 |
-
"black",
|
| 65 |
-
"brown",
|
| 66 |
-
"gray",
|
| 67 |
-
"grey",
|
| 68 |
-
"orange",
|
| 69 |
-
"pink",
|
| 70 |
-
"purple",
|
| 71 |
-
"colorful",
|
| 72 |
-
"colourful",
|
| 73 |
-
"wooden",
|
| 74 |
-
"metal",
|
| 75 |
-
"plastic",
|
| 76 |
-
"striped",
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
NON_VISUAL_HEADS = {
|
| 80 |
-
"background",
|
| 81 |
-
"foreground",
|
| 82 |
-
"caption",
|
| 83 |
-
"copyright",
|
| 84 |
-
"credit",
|
| 85 |
-
"item",
|
| 86 |
-
"edge",
|
| 87 |
-
"image",
|
| 88 |
-
"left",
|
| 89 |
-
"logo",
|
| 90 |
-
"method",
|
| 91 |
-
"middle",
|
| 92 |
-
"number",
|
| 93 |
-
"photo",
|
| 94 |
-
"photograph",
|
| 95 |
-
"picture",
|
| 96 |
-
"place",
|
| 97 |
-
"right",
|
| 98 |
-
"scene",
|
| 99 |
-
"side",
|
| 100 |
-
"statement",
|
| 101 |
-
"stock",
|
| 102 |
-
"text",
|
| 103 |
-
"thing",
|
| 104 |
-
"view",
|
| 105 |
-
"watermark",
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
NON_VISUAL_PHRASES = {
|
| 109 |
-
"available at",
|
| 110 |
-
"all rights reserved",
|
| 111 |
-
"click here",
|
| 112 |
-
"copyright",
|
| 113 |
-
"getty images",
|
| 114 |
-
"istock",
|
| 115 |
-
"shutterstock",
|
| 116 |
-
"stock photo",
|
| 117 |
-
}
|
| 118 |
-
|
| 119 |
-
ACTION_SPLITS = (
|
| 120 |
-
" standing ",
|
| 121 |
-
" sitting ",
|
| 122 |
-
" lying ",
|
| 123 |
-
" walking ",
|
| 124 |
-
" running ",
|
| 125 |
-
" flying ",
|
| 126 |
-
" eating ",
|
| 127 |
-
" holding ",
|
| 128 |
-
" wearing ",
|
| 129 |
-
" playing ",
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
PREPOSITION_SPLITS = (
|
| 133 |
-
" next to ",
|
| 134 |
-
" in front of ",
|
| 135 |
-
" on top of ",
|
| 136 |
-
" inside ",
|
| 137 |
-
" outside ",
|
| 138 |
-
" with ",
|
| 139 |
-
" without ",
|
| 140 |
-
" near ",
|
| 141 |
-
" beside ",
|
| 142 |
-
" behind ",
|
| 143 |
-
" under ",
|
| 144 |
-
" over ",
|
| 145 |
-
" from ",
|
| 146 |
-
" into ",
|
| 147 |
-
" across ",
|
| 148 |
-
" around ",
|
| 149 |
-
" at ",
|
| 150 |
-
" on ",
|
| 151 |
-
" in ",
|
| 152 |
-
" of ",
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
CANONICAL_REWRITES = {
|
| 156 |
-
"aeroplane": "airplane",
|
| 157 |
-
"aircraft": "airplane",
|
| 158 |
-
"bike": "bicycle",
|
| 159 |
-
"cell phone": "phone",
|
| 160 |
-
"mobile phone": "phone",
|
| 161 |
-
"motorbike": "motorcycle",
|
| 162 |
-
"plant pot": "potted plant",
|
| 163 |
-
"tv": "television",
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
TOKEN_SYNONYMS = {
|
| 167 |
-
"airplane": {"airplane", "aeroplane", "aircraft", "plane"},
|
| 168 |
-
"bicycle": {"bicycle", "bike"},
|
| 169 |
-
"motorcycle": {"motorcycle", "motorbike"},
|
| 170 |
-
"person": {"person", "people", "man", "woman", "boy", "girl", "teenager", "teenagers"},
|
| 171 |
-
"people": {"person", "people", "man", "woman", "men", "women", "children", "teenager", "teenagers"},
|
| 172 |
-
"phone": {"phone", "cell", "mobile", "telephone"},
|
| 173 |
-
"television": {"television", "tv"},
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
HUMAN_GROUP_WORDS = {
|
| 177 |
-
"adults",
|
| 178 |
-
"boys",
|
| 179 |
-
"children",
|
| 180 |
-
"crowd",
|
| 181 |
-
"girls",
|
| 182 |
-
"kids",
|
| 183 |
-
"men",
|
| 184 |
-
"people",
|
| 185 |
-
"teenagers",
|
| 186 |
-
"teens",
|
| 187 |
-
"women",
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
HUMAN_SINGULAR_WORDS = {
|
| 191 |
-
"adult",
|
| 192 |
-
"baby",
|
| 193 |
-
"boy",
|
| 194 |
-
"child",
|
| 195 |
-
"girl",
|
| 196 |
-
"kid",
|
| 197 |
-
"man",
|
| 198 |
-
"person",
|
| 199 |
-
"teenager",
|
| 200 |
-
"woman",
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
-
HUMAN_ROLE_WORDS = {
|
| 204 |
-
"actor",
|
| 205 |
-
"actress",
|
| 206 |
-
"artist",
|
| 207 |
-
"athlete",
|
| 208 |
-
"boss",
|
| 209 |
-
"coach",
|
| 210 |
-
"doctor",
|
| 211 |
-
"lawyer",
|
| 212 |
-
"manager",
|
| 213 |
-
"minister",
|
| 214 |
-
"musician",
|
| 215 |
-
"player",
|
| 216 |
-
"politician",
|
| 217 |
-
"president",
|
| 218 |
-
"singer",
|
| 219 |
-
"teacher",
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
DEFAULT_HYPERNYMS = {
|
| 223 |
-
"airplane": ("aircraft", "vehicle"),
|
| 224 |
-
"apple": ("fruit", "food"),
|
| 225 |
-
"backpack": ("bag", "accessory"),
|
| 226 |
-
"baseball bat": ("bat", "sports equipment"),
|
| 227 |
-
"bear": ("mammal", "animal"),
|
| 228 |
-
"bicycle": ("vehicle",),
|
| 229 |
-
"bird": ("animal",),
|
| 230 |
-
"boat": ("vehicle",),
|
| 231 |
-
"bottle": ("container",),
|
| 232 |
-
"bus": ("vehicle",),
|
| 233 |
-
"car": ("vehicle",),
|
| 234 |
-
"cat": ("mammal", "animal"),
|
| 235 |
-
"chair": ("furniture",),
|
| 236 |
-
"cup": ("container",),
|
| 237 |
-
"dog": ("mammal", "animal"),
|
| 238 |
-
"flower": ("plant",),
|
| 239 |
-
"fork": ("utensil",),
|
| 240 |
-
"horse": ("mammal", "animal"),
|
| 241 |
-
"knife": ("utensil",),
|
| 242 |
-
"lamp": ("light", "furniture"),
|
| 243 |
-
"laptop": ("computer", "electronic device"),
|
| 244 |
-
"person": ("human", "animal"),
|
| 245 |
-
"phone": ("electronic device",),
|
| 246 |
-
"potted plant": ("plant",),
|
| 247 |
-
"shirt": ("clothing",),
|
| 248 |
-
"shoe": ("footwear", "clothing"),
|
| 249 |
-
"skis": ("sports equipment",),
|
| 250 |
-
"spoon": ("utensil",),
|
| 251 |
-
"sports ball": ("ball", "sports equipment"),
|
| 252 |
-
"table": ("furniture",),
|
| 253 |
-
"television": ("electronic device",),
|
| 254 |
-
"train": ("vehicle",),
|
| 255 |
-
"tree": ("plant",),
|
| 256 |
-
"truck": ("vehicle",),
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
@dataclass(frozen=True)
|
| 261 |
-
class ImageQuality:
|
| 262 |
-
width: int
|
| 263 |
-
height: int
|
| 264 |
-
brightness: float
|
| 265 |
-
contrast: float
|
| 266 |
-
entropy: float
|
| 267 |
-
black_border_fraction: float
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
@dataclass(frozen=True)
|
| 271 |
-
class ParentCleanDecision:
|
| 272 |
-
original_text: str
|
| 273 |
-
canonical_text: str
|
| 274 |
-
keep: bool
|
| 275 |
-
quality_score: float
|
| 276 |
-
reasons: tuple[str, ...]
|
| 277 |
-
hypernyms: tuple[str, ...]
|
| 278 |
-
image_quality: ImageQuality | None = None
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def clean_parent(
|
| 282 |
-
parent_text: str,
|
| 283 |
-
caption: str = "",
|
| 284 |
-
parent_image: Image.Image | None = None,
|
| 285 |
-
min_score: float = 0.45,
|
| 286 |
-
hypernym_map: dict[str, tuple[str, ...]] | None = None,
|
| 287 |
-
) -> ParentCleanDecision:
|
| 288 |
-
canonical = canonicalize_parent_text(parent_text)
|
| 289 |
-
reasons: list[str] = []
|
| 290 |
-
fatal = False
|
| 291 |
-
|
| 292 |
-
if not canonical:
|
| 293 |
-
reasons.append("empty_after_canonicalization")
|
| 294 |
-
fatal = True
|
| 295 |
-
if looks_like_boilerplate(parent_text):
|
| 296 |
-
reasons.append("boilerplate_or_url")
|
| 297 |
-
fatal = True
|
| 298 |
-
if canonical and is_non_visual_parent(canonical):
|
| 299 |
-
reasons.append("non_visual_parent")
|
| 300 |
-
fatal = True
|
| 301 |
-
if canonical and len(canonical.split()) > 6:
|
| 302 |
-
reasons.append("too_long_for_clean_parent")
|
| 303 |
-
if canonical and caption_duplicates_parent(caption, canonical):
|
| 304 |
-
reasons.append("duplicates_caption")
|
| 305 |
-
fatal = True
|
| 306 |
-
if canonical and not caption_mentions_parent(caption, canonical):
|
| 307 |
-
reasons.append("caption_does_not_mention_parent")
|
| 308 |
-
|
| 309 |
-
image_quality = image_quality_stats(parent_image) if parent_image is not None else None
|
| 310 |
-
if image_quality is not None:
|
| 311 |
-
if image_quality.entropy < 1.0 or image_quality.contrast < 3.0:
|
| 312 |
-
reasons.append("low_information_crop")
|
| 313 |
-
if image_quality.black_border_fraction > 0.65:
|
| 314 |
-
reasons.append("mostly_black_border")
|
| 315 |
-
if (
|
| 316 |
-
"caption_does_not_mention_parent" in reasons
|
| 317 |
-
and "low_information_crop" in reasons
|
| 318 |
-
and "mostly_black_border" in reasons
|
| 319 |
-
):
|
| 320 |
-
reasons.append("text_slide_or_bad_crop")
|
| 321 |
-
fatal = True
|
| 322 |
-
|
| 323 |
-
score = parent_quality_score(canonical, reasons, image_quality)
|
| 324 |
-
hmap = DEFAULT_HYPERNYMS if hypernym_map is None else hypernym_map
|
| 325 |
-
hypernyms = tuple(hmap.get(canonical, ()))
|
| 326 |
-
keep = not fatal and score >= min_score
|
| 327 |
-
return ParentCleanDecision(
|
| 328 |
-
original_text=parent_text,
|
| 329 |
-
canonical_text=canonical,
|
| 330 |
-
keep=keep,
|
| 331 |
-
quality_score=score,
|
| 332 |
-
reasons=tuple(reasons),
|
| 333 |
-
hypernyms=hypernyms,
|
| 334 |
-
image_quality=image_quality,
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
def canonicalize_parent_text(text: str) -> str:
|
| 339 |
-
text = normalize_text(text)
|
| 340 |
-
if not text:
|
| 341 |
-
return ""
|
| 342 |
-
text = strip_boilerplate_tail(text)
|
| 343 |
-
for marker in ACTION_SPLITS:
|
| 344 |
-
if marker in text:
|
| 345 |
-
text = text.split(marker, maxsplit=1)[0].strip()
|
| 346 |
-
break
|
| 347 |
-
human = canonicalize_human_text(text)
|
| 348 |
-
if human:
|
| 349 |
-
return human
|
| 350 |
-
for marker in PREPOSITION_SPLITS:
|
| 351 |
-
if marker in text:
|
| 352 |
-
text = text.split(marker, maxsplit=1)[0].strip()
|
| 353 |
-
break
|
| 354 |
-
tokens = TOKEN_RE.findall(text)
|
| 355 |
-
while tokens and (tokens[0] in LEADING_DETERMINERS or tokens[0] in QUANTITY_WORDS):
|
| 356 |
-
tokens.pop(0)
|
| 357 |
-
if len(tokens) > 2:
|
| 358 |
-
while tokens and tokens[0] in VISUAL_MODIFIERS:
|
| 359 |
-
tokens.pop(0)
|
| 360 |
-
candidate = " ".join(tokens).strip()
|
| 361 |
-
candidate = CANONICAL_REWRITES.get(candidate, candidate)
|
| 362 |
-
if candidate.endswith("s") and candidate[:-1] in DEFAULT_HYPERNYMS:
|
| 363 |
-
candidate = candidate[:-1]
|
| 364 |
-
return candidate
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
def canonicalize_human_text(text: str) -> str:
|
| 368 |
-
tokens = TOKEN_RE.findall(text)
|
| 369 |
-
if not tokens:
|
| 370 |
-
return ""
|
| 371 |
-
token_set = set(tokens)
|
| 372 |
-
if token_set.intersection(HUMAN_GROUP_WORDS):
|
| 373 |
-
return "people"
|
| 374 |
-
for word in ("baby", "woman", "man", "girl", "boy", "child", "teenager", "person"):
|
| 375 |
-
if word in token_set:
|
| 376 |
-
return word
|
| 377 |
-
if token_set.intersection(HUMAN_ROLE_WORDS):
|
| 378 |
-
return "person"
|
| 379 |
-
return ""
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
def normalize_text(text: str) -> str:
|
| 383 |
-
text = unicodedata.normalize("NFKC", str(text))
|
| 384 |
-
text = HTML_RE.sub(" ", text)
|
| 385 |
-
text = text.replace("_", " ").replace("/", " ")
|
| 386 |
-
text = text.strip().lower()
|
| 387 |
-
text = text.strip(" \t\r\n\"'.,;:!?()[]{}")
|
| 388 |
-
return SPACE_RE.sub(" ", text)
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
def strip_boilerplate_tail(text: str) -> str:
|
| 392 |
-
for marker in (" - available at ", " available at ", " | ", " © ", " copyright "):
|
| 393 |
-
if marker in text:
|
| 394 |
-
text = text.split(marker, maxsplit=1)[0]
|
| 395 |
-
return text.strip()
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
def looks_like_boilerplate(text: str) -> bool:
|
| 399 |
-
normalized = normalize_text(text)
|
| 400 |
-
if URL_RE.search(normalized) or EMAIL_RE.search(normalized):
|
| 401 |
-
return True
|
| 402 |
-
return any(phrase in normalized for phrase in NON_VISUAL_PHRASES)
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
def is_non_visual_parent(canonical: str) -> bool:
|
| 406 |
-
tokens = canonical.split()
|
| 407 |
-
if not tokens:
|
| 408 |
-
return True
|
| 409 |
-
if canonical in NON_VISUAL_HEADS:
|
| 410 |
-
return True
|
| 411 |
-
if tokens[-1] in NON_VISUAL_HEADS:
|
| 412 |
-
return True
|
| 413 |
-
if all(token.isdigit() for token in tokens):
|
| 414 |
-
return True
|
| 415 |
-
return False
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
def caption_mentions_parent(caption: str, canonical_parent: str) -> bool:
|
| 419 |
-
if not caption or not canonical_parent:
|
| 420 |
-
return True
|
| 421 |
-
caption_tokens = set(TOKEN_RE.findall(normalize_text(caption)))
|
| 422 |
-
parent_tokens = TOKEN_RE.findall(canonical_parent)
|
| 423 |
-
if not parent_tokens:
|
| 424 |
-
return False
|
| 425 |
-
for token in parent_tokens:
|
| 426 |
-
synonyms = TOKEN_SYNONYMS.get(token, {token})
|
| 427 |
-
if not caption_tokens.intersection(synonyms):
|
| 428 |
-
return False
|
| 429 |
-
return True
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
def caption_duplicates_parent(caption: str, canonical_parent: str) -> bool:
|
| 433 |
-
if not caption or not canonical_parent:
|
| 434 |
-
return False
|
| 435 |
-
caption_tokens = TOKEN_RE.findall(normalize_text(caption))
|
| 436 |
-
parent_tokens = TOKEN_RE.findall(canonical_parent)
|
| 437 |
-
if len(parent_tokens) < 6 or not caption_tokens:
|
| 438 |
-
return False
|
| 439 |
-
caption_set = set(caption_tokens)
|
| 440 |
-
parent_set = set(parent_tokens)
|
| 441 |
-
overlap = len(caption_set.intersection(parent_set))
|
| 442 |
-
return overlap / max(len(parent_set), 1) >= 0.85 and overlap / max(len(caption_set), 1) >= 0.65
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
def image_quality_stats(image: Image.Image | None) -> ImageQuality | None:
|
| 446 |
-
if image is None:
|
| 447 |
-
return None
|
| 448 |
-
rgb = image.convert("RGB")
|
| 449 |
-
width, height = rgb.size
|
| 450 |
-
gray = np.asarray(rgb.convert("L"), dtype=np.float32)
|
| 451 |
-
brightness = float(gray.mean())
|
| 452 |
-
contrast = float(gray.std())
|
| 453 |
-
hist, _ = np.histogram(gray, bins=64, range=(0, 256), density=True)
|
| 454 |
-
hist = hist[hist > 0]
|
| 455 |
-
entropy = float(-(hist * np.log2(hist)).sum())
|
| 456 |
-
border = _border_pixels(gray)
|
| 457 |
-
black_border_fraction = float((border < 8).mean()) if border.size else 0.0
|
| 458 |
-
return ImageQuality(
|
| 459 |
-
width=width,
|
| 460 |
-
height=height,
|
| 461 |
-
brightness=brightness,
|
| 462 |
-
contrast=contrast,
|
| 463 |
-
entropy=entropy,
|
| 464 |
-
black_border_fraction=black_border_fraction,
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
def parent_quality_score(canonical: str, reasons: list[str], image_quality: ImageQuality | None) -> float:
|
| 469 |
-
if not canonical:
|
| 470 |
-
return 0.0
|
| 471 |
-
score = 1.0
|
| 472 |
-
penalties = {
|
| 473 |
-
"caption_does_not_mention_parent": 0.20,
|
| 474 |
-
"too_long_for_clean_parent": 0.20,
|
| 475 |
-
"low_information_crop": 0.25,
|
| 476 |
-
"mostly_black_border": 0.15,
|
| 477 |
-
"non_visual_parent": 0.60,
|
| 478 |
-
"boilerplate_or_url": 0.80,
|
| 479 |
-
"duplicates_caption": 0.80,
|
| 480 |
-
"text_slide_or_bad_crop": 0.80,
|
| 481 |
-
}
|
| 482 |
-
for reason in reasons:
|
| 483 |
-
score -= penalties.get(reason, 0.10)
|
| 484 |
-
if image_quality is not None:
|
| 485 |
-
if image_quality.brightness < 8 or image_quality.brightness > 247:
|
| 486 |
-
score -= 0.10
|
| 487 |
-
if image_quality.contrast > 8 and image_quality.entropy > 2:
|
| 488 |
-
score += 0.05
|
| 489 |
-
if canonical in DEFAULT_HYPERNYMS:
|
| 490 |
-
score += 0.05
|
| 491 |
-
return max(0.0, min(1.0, score))
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
def merge_vlm_decision(
|
| 495 |
-
cheap: ParentCleanDecision,
|
| 496 |
-
vlm_payload: dict[str, Any] | None,
|
| 497 |
-
vlm_can_rescue: bool = False,
|
| 498 |
-
) -> ParentCleanDecision:
|
| 499 |
-
if not vlm_payload:
|
| 500 |
-
return cheap
|
| 501 |
-
reasons = list(cheap.reasons)
|
| 502 |
-
canonical = normalize_text(vlm_payload.get("canonical_parent") or cheap.canonical_text)
|
| 503 |
-
if canonical:
|
| 504 |
-
canonical = canonicalize_parent_text(canonical)
|
| 505 |
-
hypernyms = tuple(
|
| 506 |
-
normalize_text(value)
|
| 507 |
-
for value in vlm_payload.get("hypernyms", cheap.hypernyms)
|
| 508 |
-
if normalize_text(value)
|
| 509 |
-
)
|
| 510 |
-
quality_score = float(vlm_payload.get("quality_score", cheap.quality_score))
|
| 511 |
-
keep_payload = vlm_payload.get("keep")
|
| 512 |
-
if keep_payload is False:
|
| 513 |
-
reasons.append("vlm_reject")
|
| 514 |
-
keep = False
|
| 515 |
-
elif keep_payload is True and vlm_can_rescue:
|
| 516 |
-
keep = quality_score >= 0.45 and bool(canonical)
|
| 517 |
-
else:
|
| 518 |
-
keep = cheap.keep and keep_payload is not False
|
| 519 |
-
reject_reason = normalize_text(vlm_payload.get("reject_reason") or "")
|
| 520 |
-
if reject_reason:
|
| 521 |
-
reasons.append(f"vlm:{reject_reason}")
|
| 522 |
-
return ParentCleanDecision(
|
| 523 |
-
original_text=cheap.original_text,
|
| 524 |
-
canonical_text=canonical,
|
| 525 |
-
keep=keep,
|
| 526 |
-
quality_score=max(0.0, min(1.0, quality_score)),
|
| 527 |
-
reasons=tuple(dict.fromkeys(reasons)),
|
| 528 |
-
hypernyms=hypernyms,
|
| 529 |
-
image_quality=cheap.image_quality,
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
def expand_parent_texts(decision: ParentCleanDecision, add_hypernyms: bool) -> tuple[str, ...]:
|
| 534 |
-
if not decision.keep:
|
| 535 |
-
return ()
|
| 536 |
-
values = [decision.canonical_text]
|
| 537 |
-
if add_hypernyms:
|
| 538 |
-
values.extend(decision.hypernyms)
|
| 539 |
-
return tuple(dict.fromkeys(value for value in values if value))
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
def _border_pixels(gray: np.ndarray) -> np.ndarray:
|
| 543 |
-
if gray.ndim != 2 or min(gray.shape) < 4:
|
| 544 |
-
return np.asarray([], dtype=np.float32)
|
| 545 |
-
width = max(1, min(gray.shape) // 16)
|
| 546 |
-
top = gray[:width, :].reshape(-1)
|
| 547 |
-
bottom = gray[-width:, :].reshape(-1)
|
| 548 |
-
left = gray[:, :width].reshape(-1)
|
| 549 |
-
right = gray[:, -width:].reshape(-1)
|
| 550 |
-
return np.concatenate([top, bottom, left, right])
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
def finite_float(value: float) -> float:
|
| 554 |
-
return value if math.isfinite(value) else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/grit_webdataset.py
DELETED
|
@@ -1,133 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import copy
|
| 4 |
-
import glob
|
| 5 |
-
import hashlib
|
| 6 |
-
import random
|
| 7 |
-
from collections.abc import Iterator, Sequence
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Any
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import webdataset as wds
|
| 13 |
-
from PIL import Image
|
| 14 |
-
from torch.utils.data import IterableDataset, get_worker_info
|
| 15 |
-
|
| 16 |
-
from hyper3_clip.data.transforms import build_train_transform
|
| 17 |
-
from hyper3_clip.training.distributed import get_rank, get_world_size
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
PART_SAMPLING_MODES = {"random_one", "all"}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ProcessedGritDataset(IterableDataset):
|
| 24 |
-
"""Reader for official HyCoCLIP processed GRIT shards."""
|
| 25 |
-
|
| 26 |
-
def __init__(
|
| 27 |
-
self,
|
| 28 |
-
tarfiles: Sequence[str],
|
| 29 |
-
image_size: int,
|
| 30 |
-
seed: int,
|
| 31 |
-
shuffle_buffer: int = 4000,
|
| 32 |
-
part_sampling: str = "random_one",
|
| 33 |
-
max_parts: int | None = None,
|
| 34 |
-
train_transform: str = "wide_random_crop",
|
| 35 |
-
image_normalization: str = "imagenet",
|
| 36 |
-
deterministic_transforms: bool = False,
|
| 37 |
-
) -> None:
|
| 38 |
-
self.tarfiles = _expand_tarfiles(tarfiles)
|
| 39 |
-
if not self.tarfiles:
|
| 40 |
-
raise FileNotFoundError(f"No GRIT processed shards matched {tarfiles!r}")
|
| 41 |
-
rank = get_rank()
|
| 42 |
-
world_size = get_world_size()
|
| 43 |
-
self.tarfiles = self.tarfiles[rank::world_size]
|
| 44 |
-
self.shuffle_buffer = shuffle_buffer
|
| 45 |
-
self.seed = seed
|
| 46 |
-
if part_sampling not in PART_SAMPLING_MODES:
|
| 47 |
-
raise ValueError(f"part_sampling must be one of {sorted(PART_SAMPLING_MODES)}, got {part_sampling!r}")
|
| 48 |
-
if max_parts is not None and max_parts <= 0:
|
| 49 |
-
raise ValueError("max_parts must be positive when set")
|
| 50 |
-
self.part_sampling = part_sampling
|
| 51 |
-
self.max_parts = max_parts
|
| 52 |
-
self.deterministic_transforms = deterministic_transforms
|
| 53 |
-
self.transform = build_train_transform(image_size, preset=train_transform, normalization=image_normalization)
|
| 54 |
-
|
| 55 |
-
def __iter__(self) -> Iterator[dict[str, Any]]:
|
| 56 |
-
worker = get_worker_info()
|
| 57 |
-
worker_id = worker.id if worker is not None else 0
|
| 58 |
-
shuffle_rng = random.Random(self.seed + get_rank() * 1_000_003 + worker_id)
|
| 59 |
-
part_rng = random.Random(self.seed + 31_415_926 + get_rank() * 1_000_003 + worker_id)
|
| 60 |
-
pipeline: Any = wds.DataPipeline(
|
| 61 |
-
wds.SimpleShardList(self.tarfiles, seed=self.seed),
|
| 62 |
-
wds.split_by_worker,
|
| 63 |
-
wds.tarfile_to_samples(),
|
| 64 |
-
wds.shuffle(self.shuffle_buffer, initial=self.shuffle_buffer, rng=shuffle_rng),
|
| 65 |
-
wds.decode("pil", handler=wds.warn_and_continue),
|
| 66 |
-
)
|
| 67 |
-
while True:
|
| 68 |
-
pipeline_copy = copy.deepcopy(pipeline)
|
| 69 |
-
for sample in pipeline_copy:
|
| 70 |
-
yield self._decode_sample(sample, part_rng)
|
| 71 |
-
|
| 72 |
-
def _decode_sample(self, sample: dict[str, Any], rng: random.Random) -> dict[str, Any]:
|
| 73 |
-
num_parents = int(_as_text(sample["numparents.txt"]))
|
| 74 |
-
parent_indices = self._select_parent_indices(num_parents, rng)
|
| 75 |
-
parent_keys = [f"parent{parent_index:03d}" for parent_index in parent_indices]
|
| 76 |
-
sample_key = _as_text(sample.get("__key__", ""))
|
| 77 |
-
return {
|
| 78 |
-
"image": self._transform_image(sample["child.jpg"], sample_key, "child"),
|
| 79 |
-
"caption": _as_text(sample["child.txt"]),
|
| 80 |
-
"part_images": [
|
| 81 |
-
self._transform_image(sample[f"{parent_key}.jpg"], sample_key, parent_key) for parent_key in parent_keys
|
| 82 |
-
],
|
| 83 |
-
"part_texts": [_as_text(sample[f"{parent_key}.txt"]) for parent_key in parent_keys],
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
def _select_parent_indices(self, num_parents: int, rng: random.Random) -> list[int]:
|
| 87 |
-
if self.part_sampling == "random_one":
|
| 88 |
-
return [rng.randrange(num_parents)]
|
| 89 |
-
parent_indices = list(range(num_parents))
|
| 90 |
-
if self.max_parts is not None and len(parent_indices) > self.max_parts:
|
| 91 |
-
parent_indices = sorted(rng.sample(parent_indices, k=self.max_parts))
|
| 92 |
-
return parent_indices
|
| 93 |
-
|
| 94 |
-
def _transform_image(self, value: Any, sample_key: str, role: str) -> torch.Tensor:
|
| 95 |
-
image = _as_image(value)
|
| 96 |
-
if not self.deterministic_transforms:
|
| 97 |
-
return self.transform(image)
|
| 98 |
-
transform_seed = _stable_seed(self.seed, sample_key, role)
|
| 99 |
-
python_random_state = random.getstate()
|
| 100 |
-
try:
|
| 101 |
-
random.seed(transform_seed)
|
| 102 |
-
with torch.random.fork_rng(devices=[]):
|
| 103 |
-
torch.manual_seed(transform_seed)
|
| 104 |
-
return self.transform(image)
|
| 105 |
-
finally:
|
| 106 |
-
random.setstate(python_random_state)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def _expand_tarfiles(tarfiles: Sequence[str]) -> list[str]:
|
| 110 |
-
expanded: list[str] = []
|
| 111 |
-
for pattern in tarfiles:
|
| 112 |
-
matches = sorted(glob.glob(pattern))
|
| 113 |
-
expanded.extend(matches if matches else [pattern])
|
| 114 |
-
return [str(Path(path)) for path in expanded]
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def _as_text(value: Any) -> str:
|
| 118 |
-
return value.decode("utf-8") if isinstance(value, bytes) else str(value)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def _as_image(value: Any) -> Image.Image:
|
| 122 |
-
if not isinstance(value, Image.Image):
|
| 123 |
-
raise TypeError(f"Expected PIL image from WebDataset decode, got {type(value)!r}")
|
| 124 |
-
return value.convert("RGB")
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def _stable_seed(seed: int, *parts: str) -> int:
|
| 128 |
-
digest = hashlib.blake2b(digest_size=8)
|
| 129 |
-
digest.update(str(seed).encode("utf-8"))
|
| 130 |
-
for part in parts:
|
| 131 |
-
digest.update(b"\0")
|
| 132 |
-
digest.update(part.encode("utf-8"))
|
| 133 |
-
return int.from_bytes(digest.digest(), byteorder="big", signed=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/manifest_dataset.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import json
|
| 4 |
-
import random
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from PIL import Image
|
| 10 |
-
from torch.utils.data import Dataset, get_worker_info
|
| 11 |
-
|
| 12 |
-
from hyper3_clip.data.collators import collate_grounded as collate_grounded
|
| 13 |
-
from hyper3_clip.data.transforms import build_train_transform
|
| 14 |
-
from hyper3_clip.data.types import GroundedParent, GroundedRecord
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
__all__ = ["GroundedManifestDataset", "collate_grounded"]
|
| 18 |
-
PART_SAMPLING_MODES = {"random_one", "all"}
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class GroundedManifestDataset(Dataset):
|
| 22 |
-
"""Manifest dataset with one full image/caption and one or more grounded parents per row."""
|
| 23 |
-
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
manifests: list[str] | str | Path,
|
| 27 |
-
image_size: int,
|
| 28 |
-
seed: int,
|
| 29 |
-
manifest_weights: list[float] | None = None,
|
| 30 |
-
part_sampling: str = "random_one",
|
| 31 |
-
max_parts: int | None = None,
|
| 32 |
-
train_transform: str = "wide_random_crop",
|
| 33 |
-
image_normalization: str = "imagenet",
|
| 34 |
-
) -> None:
|
| 35 |
-
manifest_paths = [str(manifests)] if isinstance(manifests, str | Path) else manifests
|
| 36 |
-
self.records: list[GroundedRecord] = []
|
| 37 |
-
source_records: list[list[GroundedRecord]] = []
|
| 38 |
-
for manifest_path in manifest_paths:
|
| 39 |
-
rows: list[GroundedRecord] = []
|
| 40 |
-
with Path(manifest_path).open("r", encoding="utf-8") as handle:
|
| 41 |
-
for line in handle:
|
| 42 |
-
if line.strip():
|
| 43 |
-
rows.append(GroundedRecord.from_json(json.loads(line)))
|
| 44 |
-
source_records.append(rows)
|
| 45 |
-
|
| 46 |
-
if manifest_weights is None:
|
| 47 |
-
for rows in source_records:
|
| 48 |
-
self.records.extend(rows)
|
| 49 |
-
else:
|
| 50 |
-
if len(manifest_weights) != len(source_records):
|
| 51 |
-
raise ValueError("manifest_weights must match manifests length")
|
| 52 |
-
max_len = max(len(rows) for rows in source_records if rows)
|
| 53 |
-
for rows, weight in zip(source_records, manifest_weights):
|
| 54 |
-
if not rows or weight <= 0.0:
|
| 55 |
-
continue
|
| 56 |
-
target_len = max(1, int(round(max_len * weight)))
|
| 57 |
-
for idx in range(target_len):
|
| 58 |
-
self.records.append(rows[idx % len(rows)])
|
| 59 |
-
|
| 60 |
-
self.seed = seed
|
| 61 |
-
if part_sampling not in PART_SAMPLING_MODES:
|
| 62 |
-
raise ValueError(f"part_sampling must be one of {sorted(PART_SAMPLING_MODES)}, got {part_sampling!r}")
|
| 63 |
-
if max_parts is not None and max_parts <= 0:
|
| 64 |
-
raise ValueError("max_parts must be positive when set")
|
| 65 |
-
self.part_sampling = part_sampling
|
| 66 |
-
self.max_parts = max_parts
|
| 67 |
-
self.transform = build_train_transform(image_size, preset=train_transform, normalization=image_normalization)
|
| 68 |
-
|
| 69 |
-
def __len__(self) -> int:
|
| 70 |
-
return len(self.records)
|
| 71 |
-
|
| 72 |
-
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 73 |
-
record = self.records[index]
|
| 74 |
-
parents = self._select_parents(index, record.parents)
|
| 75 |
-
return {
|
| 76 |
-
"image": self._load_image(record.image_path),
|
| 77 |
-
"part_images": [self._load_parent_image(record.image_path, parent) for parent in parents],
|
| 78 |
-
"caption": record.caption,
|
| 79 |
-
"part_texts": [parent.text for parent in parents],
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
def _select_parents(self, index: int, parents: tuple[GroundedParent, ...]) -> tuple[GroundedParent, ...]:
|
| 83 |
-
if self.part_sampling == "all":
|
| 84 |
-
if self.max_parts is None or len(parents) <= self.max_parts:
|
| 85 |
-
return parents
|
| 86 |
-
worker = get_worker_info()
|
| 87 |
-
worker_id = worker.id if worker is not None else 0
|
| 88 |
-
rng = random.Random(self.seed + index + 1_000_003 * worker_id)
|
| 89 |
-
parent_indices = sorted(rng.sample(range(len(parents)), k=self.max_parts))
|
| 90 |
-
return tuple(parents[parent_index] for parent_index in parent_indices)
|
| 91 |
-
worker = get_worker_info()
|
| 92 |
-
worker_id = worker.id if worker is not None else 0
|
| 93 |
-
rng = random.Random(self.seed + index + 1_000_003 * worker_id)
|
| 94 |
-
return (parents[rng.randrange(len(parents))],)
|
| 95 |
-
|
| 96 |
-
def _load_image(self, path: Path) -> torch.Tensor:
|
| 97 |
-
with Image.open(path) as image:
|
| 98 |
-
return self.transform(image.convert("RGB"))
|
| 99 |
-
|
| 100 |
-
def _load_parent_image(self, image_path: Path, parent: GroundedParent) -> torch.Tensor:
|
| 101 |
-
source_path = parent.image_path or image_path
|
| 102 |
-
with Image.open(source_path) as image:
|
| 103 |
-
rgb = image.convert("RGB")
|
| 104 |
-
if parent.bbox is not None:
|
| 105 |
-
rgb = _crop_bbox(rgb, parent.bbox)
|
| 106 |
-
return self.transform(rgb)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def _crop_bbox(image: Image.Image, bbox: tuple[float, float, float, float]) -> Image.Image:
|
| 110 |
-
width, height = image.size
|
| 111 |
-
left, top, right, bottom = bbox
|
| 112 |
-
crop_box = (
|
| 113 |
-
max(0, min(width, int(round(left)))),
|
| 114 |
-
max(0, min(height, int(round(top)))),
|
| 115 |
-
max(0, min(width, int(round(right)))),
|
| 116 |
-
max(0, min(height, int(round(bottom)))),
|
| 117 |
-
)
|
| 118 |
-
if crop_box[2] <= crop_box[0] or crop_box[3] <= crop_box[1]:
|
| 119 |
-
return image
|
| 120 |
-
return image.crop(crop_box)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/mixed_dataset.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import random
|
| 4 |
-
from collections.abc import Iterator
|
| 5 |
-
from typing import Any
|
| 6 |
-
|
| 7 |
-
from torch.utils.data import Dataset, IterableDataset, get_worker_info
|
| 8 |
-
|
| 9 |
-
from hyper3_clip.training.distributed import get_rank, get_world_size
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class MixedGroundedIterableDataset(IterableDataset):
|
| 13 |
-
"""Infinite stream that mixes a primary stream with a finite grounded dataset.
|
| 14 |
-
|
| 15 |
-
This is intended for cleaned processed-GRIT plus explicit taxonomy hierarchy
|
| 16 |
-
manifests. The primary stream remains the pacing dataset, while auxiliary
|
| 17 |
-
examples are sampled with a fixed probability.
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
def __init__(
|
| 21 |
-
self,
|
| 22 |
-
primary: IterableDataset,
|
| 23 |
-
auxiliary: Dataset,
|
| 24 |
-
auxiliary_probability: float,
|
| 25 |
-
seed: int,
|
| 26 |
-
) -> None:
|
| 27 |
-
if not 0.0 <= auxiliary_probability <= 1.0:
|
| 28 |
-
raise ValueError("auxiliary_probability must be in [0, 1]")
|
| 29 |
-
if len(auxiliary) == 0:
|
| 30 |
-
raise ValueError("auxiliary dataset must not be empty")
|
| 31 |
-
self.primary = primary
|
| 32 |
-
self.auxiliary = auxiliary
|
| 33 |
-
self.auxiliary_probability = auxiliary_probability
|
| 34 |
-
self.seed = seed
|
| 35 |
-
|
| 36 |
-
def __iter__(self) -> Iterator[dict[str, Any]]:
|
| 37 |
-
worker = get_worker_info()
|
| 38 |
-
worker_id = worker.id if worker is not None else 0
|
| 39 |
-
num_workers = worker.num_workers if worker is not None else 1
|
| 40 |
-
rank = get_rank()
|
| 41 |
-
world_size = get_world_size()
|
| 42 |
-
rng = random.Random(self.seed + 1_000_003 * rank + 9_176 * worker_id)
|
| 43 |
-
primary_iter = iter(self.primary)
|
| 44 |
-
auxiliary_iter = self._iter_auxiliary_indices(rng, rank, world_size, worker_id, num_workers)
|
| 45 |
-
|
| 46 |
-
while True:
|
| 47 |
-
if rng.random() < self.auxiliary_probability:
|
| 48 |
-
yield self.auxiliary[next(auxiliary_iter)]
|
| 49 |
-
else:
|
| 50 |
-
yield next(primary_iter)
|
| 51 |
-
|
| 52 |
-
def _iter_auxiliary_indices(
|
| 53 |
-
self,
|
| 54 |
-
rng: random.Random,
|
| 55 |
-
rank: int,
|
| 56 |
-
world_size: int,
|
| 57 |
-
worker_id: int,
|
| 58 |
-
num_workers: int,
|
| 59 |
-
) -> Iterator[int]:
|
| 60 |
-
indices = list(range(len(self.auxiliary)))
|
| 61 |
-
indices = indices[rank::world_size]
|
| 62 |
-
indices = indices[worker_id::num_workers]
|
| 63 |
-
if not indices:
|
| 64 |
-
indices = list(range(len(self.auxiliary)))
|
| 65 |
-
while True:
|
| 66 |
-
shuffled = list(indices)
|
| 67 |
-
rng.shuffle(shuffled)
|
| 68 |
-
yield from shuffled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/transforms.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from torchvision import transforms
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 7 |
-
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 8 |
-
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 9 |
-
CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 10 |
-
SIGLIP_MEAN = (0.5, 0.5, 0.5)
|
| 11 |
-
SIGLIP_STD = (0.5, 0.5, 0.5)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def normalization_stats(normalization: str) -> tuple[tuple[float, float, float], tuple[float, float, float]]:
|
| 15 |
-
if normalization == "imagenet":
|
| 16 |
-
return IMAGENET_MEAN, IMAGENET_STD
|
| 17 |
-
if normalization == "clip":
|
| 18 |
-
return CLIP_MEAN, CLIP_STD
|
| 19 |
-
if normalization == "siglip":
|
| 20 |
-
return SIGLIP_MEAN, SIGLIP_STD
|
| 21 |
-
raise ValueError("normalization must be one of 'imagenet', 'clip', or 'siglip'")
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def build_train_transform(
|
| 25 |
-
image_size: int,
|
| 26 |
-
preset: str = "wide_random_crop",
|
| 27 |
-
normalization: str = "imagenet",
|
| 28 |
-
) -> transforms.Compose:
|
| 29 |
-
if preset == "wide_random_crop":
|
| 30 |
-
steps = [
|
| 31 |
-
transforms.RandomResizedCrop(
|
| 32 |
-
size=image_size,
|
| 33 |
-
scale=(0.5, 1.0),
|
| 34 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 35 |
-
),
|
| 36 |
-
transforms.ToTensor(),
|
| 37 |
-
]
|
| 38 |
-
elif preset == "wide_random_crop_light_color":
|
| 39 |
-
steps = [
|
| 40 |
-
transforms.RandomResizedCrop(
|
| 41 |
-
size=image_size,
|
| 42 |
-
scale=(0.5, 1.0),
|
| 43 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 44 |
-
),
|
| 45 |
-
transforms.RandomApply(
|
| 46 |
-
[
|
| 47 |
-
transforms.ColorJitter(
|
| 48 |
-
brightness=0.2,
|
| 49 |
-
contrast=0.2,
|
| 50 |
-
saturation=0.2,
|
| 51 |
-
hue=0.05,
|
| 52 |
-
)
|
| 53 |
-
],
|
| 54 |
-
p=0.4,
|
| 55 |
-
),
|
| 56 |
-
transforms.ToTensor(),
|
| 57 |
-
]
|
| 58 |
-
elif preset == "medium_random_crop":
|
| 59 |
-
steps = [
|
| 60 |
-
transforms.RandomResizedCrop(
|
| 61 |
-
size=image_size,
|
| 62 |
-
scale=(0.6, 1.0),
|
| 63 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 64 |
-
),
|
| 65 |
-
transforms.ToTensor(),
|
| 66 |
-
]
|
| 67 |
-
elif preset == "center_crop":
|
| 68 |
-
steps = [
|
| 69 |
-
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 70 |
-
transforms.CenterCrop(image_size),
|
| 71 |
-
transforms.ToTensor(),
|
| 72 |
-
]
|
| 73 |
-
elif preset == "tight_crop_color_jitter_gray":
|
| 74 |
-
steps = [
|
| 75 |
-
transforms.RandomResizedCrop(
|
| 76 |
-
size=image_size,
|
| 77 |
-
scale=(0.8, 1.0),
|
| 78 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 79 |
-
),
|
| 80 |
-
transforms.RandomApply(
|
| 81 |
-
[
|
| 82 |
-
transforms.ColorJitter(
|
| 83 |
-
brightness=0.4,
|
| 84 |
-
contrast=0.4,
|
| 85 |
-
saturation=0.4,
|
| 86 |
-
hue=0.1,
|
| 87 |
-
)
|
| 88 |
-
],
|
| 89 |
-
p=0.8,
|
| 90 |
-
),
|
| 91 |
-
transforms.RandomGrayscale(p=0.2),
|
| 92 |
-
transforms.ToTensor(),
|
| 93 |
-
]
|
| 94 |
-
else:
|
| 95 |
-
raise ValueError(
|
| 96 |
-
f"Unsupported train transform preset {preset!r}; "
|
| 97 |
-
"expected 'wide_random_crop', 'wide_random_crop_light_color', "
|
| 98 |
-
"'medium_random_crop', 'tight_crop_color_jitter_gray', or 'center_crop'"
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
mean, std = normalization_stats(normalization)
|
| 102 |
-
return transforms.Compose([*steps, transforms.Normalize(mean=mean, std=std)])
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def build_eval_transform(image_size: int, normalization: str = "imagenet") -> transforms.Compose:
|
| 106 |
-
mean, std = normalization_stats(normalization)
|
| 107 |
-
return transforms.Compose(
|
| 108 |
-
[
|
| 109 |
-
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 110 |
-
transforms.CenterCrop(image_size),
|
| 111 |
-
transforms.ToTensor(),
|
| 112 |
-
transforms.Normalize(mean=mean, std=std),
|
| 113 |
-
]
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def build_retrieval_transform(image_size: int, normalization: str = "imagenet") -> transforms.Compose:
|
| 118 |
-
mean, std = normalization_stats(normalization)
|
| 119 |
-
return transforms.Compose(
|
| 120 |
-
[
|
| 121 |
-
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
|
| 122 |
-
transforms.ToTensor(),
|
| 123 |
-
transforms.Normalize(mean=mean, std=std),
|
| 124 |
-
]
|
| 125 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/data/types.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Any
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass(frozen=True)
|
| 9 |
-
class GroundedParent:
|
| 10 |
-
text: str
|
| 11 |
-
image_path: Path | None = None
|
| 12 |
-
bbox: tuple[float, float, float, float] | None = None
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@dataclass(frozen=True)
|
| 16 |
-
class GroundedRecord:
|
| 17 |
-
image_path: Path
|
| 18 |
-
caption: str
|
| 19 |
-
parents: tuple[GroundedParent, ...]
|
| 20 |
-
|
| 21 |
-
@classmethod
|
| 22 |
-
def from_json(cls, payload: dict[str, Any]) -> "GroundedRecord":
|
| 23 |
-
parents_payload = payload.get("parents")
|
| 24 |
-
if parents_payload is None:
|
| 25 |
-
parents_payload = [
|
| 26 |
-
{
|
| 27 |
-
"text": payload.get("box_text", ""),
|
| 28 |
-
"image_path": payload.get("box_image_path"),
|
| 29 |
-
"bbox": payload.get("bbox"),
|
| 30 |
-
}
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
parents: list[GroundedParent] = []
|
| 34 |
-
for parent_payload in parents_payload:
|
| 35 |
-
text = str(parent_payload.get("text") or parent_payload.get("box_text") or "").strip()
|
| 36 |
-
image_path = parent_payload.get("image_path") or parent_payload.get("box_image_path")
|
| 37 |
-
bbox_payload = parent_payload.get("bbox")
|
| 38 |
-
bbox = None
|
| 39 |
-
if bbox_payload is not None:
|
| 40 |
-
if len(bbox_payload) != 4:
|
| 41 |
-
raise ValueError(f"Expected four bbox values, got {bbox_payload!r}")
|
| 42 |
-
bbox = tuple(float(value) for value in bbox_payload)
|
| 43 |
-
parents.append(GroundedParent(text=text, image_path=Path(image_path) if image_path else None, bbox=bbox))
|
| 44 |
-
|
| 45 |
-
if not parents:
|
| 46 |
-
raise ValueError("Grounded records must include at least one parent")
|
| 47 |
-
|
| 48 |
-
return cls(image_path=Path(payload["image_path"]), caption=str(payload["caption"]), parents=tuple(parents))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/evaluation/__init__.py
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
from hyper3_clip.evaluation.classification import evaluate_imagenet_zero_shot
|
| 2 |
-
from hyper3_clip.evaluation.hierarchical import evaluate_imagenet_hierarchical
|
| 3 |
-
from hyper3_clip.evaluation.pep import PEPEntailmentDataset, evaluate_pep_entailment
|
| 4 |
-
from hyper3_clip.evaluation.retrieval import (
|
| 5 |
-
CocoCaptionRetrieval,
|
| 6 |
-
CocoKarpathyCaptionRetrieval,
|
| 7 |
-
Flickr30kCaptionRetrieval,
|
| 8 |
-
evaluate_caption_retrieval,
|
| 9 |
-
)
|
| 10 |
-
|
| 11 |
-
__all__ = [
|
| 12 |
-
"CocoCaptionRetrieval",
|
| 13 |
-
"CocoKarpathyCaptionRetrieval",
|
| 14 |
-
"Flickr30kCaptionRetrieval",
|
| 15 |
-
"PEPEntailmentDataset",
|
| 16 |
-
"evaluate_caption_retrieval",
|
| 17 |
-
"evaluate_imagenet_hierarchical",
|
| 18 |
-
"evaluate_imagenet_zero_shot",
|
| 19 |
-
"evaluate_pep_entailment",
|
| 20 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/evaluation/classification.py
DELETED
|
@@ -1,105 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch.utils.data import DataLoader
|
| 7 |
-
from torch.utils.data import Subset
|
| 8 |
-
from torchvision import datasets
|
| 9 |
-
|
| 10 |
-
from hyper3_clip.data.transforms import build_eval_transform
|
| 11 |
-
|
| 12 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
IMAGENET_PROMPTS = (
|
| 16 |
-
"i took a picture : itap of a {}.",
|
| 17 |
-
"pics : a bad photo of the {}.",
|
| 18 |
-
"pics : a origami {}.",
|
| 19 |
-
"pics : a photo of the large {}.",
|
| 20 |
-
"pics : a {} in a video game.",
|
| 21 |
-
"pics : art of the {}.",
|
| 22 |
-
"pics : a photo of the small {}.",
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@torch.inference_mode()
|
| 27 |
-
def evaluate_imagenet_zero_shot(
|
| 28 |
-
model: Hyper3CLIP,
|
| 29 |
-
imagenet_val_root: str | Path,
|
| 30 |
-
device: torch.device,
|
| 31 |
-
batch_size: int = 128,
|
| 32 |
-
image_size: int = 224,
|
| 33 |
-
max_text_length: int = 77,
|
| 34 |
-
max_items: int | None = None,
|
| 35 |
-
prompts: tuple[str, ...] = IMAGENET_PROMPTS,
|
| 36 |
-
) -> dict[str, float]:
|
| 37 |
-
model.eval()
|
| 38 |
-
dataset = datasets.ImageFolder(str(imagenet_val_root), transform=build_eval_transform(image_size))
|
| 39 |
-
class_names = _imagenet_prompt_names(dataset.classes, Path(imagenet_val_root))
|
| 40 |
-
classifier = _build_text_classifier(model, class_names, prompts, device, max_text_length)
|
| 41 |
-
eval_dataset = Subset(dataset, range(min(max_items, len(dataset)))) if max_items is not None else dataset
|
| 42 |
-
loader = DataLoader(eval_dataset, batch_size=batch_size, num_workers=4, pin_memory=device.type == "cuda")
|
| 43 |
-
correct = 0
|
| 44 |
-
total = 0
|
| 45 |
-
per_class_correct = torch.zeros(len(dataset.classes), dtype=torch.float64)
|
| 46 |
-
per_class_total = torch.zeros(len(dataset.classes), dtype=torch.float64)
|
| 47 |
-
for images, targets in loader:
|
| 48 |
-
images = images.to(device, non_blocking=True)
|
| 49 |
-
targets = targets.to(device, non_blocking=True)
|
| 50 |
-
predictions = model.similarity_scores(model.encode_image(images), classifier).argmax(dim=1)
|
| 51 |
-
matches = predictions == targets
|
| 52 |
-
correct += int(matches.sum().item())
|
| 53 |
-
total += targets.numel()
|
| 54 |
-
per_class_correct.scatter_add_(0, targets.cpu(), matches.cpu().double())
|
| 55 |
-
per_class_total.scatter_add_(0, targets.cpu(), torch.ones_like(targets.cpu(), dtype=torch.float64))
|
| 56 |
-
|
| 57 |
-
observed_classes = per_class_total > 0
|
| 58 |
-
mean_per_class = (per_class_correct[observed_classes] / per_class_total[observed_classes]).mean().item()
|
| 59 |
-
top1 = correct / max(total, 1)
|
| 60 |
-
return {"top1": top1, "top1_pct": 100.0 * top1, "mean_per_class_acc_pct": 100.0 * mean_per_class}
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def _build_text_classifier(
|
| 64 |
-
model: Hyper3CLIP,
|
| 65 |
-
class_names: list[str],
|
| 66 |
-
prompts: tuple[str, ...],
|
| 67 |
-
device: torch.device,
|
| 68 |
-
max_text_length: int,
|
| 69 |
-
) -> torch.Tensor:
|
| 70 |
-
class_embeddings: list[torch.Tensor] = []
|
| 71 |
-
for class_name in class_names:
|
| 72 |
-
readable_name = class_name.replace("_", " ")
|
| 73 |
-
texts = [prompt.format(readable_name) for prompt in prompts]
|
| 74 |
-
tokenized = model.tokenizer(
|
| 75 |
-
texts,
|
| 76 |
-
padding=True,
|
| 77 |
-
truncation=True,
|
| 78 |
-
max_length=max_text_length,
|
| 79 |
-
return_tensors="pt",
|
| 80 |
-
).to(device)
|
| 81 |
-
attention_mask = (
|
| 82 |
-
tokenized.attention_mask if "attention_mask" in tokenized else torch.ones_like(tokenized.input_ids)
|
| 83 |
-
)
|
| 84 |
-
tangent = model.encode_text(tokenized.input_ids, attention_mask, project=False).float().mean(dim=0, keepdim=True)
|
| 85 |
-
class_embeddings.append(model.project_text_features(tangent).squeeze(0))
|
| 86 |
-
return torch.stack(class_embeddings, dim=0)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def _imagenet_prompt_names(class_names: list[str], imagenet_val_root: Path) -> list[str]:
|
| 90 |
-
if len(class_names) == 1000 and all(_looks_like_wnid(class_name) for class_name in class_names):
|
| 91 |
-
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
| 92 |
-
|
| 93 |
-
label_index = imagenet_val_root / "imagenet_label_to_wnid.tsv"
|
| 94 |
-
if label_index.exists():
|
| 95 |
-
wnid_to_label: dict[str, int] = {}
|
| 96 |
-
for line in label_index.read_text(encoding="utf-8").splitlines():
|
| 97 |
-
label, wnid = line.split("\t", maxsplit=1)
|
| 98 |
-
wnid_to_label[wnid] = int(label)
|
| 99 |
-
return [_IMAGENET_CATEGORIES[wnid_to_label[class_name]] for class_name in class_names]
|
| 100 |
-
return list(_IMAGENET_CATEGORIES)
|
| 101 |
-
return class_names
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def _looks_like_wnid(class_name: str) -> bool:
|
| 105 |
-
return len(class_name) == 9 and class_name.startswith("n") and class_name[1:].isdigit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/evaluation/hierarchical.py
DELETED
|
@@ -1,118 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import csv
|
| 4 |
-
import pickle
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
import networkx as nx
|
| 8 |
-
import torch
|
| 9 |
-
from torch.utils.data import DataLoader, Subset
|
| 10 |
-
from torchvision import datasets
|
| 11 |
-
|
| 12 |
-
from hyper3_clip.data.transforms import build_eval_transform
|
| 13 |
-
from hyper3_clip.evaluation.classification import IMAGENET_PROMPTS, _build_text_classifier, _imagenet_prompt_names, _looks_like_wnid
|
| 14 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@torch.inference_mode()
|
| 18 |
-
def evaluate_imagenet_hierarchical(
|
| 19 |
-
model: Hyper3CLIP,
|
| 20 |
-
imagenet_val_root: str | Path,
|
| 21 |
-
assets_root: str | Path,
|
| 22 |
-
device: torch.device,
|
| 23 |
-
batch_size: int = 128,
|
| 24 |
-
image_size: int = 224,
|
| 25 |
-
max_text_length: int = 77,
|
| 26 |
-
max_items: int | None = None,
|
| 27 |
-
prompts: tuple[str, ...] = IMAGENET_PROMPTS,
|
| 28 |
-
) -> dict[str, float]:
|
| 29 |
-
model.eval()
|
| 30 |
-
imagenet_root = Path(imagenet_val_root)
|
| 31 |
-
dataset = datasets.ImageFolder(str(imagenet_root), transform=build_eval_transform(image_size))
|
| 32 |
-
class_names = _imagenet_prompt_names(dataset.classes, imagenet_root)
|
| 33 |
-
classifier = _build_text_classifier(model, class_names, prompts, device, max_text_length)
|
| 34 |
-
eval_dataset = Subset(dataset, range(min(max_items, len(dataset)))) if max_items is not None else dataset
|
| 35 |
-
loader = DataLoader(eval_dataset, batch_size=batch_size, num_workers=4, pin_memory=device.type == "cuda")
|
| 36 |
-
|
| 37 |
-
assets_path = Path(assets_root)
|
| 38 |
-
synsets_ordering = pickle.load((assets_path / "all_synsets.pkl").open("rb"))
|
| 39 |
-
ancestor_indices = pickle.load((assets_path / "all_ancestors_indices.pkl").open("rb"))
|
| 40 |
-
graph = _create_graph_from_edges(assets_path / "imagenet_isa.txt")
|
| 41 |
-
dataset_to_official = _dataset_to_official_indices(dataset.classes, imagenet_root, synsets_ordering).to(device)
|
| 42 |
-
|
| 43 |
-
totals = torch.zeros(5, dtype=torch.float64)
|
| 44 |
-
total_count = 0
|
| 45 |
-
for images, targets in loader:
|
| 46 |
-
images = images.to(device, non_blocking=True)
|
| 47 |
-
official_targets = dataset_to_official[targets.to(device, non_blocking=True)]
|
| 48 |
-
dataset_predictions = model.similarity_scores(model.encode_image(images), classifier).argmax(dim=1)
|
| 49 |
-
official_predictions = dataset_to_official[dataset_predictions]
|
| 50 |
-
batch_totals = _hierarchical_totals(
|
| 51 |
-
official_predictions.cpu().tolist(),
|
| 52 |
-
official_targets.cpu().tolist(),
|
| 53 |
-
ancestor_indices,
|
| 54 |
-
graph,
|
| 55 |
-
synsets_ordering,
|
| 56 |
-
)
|
| 57 |
-
totals += torch.tensor(batch_totals, dtype=torch.float64)
|
| 58 |
-
total_count += int(official_targets.numel())
|
| 59 |
-
|
| 60 |
-
averages = totals / max(total_count, 1)
|
| 61 |
-
return {
|
| 62 |
-
"tie": float(averages[0].item()),
|
| 63 |
-
"lca": float(averages[1].item()),
|
| 64 |
-
"jaccard": float(averages[2].item()),
|
| 65 |
-
"hierarchical_precision": float(averages[3].item()),
|
| 66 |
-
"hierarchical_recall": float(averages[4].item()),
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def _create_graph_from_edges(edge_file: Path) -> nx.DiGraph:
|
| 71 |
-
graph = nx.DiGraph()
|
| 72 |
-
with edge_file.open("r", encoding="utf-8") as handle:
|
| 73 |
-
reader = csv.reader(handle, delimiter=" ")
|
| 74 |
-
for parent, child in reader:
|
| 75 |
-
graph.add_edge(parent, child)
|
| 76 |
-
return graph
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def _dataset_to_official_indices(class_names: list[str], imagenet_val_root: Path, synsets_ordering: list[str]) -> torch.Tensor:
|
| 80 |
-
label_index = imagenet_val_root / "imagenet_label_to_wnid.tsv"
|
| 81 |
-
if label_index.exists():
|
| 82 |
-
wnid_to_label = {}
|
| 83 |
-
for line in label_index.read_text(encoding="utf-8").splitlines():
|
| 84 |
-
label, wnid = line.split("\t", maxsplit=1)
|
| 85 |
-
wnid_to_label[wnid] = int(label)
|
| 86 |
-
return torch.tensor([wnid_to_label[class_name] for class_name in class_names], dtype=torch.long)
|
| 87 |
-
if all(_looks_like_wnid(class_name) for class_name in class_names):
|
| 88 |
-
synset_to_label = {synset: label for label, synset in enumerate(synsets_ordering)}
|
| 89 |
-
return torch.tensor([synset_to_label[class_name] for class_name in class_names], dtype=torch.long)
|
| 90 |
-
return torch.arange(len(class_names), dtype=torch.long)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def _hierarchical_totals(
|
| 94 |
-
predicted_labels: list[int],
|
| 95 |
-
true_labels: list[int],
|
| 96 |
-
ancestor_indices: list[list[int]],
|
| 97 |
-
graph: nx.DiGraph,
|
| 98 |
-
synsets_ordering: list[str],
|
| 99 |
-
) -> tuple[float, float, float, float, float]:
|
| 100 |
-
undirected_graph = graph.to_undirected()
|
| 101 |
-
tree_induced_error = 0.0
|
| 102 |
-
least_common_ancestor = 0.0
|
| 103 |
-
jaccard = 0.0
|
| 104 |
-
hierarchical_precision = 0.0
|
| 105 |
-
hierarchical_recall = 0.0
|
| 106 |
-
for pred_label, true_label in zip(predicted_labels, true_labels):
|
| 107 |
-
pred_synset = synsets_ordering[pred_label]
|
| 108 |
-
true_synset = synsets_ordering[true_label]
|
| 109 |
-
pred_ancestors = set(ancestor_indices[pred_label])
|
| 110 |
-
true_ancestors = set(ancestor_indices[true_label])
|
| 111 |
-
intersection = pred_ancestors.intersection(true_ancestors)
|
| 112 |
-
union = pred_ancestors.union(true_ancestors)
|
| 113 |
-
tree_induced_error += nx.shortest_path_length(undirected_graph, source=pred_synset, target=true_synset)
|
| 114 |
-
least_common_ancestor += len(pred_ancestors) - len(intersection) + 1
|
| 115 |
-
jaccard += len(intersection) / len(union)
|
| 116 |
-
hierarchical_precision += len(intersection) / len(pred_ancestors)
|
| 117 |
-
hierarchical_recall += len(intersection) / len(true_ancestors)
|
| 118 |
-
return tree_induced_error, least_common_ancestor, jaccard, hierarchical_precision, hierarchical_recall
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/evaluation/pep.py
DELETED
|
@@ -1,462 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import csv
|
| 4 |
-
import hashlib
|
| 5 |
-
import json
|
| 6 |
-
import math
|
| 7 |
-
from urllib.parse import urlparse
|
| 8 |
-
import urllib.request
|
| 9 |
-
from dataclasses import dataclass
|
| 10 |
-
from io import BytesIO
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
from typing import Any
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from PIL import Image
|
| 16 |
-
from torch.utils.data import Dataset
|
| 17 |
-
|
| 18 |
-
from hyper3_clip.data.transforms import build_eval_transform
|
| 19 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 20 |
-
from hyper3_clip.models.losses import factor_oxy_angle
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
@dataclass(frozen=True)
|
| 24 |
-
class PEPSample:
|
| 25 |
-
image_id: str
|
| 26 |
-
image_path: Path | None
|
| 27 |
-
image_url: str | None
|
| 28 |
-
positive_captions: tuple[str, ...]
|
| 29 |
-
negative_captions: tuple[str, ...] = ()
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class PEPEntailmentDataset(Dataset):
|
| 33 |
-
def __init__(
|
| 34 |
-
self,
|
| 35 |
-
annotations_path: str | Path,
|
| 36 |
-
image_root: str | Path | None = None,
|
| 37 |
-
image_size: int = 224,
|
| 38 |
-
max_items: int | None = None,
|
| 39 |
-
image_cache_dir: str | Path | None = None,
|
| 40 |
-
allow_image_download: bool = False,
|
| 41 |
-
) -> None:
|
| 42 |
-
self.samples, self.global_negative_captions = load_pep_samples(
|
| 43 |
-
annotations_path,
|
| 44 |
-
image_root=image_root,
|
| 45 |
-
max_items=max_items,
|
| 46 |
-
)
|
| 47 |
-
self.transform = build_eval_transform(image_size)
|
| 48 |
-
self.image_cache_dir = Path(image_cache_dir) if image_cache_dir is not None else None
|
| 49 |
-
self.allow_image_download = allow_image_download
|
| 50 |
-
|
| 51 |
-
def __len__(self) -> int:
|
| 52 |
-
return len(self.samples)
|
| 53 |
-
|
| 54 |
-
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 55 |
-
sample = self.samples[index]
|
| 56 |
-
image = _load_sample_image(sample, self.image_cache_dir, self.allow_image_download)
|
| 57 |
-
return {
|
| 58 |
-
"image": self.transform(image.convert("RGB")),
|
| 59 |
-
"image_id": sample.image_id,
|
| 60 |
-
"positive_captions": sample.positive_captions,
|
| 61 |
-
"negative_captions": sample.negative_captions,
|
| 62 |
-
}
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
@torch.inference_mode()
|
| 66 |
-
def evaluate_pep_entailment(
|
| 67 |
-
model: Hyper3CLIP,
|
| 68 |
-
annotations_path: str | Path,
|
| 69 |
-
device: torch.device,
|
| 70 |
-
image_root: str | Path | None = None,
|
| 71 |
-
image_size: int = 224,
|
| 72 |
-
max_text_length: int = 77,
|
| 73 |
-
batch_size: int = 128,
|
| 74 |
-
max_items: int | None = None,
|
| 75 |
-
image_cache_dir: str | Path | None = None,
|
| 76 |
-
allow_image_download: bool = False,
|
| 77 |
-
negative_pool_strategy: str = "annotation",
|
| 78 |
-
max_negatives_per_image: int | None = None,
|
| 79 |
-
pair_batch_size: int = 8192,
|
| 80 |
-
) -> dict[str, float]:
|
| 81 |
-
"""Evaluate ARGENT-style PEP entailment AUC/AP.
|
| 82 |
-
|
| 83 |
-
PEP treats image-caption hierarchy evaluation as binary entailment
|
| 84 |
-
classification. Positives are the hierarchical captions attached to the same
|
| 85 |
-
image. Negatives come either from explicit annotation/global pools, or from
|
| 86 |
-
other samples' finest captions when ``negative_pool_strategy`` is
|
| 87 |
-
``"all_fine_captions"``.
|
| 88 |
-
"""
|
| 89 |
-
if negative_pool_strategy not in {"annotation", "all_fine_captions"}:
|
| 90 |
-
raise ValueError("negative_pool_strategy must be 'annotation' or 'all_fine_captions'")
|
| 91 |
-
|
| 92 |
-
model.eval()
|
| 93 |
-
dataset = PEPEntailmentDataset(
|
| 94 |
-
annotations_path,
|
| 95 |
-
image_root=image_root,
|
| 96 |
-
image_size=image_size,
|
| 97 |
-
max_items=max_items,
|
| 98 |
-
image_cache_dir=image_cache_dir,
|
| 99 |
-
allow_image_download=allow_image_download,
|
| 100 |
-
)
|
| 101 |
-
if len(dataset) == 0:
|
| 102 |
-
raise ValueError("PEP evaluation requires at least one sample")
|
| 103 |
-
|
| 104 |
-
image_feats = _encode_images(model, dataset, device, batch_size)
|
| 105 |
-
pair_image_indices, pair_captions, labels = _build_pep_pairs(
|
| 106 |
-
dataset.samples,
|
| 107 |
-
dataset.global_negative_captions,
|
| 108 |
-
negative_pool_strategy=negative_pool_strategy,
|
| 109 |
-
max_negatives_per_image=max_negatives_per_image,
|
| 110 |
-
)
|
| 111 |
-
if not any(labels) or all(labels):
|
| 112 |
-
raise ValueError("PEP evaluation requires both positive and negative pairs")
|
| 113 |
-
|
| 114 |
-
captions = sorted(set(pair_captions))
|
| 115 |
-
caption_to_index = {caption: index for index, caption in enumerate(captions)}
|
| 116 |
-
text_feats = _encode_texts(model, captions, device, max_text_length, batch_size)
|
| 117 |
-
pair_text_indices = [caption_to_index[caption] for caption in pair_captions]
|
| 118 |
-
scores = _score_pep_pairs(
|
| 119 |
-
model,
|
| 120 |
-
image_feats,
|
| 121 |
-
text_feats,
|
| 122 |
-
pair_image_indices,
|
| 123 |
-
pair_text_indices,
|
| 124 |
-
device,
|
| 125 |
-
pair_batch_size=pair_batch_size,
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
auc = _roc_auc_score(labels, scores)
|
| 129 |
-
ap = _average_precision_score(labels, scores)
|
| 130 |
-
positive_scores = [score for score, label in zip(scores, labels) if label == 1]
|
| 131 |
-
negative_scores = [score for score, label in zip(scores, labels) if label == 0]
|
| 132 |
-
return {
|
| 133 |
-
"auc_roc": auc,
|
| 134 |
-
"average_precision": ap,
|
| 135 |
-
"auc_roc_pct": 100.0 * auc,
|
| 136 |
-
"average_precision_pct": 100.0 * ap,
|
| 137 |
-
"num_samples": float(len(dataset.samples)),
|
| 138 |
-
"num_pairs": float(len(labels)),
|
| 139 |
-
"num_positive_pairs": float(sum(labels)),
|
| 140 |
-
"num_negative_pairs": float(len(labels) - sum(labels)),
|
| 141 |
-
"mean_positive_score": float(sum(positive_scores) / len(positive_scores)),
|
| 142 |
-
"mean_negative_score": float(sum(negative_scores) / len(negative_scores)),
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def probabilistic_entailment_score(
|
| 147 |
-
specific: torch.Tensor,
|
| 148 |
-
general: torch.Tensor,
|
| 149 |
-
kappa: torch.Tensor,
|
| 150 |
-
) -> torch.Tensor:
|
| 151 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 152 |
-
if angles.dim() == 2:
|
| 153 |
-
angles = angles.mean(dim=-1)
|
| 154 |
-
return torch.clamp(1.0 - (2.0 * angles / math.pi), min=0.0, max=1.0)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
def load_pep_samples(
|
| 158 |
-
annotations_path: str | Path,
|
| 159 |
-
image_root: str | Path | None = None,
|
| 160 |
-
max_items: int | None = None,
|
| 161 |
-
) -> tuple[list[PEPSample], tuple[str, ...]]:
|
| 162 |
-
path = Path(annotations_path)
|
| 163 |
-
if path.suffix.lower() == ".jsonl":
|
| 164 |
-
samples = [_sample_from_mapping(json.loads(line), image_root) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
|
| 165 |
-
return _limit_samples(samples, max_items), ()
|
| 166 |
-
if path.suffix.lower() == ".json":
|
| 167 |
-
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 168 |
-
samples_payload = payload.get("samples", payload.get("data", [])) if isinstance(payload, dict) else payload
|
| 169 |
-
global_negatives = _caption_tuple(
|
| 170 |
-
payload.get("negative_captions") or payload.get("global_negative_captions") or payload.get("negative_pool")
|
| 171 |
-
) if isinstance(payload, dict) else ()
|
| 172 |
-
samples = [_sample_from_mapping(item, image_root) for item in samples_payload]
|
| 173 |
-
return _limit_samples(samples, max_items), global_negatives
|
| 174 |
-
if path.suffix.lower() in {".csv", ".tsv"}:
|
| 175 |
-
return _load_csv_samples(path, image_root, max_items)
|
| 176 |
-
raise ValueError(f"Unsupported PEP annotations format {path.suffix!r}; expected .json, .jsonl, .csv, or .tsv")
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
def _load_csv_samples(
|
| 180 |
-
path: Path,
|
| 181 |
-
image_root: str | Path | None,
|
| 182 |
-
max_items: int | None,
|
| 183 |
-
) -> tuple[list[PEPSample], tuple[str, ...]]:
|
| 184 |
-
delimiter = "\t" if path.suffix.lower() == ".tsv" else ","
|
| 185 |
-
with path.open("r", encoding="utf-8", newline="") as handle:
|
| 186 |
-
rows = list(csv.DictReader(handle, delimiter=delimiter))
|
| 187 |
-
if not rows:
|
| 188 |
-
return [], ()
|
| 189 |
-
|
| 190 |
-
has_pair_labels = "caption" in rows[0] and "label" in rows[0]
|
| 191 |
-
if has_pair_labels:
|
| 192 |
-
grouped: dict[str, dict[str, Any]] = {}
|
| 193 |
-
for row in rows:
|
| 194 |
-
key = _image_key(row)
|
| 195 |
-
item = grouped.setdefault(key, {**row, "positive_captions": [], "negative_captions": []})
|
| 196 |
-
if _truthy_label(row["label"]):
|
| 197 |
-
item["positive_captions"].append(row["caption"])
|
| 198 |
-
else:
|
| 199 |
-
item["negative_captions"].append(row["caption"])
|
| 200 |
-
samples = [_sample_from_mapping(item, image_root) for item in grouped.values()]
|
| 201 |
-
return _limit_samples(samples, max_items), ()
|
| 202 |
-
|
| 203 |
-
samples = [_sample_from_mapping(row, image_root) for row in rows]
|
| 204 |
-
return _limit_samples(samples, max_items), ()
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def _sample_from_mapping(item: dict[str, Any], image_root: str | Path | None) -> PEPSample:
|
| 208 |
-
positives = _extract_positive_captions(item)
|
| 209 |
-
if not positives:
|
| 210 |
-
raise ValueError(f"PEP sample {item.get('id', item.get('image_id', '<unknown>'))!r} has no positive captions")
|
| 211 |
-
image_path = _extract_image_path(item, image_root)
|
| 212 |
-
image_url = _first_present(item, ("image_url", "url"))
|
| 213 |
-
if image_path is None and not image_url:
|
| 214 |
-
raise ValueError(f"PEP sample {item.get('id', item.get('image_id', '<unknown>'))!r} has no image path or URL")
|
| 215 |
-
image_id = str(_first_present(item, ("image_id", "id", "uid")) or image_path or image_url)
|
| 216 |
-
negatives = _caption_tuple(_first_present(item, ("negative_captions", "negatives", "negative_pool")))
|
| 217 |
-
return PEPSample(
|
| 218 |
-
image_id=image_id,
|
| 219 |
-
image_path=image_path,
|
| 220 |
-
image_url=str(image_url) if image_url else None,
|
| 221 |
-
positive_captions=positives,
|
| 222 |
-
negative_captions=negatives,
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
def _extract_positive_captions(item: dict[str, Any]) -> tuple[str, ...]:
|
| 227 |
-
raw = _first_present(item, ("positive_captions", "hierarchical_captions", "caption_hierarchy", "captions"))
|
| 228 |
-
captions = _caption_tuple(raw)
|
| 229 |
-
if captions:
|
| 230 |
-
return captions
|
| 231 |
-
caption = item.get("caption")
|
| 232 |
-
return (str(caption).strip(),) if caption else ()
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
def _caption_tuple(raw: Any) -> tuple[str, ...]:
|
| 236 |
-
if raw is None:
|
| 237 |
-
return ()
|
| 238 |
-
if isinstance(raw, str):
|
| 239 |
-
stripped = raw.strip()
|
| 240 |
-
if not stripped:
|
| 241 |
-
return ()
|
| 242 |
-
if stripped.startswith("["):
|
| 243 |
-
try:
|
| 244 |
-
parsed = json.loads(stripped)
|
| 245 |
-
return _caption_tuple(parsed)
|
| 246 |
-
except json.JSONDecodeError:
|
| 247 |
-
pass
|
| 248 |
-
separator = "=>" if "=>" in stripped else "||" if "||" in stripped else None
|
| 249 |
-
values = stripped.split(separator) if separator else [stripped]
|
| 250 |
-
return tuple(value.strip() for value in values if value.strip())
|
| 251 |
-
if isinstance(raw, (list, tuple)):
|
| 252 |
-
return tuple(str(value).strip() for value in raw if str(value).strip())
|
| 253 |
-
return (str(raw).strip(),)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def _extract_image_path(item: dict[str, Any], image_root: str | Path | None) -> Path | None:
|
| 257 |
-
raw = _first_present(item, ("image_path", "path", "file_name", "filename"))
|
| 258 |
-
if raw is None:
|
| 259 |
-
image_url = _first_present(item, ("image_url", "url"))
|
| 260 |
-
if image_url is None or image_root is None:
|
| 261 |
-
return None
|
| 262 |
-
return _url_to_local_image_path(str(image_url), Path(image_root))
|
| 263 |
-
path = Path(str(raw))
|
| 264 |
-
if not path.is_absolute() and image_root is not None:
|
| 265 |
-
path = Path(image_root) / path
|
| 266 |
-
return path
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
def _url_to_local_image_path(url: str, image_root: Path) -> Path:
|
| 270 |
-
url_path = Path(urlparse(url).path)
|
| 271 |
-
filename = url_path.name
|
| 272 |
-
if not filename:
|
| 273 |
-
raise ValueError(f"Cannot infer image filename from URL {url!r}")
|
| 274 |
-
candidates = [image_root / filename]
|
| 275 |
-
if url_path.parent.name:
|
| 276 |
-
candidates.append(image_root / url_path.parent.name / filename)
|
| 277 |
-
for candidate in candidates:
|
| 278 |
-
if candidate.exists():
|
| 279 |
-
return candidate
|
| 280 |
-
return candidates[0]
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
def _first_present(item: dict[str, Any], keys: tuple[str, ...]) -> Any:
|
| 284 |
-
for key in keys:
|
| 285 |
-
value = item.get(key)
|
| 286 |
-
if value not in (None, ""):
|
| 287 |
-
return value
|
| 288 |
-
return None
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
def _image_key(row: dict[str, Any]) -> str:
|
| 292 |
-
return str(_first_present(row, ("image_id", "id", "image_path", "path", "file_name", "filename", "image_url", "url")))
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def _truthy_label(raw: Any) -> bool:
|
| 296 |
-
return str(raw).strip().lower() in {"1", "true", "yes", "positive", "pos"}
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
def _limit_samples(samples: list[PEPSample], max_items: int | None) -> list[PEPSample]:
|
| 300 |
-
return samples[:max_items] if max_items is not None else samples
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
def _load_sample_image(sample: PEPSample, image_cache_dir: Path | None, allow_image_download: bool) -> Image.Image:
|
| 304 |
-
if sample.image_path is not None:
|
| 305 |
-
with Image.open(sample.image_path) as image:
|
| 306 |
-
return image.convert("RGB")
|
| 307 |
-
if not sample.image_url:
|
| 308 |
-
raise ValueError(f"PEP sample {sample.image_id!r} has no image path or URL")
|
| 309 |
-
if not allow_image_download:
|
| 310 |
-
raise ValueError("PEP sample uses image_url; set allow_image_download=true and image_cache_dir to evaluate it")
|
| 311 |
-
if image_cache_dir is None:
|
| 312 |
-
with urllib.request.urlopen(sample.image_url, timeout=30) as response:
|
| 313 |
-
return Image.open(BytesIO(response.read())).convert("RGB")
|
| 314 |
-
image_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 315 |
-
cache_path = image_cache_dir / _url_cache_name(sample.image_url)
|
| 316 |
-
if not cache_path.exists():
|
| 317 |
-
urllib.request.urlretrieve(sample.image_url, cache_path)
|
| 318 |
-
with Image.open(cache_path) as image:
|
| 319 |
-
return image.convert("RGB")
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
def _url_cache_name(url: str) -> str:
|
| 323 |
-
suffix = Path(url.split("?", maxsplit=1)[0]).suffix or ".jpg"
|
| 324 |
-
digest = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
|
| 325 |
-
return f"{digest}{suffix}"
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
def _encode_images(
|
| 329 |
-
model: Hyper3CLIP,
|
| 330 |
-
dataset: PEPEntailmentDataset,
|
| 331 |
-
device: torch.device,
|
| 332 |
-
batch_size: int,
|
| 333 |
-
) -> torch.Tensor:
|
| 334 |
-
feats: list[torch.Tensor] = []
|
| 335 |
-
batch: list[torch.Tensor] = []
|
| 336 |
-
for index in range(len(dataset)):
|
| 337 |
-
batch.append(dataset[index]["image"])
|
| 338 |
-
if len(batch) == batch_size or index == len(dataset) - 1:
|
| 339 |
-
images = torch.stack(batch).to(device)
|
| 340 |
-
feats.append(model.encode_image(images).cpu())
|
| 341 |
-
batch = []
|
| 342 |
-
return torch.cat(feats)
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
def _encode_texts(
|
| 346 |
-
model: Hyper3CLIP,
|
| 347 |
-
captions: list[str],
|
| 348 |
-
device: torch.device,
|
| 349 |
-
max_text_length: int,
|
| 350 |
-
batch_size: int,
|
| 351 |
-
) -> torch.Tensor:
|
| 352 |
-
feats: list[torch.Tensor] = []
|
| 353 |
-
for start in range(0, len(captions), batch_size):
|
| 354 |
-
batch = captions[start : start + batch_size]
|
| 355 |
-
tokenized = model.tokenizer(
|
| 356 |
-
batch,
|
| 357 |
-
padding=True,
|
| 358 |
-
truncation=True,
|
| 359 |
-
max_length=max_text_length,
|
| 360 |
-
return_tensors="pt",
|
| 361 |
-
).to(device)
|
| 362 |
-
attention_mask = (
|
| 363 |
-
tokenized.attention_mask if "attention_mask" in tokenized else torch.ones_like(tokenized.input_ids)
|
| 364 |
-
)
|
| 365 |
-
feats.append(model.encode_text(tokenized.input_ids, attention_mask).cpu())
|
| 366 |
-
return torch.cat(feats)
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
def _build_pep_pairs(
|
| 370 |
-
samples: list[PEPSample],
|
| 371 |
-
global_negative_captions: tuple[str, ...],
|
| 372 |
-
negative_pool_strategy: str,
|
| 373 |
-
max_negatives_per_image: int | None,
|
| 374 |
-
) -> tuple[list[int], list[str], list[int]]:
|
| 375 |
-
fine_caption_pool = tuple(sample.positive_captions[-1] for sample in samples)
|
| 376 |
-
pair_image_indices: list[int] = []
|
| 377 |
-
pair_captions: list[str] = []
|
| 378 |
-
labels: list[int] = []
|
| 379 |
-
for image_index, sample in enumerate(samples):
|
| 380 |
-
positives = set(sample.positive_captions)
|
| 381 |
-
for caption in sample.positive_captions:
|
| 382 |
-
pair_image_indices.append(image_index)
|
| 383 |
-
pair_captions.append(caption)
|
| 384 |
-
labels.append(1)
|
| 385 |
-
|
| 386 |
-
negatives = sample.negative_captions or global_negative_captions
|
| 387 |
-
if not negatives and negative_pool_strategy == "all_fine_captions":
|
| 388 |
-
negatives = tuple(caption for idx, caption in enumerate(fine_caption_pool) if idx != image_index)
|
| 389 |
-
negatives = tuple(caption for caption in negatives if caption not in positives)
|
| 390 |
-
if max_negatives_per_image is not None:
|
| 391 |
-
negatives = negatives[:max_negatives_per_image]
|
| 392 |
-
for caption in negatives:
|
| 393 |
-
pair_image_indices.append(image_index)
|
| 394 |
-
pair_captions.append(caption)
|
| 395 |
-
labels.append(0)
|
| 396 |
-
return pair_image_indices, pair_captions, labels
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
def _score_pep_pairs(
|
| 400 |
-
model: Hyper3CLIP,
|
| 401 |
-
image_feats: torch.Tensor,
|
| 402 |
-
text_feats: torch.Tensor,
|
| 403 |
-
pair_image_indices: list[int],
|
| 404 |
-
pair_text_indices: list[int],
|
| 405 |
-
device: torch.device,
|
| 406 |
-
pair_batch_size: int,
|
| 407 |
-
) -> list[float]:
|
| 408 |
-
kappa = model._kappa().detach().to(device)
|
| 409 |
-
scores: list[torch.Tensor] = []
|
| 410 |
-
for start in range(0, len(pair_image_indices), pair_batch_size):
|
| 411 |
-
image_index = torch.tensor(pair_image_indices[start : start + pair_batch_size], dtype=torch.long)
|
| 412 |
-
text_index = torch.tensor(pair_text_indices[start : start + pair_batch_size], dtype=torch.long)
|
| 413 |
-
batch_images = image_feats.index_select(0, image_index).to(device)
|
| 414 |
-
batch_texts = text_feats.index_select(0, text_index).to(device)
|
| 415 |
-
scores.append(probabilistic_entailment_score(batch_images, batch_texts, kappa).cpu())
|
| 416 |
-
return torch.cat(scores).tolist()
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
def _roc_auc_score(labels: list[int], scores: list[float]) -> float:
|
| 420 |
-
positives = sum(labels)
|
| 421 |
-
negatives = len(labels) - positives
|
| 422 |
-
if positives == 0 or negatives == 0:
|
| 423 |
-
raise ValueError("ROC AUC requires both positive and negative labels")
|
| 424 |
-
|
| 425 |
-
sorted_pairs = sorted(zip(scores, labels), key=lambda pair: pair[0])
|
| 426 |
-
rank_sum_pos = 0.0
|
| 427 |
-
rank = 1
|
| 428 |
-
index = 0
|
| 429 |
-
while index < len(sorted_pairs):
|
| 430 |
-
end = index + 1
|
| 431 |
-
while end < len(sorted_pairs) and sorted_pairs[end][0] == sorted_pairs[index][0]:
|
| 432 |
-
end += 1
|
| 433 |
-
avg_rank = (rank + rank + (end - index) - 1) / 2.0
|
| 434 |
-
rank_sum_pos += avg_rank * sum(label for _, label in sorted_pairs[index:end])
|
| 435 |
-
rank += end - index
|
| 436 |
-
index = end
|
| 437 |
-
return (rank_sum_pos - positives * (positives + 1) / 2.0) / (positives * negatives)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
def _average_precision_score(labels: list[int], scores: list[float]) -> float:
|
| 441 |
-
positives = sum(labels)
|
| 442 |
-
if positives == 0:
|
| 443 |
-
raise ValueError("Average precision requires at least one positive label")
|
| 444 |
-
|
| 445 |
-
sorted_pairs = sorted(zip(scores, labels), key=lambda pair: pair[0], reverse=True)
|
| 446 |
-
true_positives = 0
|
| 447 |
-
false_positives = 0
|
| 448 |
-
previous_recall = 0.0
|
| 449 |
-
ap = 0.0
|
| 450 |
-
index = 0
|
| 451 |
-
while index < len(sorted_pairs):
|
| 452 |
-
end = index + 1
|
| 453 |
-
while end < len(sorted_pairs) and sorted_pairs[end][0] == sorted_pairs[index][0]:
|
| 454 |
-
end += 1
|
| 455 |
-
true_positives += sum(label for _, label in sorted_pairs[index:end])
|
| 456 |
-
false_positives += (end - index) - sum(label for _, label in sorted_pairs[index:end])
|
| 457 |
-
recall = true_positives / positives
|
| 458 |
-
precision = true_positives / (true_positives + false_positives)
|
| 459 |
-
ap += (recall - previous_recall) * precision
|
| 460 |
-
previous_recall = recall
|
| 461 |
-
index = end
|
| 462 |
-
return ap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/evaluation/retrieval.py
DELETED
|
@@ -1,215 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import json
|
| 4 |
-
from collections import defaultdict
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from PIL import Image
|
| 10 |
-
from torch.utils.data import Dataset
|
| 11 |
-
|
| 12 |
-
from hyper3_clip.data.transforms import build_retrieval_transform
|
| 13 |
-
|
| 14 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class CocoCaptionRetrieval(Dataset):
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
root: str | Path,
|
| 21 |
-
image_size: int = 224,
|
| 22 |
-
max_items: int | None = None,
|
| 23 |
-
image_normalization: str = "imagenet",
|
| 24 |
-
) -> None:
|
| 25 |
-
self.root = Path(root)
|
| 26 |
-
with (self.root / "annotations" / "captions_val2017.json").open("r", encoding="utf-8") as handle:
|
| 27 |
-
payload = json.load(handle)
|
| 28 |
-
images = {item["id"]: item["file_name"] for item in payload["images"]}
|
| 29 |
-
captions: dict[int, list[str]] = defaultdict(list)
|
| 30 |
-
for annotation in payload["annotations"]:
|
| 31 |
-
captions[int(annotation["image_id"])].append(str(annotation["caption"]))
|
| 32 |
-
self.items = [
|
| 33 |
-
{"image_id": image_id, "image_path": self.root / "val2017" / images[image_id], "captions": captions[image_id]}
|
| 34 |
-
for image_id in sorted(captions)
|
| 35 |
-
]
|
| 36 |
-
if max_items is not None:
|
| 37 |
-
self.items = self.items[:max_items]
|
| 38 |
-
self.transform = build_retrieval_transform(image_size, normalization=image_normalization)
|
| 39 |
-
|
| 40 |
-
def __len__(self) -> int:
|
| 41 |
-
return len(self.items)
|
| 42 |
-
|
| 43 |
-
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 44 |
-
item = self.items[index]
|
| 45 |
-
with Image.open(item["image_path"]) as image:
|
| 46 |
-
tensor = self.transform(image.convert("RGB"))
|
| 47 |
-
return {"image": tensor, "captions": item["captions"], "image_id": item["image_id"]}
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class CocoKarpathyCaptionRetrieval(Dataset):
|
| 51 |
-
def __init__(
|
| 52 |
-
self,
|
| 53 |
-
root: str | Path,
|
| 54 |
-
split: str = "test",
|
| 55 |
-
image_size: int = 224,
|
| 56 |
-
max_items: int | None = None,
|
| 57 |
-
image_normalization: str = "imagenet",
|
| 58 |
-
) -> None:
|
| 59 |
-
self.root = Path(root)
|
| 60 |
-
with (self.root / "karpathy" / "dataset_coco.json").open("r", encoding="utf-8") as handle:
|
| 61 |
-
payload = json.load(handle)
|
| 62 |
-
images = [item for item in payload["images"] if item["split"] == split]
|
| 63 |
-
if max_items is not None:
|
| 64 |
-
images = images[:max_items]
|
| 65 |
-
self.items = [
|
| 66 |
-
{
|
| 67 |
-
"image_id": item["imgid"],
|
| 68 |
-
"image_path": self.root / item["filepath"] / item["filename"],
|
| 69 |
-
"captions": [sentence["raw"].strip() for sentence in item["sentences"]],
|
| 70 |
-
}
|
| 71 |
-
for item in images
|
| 72 |
-
]
|
| 73 |
-
self.transform = build_retrieval_transform(image_size, normalization=image_normalization)
|
| 74 |
-
|
| 75 |
-
def __len__(self) -> int:
|
| 76 |
-
return len(self.items)
|
| 77 |
-
|
| 78 |
-
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 79 |
-
item = self.items[index]
|
| 80 |
-
with Image.open(item["image_path"]) as image:
|
| 81 |
-
tensor = self.transform(image.convert("RGB"))
|
| 82 |
-
return {"image": tensor, "captions": item["captions"], "image_id": item["image_id"]}
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
class Flickr30kCaptionRetrieval(Dataset):
|
| 86 |
-
def __init__(
|
| 87 |
-
self,
|
| 88 |
-
root: str | Path,
|
| 89 |
-
split: str = "test",
|
| 90 |
-
image_size: int = 224,
|
| 91 |
-
max_items: int | None = None,
|
| 92 |
-
image_normalization: str = "imagenet",
|
| 93 |
-
) -> None:
|
| 94 |
-
self.root = Path(root)
|
| 95 |
-
with (self.root / "dataset_flickr30k.json").open("r", encoding="utf-8") as handle:
|
| 96 |
-
payload = json.load(handle)
|
| 97 |
-
self.items = []
|
| 98 |
-
for index, image_payload in enumerate(payload["images"]):
|
| 99 |
-
if image_payload.get("split") != split:
|
| 100 |
-
continue
|
| 101 |
-
captions = [str(sentence.get("raw") or " ".join(sentence.get("tokens", []))) for sentence in image_payload["sentences"]]
|
| 102 |
-
self.items.append(
|
| 103 |
-
{
|
| 104 |
-
"image_id": index,
|
| 105 |
-
"image_path": self.root / "flickr30k_images" / image_payload["filename"],
|
| 106 |
-
"captions": captions,
|
| 107 |
-
}
|
| 108 |
-
)
|
| 109 |
-
if max_items is not None:
|
| 110 |
-
self.items = self.items[:max_items]
|
| 111 |
-
self.transform = build_retrieval_transform(image_size, normalization=image_normalization)
|
| 112 |
-
|
| 113 |
-
def __len__(self) -> int:
|
| 114 |
-
return len(self.items)
|
| 115 |
-
|
| 116 |
-
def __getitem__(self, index: int) -> dict[str, Any]:
|
| 117 |
-
item = self.items[index]
|
| 118 |
-
with Image.open(item["image_path"]) as image:
|
| 119 |
-
tensor = self.transform(image.convert("RGB"))
|
| 120 |
-
return {"image": tensor, "captions": item["captions"], "image_id": item["image_id"]}
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
@torch.inference_mode()
|
| 124 |
-
def evaluate_caption_retrieval(
|
| 125 |
-
model: Hyper3CLIP,
|
| 126 |
-
dataset: Dataset,
|
| 127 |
-
device: torch.device,
|
| 128 |
-
max_text_length: int = 77,
|
| 129 |
-
batch_size: int = 128,
|
| 130 |
-
) -> dict[str, float]:
|
| 131 |
-
model.eval()
|
| 132 |
-
image_feats: list[torch.Tensor] = []
|
| 133 |
-
captions: list[str] = []
|
| 134 |
-
text_feats: list[torch.Tensor] = []
|
| 135 |
-
text_to_image: list[int] = []
|
| 136 |
-
|
| 137 |
-
image_batch: list[torch.Tensor] = []
|
| 138 |
-
for item_index in range(len(dataset)):
|
| 139 |
-
item = dataset[item_index]
|
| 140 |
-
image_batch.append(item["image"])
|
| 141 |
-
if len(image_batch) == batch_size or item_index == len(dataset) - 1:
|
| 142 |
-
images = torch.stack(image_batch).to(device)
|
| 143 |
-
image_feats.append(model.encode_retrieval_image(images).cpu())
|
| 144 |
-
image_batch = []
|
| 145 |
-
captions.extend(item["captions"])
|
| 146 |
-
text_to_image.extend([item_index] * len(item["captions"]))
|
| 147 |
-
|
| 148 |
-
for start in range(0, len(captions), batch_size):
|
| 149 |
-
caption_batch = captions[start : start + batch_size]
|
| 150 |
-
tokenized = model.tokenizer(
|
| 151 |
-
caption_batch,
|
| 152 |
-
padding=True,
|
| 153 |
-
truncation=True,
|
| 154 |
-
max_length=max_text_length,
|
| 155 |
-
return_tensors="pt",
|
| 156 |
-
).to(device)
|
| 157 |
-
attention_mask = (
|
| 158 |
-
tokenized.attention_mask if "attention_mask" in tokenized else torch.ones_like(tokenized.input_ids)
|
| 159 |
-
)
|
| 160 |
-
text_feats.append(model.encode_retrieval_text(tokenized.input_ids, attention_mask).cpu())
|
| 161 |
-
|
| 162 |
-
images = torch.cat(image_feats).to(device)
|
| 163 |
-
texts = torch.cat(text_feats).to(device)
|
| 164 |
-
scores_i2t = _retrieval_similarity_scores(model, images, texts, chunk_size=max(1, min(batch_size, 64)))
|
| 165 |
-
scores_t2i = scores_i2t.transpose(0, 1)
|
| 166 |
-
target_device = scores_i2t.device
|
| 167 |
-
text_targets = torch.tensor(text_to_image, device=target_device)
|
| 168 |
-
fractions = {
|
| 169 |
-
"image_to_text_r1": _recall_at_k(scores_i2t, _image_to_text_targets(text_to_image, len(dataset), target_device), 1),
|
| 170 |
-
"image_to_text_r5": _recall_at_k(scores_i2t, _image_to_text_targets(text_to_image, len(dataset), target_device), 5),
|
| 171 |
-
"image_to_text_r10": _recall_at_k(scores_i2t, _image_to_text_targets(text_to_image, len(dataset), target_device), 10),
|
| 172 |
-
"text_to_image_r1": _single_target_recall_at_k(scores_t2i, text_targets, 1),
|
| 173 |
-
"text_to_image_r5": _single_target_recall_at_k(scores_t2i, text_targets, 5),
|
| 174 |
-
"text_to_image_r10": _single_target_recall_at_k(scores_t2i, text_targets, 10),
|
| 175 |
-
}
|
| 176 |
-
return {
|
| 177 |
-
**fractions,
|
| 178 |
-
"i2t_r1": 100.0 * fractions["image_to_text_r1"],
|
| 179 |
-
"i2t_r5": 100.0 * fractions["image_to_text_r5"],
|
| 180 |
-
"i2t_r10": 100.0 * fractions["image_to_text_r10"],
|
| 181 |
-
"t2i_r1": 100.0 * fractions["text_to_image_r1"],
|
| 182 |
-
"t2i_r5": 100.0 * fractions["text_to_image_r5"],
|
| 183 |
-
"t2i_r10": 100.0 * fractions["text_to_image_r10"],
|
| 184 |
-
}
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def _retrieval_similarity_scores(
|
| 188 |
-
model: Hyper3CLIP, images: torch.Tensor, texts: torch.Tensor, chunk_size: int
|
| 189 |
-
) -> torch.Tensor:
|
| 190 |
-
if not getattr(model, "retrieval_requires_chunking", False):
|
| 191 |
-
return model.retrieval_similarity_scores(images, texts)
|
| 192 |
-
|
| 193 |
-
chunks: list[torch.Tensor] = []
|
| 194 |
-
for start in range(0, images.shape[0], chunk_size):
|
| 195 |
-
chunk_scores = model.retrieval_similarity_scores(images[start : start + chunk_size], texts)
|
| 196 |
-
chunks.append(chunk_scores.cpu())
|
| 197 |
-
return torch.cat(chunks, dim=0)
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def _image_to_text_targets(text_to_image: list[int], num_images: int, device: torch.device) -> list[torch.Tensor]:
|
| 201 |
-
targets: list[list[int]] = [[] for _ in range(num_images)]
|
| 202 |
-
for text_index, image_index in enumerate(text_to_image):
|
| 203 |
-
targets[image_index].append(text_index)
|
| 204 |
-
return [torch.tensor(indices, device=device) for indices in targets]
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def _recall_at_k(scores: torch.Tensor, targets: list[torch.Tensor], k: int) -> float:
|
| 208 |
-
topk = scores.topk(k=min(k, scores.shape[1]), dim=1).indices
|
| 209 |
-
hits = [bool(torch.isin(targets[row], topk[row]).any().item()) for row in range(scores.shape[0])]
|
| 210 |
-
return float(sum(hits) / len(hits))
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
def _single_target_recall_at_k(scores: torch.Tensor, targets: torch.Tensor, k: int) -> float:
|
| 214 |
-
topk = scores.topk(k=min(k, scores.shape[1]), dim=1).indices
|
| 215 |
-
return float((topk == targets[:, None]).any(dim=1).float().mean().item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 2 |
-
|
| 3 |
-
__all__ = ["Hyper3CLIP"]
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/encoders.py
DELETED
|
@@ -1,173 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import timm
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from transformers import (
|
| 7 |
-
AutoConfig,
|
| 8 |
-
AutoModel,
|
| 9 |
-
AutoTokenizer,
|
| 10 |
-
CLIPTextConfig,
|
| 11 |
-
CLIPTextModel,
|
| 12 |
-
CLIPTextModelWithProjection,
|
| 13 |
-
CLIPVisionConfig,
|
| 14 |
-
CLIPVisionModel,
|
| 15 |
-
CLIPVisionModelWithProjection,
|
| 16 |
-
SiglipTextConfig,
|
| 17 |
-
SiglipTextModel,
|
| 18 |
-
SiglipVisionConfig,
|
| 19 |
-
SiglipVisionModel,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class VisionEncoder(nn.Module):
|
| 24 |
-
def __init__(self, backbone_name: str, pretrained: bool = True) -> None:
|
| 25 |
-
super().__init__()
|
| 26 |
-
self.kind = "timm"
|
| 27 |
-
if backbone_name.startswith("hf_clip_projected:"):
|
| 28 |
-
self.kind = "hf_clip_projected"
|
| 29 |
-
model_name = backbone_name.removeprefix("hf_clip_projected:")
|
| 30 |
-
self.backbone = (
|
| 31 |
-
CLIPVisionModelWithProjection.from_pretrained(model_name)
|
| 32 |
-
if pretrained
|
| 33 |
-
else CLIPVisionModelWithProjection(CLIPVisionConfig.from_pretrained(model_name))
|
| 34 |
-
)
|
| 35 |
-
self.output_dim = self.backbone.config.projection_dim
|
| 36 |
-
elif backbone_name.startswith("hf_clip:"):
|
| 37 |
-
self.kind = "hf_vision"
|
| 38 |
-
model_name = backbone_name.removeprefix("hf_clip:")
|
| 39 |
-
self.backbone = (
|
| 40 |
-
CLIPVisionModel.from_pretrained(model_name)
|
| 41 |
-
if pretrained
|
| 42 |
-
else CLIPVisionModel(CLIPVisionConfig.from_pretrained(model_name))
|
| 43 |
-
)
|
| 44 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 45 |
-
elif backbone_name.startswith("hf_siglip:"):
|
| 46 |
-
self.kind = "hf_vision"
|
| 47 |
-
model_name = backbone_name.removeprefix("hf_siglip:")
|
| 48 |
-
self.backbone = (
|
| 49 |
-
SiglipVisionModel.from_pretrained(model_name)
|
| 50 |
-
if pretrained
|
| 51 |
-
else SiglipVisionModel(SiglipVisionConfig.from_pretrained(model_name))
|
| 52 |
-
)
|
| 53 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 54 |
-
else:
|
| 55 |
-
self.backbone = timm.create_model(
|
| 56 |
-
backbone_name,
|
| 57 |
-
pretrained=pretrained,
|
| 58 |
-
num_classes=0,
|
| 59 |
-
global_pool="avg",
|
| 60 |
-
)
|
| 61 |
-
self.output_dim = self.backbone.num_features
|
| 62 |
-
|
| 63 |
-
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 64 |
-
if self.kind == "hf_clip_projected":
|
| 65 |
-
return self.backbone(pixel_values=image).image_embeds
|
| 66 |
-
if self.kind == "hf_vision":
|
| 67 |
-
out = self.backbone(pixel_values=image)
|
| 68 |
-
if hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 69 |
-
return out.pooler_output
|
| 70 |
-
return out.last_hidden_state[:, 0]
|
| 71 |
-
return self.backbone(image)
|
| 72 |
-
|
| 73 |
-
def forward_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 74 |
-
if self.kind == "hf_clip_projected":
|
| 75 |
-
out = self.backbone(pixel_values=image)
|
| 76 |
-
tokens = getattr(out, "last_hidden_state", None)
|
| 77 |
-
if tokens is None and hasattr(out, "vision_model_output"):
|
| 78 |
-
tokens = out.vision_model_output.last_hidden_state
|
| 79 |
-
if tokens is None:
|
| 80 |
-
raise RuntimeError("Projected CLIP vision output did not include patch tokens")
|
| 81 |
-
return out.image_embeds, tokens
|
| 82 |
-
if self.kind == "hf_vision":
|
| 83 |
-
out = self.backbone(pixel_values=image)
|
| 84 |
-
if hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 85 |
-
pooled = out.pooler_output
|
| 86 |
-
else:
|
| 87 |
-
pooled = out.last_hidden_state[:, 0]
|
| 88 |
-
return pooled, out.last_hidden_state
|
| 89 |
-
|
| 90 |
-
if not hasattr(self.backbone, "forward_features"):
|
| 91 |
-
pooled = self.backbone(image)
|
| 92 |
-
return pooled, pooled[:, None, :]
|
| 93 |
-
features = self.backbone.forward_features(image)
|
| 94 |
-
if hasattr(self.backbone, "forward_head"):
|
| 95 |
-
pooled = self.backbone.forward_head(features, pre_logits=False)
|
| 96 |
-
else:
|
| 97 |
-
pooled = self.backbone(image)
|
| 98 |
-
return pooled, _tokens_from_features(features)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
class TextEncoder(nn.Module):
|
| 102 |
-
def __init__(self, model_name: str, pretrained: bool = True, pooling: str = "auto") -> None:
|
| 103 |
-
super().__init__()
|
| 104 |
-
if pooling not in {"auto", "pooler", "cls", "mean"}:
|
| 105 |
-
raise ValueError(f"Unsupported text pooling {pooling!r}; expected auto, pooler, cls, or mean")
|
| 106 |
-
self.kind = "hf_text"
|
| 107 |
-
self.pooling = pooling
|
| 108 |
-
tokenizer_name = model_name.removeprefix("hf_clip_projected:")
|
| 109 |
-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 110 |
-
model_name_lower = model_name.lower()
|
| 111 |
-
if model_name.startswith("hf_clip_projected:"):
|
| 112 |
-
self.kind = "hf_clip_projected"
|
| 113 |
-
projected_model_name = model_name.removeprefix("hf_clip_projected:")
|
| 114 |
-
if pretrained:
|
| 115 |
-
self.backbone = CLIPTextModelWithProjection.from_pretrained(projected_model_name)
|
| 116 |
-
else:
|
| 117 |
-
self.backbone = CLIPTextModelWithProjection(CLIPTextConfig.from_pretrained(projected_model_name))
|
| 118 |
-
self.output_dim = self.backbone.config.projection_dim
|
| 119 |
-
elif "siglip" in model_name_lower:
|
| 120 |
-
if pretrained:
|
| 121 |
-
self.backbone = SiglipTextModel.from_pretrained(model_name)
|
| 122 |
-
else:
|
| 123 |
-
self.backbone = SiglipTextModel(SiglipTextConfig.from_pretrained(model_name))
|
| 124 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 125 |
-
elif "clip" in model_name_lower:
|
| 126 |
-
if pretrained:
|
| 127 |
-
self.backbone = CLIPTextModel.from_pretrained(model_name)
|
| 128 |
-
else:
|
| 129 |
-
self.backbone = CLIPTextModel(CLIPTextConfig.from_pretrained(model_name))
|
| 130 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 131 |
-
else:
|
| 132 |
-
if pretrained:
|
| 133 |
-
self.backbone = AutoModel.from_pretrained(model_name)
|
| 134 |
-
else:
|
| 135 |
-
self.backbone = AutoModel.from_config(AutoConfig.from_pretrained(model_name))
|
| 136 |
-
hidden_size = getattr(self.backbone.config, "hidden_size", None)
|
| 137 |
-
if hidden_size is None:
|
| 138 |
-
raise ValueError(f"Unsupported text model config for {model_name}")
|
| 139 |
-
self.output_dim = hidden_size
|
| 140 |
-
|
| 141 |
-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 142 |
-
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
|
| 143 |
-
if self.kind == "hf_clip_projected":
|
| 144 |
-
return out.text_embeds
|
| 145 |
-
if self.pooling == "mean":
|
| 146 |
-
mask = attention_mask.to(dtype=out.last_hidden_state.dtype).unsqueeze(-1)
|
| 147 |
-
summed = (out.last_hidden_state * mask).sum(dim=1)
|
| 148 |
-
denom = mask.sum(dim=1).clamp_min(1.0)
|
| 149 |
-
return summed / denom
|
| 150 |
-
if self.pooling in {"auto", "pooler"} and hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 151 |
-
return out.pooler_output
|
| 152 |
-
return out.last_hidden_state[:, 0]
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def _tokens_from_features(features: torch.Tensor | dict | tuple | list) -> torch.Tensor:
|
| 156 |
-
if isinstance(features, dict):
|
| 157 |
-
for key in ("x", "last_hidden_state", "features"):
|
| 158 |
-
if key in features:
|
| 159 |
-
features = features[key]
|
| 160 |
-
break
|
| 161 |
-
else:
|
| 162 |
-
features = next(iter(features.values()))
|
| 163 |
-
if isinstance(features, tuple | list):
|
| 164 |
-
features = features[0]
|
| 165 |
-
if not torch.is_tensor(features):
|
| 166 |
-
raise TypeError(f"Expected tensor features, got {type(features)!r}")
|
| 167 |
-
if features.ndim == 4:
|
| 168 |
-
return features.flatten(2).transpose(1, 2)
|
| 169 |
-
if features.ndim == 3:
|
| 170 |
-
return features
|
| 171 |
-
if features.ndim == 2:
|
| 172 |
-
return features[:, None, :]
|
| 173 |
-
raise ValueError(f"Unsupported feature tensor shape {tuple(features.shape)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/experimental.py
DELETED
|
@@ -1,587 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from collections.abc import Callable
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from torch import Tensor, nn
|
| 8 |
-
|
| 9 |
-
from hyper3_clip.models.lorentz import exp_map0, metric_pairwise_dist
|
| 10 |
-
from hyper3_clip.models.losses import beta_cal_loss
|
| 11 |
-
from hyper3_clip.models.tren import TRENRegionEncoder
|
| 12 |
-
from hyper3_clip.training.distributed import gather_variable_with_grad, gather_with_grad, get_rank
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
ProjectionHeadFactory = Callable[[int, int, int | None], nn.Module]
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class ExperimentalObjectiveMixin:
|
| 19 |
-
@staticmethod
|
| 20 |
-
def _validate_experimental_options(
|
| 21 |
-
*,
|
| 22 |
-
proclip_geometry: str,
|
| 23 |
-
proclip_projection_hidden_dim: int | None,
|
| 24 |
-
proclip_component_dim: int | None,
|
| 25 |
-
beta_clip_weight: float,
|
| 26 |
-
beta_clip_global_weight: float,
|
| 27 |
-
beta_clip_beta: float,
|
| 28 |
-
beta_clip_variant: str,
|
| 29 |
-
beta_clip_similarity: str,
|
| 30 |
-
beta_clip_num_heads: int,
|
| 31 |
-
beta_clip_mlp_ratio: float,
|
| 32 |
-
tren_weight: float,
|
| 33 |
-
tren_visual_distill_weight: float,
|
| 34 |
-
tren_text_distill_weight: float,
|
| 35 |
-
tren_region_text_weight: float,
|
| 36 |
-
tren_num_region_tokens: int,
|
| 37 |
-
tren_num_decoder_layers: int,
|
| 38 |
-
tren_num_attention_heads: int,
|
| 39 |
-
tren_prompt_grid_size: int,
|
| 40 |
-
tren_dropout: float,
|
| 41 |
-
) -> None:
|
| 42 |
-
if proclip_geometry not in {"product", "hyperbolic", "euclidean", "spherical", "clip"}:
|
| 43 |
-
raise ValueError("proclip_geometry must be 'product', 'hyperbolic', 'euclidean', 'spherical', or 'clip'")
|
| 44 |
-
if proclip_projection_hidden_dim is not None and proclip_projection_hidden_dim <= 0:
|
| 45 |
-
raise ValueError("proclip_projection_hidden_dim must be positive when set")
|
| 46 |
-
if proclip_component_dim is not None and proclip_component_dim <= 0:
|
| 47 |
-
raise ValueError("proclip_component_dim must be positive when set")
|
| 48 |
-
if beta_clip_variant not in {"ce", "bce"}:
|
| 49 |
-
raise ValueError("beta_clip_variant must be 'ce' or 'bce'")
|
| 50 |
-
if beta_clip_similarity not in {"metric", "dot"}:
|
| 51 |
-
raise ValueError("beta_clip_similarity must be 'metric' or 'dot'")
|
| 52 |
-
if beta_clip_weight < 0.0:
|
| 53 |
-
raise ValueError("beta_clip_weight must be non-negative")
|
| 54 |
-
if beta_clip_global_weight < 0.0:
|
| 55 |
-
raise ValueError("beta_clip_global_weight must be non-negative")
|
| 56 |
-
if beta_clip_beta < 0.0:
|
| 57 |
-
raise ValueError("beta_clip_beta must be non-negative")
|
| 58 |
-
if beta_clip_num_heads <= 0:
|
| 59 |
-
raise ValueError("beta_clip_num_heads must be positive")
|
| 60 |
-
if beta_clip_mlp_ratio <= 0.0:
|
| 61 |
-
raise ValueError("beta_clip_mlp_ratio must be positive")
|
| 62 |
-
if tren_weight < 0.0:
|
| 63 |
-
raise ValueError("tren_weight must be non-negative")
|
| 64 |
-
if tren_visual_distill_weight < 0.0 or tren_text_distill_weight < 0.0 or tren_region_text_weight < 0.0:
|
| 65 |
-
raise ValueError("T-REN loss weights must be non-negative")
|
| 66 |
-
if tren_num_region_tokens <= 0:
|
| 67 |
-
raise ValueError("tren_num_region_tokens must be positive")
|
| 68 |
-
if tren_num_decoder_layers <= 0:
|
| 69 |
-
raise ValueError("tren_num_decoder_layers must be positive")
|
| 70 |
-
if tren_num_attention_heads <= 0:
|
| 71 |
-
raise ValueError("tren_num_attention_heads must be positive")
|
| 72 |
-
if tren_prompt_grid_size <= 0:
|
| 73 |
-
raise ValueError("tren_prompt_grid_size must be positive")
|
| 74 |
-
if tren_dropout < 0.0:
|
| 75 |
-
raise ValueError("tren_dropout must be non-negative")
|
| 76 |
-
|
| 77 |
-
def _init_experimental_modules(
|
| 78 |
-
self,
|
| 79 |
-
*,
|
| 80 |
-
beta_clip_num_heads: int,
|
| 81 |
-
beta_clip_mlp_ratio: float,
|
| 82 |
-
tren_num_region_tokens: int,
|
| 83 |
-
tren_num_decoder_layers: int,
|
| 84 |
-
tren_num_attention_heads: int,
|
| 85 |
-
tren_prompt_grid_size: int,
|
| 86 |
-
tren_dropout: float,
|
| 87 |
-
projection_hidden_dim: int | None,
|
| 88 |
-
proclip_projection_hidden_dim: int | None,
|
| 89 |
-
projection_head: ProjectionHeadFactory,
|
| 90 |
-
) -> None:
|
| 91 |
-
if self.beta_query_pooling_enabled:
|
| 92 |
-
if self.vision_encoder.output_dim % beta_clip_num_heads != 0:
|
| 93 |
-
raise ValueError("vision encoder output_dim must be divisible by beta_clip_num_heads")
|
| 94 |
-
beta_clip_hidden_dim = max(1, int(round(self.vision_encoder.output_dim * beta_clip_mlp_ratio)))
|
| 95 |
-
self.beta_clip_text_query_proj = nn.Linear(self.text_encoder.output_dim, self.vision_encoder.output_dim)
|
| 96 |
-
self.beta_clip_cross_attention = nn.MultiheadAttention(
|
| 97 |
-
self.vision_encoder.output_dim,
|
| 98 |
-
beta_clip_num_heads,
|
| 99 |
-
batch_first=True,
|
| 100 |
-
)
|
| 101 |
-
self.beta_clip_mlp_norm = nn.LayerNorm(self.vision_encoder.output_dim)
|
| 102 |
-
self.beta_clip_pool_mlp = nn.Sequential(
|
| 103 |
-
nn.Linear(self.vision_encoder.output_dim, beta_clip_hidden_dim),
|
| 104 |
-
nn.GELU(),
|
| 105 |
-
nn.Linear(beta_clip_hidden_dim, self.vision_encoder.output_dim),
|
| 106 |
-
)
|
| 107 |
-
if self.beta_clip_enabled:
|
| 108 |
-
self.beta_clip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 109 |
-
if self.tren_enabled:
|
| 110 |
-
self.tren_region_encoder = TRENRegionEncoder(
|
| 111 |
-
vision_dim=self.vision_encoder.output_dim,
|
| 112 |
-
text_dim=self.text_encoder.output_dim,
|
| 113 |
-
num_region_tokens=tren_num_region_tokens,
|
| 114 |
-
num_decoder_layers=tren_num_decoder_layers,
|
| 115 |
-
num_attention_heads=tren_num_attention_heads,
|
| 116 |
-
prompt_grid_size=tren_prompt_grid_size,
|
| 117 |
-
dropout=tren_dropout,
|
| 118 |
-
)
|
| 119 |
-
self.tren_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 120 |
-
if self.proclip_enabled:
|
| 121 |
-
component_dim = self._proclip_component_dim
|
| 122 |
-
spherical_dim = self._proclip_spherical_ambient_dim
|
| 123 |
-
proclip_hidden_dim = proclip_projection_hidden_dim
|
| 124 |
-
if proclip_hidden_dim is None:
|
| 125 |
-
proclip_hidden_dim = projection_hidden_dim
|
| 126 |
-
if self.proclip_dedicated_hyperbolic:
|
| 127 |
-
self.proclip_image_hyperbolic_proj = projection_head(
|
| 128 |
-
self.vision_encoder.output_dim, self.embed_dim, proclip_hidden_dim
|
| 129 |
-
)
|
| 130 |
-
self.proclip_text_hyperbolic_proj = projection_head(
|
| 131 |
-
self.text_encoder.output_dim, self.embed_dim, proclip_hidden_dim
|
| 132 |
-
)
|
| 133 |
-
self.proclip_image_euclidean_proj = projection_head(
|
| 134 |
-
self.vision_encoder.output_dim, component_dim, proclip_hidden_dim
|
| 135 |
-
)
|
| 136 |
-
self.proclip_text_euclidean_proj = projection_head(
|
| 137 |
-
self.text_encoder.output_dim, component_dim, proclip_hidden_dim
|
| 138 |
-
)
|
| 139 |
-
self.proclip_image_spherical_proj = projection_head(
|
| 140 |
-
self.vision_encoder.output_dim, spherical_dim, proclip_hidden_dim
|
| 141 |
-
)
|
| 142 |
-
self.proclip_text_spherical_proj = projection_head(
|
| 143 |
-
self.text_encoder.output_dim, spherical_dim, proclip_hidden_dim
|
| 144 |
-
)
|
| 145 |
-
self.proclip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 146 |
-
self.proclip_log_weights = nn.Parameter(torch.zeros(3))
|
| 147 |
-
|
| 148 |
-
@property
|
| 149 |
-
def proclip_enabled(self) -> bool:
|
| 150 |
-
return (
|
| 151 |
-
self.objective_name == "proclip"
|
| 152 |
-
or self.proclip_component_dim is not None
|
| 153 |
-
or self.proclip_weight > 0.0
|
| 154 |
-
or self.proclip_retrieval
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
@property
|
| 158 |
-
def beta_clip_enabled(self) -> bool:
|
| 159 |
-
return self.beta_clip_weight > 0.0
|
| 160 |
-
|
| 161 |
-
@property
|
| 162 |
-
def beta_query_pooling_enabled(self) -> bool:
|
| 163 |
-
return self.beta_clip_enabled or (
|
| 164 |
-
self.objective_name == "uncha"
|
| 165 |
-
and self.uncha_entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
@property
|
| 169 |
-
def tren_enabled(self) -> bool:
|
| 170 |
-
return self.tren_weight > 0.0
|
| 171 |
-
|
| 172 |
-
@property
|
| 173 |
-
def _proclip_component_dim(self) -> int:
|
| 174 |
-
return int(self.proclip_component_dim or self.embed_dim)
|
| 175 |
-
|
| 176 |
-
@property
|
| 177 |
-
def _proclip_spherical_ambient_dim(self) -> int:
|
| 178 |
-
return self._proclip_component_dim + 1
|
| 179 |
-
|
| 180 |
-
def _clamp_experimental_logit_scales(self) -> None:
|
| 181 |
-
if self.proclip_enabled:
|
| 182 |
-
self.proclip_logit_scale.clamp_(max=4.6052)
|
| 183 |
-
if self.beta_clip_enabled:
|
| 184 |
-
self.beta_clip_logit_scale.clamp_(max=4.6052)
|
| 185 |
-
if self.tren_enabled:
|
| 186 |
-
self.tren_logit_scale.clamp_(max=4.6052)
|
| 187 |
-
|
| 188 |
-
def _detached_experimental_logit_scales(self) -> dict[str, torch.Tensor]:
|
| 189 |
-
logs = {}
|
| 190 |
-
if self.proclip_enabled:
|
| 191 |
-
logs.update(self._detached_proclip_logs())
|
| 192 |
-
if self.beta_clip_enabled:
|
| 193 |
-
logs["beta_clip_logit_scale"] = self.beta_clip_logit_scale.exp().detach()
|
| 194 |
-
if self.tren_enabled:
|
| 195 |
-
logs["tren_logit_scale"] = self.tren_logit_scale.exp().detach()
|
| 196 |
-
return logs
|
| 197 |
-
|
| 198 |
-
def _beta_clip_global_contrastive_loss(
|
| 199 |
-
self,
|
| 200 |
-
*,
|
| 201 |
-
image_euc: torch.Tensor,
|
| 202 |
-
text_euc: torch.Tensor,
|
| 203 |
-
targets: torch.Tensor,
|
| 204 |
-
) -> torch.Tensor:
|
| 205 |
-
image_feats = F.normalize(image_euc.float(), dim=-1)
|
| 206 |
-
text_feats = F.normalize(text_euc.float(), dim=-1)
|
| 207 |
-
all_image_feats = gather_with_grad(image_feats)
|
| 208 |
-
all_text_feats = gather_with_grad(text_feats)
|
| 209 |
-
if self.objective_name == "hycoclip":
|
| 210 |
-
scale = self.logit_scale.exp().clamp(max=100.0)
|
| 211 |
-
elif self.objective_name == "proclip":
|
| 212 |
-
scale = self.proclip_logit_scale.exp().clamp(max=100.0)
|
| 213 |
-
else:
|
| 214 |
-
scale = self.global_logit_scale.exp().clamp(max=100.0)
|
| 215 |
-
logits_i_t = image_feats @ all_text_feats.T * scale
|
| 216 |
-
logits_t_i = text_feats @ all_image_feats.T * scale
|
| 217 |
-
return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
|
| 218 |
-
|
| 219 |
-
def _beta_query_entailment_embeddings(
|
| 220 |
-
self,
|
| 221 |
-
*,
|
| 222 |
-
image_tokens: torch.Tensor,
|
| 223 |
-
beta_query_input_ids: torch.Tensor | None,
|
| 224 |
-
beta_query_attention_mask: torch.Tensor | None,
|
| 225 |
-
beta_query_owner: torch.Tensor | None,
|
| 226 |
-
beta_query_parent: torch.Tensor | None,
|
| 227 |
-
beta_query_weight: torch.Tensor | None,
|
| 228 |
-
beta_query_source_part: torch.Tensor | None,
|
| 229 |
-
kappa: torch.Tensor,
|
| 230 |
-
query_base: torch.Tensor | None = None,
|
| 231 |
-
) -> dict[str, torch.Tensor]:
|
| 232 |
-
if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
|
| 233 |
-
raise ValueError(f"{self.uncha_entailment_loss} requires beta query tensors from the collator")
|
| 234 |
-
if beta_query_parent is None or beta_query_weight is None:
|
| 235 |
-
raise ValueError(f"{self.uncha_entailment_loss} requires beta query hierarchy metadata from the collator")
|
| 236 |
-
if self.uncha_entailment_loss == "hier_beta_sourcepart_argent" and beta_query_source_part is None:
|
| 237 |
-
raise ValueError("hier_beta_sourcepart_argent requires beta_query_source_part from the collator")
|
| 238 |
-
if beta_query_input_ids.shape[0] == 0:
|
| 239 |
-
source_part = (
|
| 240 |
-
beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)
|
| 241 |
-
if beta_query_source_part is not None
|
| 242 |
-
else beta_query_owner.new_zeros((0,), device=image_tokens.device, dtype=torch.long)
|
| 243 |
-
)
|
| 244 |
-
return {
|
| 245 |
-
"beta_query_image_feats": image_tokens.new_zeros((0, self.embed_dim)),
|
| 246 |
-
"beta_query_text_feats": image_tokens.new_zeros((0, self.embed_dim)),
|
| 247 |
-
"beta_query_owner": beta_query_owner.to(device=image_tokens.device, dtype=torch.long),
|
| 248 |
-
"beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
|
| 249 |
-
"beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
|
| 250 |
-
"beta_query_source_part": source_part,
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
|
| 254 |
-
if query_base is None:
|
| 255 |
-
query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
|
| 256 |
-
conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, query_owner)
|
| 257 |
-
query_image_euc = self.image_proj(conditioned_image_base)
|
| 258 |
-
query_text_euc = self.text_proj(query_base)
|
| 259 |
-
return {
|
| 260 |
-
"beta_query_image_feats": self.project_image_features(query_image_euc),
|
| 261 |
-
"beta_query_text_feats": self.project_text_features(query_text_euc),
|
| 262 |
-
"beta_query_owner": query_owner,
|
| 263 |
-
"beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
|
| 264 |
-
"beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
|
| 265 |
-
**(
|
| 266 |
-
{"beta_query_source_part": beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)}
|
| 267 |
-
if beta_query_source_part is not None
|
| 268 |
-
else {}
|
| 269 |
-
),
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
def _beta_clip_auxiliary_loss(
|
| 273 |
-
self,
|
| 274 |
-
*,
|
| 275 |
-
image_tokens: torch.Tensor,
|
| 276 |
-
beta_query_input_ids: torch.Tensor | None,
|
| 277 |
-
beta_query_attention_mask: torch.Tensor | None,
|
| 278 |
-
beta_query_owner: torch.Tensor | None,
|
| 279 |
-
global_targets: torch.Tensor,
|
| 280 |
-
kappa: torch.Tensor,
|
| 281 |
-
) -> torch.Tensor:
|
| 282 |
-
if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
|
| 283 |
-
raise ValueError("beta-CLIP auxiliary requires beta query tensors from the collator")
|
| 284 |
-
if beta_query_input_ids.shape[0] == 0:
|
| 285 |
-
return image_tokens.new_zeros(())
|
| 286 |
-
|
| 287 |
-
beta_query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
|
| 288 |
-
query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
|
| 289 |
-
conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, beta_query_owner)
|
| 290 |
-
query_image_euc = self.image_proj(conditioned_image_base)
|
| 291 |
-
query_text_euc = self.text_proj(query_base)
|
| 292 |
-
|
| 293 |
-
if self.beta_clip_similarity == "dot":
|
| 294 |
-
query_image_feats = F.normalize(query_image_euc.float(), dim=-1)
|
| 295 |
-
query_text_feats = F.normalize(query_text_euc.float(), dim=-1)
|
| 296 |
-
else:
|
| 297 |
-
query_image_feats = self.project_image_features(query_image_euc)
|
| 298 |
-
query_text_feats = self.project_text_features(query_text_euc)
|
| 299 |
-
|
| 300 |
-
all_query_image_feats, query_counts = gather_variable_with_grad(query_image_feats)
|
| 301 |
-
all_query_text_feats, _ = gather_variable_with_grad(query_text_feats)
|
| 302 |
-
query_offset = query_counts[: get_rank()].sum() if query_counts.numel() > 1 else query_counts.new_zeros(())
|
| 303 |
-
query_targets = torch.arange(query_image_feats.size(0), device=query_image_feats.device) + query_offset
|
| 304 |
-
query_group_ids = global_targets.index_select(0, beta_query_owner)
|
| 305 |
-
all_query_group_ids, _ = gather_variable_with_grad(query_group_ids)
|
| 306 |
-
|
| 307 |
-
scale = self.beta_clip_logit_scale.exp().clamp(max=100.0)
|
| 308 |
-
if self.beta_clip_similarity == "dot":
|
| 309 |
-
logits_i_t = query_image_feats @ all_query_text_feats.T * scale
|
| 310 |
-
logits_t_i = query_text_feats @ all_query_image_feats.T * scale
|
| 311 |
-
else:
|
| 312 |
-
logits_i_t = -metric_pairwise_dist(
|
| 313 |
-
query_image_feats,
|
| 314 |
-
all_query_text_feats,
|
| 315 |
-
kappa,
|
| 316 |
-
product_metric=self.phyclip_product_metric,
|
| 317 |
-
) * scale
|
| 318 |
-
logits_t_i = -metric_pairwise_dist(
|
| 319 |
-
query_text_feats,
|
| 320 |
-
all_query_image_feats,
|
| 321 |
-
kappa,
|
| 322 |
-
product_metric=self.phyclip_product_metric,
|
| 323 |
-
) * scale
|
| 324 |
-
return 0.5 * (
|
| 325 |
-
beta_cal_loss(
|
| 326 |
-
logits_i_t,
|
| 327 |
-
targets=query_targets,
|
| 328 |
-
group_ids=query_group_ids,
|
| 329 |
-
all_group_ids=all_query_group_ids,
|
| 330 |
-
beta=self.beta_clip_beta,
|
| 331 |
-
variant=self.beta_clip_variant,
|
| 332 |
-
)
|
| 333 |
-
+ beta_cal_loss(
|
| 334 |
-
logits_t_i,
|
| 335 |
-
targets=query_targets,
|
| 336 |
-
group_ids=query_group_ids,
|
| 337 |
-
all_group_ids=all_query_group_ids,
|
| 338 |
-
beta=self.beta_clip_beta,
|
| 339 |
-
variant=self.beta_clip_variant,
|
| 340 |
-
)
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
def _beta_clip_text_conditioned_pool(
|
| 344 |
-
self,
|
| 345 |
-
image_tokens: torch.Tensor,
|
| 346 |
-
query_base: torch.Tensor,
|
| 347 |
-
query_owner: torch.Tensor,
|
| 348 |
-
) -> torch.Tensor:
|
| 349 |
-
if image_tokens.ndim != 3:
|
| 350 |
-
raise ValueError("beta-CLIP image tokens must have shape [batch, tokens, dim]")
|
| 351 |
-
if getattr(self, "group_beta_query_pooling", False):
|
| 352 |
-
return self._beta_clip_text_conditioned_pool_grouped(image_tokens, query_base, query_owner)
|
| 353 |
-
if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1:
|
| 354 |
-
image_tokens = image_tokens[:, 1:, :]
|
| 355 |
-
selected_tokens = image_tokens.index_select(0, query_owner).to(dtype=query_base.dtype)
|
| 356 |
-
query = self.beta_clip_text_query_proj(query_base).unsqueeze(1)
|
| 357 |
-
attended, _ = self.beta_clip_cross_attention(query, selected_tokens, selected_tokens, need_weights=False)
|
| 358 |
-
pooled = attended.squeeze(1)
|
| 359 |
-
return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
|
| 360 |
-
|
| 361 |
-
def _beta_clip_text_conditioned_pool_grouped(
|
| 362 |
-
self,
|
| 363 |
-
image_tokens: torch.Tensor,
|
| 364 |
-
query_base: torch.Tensor,
|
| 365 |
-
query_owner: torch.Tensor,
|
| 366 |
-
) -> torch.Tensor:
|
| 367 |
-
if query_owner.numel() == 0:
|
| 368 |
-
return query_base.new_zeros((0, self.vision_encoder.output_dim))
|
| 369 |
-
if query_owner.min().item() < 0 or query_owner.max().item() >= image_tokens.size(0):
|
| 370 |
-
raise IndexError("beta_query_owner contains an out-of-range image index")
|
| 371 |
-
|
| 372 |
-
tokens = image_tokens[:, 1:, :] if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1 else image_tokens
|
| 373 |
-
tokens = tokens.to(dtype=query_base.dtype)
|
| 374 |
-
query_projected = self.beta_clip_text_query_proj(query_base)
|
| 375 |
-
counts = torch.bincount(query_owner, minlength=image_tokens.size(0))
|
| 376 |
-
max_queries = int(counts.max().item())
|
| 377 |
-
|
| 378 |
-
order = torch.argsort(query_owner)
|
| 379 |
-
sorted_owner = query_owner.index_select(0, order)
|
| 380 |
-
owner_offsets = torch.zeros_like(counts)
|
| 381 |
-
owner_offsets[1:] = counts.cumsum(0)[:-1]
|
| 382 |
-
sorted_positions = torch.arange(query_owner.numel(), device=query_owner.device) - owner_offsets.index_select(
|
| 383 |
-
0, sorted_owner
|
| 384 |
-
)
|
| 385 |
-
positions = torch.empty_like(sorted_positions)
|
| 386 |
-
positions[order] = sorted_positions
|
| 387 |
-
|
| 388 |
-
packed_query = query_projected.new_zeros((image_tokens.size(0), max_queries, query_projected.size(-1)))
|
| 389 |
-
packed_query[query_owner, positions] = query_projected
|
| 390 |
-
attended, _ = self.beta_clip_cross_attention(packed_query, tokens, tokens, need_weights=False)
|
| 391 |
-
pooled = attended[query_owner, positions]
|
| 392 |
-
return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
|
| 393 |
-
|
| 394 |
-
def _tren_auxiliary_losses(
|
| 395 |
-
self,
|
| 396 |
-
*,
|
| 397 |
-
image_tokens: torch.Tensor,
|
| 398 |
-
part_owner: torch.Tensor,
|
| 399 |
-
part_image_base: torch.Tensor,
|
| 400 |
-
part_text_base: torch.Tensor,
|
| 401 |
-
) -> dict[str, torch.Tensor]:
|
| 402 |
-
zero = image_tokens.new_zeros(())
|
| 403 |
-
if part_owner.numel() == 0:
|
| 404 |
-
return {
|
| 405 |
-
"tren_loss": zero,
|
| 406 |
-
"tren_visual_distill_loss": zero,
|
| 407 |
-
"tren_text_distill_loss": zero,
|
| 408 |
-
"tren_region_text_contrastive_loss": zero,
|
| 409 |
-
"tren_assignment_count": part_owner.new_tensor(0),
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
tren_outputs = self.tren_region_encoder(image_tokens)
|
| 413 |
-
visual_tokens = tren_outputs["visual_tokens"].flatten(1, 2)
|
| 414 |
-
text_tokens = tren_outputs["text_aligned_tokens"].flatten(1, 2)
|
| 415 |
-
|
| 416 |
-
matched_visual: list[torch.Tensor] = []
|
| 417 |
-
matched_text: list[torch.Tensor] = []
|
| 418 |
-
target_visual: list[torch.Tensor] = []
|
| 419 |
-
target_text: list[torch.Tensor] = []
|
| 420 |
-
for owner in range(image_tokens.size(0)):
|
| 421 |
-
region_mask = part_owner == owner
|
| 422 |
-
if not bool(region_mask.any()):
|
| 423 |
-
continue
|
| 424 |
-
owner_target_visual = part_image_base[region_mask].detach()
|
| 425 |
-
owner_target_text = part_text_base[region_mask].detach()
|
| 426 |
-
owner_visual_tokens = visual_tokens[owner]
|
| 427 |
-
owner_text_tokens = text_tokens[owner]
|
| 428 |
-
pred_indices, target_indices = _greedy_region_assignment(owner_visual_tokens, owner_target_visual)
|
| 429 |
-
if pred_indices.numel() == 0:
|
| 430 |
-
continue
|
| 431 |
-
matched_visual.append(owner_visual_tokens.index_select(0, pred_indices))
|
| 432 |
-
matched_text.append(owner_text_tokens.index_select(0, pred_indices))
|
| 433 |
-
target_visual.append(owner_target_visual.index_select(0, target_indices))
|
| 434 |
-
target_text.append(owner_target_text.index_select(0, target_indices))
|
| 435 |
-
|
| 436 |
-
if not matched_visual:
|
| 437 |
-
return {
|
| 438 |
-
"tren_loss": zero,
|
| 439 |
-
"tren_visual_distill_loss": zero,
|
| 440 |
-
"tren_text_distill_loss": zero,
|
| 441 |
-
"tren_region_text_contrastive_loss": zero,
|
| 442 |
-
"tren_assignment_count": part_owner.new_tensor(0),
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
matched_visual_tensor = torch.cat(matched_visual, dim=0)
|
| 446 |
-
matched_text_tensor = torch.cat(matched_text, dim=0)
|
| 447 |
-
target_visual_tensor = torch.cat(target_visual, dim=0)
|
| 448 |
-
target_text_tensor = torch.cat(target_text, dim=0)
|
| 449 |
-
visual_distill = 1.0 - F.cosine_similarity(matched_visual_tensor, target_visual_tensor, dim=-1).mean()
|
| 450 |
-
text_distill = 1.0 - F.cosine_similarity(matched_text_tensor, target_text_tensor, dim=-1).mean()
|
| 451 |
-
region_text = _symmetric_dot_contrastive(
|
| 452 |
-
matched_text_tensor,
|
| 453 |
-
target_text_tensor,
|
| 454 |
-
scale=self.tren_logit_scale.exp().clamp(max=100.0),
|
| 455 |
-
)
|
| 456 |
-
total = (
|
| 457 |
-
self.tren_visual_distill_weight * visual_distill
|
| 458 |
-
+ self.tren_text_distill_weight * text_distill
|
| 459 |
-
+ self.tren_region_text_weight * region_text
|
| 460 |
-
)
|
| 461 |
-
return {
|
| 462 |
-
"tren_loss": total,
|
| 463 |
-
"tren_visual_distill_loss": visual_distill,
|
| 464 |
-
"tren_text_distill_loss": text_distill,
|
| 465 |
-
"tren_region_text_contrastive_loss": region_text,
|
| 466 |
-
"tren_assignment_count": part_owner.new_tensor(matched_visual_tensor.size(0)),
|
| 467 |
-
}
|
| 468 |
-
|
| 469 |
-
def _project_proclip_image_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
|
| 470 |
-
if self.proclip_geometry == "clip":
|
| 471 |
-
return F.normalize(base_feats.float(), dim=-1)
|
| 472 |
-
if self.proclip_dedicated_hyperbolic:
|
| 473 |
-
hyperbolic = exp_map0(self.proclip_image_hyperbolic_proj(base_feats.float()), self._kappa().float())
|
| 474 |
-
return self._pack_proclip_features(
|
| 475 |
-
hyperbolic=hyperbolic,
|
| 476 |
-
euclidean=self.proclip_image_euclidean_proj(base_feats.float()),
|
| 477 |
-
spherical=self.proclip_image_spherical_proj(base_feats.float()),
|
| 478 |
-
)
|
| 479 |
-
|
| 480 |
-
def _project_proclip_text_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
|
| 481 |
-
if self.proclip_geometry == "clip":
|
| 482 |
-
return F.normalize(base_feats.float(), dim=-1)
|
| 483 |
-
if self.proclip_dedicated_hyperbolic:
|
| 484 |
-
hyperbolic = exp_map0(self.proclip_text_hyperbolic_proj(base_feats.float()), self._kappa().float())
|
| 485 |
-
return self._pack_proclip_features(
|
| 486 |
-
hyperbolic=hyperbolic,
|
| 487 |
-
euclidean=self.proclip_text_euclidean_proj(base_feats.float()),
|
| 488 |
-
spherical=self.proclip_text_spherical_proj(base_feats.float()),
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
def _pack_proclip_features(self, hyperbolic: torch.Tensor, euclidean: torch.Tensor, spherical: torch.Tensor) -> torch.Tensor:
|
| 492 |
-
spherical = F.normalize(spherical.float(), dim=-1)
|
| 493 |
-
if self.proclip_geometry == "hyperbolic":
|
| 494 |
-
return hyperbolic.float()
|
| 495 |
-
if self.proclip_geometry == "euclidean":
|
| 496 |
-
return euclidean.float()
|
| 497 |
-
if self.proclip_geometry == "spherical":
|
| 498 |
-
return spherical
|
| 499 |
-
return torch.cat([hyperbolic.float(), euclidean.float(), spherical], dim=-1)
|
| 500 |
-
|
| 501 |
-
def _split_proclip_features(self, feats: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 502 |
-
hyperbolic_dim = self.embed_dim + 1
|
| 503 |
-
component_dim = self._proclip_component_dim
|
| 504 |
-
spherical_dim = self._proclip_spherical_ambient_dim
|
| 505 |
-
hyperbolic = feats[:, :hyperbolic_dim]
|
| 506 |
-
euclidean = feats[:, hyperbolic_dim : hyperbolic_dim + component_dim]
|
| 507 |
-
spherical = feats[:, hyperbolic_dim + component_dim : hyperbolic_dim + component_dim + spherical_dim]
|
| 508 |
-
return hyperbolic, euclidean, spherical
|
| 509 |
-
|
| 510 |
-
def _proclip_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
|
| 511 |
-
if self.proclip_geometry == "clip":
|
| 512 |
-
return image_feats.float() @ text_feats.float().T
|
| 513 |
-
if self.proclip_geometry == "hyperbolic":
|
| 514 |
-
return -metric_pairwise_dist(image_feats, text_feats, self._kappa()).square()
|
| 515 |
-
if self.proclip_geometry == "euclidean":
|
| 516 |
-
return -torch.cdist(image_feats.float(), text_feats.float(), p=2).square()
|
| 517 |
-
if self.proclip_geometry == "spherical":
|
| 518 |
-
dot = (image_feats.float() @ text_feats.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
|
| 519 |
-
return -torch.acos(dot).square()
|
| 520 |
-
image_hyp, image_euc, image_sph = self._split_proclip_features(image_feats)
|
| 521 |
-
text_hyp, text_euc, text_sph = self._split_proclip_features(text_feats)
|
| 522 |
-
weights = self.proclip_log_weights.exp().to(device=image_feats.device, dtype=torch.float32)
|
| 523 |
-
hyperbolic_dist2 = metric_pairwise_dist(image_hyp, text_hyp, self._kappa()).square()
|
| 524 |
-
euclidean_dist2 = torch.cdist(image_euc.float(), text_euc.float(), p=2).square()
|
| 525 |
-
spherical_dot = (image_sph.float() @ text_sph.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
|
| 526 |
-
spherical_dist2 = torch.acos(spherical_dot).square()
|
| 527 |
-
return -(weights[0] * hyperbolic_dist2 + weights[1] * euclidean_dist2 + weights[2] * spherical_dist2)
|
| 528 |
-
|
| 529 |
-
def _proclip_contrastive_loss(
|
| 530 |
-
self,
|
| 531 |
-
image_feats: torch.Tensor,
|
| 532 |
-
text_feats: torch.Tensor,
|
| 533 |
-
all_image_feats: torch.Tensor,
|
| 534 |
-
all_text_feats: torch.Tensor,
|
| 535 |
-
targets: torch.Tensor,
|
| 536 |
-
) -> torch.Tensor:
|
| 537 |
-
scale = self.proclip_logit_scale.exp().clamp(max=100.0)
|
| 538 |
-
logits_i_t = self._proclip_similarity_scores(image_feats, all_text_feats) * scale
|
| 539 |
-
logits_t_i = self._proclip_similarity_scores(text_feats, all_image_feats) * scale
|
| 540 |
-
return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
|
| 541 |
-
|
| 542 |
-
def _detached_proclip_logs(self) -> dict[str, torch.Tensor]:
|
| 543 |
-
weights = self.proclip_log_weights.exp().detach()
|
| 544 |
-
return {
|
| 545 |
-
"proclip_logit_scale": self.proclip_logit_scale.exp().detach(),
|
| 546 |
-
"proclip_hyperbolic_weight": weights[0],
|
| 547 |
-
"proclip_euclidean_weight": weights[1],
|
| 548 |
-
"proclip_spherical_weight": weights[2],
|
| 549 |
-
}
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
def _greedy_region_assignment(pred_tokens: torch.Tensor, target_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 553 |
-
if pred_tokens.numel() == 0 or target_tokens.numel() == 0:
|
| 554 |
-
empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
|
| 555 |
-
return empty, empty
|
| 556 |
-
similarities = F.normalize(pred_tokens.float(), dim=-1) @ F.normalize(target_tokens.float(), dim=-1).T
|
| 557 |
-
pair_scores = similarities.flatten()
|
| 558 |
-
order = torch.argsort(pair_scores, descending=True)
|
| 559 |
-
used_pred = torch.zeros(pred_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
|
| 560 |
-
used_target = torch.zeros(target_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
|
| 561 |
-
pred_indices: list[torch.Tensor] = []
|
| 562 |
-
target_indices: list[torch.Tensor] = []
|
| 563 |
-
for flat_index in order:
|
| 564 |
-
pred_index = torch.div(flat_index, target_tokens.size(0), rounding_mode="floor")
|
| 565 |
-
target_index = flat_index % target_tokens.size(0)
|
| 566 |
-
if used_pred[pred_index] or used_target[target_index]:
|
| 567 |
-
continue
|
| 568 |
-
used_pred[pred_index] = True
|
| 569 |
-
used_target[target_index] = True
|
| 570 |
-
pred_indices.append(pred_index)
|
| 571 |
-
target_indices.append(target_index)
|
| 572 |
-
if len(target_indices) == target_tokens.size(0):
|
| 573 |
-
break
|
| 574 |
-
if not pred_indices:
|
| 575 |
-
empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
|
| 576 |
-
return empty, empty
|
| 577 |
-
return torch.stack(pred_indices), torch.stack(target_indices)
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
def _symmetric_dot_contrastive(region_tokens: torch.Tensor, text_tokens: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 581 |
-
if region_tokens.size(0) == 1:
|
| 582 |
-
return region_tokens.new_zeros(())
|
| 583 |
-
region_tokens = F.normalize(region_tokens.float(), dim=-1)
|
| 584 |
-
text_tokens = F.normalize(text_tokens.float(), dim=-1)
|
| 585 |
-
logits = region_tokens @ text_tokens.T * scale
|
| 586 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 587 |
-
return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/himo.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import Tensor
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def hide_reconstruct_embeddings(
|
| 8 |
-
embeddings: Tensor,
|
| 9 |
-
*,
|
| 10 |
-
variance_threshold: float = 0.9,
|
| 11 |
-
detach_pca: bool = True,
|
| 12 |
-
eps: float = 1e-8,
|
| 13 |
-
) -> Tensor:
|
| 14 |
-
"""HiMo-CLIP HiDe: PCA-reconstruct embeddings using top principal components.
|
| 15 |
-
|
| 16 |
-
Given a batch of embeddings ``U ∈ R^{B×D}``, compute mean-centered embeddings,
|
| 17 |
-
perform SVD/PCA, choose the smallest number of components whose cumulative
|
| 18 |
-
explained variance exceeds ``variance_threshold``, and reconstruct each
|
| 19 |
-
embedding from this principal subspace:
|
| 20 |
-
|
| 21 |
-
u'_i = P^T (P (u_i - ū)) + ū
|
| 22 |
-
|
| 23 |
-
where P stacks the selected principal components as rows.
|
| 24 |
-
"""
|
| 25 |
-
if embeddings.ndim != 2:
|
| 26 |
-
raise ValueError("hide_reconstruct_embeddings expects a [batch, dim] tensor")
|
| 27 |
-
if not (0.0 < variance_threshold <= 1.0):
|
| 28 |
-
raise ValueError("variance_threshold must be in (0, 1]")
|
| 29 |
-
if embeddings.size(0) < 2:
|
| 30 |
-
return embeddings
|
| 31 |
-
|
| 32 |
-
u = embeddings.to(dtype=torch.float32)
|
| 33 |
-
mean = u.mean(dim=0, keepdim=True)
|
| 34 |
-
centered = u - mean
|
| 35 |
-
if detach_pca:
|
| 36 |
-
centered_for_pca = centered.detach()
|
| 37 |
-
else:
|
| 38 |
-
centered_for_pca = centered
|
| 39 |
-
|
| 40 |
-
# SVD: centered = U S Vh, principal components are rows of Vh.
|
| 41 |
-
_, s, vh = torch.linalg.svd(centered_for_pca, full_matrices=False)
|
| 42 |
-
if s.numel() == 0 or float((s.square().sum()).item()) <= eps:
|
| 43 |
-
return embeddings
|
| 44 |
-
|
| 45 |
-
explained = s.square()
|
| 46 |
-
cumulative = explained.cumsum(dim=0) / explained.sum().clamp_min(eps)
|
| 47 |
-
m = int((cumulative >= variance_threshold).to(dtype=torch.int64).argmax().item()) + 1
|
| 48 |
-
m = max(1, min(m, vh.size(0)))
|
| 49 |
-
p = vh[:m]
|
| 50 |
-
if detach_pca:
|
| 51 |
-
p = p.detach()
|
| 52 |
-
|
| 53 |
-
recon = (centered @ p.T) @ p + mean
|
| 54 |
-
return recon.to(dtype=embeddings.dtype)
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/hyper3_clip.py
DELETED
|
@@ -1,958 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
from hyper3_clip.models.encoders import TextEncoder, VisionEncoder
|
| 8 |
-
from hyper3_clip.models.experimental import ExperimentalObjectiveMixin
|
| 9 |
-
from hyper3_clip.models.himo import hide_reconstruct_embeddings
|
| 10 |
-
from hyper3_clip.models.lorentz import exp_map0, metric_similarity
|
| 11 |
-
from hyper3_clip.models.objectives import build_objective
|
| 12 |
-
from hyper3_clip.training.distributed import (
|
| 13 |
-
gather_with_grad,
|
| 14 |
-
get_rank,
|
| 15 |
-
get_world_size,
|
| 16 |
-
local_target_indices,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class Hyper3CLIP(ExperimentalObjectiveMixin, nn.Module):
|
| 21 |
-
def __init__(
|
| 22 |
-
self,
|
| 23 |
-
vision_backbone: str,
|
| 24 |
-
text_model_name: str,
|
| 25 |
-
embed_dim: int,
|
| 26 |
-
curv_init: float,
|
| 27 |
-
learn_curv: bool,
|
| 28 |
-
entail_weight: float,
|
| 29 |
-
inter_aperture_scale: float,
|
| 30 |
-
intra_aperture_scale: float,
|
| 31 |
-
objective: str = "hycoclip",
|
| 32 |
-
uncha_piecewise_factor: float = 0.1,
|
| 33 |
-
uncha_calibration_alpha: float = 10.0,
|
| 34 |
-
uncha_stop_grad_calibration: bool = True,
|
| 35 |
-
vision_pretrained: bool = True,
|
| 36 |
-
text_pretrained: bool = True,
|
| 37 |
-
text_pooling: str = "auto",
|
| 38 |
-
freeze_vision_encoder: bool = False,
|
| 39 |
-
freeze_text_encoder: bool = False,
|
| 40 |
-
normalize_encoder_features: bool = False,
|
| 41 |
-
projection_hidden_dim: int | None = None,
|
| 42 |
-
uncha_entailment_geometry: str = "lorentz",
|
| 43 |
-
uncha_aggregate_weight: float = 0.0,
|
| 44 |
-
uncha_entailment_loss: str = "piecewise",
|
| 45 |
-
uncha_argent_beta: float = 1.0,
|
| 46 |
-
uncha_argent_norm_weight: float = 0.0,
|
| 47 |
-
uncha_argent_aux_weight: float = 0.5,
|
| 48 |
-
uncha_argent_aggregation: str = "uncha",
|
| 49 |
-
uncha_part_weight_power: float = 0.0,
|
| 50 |
-
uncha_contrastive_loss: str = "ce",
|
| 51 |
-
uncha_sigmoid_bias_init: float = -10.0,
|
| 52 |
-
uncha_sigmoid_negative_weight: float = 1.0,
|
| 53 |
-
uncha_part_quality_mode: str = "none",
|
| 54 |
-
uncha_part_quality_topk: int = 5,
|
| 55 |
-
uncha_part_quality_temperature: float = 4.0,
|
| 56 |
-
uncha_entailment_warmup_steps: int = 0,
|
| 57 |
-
uncha_contrastive_global_weight: float = 1.0,
|
| 58 |
-
uncha_contrastive_local_weight: float = 1.0,
|
| 59 |
-
uncha_contrastive_global_local_weight: float = 1.0,
|
| 60 |
-
uncha_global_local_mode: str = "repeat",
|
| 61 |
-
uncha_global_local_metric: str = "distance",
|
| 62 |
-
uncha_global_local_angle_aux_weight: float = 0.0,
|
| 63 |
-
uncha_global_local_angle_aux_mode: str = "contrastive",
|
| 64 |
-
uncha_global_local_angle_aux_scale: float = 5.5,
|
| 65 |
-
uncha_global_local_angle_aux_aperture_scale: float = 1.0,
|
| 66 |
-
uncha_beta_cal_beta: float = 0.0,
|
| 67 |
-
uncha_beta_cal_variant: str = "ce",
|
| 68 |
-
uncha_beta_cal_weight: float = 0.0,
|
| 69 |
-
uncha_himo_component_weight: float = 0.0,
|
| 70 |
-
uncha_himo_variance_threshold: float = 0.9,
|
| 71 |
-
uncha_himo_detach_pca: bool = True,
|
| 72 |
-
uncha_radius_order_weight: float = 0.0,
|
| 73 |
-
uncha_radius_order_margin: float = 0.0,
|
| 74 |
-
uncha_gramian_align_weight: float = 0.0,
|
| 75 |
-
phyclip_subspace_dim: int | None = None,
|
| 76 |
-
phyclip_product_metric: str = "l1",
|
| 77 |
-
proclip_weight: float = 0.0,
|
| 78 |
-
proclip_component_dim: int | None = None,
|
| 79 |
-
proclip_retrieval: bool = False,
|
| 80 |
-
proclip_geometry: str = "product",
|
| 81 |
-
proclip_dedicated_hyperbolic: bool = False,
|
| 82 |
-
proclip_projection_hidden_dim: int | None = None,
|
| 83 |
-
beta_clip_weight: float = 0.0,
|
| 84 |
-
beta_clip_global_weight: float = 0.0,
|
| 85 |
-
beta_clip_beta: float = 0.5,
|
| 86 |
-
beta_clip_variant: str = "ce",
|
| 87 |
-
beta_clip_similarity: str = "metric",
|
| 88 |
-
beta_clip_num_heads: int = 8,
|
| 89 |
-
beta_clip_mlp_ratio: float = 4.0,
|
| 90 |
-
beta_clip_drop_cls_token: bool = True,
|
| 91 |
-
tren_weight: float = 0.0,
|
| 92 |
-
tren_visual_distill_weight: float = 1.0,
|
| 93 |
-
tren_text_distill_weight: float = 1.0,
|
| 94 |
-
tren_region_text_weight: float = 1.0,
|
| 95 |
-
tren_num_region_tokens: int = 3,
|
| 96 |
-
tren_num_decoder_layers: int = 2,
|
| 97 |
-
tren_num_attention_heads: int = 8,
|
| 98 |
-
tren_prompt_grid_size: int = 7,
|
| 99 |
-
tren_dropout: float = 0.1,
|
| 100 |
-
fuse_whole_part_encoder_forwards: bool = False,
|
| 101 |
-
fuse_beta_query_encoder_forwards: bool = False,
|
| 102 |
-
group_beta_query_pooling: bool = False,
|
| 103 |
-
objective_autocast_dtype: str = "float32",
|
| 104 |
-
) -> None:
|
| 105 |
-
super().__init__()
|
| 106 |
-
if objective not in {"hycoclip", "uncha", "proclip"}:
|
| 107 |
-
raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip', 'uncha', or 'proclip'")
|
| 108 |
-
if phyclip_product_metric not in {"l1", "l2"}:
|
| 109 |
-
raise ValueError("phyclip_product_metric must be 'l1' or 'l2'")
|
| 110 |
-
self._validate_experimental_options(
|
| 111 |
-
proclip_geometry=proclip_geometry,
|
| 112 |
-
proclip_projection_hidden_dim=proclip_projection_hidden_dim,
|
| 113 |
-
proclip_component_dim=proclip_component_dim,
|
| 114 |
-
beta_clip_weight=beta_clip_weight,
|
| 115 |
-
beta_clip_global_weight=beta_clip_global_weight,
|
| 116 |
-
beta_clip_beta=beta_clip_beta,
|
| 117 |
-
beta_clip_variant=beta_clip_variant,
|
| 118 |
-
beta_clip_similarity=beta_clip_similarity,
|
| 119 |
-
beta_clip_num_heads=beta_clip_num_heads,
|
| 120 |
-
beta_clip_mlp_ratio=beta_clip_mlp_ratio,
|
| 121 |
-
tren_weight=tren_weight,
|
| 122 |
-
tren_visual_distill_weight=tren_visual_distill_weight,
|
| 123 |
-
tren_text_distill_weight=tren_text_distill_weight,
|
| 124 |
-
tren_region_text_weight=tren_region_text_weight,
|
| 125 |
-
tren_num_region_tokens=tren_num_region_tokens,
|
| 126 |
-
tren_num_decoder_layers=tren_num_decoder_layers,
|
| 127 |
-
tren_num_attention_heads=tren_num_attention_heads,
|
| 128 |
-
tren_prompt_grid_size=tren_prompt_grid_size,
|
| 129 |
-
tren_dropout=tren_dropout,
|
| 130 |
-
)
|
| 131 |
-
if objective_autocast_dtype not in {"float32", "fp32", "float16", "fp16", "bfloat16", "bf16"}:
|
| 132 |
-
raise ValueError("objective_autocast_dtype must be one of 'float32', 'float16', or 'bfloat16'")
|
| 133 |
-
if uncha_contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
|
| 134 |
-
raise ValueError("uncha_contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
|
| 135 |
-
if uncha_global_local_metric not in {"distance", "angle"}:
|
| 136 |
-
raise ValueError("uncha_global_local_metric must be 'distance' or 'angle'")
|
| 137 |
-
if uncha_global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
|
| 138 |
-
raise ValueError("uncha_global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
|
| 139 |
-
if uncha_global_local_angle_aux_weight < 0.0:
|
| 140 |
-
raise ValueError("uncha_global_local_angle_aux_weight must be non-negative")
|
| 141 |
-
if uncha_global_local_angle_aux_scale <= 0.0:
|
| 142 |
-
raise ValueError("uncha_global_local_angle_aux_scale must be positive")
|
| 143 |
-
if uncha_global_local_angle_aux_aperture_scale <= 0.0:
|
| 144 |
-
raise ValueError("uncha_global_local_angle_aux_aperture_scale must be positive")
|
| 145 |
-
if uncha_entailment_warmup_steps < 0:
|
| 146 |
-
raise ValueError("uncha_entailment_warmup_steps must be non-negative")
|
| 147 |
-
self.objective_name = objective
|
| 148 |
-
self.uncha_contrastive_loss = uncha_contrastive_loss
|
| 149 |
-
self.uncha_entailment_loss = uncha_entailment_loss
|
| 150 |
-
self.uncha_entailment_warmup_steps = uncha_entailment_warmup_steps
|
| 151 |
-
self.uncha_himo_component_weight = float(uncha_himo_component_weight)
|
| 152 |
-
self.uncha_himo_variance_threshold = float(uncha_himo_variance_threshold)
|
| 153 |
-
self.uncha_himo_detach_pca = bool(uncha_himo_detach_pca)
|
| 154 |
-
self.proclip_weight = float(proclip_weight)
|
| 155 |
-
self.proclip_retrieval = bool(proclip_retrieval)
|
| 156 |
-
self.proclip_geometry = proclip_geometry
|
| 157 |
-
self.proclip_dedicated_hyperbolic = bool(proclip_dedicated_hyperbolic)
|
| 158 |
-
self.beta_clip_weight = float(beta_clip_weight)
|
| 159 |
-
self.beta_clip_global_weight = float(beta_clip_global_weight)
|
| 160 |
-
self.beta_clip_beta = float(beta_clip_beta)
|
| 161 |
-
self.beta_clip_variant = beta_clip_variant
|
| 162 |
-
self.beta_clip_similarity = beta_clip_similarity
|
| 163 |
-
self.beta_clip_drop_cls_token = bool(beta_clip_drop_cls_token)
|
| 164 |
-
self.tren_weight = float(tren_weight)
|
| 165 |
-
self.tren_visual_distill_weight = float(tren_visual_distill_weight)
|
| 166 |
-
self.tren_text_distill_weight = float(tren_text_distill_weight)
|
| 167 |
-
self.tren_region_text_weight = float(tren_region_text_weight)
|
| 168 |
-
self.fuse_whole_part_encoder_forwards = bool(fuse_whole_part_encoder_forwards)
|
| 169 |
-
self.fuse_beta_query_encoder_forwards = bool(fuse_beta_query_encoder_forwards)
|
| 170 |
-
self.group_beta_query_pooling = bool(group_beta_query_pooling)
|
| 171 |
-
self.objective_autocast_dtype = objective_autocast_dtype
|
| 172 |
-
self.freeze_vision_encoder = bool(freeze_vision_encoder)
|
| 173 |
-
self.freeze_text_encoder = bool(freeze_text_encoder)
|
| 174 |
-
self.normalize_encoder_features = bool(normalize_encoder_features)
|
| 175 |
-
self.phyclip_subspace_dim = phyclip_subspace_dim
|
| 176 |
-
self.phyclip_product_metric = phyclip_product_metric
|
| 177 |
-
self.proclip_component_dim = proclip_component_dim
|
| 178 |
-
if projection_hidden_dim is not None and projection_hidden_dim <= 0:
|
| 179 |
-
raise ValueError("projection_hidden_dim must be positive when set")
|
| 180 |
-
if self.proclip_enabled and phyclip_subspace_dim is not None:
|
| 181 |
-
raise ValueError("ProCLIP mixed-curvature proxy cannot be combined with PHyCLIP Lorentz factors")
|
| 182 |
-
if phyclip_subspace_dim is not None:
|
| 183 |
-
if phyclip_subspace_dim <= 0:
|
| 184 |
-
raise ValueError("phyclip_subspace_dim must be positive when set")
|
| 185 |
-
if embed_dim % phyclip_subspace_dim != 0:
|
| 186 |
-
raise ValueError("embed_dim must be divisible by phyclip_subspace_dim")
|
| 187 |
-
self.phyclip_num_factors = embed_dim // phyclip_subspace_dim
|
| 188 |
-
else:
|
| 189 |
-
self.phyclip_num_factors = 0
|
| 190 |
-
self.vision_encoder = VisionEncoder(vision_backbone, pretrained=vision_pretrained)
|
| 191 |
-
self.text_encoder = TextEncoder(text_model_name, pretrained=text_pretrained, pooling=text_pooling)
|
| 192 |
-
self.tokenizer = self.text_encoder.tokenizer
|
| 193 |
-
self.embed_dim = embed_dim
|
| 194 |
-
if self.freeze_vision_encoder:
|
| 195 |
-
self.vision_encoder.requires_grad_(False)
|
| 196 |
-
self.vision_encoder.eval()
|
| 197 |
-
if self.freeze_text_encoder:
|
| 198 |
-
self.text_encoder.requires_grad_(False)
|
| 199 |
-
self.text_encoder.eval()
|
| 200 |
-
|
| 201 |
-
self.image_proj = _projection_head(self.vision_encoder.output_dim, embed_dim, projection_hidden_dim)
|
| 202 |
-
self.text_proj = _projection_head(self.text_encoder.output_dim, embed_dim, projection_hidden_dim)
|
| 203 |
-
self._init_experimental_modules(
|
| 204 |
-
beta_clip_num_heads=beta_clip_num_heads,
|
| 205 |
-
beta_clip_mlp_ratio=beta_clip_mlp_ratio,
|
| 206 |
-
tren_num_region_tokens=tren_num_region_tokens,
|
| 207 |
-
tren_num_decoder_layers=tren_num_decoder_layers,
|
| 208 |
-
tren_num_attention_heads=tren_num_attention_heads,
|
| 209 |
-
tren_prompt_grid_size=tren_prompt_grid_size,
|
| 210 |
-
tren_dropout=tren_dropout,
|
| 211 |
-
projection_hidden_dim=projection_hidden_dim,
|
| 212 |
-
proclip_projection_hidden_dim=proclip_projection_hidden_dim,
|
| 213 |
-
projection_head=_projection_head,
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
if objective == "hycoclip":
|
| 217 |
-
self.logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 218 |
-
elif objective == "uncha":
|
| 219 |
-
self.global_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 220 |
-
self.local_logit_scale = nn.Parameter(torch.tensor(1 / 0.05).log())
|
| 221 |
-
self.global_local_logit_scale = nn.Parameter(torch.tensor(1 / 0.06).log())
|
| 222 |
-
if uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
|
| 223 |
-
self.global_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
|
| 224 |
-
self.local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
|
| 225 |
-
self.global_local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
|
| 226 |
-
alpha_dim = phyclip_subspace_dim or embed_dim
|
| 227 |
-
alpha_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
|
| 228 |
-
self.visual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
|
| 229 |
-
self.textual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
|
| 230 |
-
|
| 231 |
-
curv_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
|
| 232 |
-
log_curv = torch.full(curv_shape, curv_init).log()
|
| 233 |
-
self.log_curv = nn.Parameter(log_curv, requires_grad=learn_curv)
|
| 234 |
-
self.curv_min = curv_init / 10.0
|
| 235 |
-
self.curv_max = curv_init * 10.0
|
| 236 |
-
self.objective = None
|
| 237 |
-
if objective != "proclip":
|
| 238 |
-
self.objective = build_objective(
|
| 239 |
-
objective=objective,
|
| 240 |
-
entail_weight=entail_weight,
|
| 241 |
-
inter_aperture_scale=inter_aperture_scale,
|
| 242 |
-
intra_aperture_scale=intra_aperture_scale,
|
| 243 |
-
uncha_piecewise_factor=uncha_piecewise_factor,
|
| 244 |
-
uncha_calibration_alpha=uncha_calibration_alpha,
|
| 245 |
-
uncha_stop_grad_calibration=uncha_stop_grad_calibration,
|
| 246 |
-
uncha_entailment_geometry=uncha_entailment_geometry,
|
| 247 |
-
uncha_aggregate_weight=uncha_aggregate_weight,
|
| 248 |
-
uncha_entailment_loss=uncha_entailment_loss,
|
| 249 |
-
uncha_argent_beta=uncha_argent_beta,
|
| 250 |
-
uncha_argent_norm_weight=uncha_argent_norm_weight,
|
| 251 |
-
uncha_argent_aux_weight=uncha_argent_aux_weight,
|
| 252 |
-
uncha_argent_aggregation=uncha_argent_aggregation,
|
| 253 |
-
uncha_part_weight_power=uncha_part_weight_power,
|
| 254 |
-
uncha_contrastive_loss=uncha_contrastive_loss,
|
| 255 |
-
uncha_sigmoid_negative_weight=uncha_sigmoid_negative_weight,
|
| 256 |
-
uncha_part_quality_mode=uncha_part_quality_mode,
|
| 257 |
-
uncha_part_quality_topk=uncha_part_quality_topk,
|
| 258 |
-
uncha_part_quality_temperature=uncha_part_quality_temperature,
|
| 259 |
-
uncha_contrastive_global_weight=uncha_contrastive_global_weight,
|
| 260 |
-
uncha_contrastive_local_weight=uncha_contrastive_local_weight,
|
| 261 |
-
uncha_contrastive_global_local_weight=uncha_contrastive_global_local_weight,
|
| 262 |
-
uncha_global_local_mode=uncha_global_local_mode,
|
| 263 |
-
uncha_global_local_metric=uncha_global_local_metric,
|
| 264 |
-
uncha_global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
|
| 265 |
-
uncha_global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
|
| 266 |
-
uncha_global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
|
| 267 |
-
uncha_global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
|
| 268 |
-
uncha_beta_cal_beta=uncha_beta_cal_beta,
|
| 269 |
-
uncha_beta_cal_variant=uncha_beta_cal_variant,
|
| 270 |
-
uncha_beta_cal_weight=uncha_beta_cal_weight,
|
| 271 |
-
uncha_himo_component_weight=uncha_himo_component_weight,
|
| 272 |
-
uncha_radius_order_weight=uncha_radius_order_weight,
|
| 273 |
-
uncha_radius_order_margin=uncha_radius_order_margin,
|
| 274 |
-
uncha_gramian_align_weight=uncha_gramian_align_weight,
|
| 275 |
-
product_metric=phyclip_product_metric,
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
def train(self, mode: bool = True) -> Hyper3CLIP:
|
| 279 |
-
super().train(mode)
|
| 280 |
-
if self.freeze_vision_encoder:
|
| 281 |
-
self.vision_encoder.eval()
|
| 282 |
-
if self.freeze_text_encoder:
|
| 283 |
-
self.text_encoder.eval()
|
| 284 |
-
return self
|
| 285 |
-
|
| 286 |
-
@property
|
| 287 |
-
def phyclip_enabled(self) -> bool:
|
| 288 |
-
return self.phyclip_subspace_dim is not None
|
| 289 |
-
|
| 290 |
-
def _kappa(self) -> torch.Tensor:
|
| 291 |
-
return self.log_curv.exp().clamp(min=self.curv_min, max=self.curv_max)
|
| 292 |
-
|
| 293 |
-
def encode_image(self, image: torch.Tensor, project: bool = True) -> torch.Tensor:
|
| 294 |
-
feats = self.image_proj(self.encode_image_base(image))
|
| 295 |
-
if not project:
|
| 296 |
-
return feats
|
| 297 |
-
return self.project_image_features(feats)
|
| 298 |
-
|
| 299 |
-
def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, project: bool = True) -> torch.Tensor:
|
| 300 |
-
feats = self.text_proj(self.encode_text_base(input_ids, attention_mask))
|
| 301 |
-
if not project:
|
| 302 |
-
return feats
|
| 303 |
-
return self.project_text_features(feats)
|
| 304 |
-
|
| 305 |
-
def encode_image_base(self, image: torch.Tensor) -> torch.Tensor:
|
| 306 |
-
with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
|
| 307 |
-
feats = self.vision_encoder(image)
|
| 308 |
-
feats = feats.detach() if self.freeze_vision_encoder else feats
|
| 309 |
-
return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
|
| 310 |
-
|
| 311 |
-
def encode_image_base_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 312 |
-
with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
|
| 313 |
-
feats, tokens = self.vision_encoder.forward_with_tokens(image)
|
| 314 |
-
if self.freeze_vision_encoder:
|
| 315 |
-
feats = feats.detach()
|
| 316 |
-
tokens = tokens.detach()
|
| 317 |
-
if self.normalize_encoder_features:
|
| 318 |
-
feats = F.normalize(feats.float(), dim=-1)
|
| 319 |
-
return feats, tokens
|
| 320 |
-
|
| 321 |
-
def encode_text_base(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 322 |
-
with torch.set_grad_enabled(self.training and not self.freeze_text_encoder):
|
| 323 |
-
feats = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 324 |
-
feats = feats.detach() if self.freeze_text_encoder else feats
|
| 325 |
-
return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
|
| 326 |
-
|
| 327 |
-
def project_image_features(self, feats: torch.Tensor) -> torch.Tensor:
|
| 328 |
-
if self.phyclip_enabled:
|
| 329 |
-
return self._project_product_features(feats, self.visual_alpha)
|
| 330 |
-
return exp_map0(feats.float() * self.visual_alpha.exp().float(), self._kappa().float())
|
| 331 |
-
|
| 332 |
-
def project_text_features(self, feats: torch.Tensor) -> torch.Tensor:
|
| 333 |
-
if self.phyclip_enabled:
|
| 334 |
-
return self._project_product_features(feats, self.textual_alpha)
|
| 335 |
-
return exp_map0(feats.float() * self.textual_alpha.exp().float(), self._kappa().float())
|
| 336 |
-
|
| 337 |
-
def similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
|
| 338 |
-
return metric_similarity(image_feats, text_feats, self._kappa(), product_metric=self.phyclip_product_metric)
|
| 339 |
-
|
| 340 |
-
def encode_retrieval_image(self, image: torch.Tensor) -> torch.Tensor:
|
| 341 |
-
base = self.encode_image_base(image)
|
| 342 |
-
tangent = self.image_proj(base)
|
| 343 |
-
if self.proclip_retrieval:
|
| 344 |
-
return self._project_proclip_image_base(base, self.project_image_features(tangent))
|
| 345 |
-
return self.project_image_features(tangent)
|
| 346 |
-
|
| 347 |
-
def encode_retrieval_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 348 |
-
base = self.encode_text_base(input_ids, attention_mask)
|
| 349 |
-
tangent = self.text_proj(base)
|
| 350 |
-
if self.proclip_retrieval:
|
| 351 |
-
return self._project_proclip_text_base(base, self.project_text_features(tangent))
|
| 352 |
-
return self.project_text_features(tangent)
|
| 353 |
-
|
| 354 |
-
def retrieval_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
|
| 355 |
-
if self.proclip_retrieval:
|
| 356 |
-
return self._proclip_similarity_scores(image_feats, text_feats)
|
| 357 |
-
return self.similarity_scores(image_feats, text_feats)
|
| 358 |
-
|
| 359 |
-
@property
|
| 360 |
-
def retrieval_requires_chunking(self) -> bool:
|
| 361 |
-
return self.phyclip_enabled or self.proclip_retrieval
|
| 362 |
-
|
| 363 |
-
def _objective_autocast(self, device_type: str):
|
| 364 |
-
dtype = {
|
| 365 |
-
"float32": torch.float32,
|
| 366 |
-
"fp32": torch.float32,
|
| 367 |
-
"float16": torch.float16,
|
| 368 |
-
"fp16": torch.float16,
|
| 369 |
-
"bfloat16": torch.bfloat16,
|
| 370 |
-
"bf16": torch.bfloat16,
|
| 371 |
-
}[self.objective_autocast_dtype]
|
| 372 |
-
enabled = device_type != "cpu" and dtype is not torch.float32
|
| 373 |
-
return torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled)
|
| 374 |
-
|
| 375 |
-
def forward(
|
| 376 |
-
self,
|
| 377 |
-
image: torch.Tensor,
|
| 378 |
-
part_images: torch.Tensor,
|
| 379 |
-
text_input_ids: torch.Tensor,
|
| 380 |
-
text_attention_mask: torch.Tensor,
|
| 381 |
-
part_text_input_ids: torch.Tensor,
|
| 382 |
-
part_text_attention_mask: torch.Tensor,
|
| 383 |
-
part_owner: torch.Tensor,
|
| 384 |
-
step: int | None = None,
|
| 385 |
-
beta_query_input_ids: torch.Tensor | None = None,
|
| 386 |
-
beta_query_attention_mask: torch.Tensor | None = None,
|
| 387 |
-
beta_query_owner: torch.Tensor | None = None,
|
| 388 |
-
beta_query_type: torch.Tensor | None = None,
|
| 389 |
-
beta_query_parent: torch.Tensor | None = None,
|
| 390 |
-
beta_query_weight: torch.Tensor | None = None,
|
| 391 |
-
beta_query_source_part: torch.Tensor | None = None,
|
| 392 |
-
) -> dict[str, torch.Tensor]:
|
| 393 |
-
with torch.no_grad():
|
| 394 |
-
self._clamp_logit_scales()
|
| 395 |
-
self.visual_alpha.clamp_(max=0.0)
|
| 396 |
-
self.textual_alpha.clamp_(max=0.0)
|
| 397 |
-
kappa = self._kappa()
|
| 398 |
-
|
| 399 |
-
feature_dim = self.embed_dim
|
| 400 |
-
beta_image_tokens = None
|
| 401 |
-
beta_query_base = None
|
| 402 |
-
part_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
|
| 403 |
-
part_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
|
| 404 |
-
hier_beta_enabled = self.objective_name == "uncha" and self.uncha_entailment_loss in {
|
| 405 |
-
"hier_beta_argent",
|
| 406 |
-
"hier_beta_sourcepart_argent",
|
| 407 |
-
}
|
| 408 |
-
if (
|
| 409 |
-
hier_beta_enabled
|
| 410 |
-
and self.fuse_beta_query_encoder_forwards
|
| 411 |
-
and not self.tren_enabled
|
| 412 |
-
and beta_query_input_ids is not None
|
| 413 |
-
and beta_query_attention_mask is not None
|
| 414 |
-
and part_images.shape[0] > 0
|
| 415 |
-
):
|
| 416 |
-
(
|
| 417 |
-
image_base,
|
| 418 |
-
text_base,
|
| 419 |
-
image_euc,
|
| 420 |
-
text_euc,
|
| 421 |
-
image_feats,
|
| 422 |
-
text_feats,
|
| 423 |
-
part_image_feats,
|
| 424 |
-
part_text_feats,
|
| 425 |
-
part_image_euc,
|
| 426 |
-
part_text_euc,
|
| 427 |
-
part_image_base,
|
| 428 |
-
part_text_base,
|
| 429 |
-
beta_image_tokens,
|
| 430 |
-
beta_query_base,
|
| 431 |
-
) = self._encode_hier_beta_whole_parts_and_queries(
|
| 432 |
-
image=image,
|
| 433 |
-
part_images=part_images,
|
| 434 |
-
text_input_ids=text_input_ids,
|
| 435 |
-
text_attention_mask=text_attention_mask,
|
| 436 |
-
part_text_input_ids=part_text_input_ids,
|
| 437 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 438 |
-
beta_query_input_ids=beta_query_input_ids,
|
| 439 |
-
beta_query_attention_mask=beta_query_attention_mask,
|
| 440 |
-
)
|
| 441 |
-
elif self.beta_query_pooling_enabled or self.tren_enabled:
|
| 442 |
-
image_base, beta_image_tokens = self.encode_image_base_with_tokens(image)
|
| 443 |
-
text_base = self.encode_text_base(text_input_ids, text_attention_mask)
|
| 444 |
-
image_euc = self.image_proj(image_base)
|
| 445 |
-
text_euc = self.text_proj(text_base)
|
| 446 |
-
image_feats = self.project_image_features(image_euc)
|
| 447 |
-
text_feats = self.project_text_features(text_euc)
|
| 448 |
-
(
|
| 449 |
-
part_image_feats,
|
| 450 |
-
part_text_feats,
|
| 451 |
-
part_image_euc,
|
| 452 |
-
part_text_euc,
|
| 453 |
-
part_image_base,
|
| 454 |
-
part_text_base,
|
| 455 |
-
) = self._encode_parts_with_base(
|
| 456 |
-
part_images=part_images,
|
| 457 |
-
part_text_input_ids=part_text_input_ids,
|
| 458 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 459 |
-
feature_dim=feature_dim,
|
| 460 |
-
)
|
| 461 |
-
elif self.fuse_whole_part_encoder_forwards and self.objective_name != "proclip" and part_images.shape[0] > 0:
|
| 462 |
-
(
|
| 463 |
-
image_base,
|
| 464 |
-
text_base,
|
| 465 |
-
image_euc,
|
| 466 |
-
text_euc,
|
| 467 |
-
image_feats,
|
| 468 |
-
text_feats,
|
| 469 |
-
part_image_feats,
|
| 470 |
-
part_text_feats,
|
| 471 |
-
part_image_euc,
|
| 472 |
-
part_text_euc,
|
| 473 |
-
part_image_base,
|
| 474 |
-
part_text_base,
|
| 475 |
-
) = self._encode_whole_and_parts(
|
| 476 |
-
image=image,
|
| 477 |
-
part_images=part_images,
|
| 478 |
-
text_input_ids=text_input_ids,
|
| 479 |
-
text_attention_mask=text_attention_mask,
|
| 480 |
-
part_text_input_ids=part_text_input_ids,
|
| 481 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 482 |
-
)
|
| 483 |
-
else:
|
| 484 |
-
image_base = self.encode_image_base(image)
|
| 485 |
-
text_base = self.encode_text_base(text_input_ids, text_attention_mask)
|
| 486 |
-
image_euc = self.image_proj(image_base)
|
| 487 |
-
text_euc = self.text_proj(text_base)
|
| 488 |
-
image_feats = self.project_image_features(image_euc)
|
| 489 |
-
text_feats = self.project_text_features(text_euc)
|
| 490 |
-
(
|
| 491 |
-
part_image_feats,
|
| 492 |
-
part_text_feats,
|
| 493 |
-
part_image_euc,
|
| 494 |
-
part_text_euc,
|
| 495 |
-
part_image_base,
|
| 496 |
-
part_text_base,
|
| 497 |
-
) = self._encode_parts_with_base(
|
| 498 |
-
part_images=part_images,
|
| 499 |
-
part_text_input_ids=part_text_input_ids,
|
| 500 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 501 |
-
feature_dim=feature_dim,
|
| 502 |
-
)
|
| 503 |
-
targets = local_target_indices(image_feats.size(0), image_feats.device)
|
| 504 |
-
|
| 505 |
-
if self.objective_name == "proclip":
|
| 506 |
-
proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
|
| 507 |
-
proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
|
| 508 |
-
proclip_loss = self._proclip_contrastive_loss(
|
| 509 |
-
image_feats=proclip_image_feats,
|
| 510 |
-
text_feats=proclip_text_feats,
|
| 511 |
-
all_image_feats=gather_with_grad(proclip_image_feats),
|
| 512 |
-
all_text_feats=gather_with_grad(proclip_text_feats),
|
| 513 |
-
targets=targets,
|
| 514 |
-
)
|
| 515 |
-
zero = proclip_loss.new_zeros(())
|
| 516 |
-
return {
|
| 517 |
-
"loss": proclip_loss,
|
| 518 |
-
"contrastive_loss": proclip_loss,
|
| 519 |
-
"entailment_loss": zero,
|
| 520 |
-
"part_count": part_owner.new_tensor(0),
|
| 521 |
-
"proclip_contrastive_loss": proclip_loss,
|
| 522 |
-
**self._detached_kappa_logs(kappa),
|
| 523 |
-
**self._detached_logit_scales(),
|
| 524 |
-
}
|
| 525 |
-
|
| 526 |
-
himo_text_feats = None
|
| 527 |
-
all_himo_text_feats = None
|
| 528 |
-
if self.objective_name == "uncha" and self.uncha_himo_component_weight > 0.0:
|
| 529 |
-
all_text_euc = gather_with_grad(text_euc)
|
| 530 |
-
all_component_euc = hide_reconstruct_embeddings(
|
| 531 |
-
all_text_euc,
|
| 532 |
-
variance_threshold=self.uncha_himo_variance_threshold,
|
| 533 |
-
detach_pca=self.uncha_himo_detach_pca,
|
| 534 |
-
)
|
| 535 |
-
if get_world_size() > 1:
|
| 536 |
-
start = text_euc.size(0) * get_rank()
|
| 537 |
-
end = start + text_euc.size(0)
|
| 538 |
-
component_euc = all_component_euc[start:end]
|
| 539 |
-
else:
|
| 540 |
-
component_euc = all_component_euc
|
| 541 |
-
himo_text_feats = self.project_text_features(component_euc)
|
| 542 |
-
all_himo_text_feats = gather_with_grad(himo_text_feats)
|
| 543 |
-
all_image_feats = gather_with_grad(image_feats)
|
| 544 |
-
all_text_feats = gather_with_grad(text_feats)
|
| 545 |
-
all_image_euc = None
|
| 546 |
-
all_text_euc = None
|
| 547 |
-
if self.objective_name == "uncha" and self.uncha_contrastive_loss == "siglip":
|
| 548 |
-
all_image_euc = gather_with_grad(image_euc)
|
| 549 |
-
all_text_euc = gather_with_grad(text_euc)
|
| 550 |
-
part_owner = part_owner.to(device=image_feats.device, dtype=torch.long)
|
| 551 |
-
beta_query_embeddings = {}
|
| 552 |
-
if self.objective_name == "uncha" and self.uncha_entailment_loss in {
|
| 553 |
-
"hier_beta_argent",
|
| 554 |
-
"hier_beta_sourcepart_argent",
|
| 555 |
-
}:
|
| 556 |
-
if beta_image_tokens is None:
|
| 557 |
-
raise RuntimeError(f"{self.uncha_entailment_loss} requires image patch tokens")
|
| 558 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 559 |
-
beta_query_embeddings = self._beta_query_entailment_embeddings(
|
| 560 |
-
image_tokens=beta_image_tokens.float(),
|
| 561 |
-
beta_query_input_ids=beta_query_input_ids,
|
| 562 |
-
beta_query_attention_mask=beta_query_attention_mask,
|
| 563 |
-
beta_query_owner=beta_query_owner,
|
| 564 |
-
beta_query_parent=beta_query_parent,
|
| 565 |
-
beta_query_weight=beta_query_weight,
|
| 566 |
-
beta_query_source_part=beta_query_source_part,
|
| 567 |
-
kappa=kappa.float(),
|
| 568 |
-
query_base=beta_query_base,
|
| 569 |
-
)
|
| 570 |
-
|
| 571 |
-
with self._objective_autocast(image.device.type):
|
| 572 |
-
if self.objective is None:
|
| 573 |
-
raise RuntimeError("Non-ProCLIP forward requires an objective module")
|
| 574 |
-
losses = self.objective(
|
| 575 |
-
{
|
| 576 |
-
"image_feats": image_feats,
|
| 577 |
-
"text_feats": text_feats,
|
| 578 |
-
"part_image_feats": part_image_feats,
|
| 579 |
-
"part_text_feats": part_text_feats,
|
| 580 |
-
"part_owner": part_owner,
|
| 581 |
-
"all_image_feats": all_image_feats,
|
| 582 |
-
"all_text_feats": all_text_feats,
|
| 583 |
-
**(
|
| 584 |
-
{
|
| 585 |
-
"image_euc_feats": image_euc,
|
| 586 |
-
"text_euc_feats": text_euc,
|
| 587 |
-
"part_image_euc_feats": part_image_euc,
|
| 588 |
-
"part_text_euc_feats": part_text_euc,
|
| 589 |
-
"all_image_euc_feats": all_image_euc,
|
| 590 |
-
"all_text_euc_feats": all_text_euc,
|
| 591 |
-
}
|
| 592 |
-
if all_image_euc is not None and all_text_euc is not None
|
| 593 |
-
else {}
|
| 594 |
-
),
|
| 595 |
-
"targets": targets,
|
| 596 |
-
"kappa": kappa,
|
| 597 |
-
"entail_weight_scale": self._entail_weight_scale(step, image_feats.device),
|
| 598 |
-
**beta_query_embeddings,
|
| 599 |
-
**(
|
| 600 |
-
{
|
| 601 |
-
"himo_text_feats": himo_text_feats,
|
| 602 |
-
"all_himo_text_feats": all_himo_text_feats,
|
| 603 |
-
}
|
| 604 |
-
if himo_text_feats is not None
|
| 605 |
-
else {}
|
| 606 |
-
),
|
| 607 |
-
},
|
| 608 |
-
self._objective_logit_scales(),
|
| 609 |
-
)
|
| 610 |
-
|
| 611 |
-
if self.beta_clip_global_weight > 0.0:
|
| 612 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 613 |
-
beta_clip_global_loss = self._beta_clip_global_contrastive_loss(
|
| 614 |
-
image_euc=image_euc,
|
| 615 |
-
text_euc=text_euc,
|
| 616 |
-
targets=targets,
|
| 617 |
-
)
|
| 618 |
-
losses = {
|
| 619 |
-
**losses,
|
| 620 |
-
"loss": losses["loss"] + self.beta_clip_global_weight * beta_clip_global_loss,
|
| 621 |
-
"beta_clip_global_loss": beta_clip_global_loss,
|
| 622 |
-
}
|
| 623 |
-
|
| 624 |
-
if self.beta_clip_enabled:
|
| 625 |
-
if beta_image_tokens is None:
|
| 626 |
-
raise RuntimeError("beta-CLIP auxiliary requires image patch tokens")
|
| 627 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 628 |
-
beta_clip_loss = self._beta_clip_auxiliary_loss(
|
| 629 |
-
image_tokens=beta_image_tokens.float(),
|
| 630 |
-
beta_query_input_ids=beta_query_input_ids,
|
| 631 |
-
beta_query_attention_mask=beta_query_attention_mask,
|
| 632 |
-
beta_query_owner=beta_query_owner,
|
| 633 |
-
global_targets=targets,
|
| 634 |
-
kappa=kappa.float(),
|
| 635 |
-
)
|
| 636 |
-
losses = {
|
| 637 |
-
**losses,
|
| 638 |
-
"loss": losses["loss"] + self.beta_clip_weight * beta_clip_loss,
|
| 639 |
-
"beta_clip_loss": beta_clip_loss,
|
| 640 |
-
}
|
| 641 |
-
|
| 642 |
-
if self.tren_enabled:
|
| 643 |
-
if beta_image_tokens is None:
|
| 644 |
-
raise RuntimeError("T-REN auxiliary requires image patch tokens")
|
| 645 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 646 |
-
tren_losses = self._tren_auxiliary_losses(
|
| 647 |
-
image_tokens=beta_image_tokens.float(),
|
| 648 |
-
part_owner=part_owner,
|
| 649 |
-
part_image_base=part_image_base.float(),
|
| 650 |
-
part_text_base=part_text_base.float(),
|
| 651 |
-
)
|
| 652 |
-
losses = {
|
| 653 |
-
**losses,
|
| 654 |
-
"loss": losses["loss"] + self.tren_weight * tren_losses["tren_loss"],
|
| 655 |
-
**tren_losses,
|
| 656 |
-
}
|
| 657 |
-
|
| 658 |
-
if self.proclip_enabled and self.proclip_weight > 0.0:
|
| 659 |
-
proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
|
| 660 |
-
proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
|
| 661 |
-
proclip_loss = self._proclip_contrastive_loss(
|
| 662 |
-
image_feats=proclip_image_feats,
|
| 663 |
-
text_feats=proclip_text_feats,
|
| 664 |
-
all_image_feats=gather_with_grad(proclip_image_feats),
|
| 665 |
-
all_text_feats=gather_with_grad(proclip_text_feats),
|
| 666 |
-
targets=targets,
|
| 667 |
-
)
|
| 668 |
-
losses = {
|
| 669 |
-
**losses,
|
| 670 |
-
"loss": losses["loss"] + self.proclip_weight * proclip_loss,
|
| 671 |
-
"proclip_contrastive_loss": proclip_loss,
|
| 672 |
-
}
|
| 673 |
-
|
| 674 |
-
return {**losses, **self._detached_kappa_logs(kappa), **self._detached_logit_scales()}
|
| 675 |
-
|
| 676 |
-
def _encode_parts(
|
| 677 |
-
self,
|
| 678 |
-
part_images: torch.Tensor,
|
| 679 |
-
part_text_input_ids: torch.Tensor,
|
| 680 |
-
part_text_attention_mask: torch.Tensor,
|
| 681 |
-
feature_dim: int,
|
| 682 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 683 |
-
if part_images.shape[0] == 0:
|
| 684 |
-
empty = part_images.new_zeros((0, feature_dim))
|
| 685 |
-
return empty, empty, empty, empty
|
| 686 |
-
|
| 687 |
-
part_image_euc = self.image_proj(self.encode_image_base(part_images))
|
| 688 |
-
part_text_euc = self.text_proj(self.encode_text_base(part_text_input_ids, part_text_attention_mask))
|
| 689 |
-
part_image_feats = self.project_image_features(part_image_euc)
|
| 690 |
-
part_text_feats = self.project_text_features(part_text_euc)
|
| 691 |
-
return part_image_feats, part_text_feats, part_image_euc, part_text_euc
|
| 692 |
-
|
| 693 |
-
def _encode_parts_with_base(
|
| 694 |
-
self,
|
| 695 |
-
part_images: torch.Tensor,
|
| 696 |
-
part_text_input_ids: torch.Tensor,
|
| 697 |
-
part_text_attention_mask: torch.Tensor,
|
| 698 |
-
feature_dim: int,
|
| 699 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 700 |
-
if part_images.shape[0] == 0:
|
| 701 |
-
empty = part_images.new_zeros((0, feature_dim))
|
| 702 |
-
empty_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
|
| 703 |
-
empty_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
|
| 704 |
-
return empty, empty, empty, empty, empty_image_base, empty_text_base
|
| 705 |
-
|
| 706 |
-
part_image_base = self.encode_image_base(part_images)
|
| 707 |
-
part_text_base = self.encode_text_base(part_text_input_ids, part_text_attention_mask)
|
| 708 |
-
part_image_euc = self.image_proj(part_image_base)
|
| 709 |
-
part_text_euc = self.text_proj(part_text_base)
|
| 710 |
-
part_image_feats = self.project_image_features(part_image_euc)
|
| 711 |
-
part_text_feats = self.project_text_features(part_text_euc)
|
| 712 |
-
return part_image_feats, part_text_feats, part_image_euc, part_text_euc, part_image_base, part_text_base
|
| 713 |
-
|
| 714 |
-
def _encode_whole_and_parts(
|
| 715 |
-
self,
|
| 716 |
-
image: torch.Tensor,
|
| 717 |
-
part_images: torch.Tensor,
|
| 718 |
-
text_input_ids: torch.Tensor,
|
| 719 |
-
text_attention_mask: torch.Tensor,
|
| 720 |
-
part_text_input_ids: torch.Tensor,
|
| 721 |
-
part_text_attention_mask: torch.Tensor,
|
| 722 |
-
) -> tuple[
|
| 723 |
-
torch.Tensor,
|
| 724 |
-
torch.Tensor,
|
| 725 |
-
torch.Tensor,
|
| 726 |
-
torch.Tensor,
|
| 727 |
-
torch.Tensor,
|
| 728 |
-
torch.Tensor,
|
| 729 |
-
torch.Tensor,
|
| 730 |
-
torch.Tensor,
|
| 731 |
-
torch.Tensor,
|
| 732 |
-
torch.Tensor,
|
| 733 |
-
torch.Tensor,
|
| 734 |
-
torch.Tensor,
|
| 735 |
-
]:
|
| 736 |
-
batch_size = image.shape[0]
|
| 737 |
-
part_count = part_images.shape[0]
|
| 738 |
-
image_base_all = self.encode_image_base(torch.cat([image, part_images], dim=0))
|
| 739 |
-
image_euc_all = self.image_proj(image_base_all)
|
| 740 |
-
image_feats_all = self.project_image_features(image_euc_all)
|
| 741 |
-
|
| 742 |
-
text_ids, text_mask = self._concat_text_batches(
|
| 743 |
-
text_input_ids,
|
| 744 |
-
text_attention_mask,
|
| 745 |
-
part_text_input_ids,
|
| 746 |
-
part_text_attention_mask,
|
| 747 |
-
)
|
| 748 |
-
text_base_all = self.encode_text_base(text_ids, text_mask)
|
| 749 |
-
text_euc_all = self.text_proj(text_base_all)
|
| 750 |
-
text_feats_all = self.project_text_features(text_euc_all)
|
| 751 |
-
|
| 752 |
-
image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
|
| 753 |
-
text_base, part_text_base = text_base_all.split([batch_size, part_count], dim=0)
|
| 754 |
-
image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
|
| 755 |
-
text_euc, part_text_euc = text_euc_all.split([batch_size, part_count], dim=0)
|
| 756 |
-
image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
|
| 757 |
-
text_feats, part_text_feats = text_feats_all.split([batch_size, part_count], dim=0)
|
| 758 |
-
return (
|
| 759 |
-
image_base,
|
| 760 |
-
text_base,
|
| 761 |
-
image_euc,
|
| 762 |
-
text_euc,
|
| 763 |
-
image_feats,
|
| 764 |
-
text_feats,
|
| 765 |
-
part_image_feats,
|
| 766 |
-
part_text_feats,
|
| 767 |
-
part_image_euc,
|
| 768 |
-
part_text_euc,
|
| 769 |
-
part_image_base,
|
| 770 |
-
part_text_base,
|
| 771 |
-
)
|
| 772 |
-
|
| 773 |
-
def _encode_hier_beta_whole_parts_and_queries(
|
| 774 |
-
self,
|
| 775 |
-
image: torch.Tensor,
|
| 776 |
-
part_images: torch.Tensor,
|
| 777 |
-
text_input_ids: torch.Tensor,
|
| 778 |
-
text_attention_mask: torch.Tensor,
|
| 779 |
-
part_text_input_ids: torch.Tensor,
|
| 780 |
-
part_text_attention_mask: torch.Tensor,
|
| 781 |
-
beta_query_input_ids: torch.Tensor,
|
| 782 |
-
beta_query_attention_mask: torch.Tensor,
|
| 783 |
-
) -> tuple[
|
| 784 |
-
torch.Tensor,
|
| 785 |
-
torch.Tensor,
|
| 786 |
-
torch.Tensor,
|
| 787 |
-
torch.Tensor,
|
| 788 |
-
torch.Tensor,
|
| 789 |
-
torch.Tensor,
|
| 790 |
-
torch.Tensor,
|
| 791 |
-
torch.Tensor,
|
| 792 |
-
torch.Tensor,
|
| 793 |
-
torch.Tensor,
|
| 794 |
-
torch.Tensor,
|
| 795 |
-
torch.Tensor,
|
| 796 |
-
torch.Tensor,
|
| 797 |
-
torch.Tensor,
|
| 798 |
-
]:
|
| 799 |
-
batch_size = image.shape[0]
|
| 800 |
-
part_count = part_images.shape[0]
|
| 801 |
-
query_count = beta_query_input_ids.shape[0]
|
| 802 |
-
|
| 803 |
-
image_base_all, image_tokens_all = self.encode_image_base_with_tokens(torch.cat([image, part_images], dim=0))
|
| 804 |
-
image_euc_all = self.image_proj(image_base_all)
|
| 805 |
-
image_feats_all = self.project_image_features(image_euc_all)
|
| 806 |
-
image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
|
| 807 |
-
image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
|
| 808 |
-
image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
|
| 809 |
-
beta_image_tokens = image_tokens_all[:batch_size]
|
| 810 |
-
|
| 811 |
-
text_ids, text_mask = self._concat_text_batch_list(
|
| 812 |
-
(text_input_ids, text_attention_mask),
|
| 813 |
-
(part_text_input_ids, part_text_attention_mask),
|
| 814 |
-
(beta_query_input_ids, beta_query_attention_mask),
|
| 815 |
-
)
|
| 816 |
-
text_base_all = self.encode_text_base(text_ids, text_mask)
|
| 817 |
-
text_euc_all = self.text_proj(text_base_all)
|
| 818 |
-
text_feats_all = self.project_text_features(text_euc_all)
|
| 819 |
-
text_base, part_text_base, beta_query_base = text_base_all.split([batch_size, part_count, query_count], dim=0)
|
| 820 |
-
text_euc, part_text_euc, _ = text_euc_all.split([batch_size, part_count, query_count], dim=0)
|
| 821 |
-
text_feats, part_text_feats, _ = text_feats_all.split([batch_size, part_count, query_count], dim=0)
|
| 822 |
-
|
| 823 |
-
return (
|
| 824 |
-
image_base,
|
| 825 |
-
text_base,
|
| 826 |
-
image_euc,
|
| 827 |
-
text_euc,
|
| 828 |
-
image_feats,
|
| 829 |
-
text_feats,
|
| 830 |
-
part_image_feats,
|
| 831 |
-
part_text_feats,
|
| 832 |
-
part_image_euc,
|
| 833 |
-
part_text_euc,
|
| 834 |
-
part_image_base,
|
| 835 |
-
part_text_base,
|
| 836 |
-
beta_image_tokens,
|
| 837 |
-
beta_query_base,
|
| 838 |
-
)
|
| 839 |
-
|
| 840 |
-
def _concat_text_batches(
|
| 841 |
-
self,
|
| 842 |
-
text_input_ids: torch.Tensor,
|
| 843 |
-
text_attention_mask: torch.Tensor,
|
| 844 |
-
part_text_input_ids: torch.Tensor,
|
| 845 |
-
part_text_attention_mask: torch.Tensor,
|
| 846 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 847 |
-
return self._concat_text_batch_list(
|
| 848 |
-
(text_input_ids, text_attention_mask),
|
| 849 |
-
(part_text_input_ids, part_text_attention_mask),
|
| 850 |
-
)
|
| 851 |
-
|
| 852 |
-
def _concat_text_batch_list(
|
| 853 |
-
self,
|
| 854 |
-
*batches: tuple[torch.Tensor, torch.Tensor],
|
| 855 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 856 |
-
target_length = max(input_ids.shape[1] for input_ids, _ in batches)
|
| 857 |
-
pad_token_id = self.text_encoder.tokenizer.pad_token_id
|
| 858 |
-
if pad_token_id is None:
|
| 859 |
-
pad_token_id = 0
|
| 860 |
-
return (
|
| 861 |
-
torch.cat([_pad_sequence_dim(input_ids, target_length, pad_token_id) for input_ids, _ in batches], dim=0),
|
| 862 |
-
torch.cat([_pad_sequence_dim(attention_mask, target_length, 0) for _, attention_mask in batches], dim=0),
|
| 863 |
-
)
|
| 864 |
-
|
| 865 |
-
def _clamp_logit_scales(self) -> None:
|
| 866 |
-
if self.objective_name == "proclip":
|
| 867 |
-
self.proclip_logit_scale.clamp_(max=4.6052)
|
| 868 |
-
self._clamp_experimental_logit_scales()
|
| 869 |
-
return
|
| 870 |
-
if self.objective_name == "hycoclip":
|
| 871 |
-
self.logit_scale.clamp_(max=4.6052)
|
| 872 |
-
self._clamp_experimental_logit_scales()
|
| 873 |
-
return
|
| 874 |
-
self.global_logit_scale.clamp_(max=4.6052)
|
| 875 |
-
self.local_logit_scale.clamp_(max=4.6052)
|
| 876 |
-
self.global_local_logit_scale.clamp_(max=4.6052)
|
| 877 |
-
self._clamp_experimental_logit_scales()
|
| 878 |
-
|
| 879 |
-
def _objective_logit_scales(self) -> torch.Tensor | dict[str, torch.Tensor]:
|
| 880 |
-
if self.objective_name == "hycoclip":
|
| 881 |
-
return self.logit_scale
|
| 882 |
-
if self.objective_name == "proclip":
|
| 883 |
-
return self.proclip_logit_scale
|
| 884 |
-
return {
|
| 885 |
-
"global": self.global_logit_scale,
|
| 886 |
-
"local": self.local_logit_scale,
|
| 887 |
-
"global_local": self.global_local_logit_scale,
|
| 888 |
-
**(
|
| 889 |
-
{
|
| 890 |
-
"global_bias": self.global_logit_bias,
|
| 891 |
-
"local_bias": self.local_logit_bias,
|
| 892 |
-
"global_local_bias": self.global_local_logit_bias,
|
| 893 |
-
}
|
| 894 |
-
if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}
|
| 895 |
-
else {}
|
| 896 |
-
),
|
| 897 |
-
}
|
| 898 |
-
|
| 899 |
-
def _detached_logit_scales(self) -> dict[str, torch.Tensor]:
|
| 900 |
-
if self.objective_name == "proclip":
|
| 901 |
-
return self._detached_experimental_logit_scales()
|
| 902 |
-
if self.objective_name == "hycoclip":
|
| 903 |
-
logs = {"logit_scale": self.logit_scale.exp().detach()}
|
| 904 |
-
logs.update(self._detached_experimental_logit_scales())
|
| 905 |
-
return logs
|
| 906 |
-
logs = {
|
| 907 |
-
"global_logit_scale": self.global_logit_scale.exp().detach(),
|
| 908 |
-
"local_logit_scale": self.local_logit_scale.exp().detach(),
|
| 909 |
-
"global_local_logit_scale": self.global_local_logit_scale.exp().detach(),
|
| 910 |
-
}
|
| 911 |
-
if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
|
| 912 |
-
logs.update(
|
| 913 |
-
{
|
| 914 |
-
"global_logit_bias": self.global_logit_bias.detach(),
|
| 915 |
-
"local_logit_bias": self.local_logit_bias.detach(),
|
| 916 |
-
"global_local_logit_bias": self.global_local_logit_bias.detach(),
|
| 917 |
-
}
|
| 918 |
-
)
|
| 919 |
-
logs.update(self._detached_experimental_logit_scales())
|
| 920 |
-
return logs
|
| 921 |
-
|
| 922 |
-
def _project_product_features(self, feats: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
|
| 923 |
-
product_feats = feats.float().reshape(feats.size(0), self.phyclip_num_factors, self.phyclip_subspace_dim)
|
| 924 |
-
product_feats = product_feats * alpha.exp().float().view(1, -1, 1)
|
| 925 |
-
return exp_map0(product_feats, self._kappa().float().view(1, -1, 1))
|
| 926 |
-
|
| 927 |
-
def _detached_kappa_logs(self, kappa: torch.Tensor) -> dict[str, torch.Tensor]:
|
| 928 |
-
detached = kappa.detach()
|
| 929 |
-
if detached.numel() == 1:
|
| 930 |
-
return {"kappa": detached.reshape(())}
|
| 931 |
-
return {
|
| 932 |
-
"kappa": detached.mean(),
|
| 933 |
-
"kappa_min": detached.min(),
|
| 934 |
-
"kappa_max": detached.max(),
|
| 935 |
-
}
|
| 936 |
-
|
| 937 |
-
def _entail_weight_scale(self, step: int | None, device: torch.device) -> torch.Tensor:
|
| 938 |
-
if self.uncha_entailment_warmup_steps <= 0 or step is None:
|
| 939 |
-
return torch.ones((), device=device)
|
| 940 |
-
scale = min(1.0, float(step + 1) / float(self.uncha_entailment_warmup_steps))
|
| 941 |
-
return torch.tensor(scale, device=device)
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
def _projection_head(input_dim: int, output_dim: int, hidden_dim: int | None) -> nn.Module:
|
| 945 |
-
if hidden_dim is None:
|
| 946 |
-
return nn.Linear(input_dim, output_dim)
|
| 947 |
-
return nn.Sequential(
|
| 948 |
-
nn.Linear(input_dim, hidden_dim),
|
| 949 |
-
nn.ReLU(),
|
| 950 |
-
nn.Linear(hidden_dim, output_dim),
|
| 951 |
-
)
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
def _pad_sequence_dim(tensor: torch.Tensor, target_length: int, value: int) -> torch.Tensor:
|
| 955 |
-
pad = target_length - tensor.shape[1]
|
| 956 |
-
if pad <= 0:
|
| 957 |
-
return tensor
|
| 958 |
-
return F.pad(tensor, (0, pad), value=value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/lorentz.py
DELETED
|
@@ -1,265 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
|
| 10 |
-
"""Compute batched Lorentzian inner product for matching rows."""
|
| 11 |
-
x = x.float()
|
| 12 |
-
y = y.float()
|
| 13 |
-
return -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def pairwise_lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
|
| 17 |
-
"""Compute all-pairs Lorentzian inner products."""
|
| 18 |
-
x = x.float()
|
| 19 |
-
y = y.float()
|
| 20 |
-
time = -x[:, :1] @ y[:, :1].T
|
| 21 |
-
space = x[:, 1:] @ y[:, 1:].T
|
| 22 |
-
return time + space
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def exp_map0(u: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 26 |
-
"""Exponential map at the origin from tangent space to hyperboloid."""
|
| 27 |
-
u = u.float()
|
| 28 |
-
kappa = kappa.float()
|
| 29 |
-
sqrt_k = torch.sqrt(kappa)
|
| 30 |
-
norm_u = torch.linalg.norm(u, dim=-1, keepdim=True).clamp_min(eps)
|
| 31 |
-
scaled = sqrt_k * norm_u
|
| 32 |
-
clipped_scaled = scaled.clamp_max(math.asinh(2**15))
|
| 33 |
-
time = torch.cosh(clipped_scaled) / sqrt_k
|
| 34 |
-
space = torch.sinh(clipped_scaled) * u / scaled.clamp_min(eps)
|
| 35 |
-
return torch.cat([time, space], dim=-1)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def log_map0(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 39 |
-
"""Logarithmic map at the origin from hyperboloid to tangent space.
|
| 40 |
-
|
| 41 |
-
Inverts ``exp_map0`` for points on the Lorentz model hyperboloid. Returns
|
| 42 |
-
vectors in the Euclidean tangent space at the origin (no time coordinate).
|
| 43 |
-
"""
|
| 44 |
-
x = x.float()
|
| 45 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 46 |
-
kappa = kappa.to(dtype=torch.float32).flatten()
|
| 47 |
-
|
| 48 |
-
if x.dim() == 2:
|
| 49 |
-
if kappa.numel() != 1:
|
| 50 |
-
raise ValueError("log_map0 expects scalar kappa for non-product embeddings")
|
| 51 |
-
sqrt_k = torch.sqrt(kappa.reshape(()))
|
| 52 |
-
alpha = torch.acosh((sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps))
|
| 53 |
-
coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
|
| 54 |
-
return x[:, 1:] * coef.unsqueeze(-1)
|
| 55 |
-
|
| 56 |
-
if x.dim() == 3:
|
| 57 |
-
if kappa.numel() == 1:
|
| 58 |
-
kappa = kappa.expand(x.shape[1])
|
| 59 |
-
if kappa.numel() != x.shape[1]:
|
| 60 |
-
raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
|
| 61 |
-
sqrt_k = torch.sqrt(kappa).view(1, -1)
|
| 62 |
-
alpha = torch.acosh((sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps))
|
| 63 |
-
coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
|
| 64 |
-
return x[..., 1:] * coef.unsqueeze(-1)
|
| 65 |
-
|
| 66 |
-
raise ValueError("log_map0 expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 70 |
-
"""Pairwise geodesic distance on the Lorentz model."""
|
| 71 |
-
kappa = kappa.float()
|
| 72 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 73 |
-
prod = (-kappa) * pairwise_lorentz_inner(x, y)
|
| 74 |
-
prod = prod.clamp_min(1.0 + dist_eps)
|
| 75 |
-
return torch.acosh(prod) / torch.sqrt(kappa)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def product_pairwise_dist(
|
| 79 |
-
x: Tensor,
|
| 80 |
-
y: Tensor,
|
| 81 |
-
kappa: Tensor,
|
| 82 |
-
metric: str = "l1",
|
| 83 |
-
eps: float = 1e-8,
|
| 84 |
-
) -> Tensor:
|
| 85 |
-
"""Pairwise distance in an l1/l2 product of Lorentz factors.
|
| 86 |
-
|
| 87 |
-
Inputs have shape ``[batch, factors, dim + 1]``. For ``metric="l1"``, this
|
| 88 |
-
matches the official PHyCLIP implementation's mean distance over factors.
|
| 89 |
-
"""
|
| 90 |
-
if x.dim() != 3 or y.dim() != 3:
|
| 91 |
-
raise ValueError("product_pairwise_dist expects [batch, factors, dim + 1] tensors")
|
| 92 |
-
if x.shape[1] != y.shape[1] or x.shape[2] != y.shape[2]:
|
| 93 |
-
raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
|
| 94 |
-
kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
|
| 95 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 96 |
-
x = x.float()
|
| 97 |
-
y = y.float()
|
| 98 |
-
inner = -x[:, None, :, 0] * y[None, :, :, 0] + torch.einsum("bkd,nkd->bnk", x[..., 1:], y[..., 1:])
|
| 99 |
-
prod = (-kappa.view(1, 1, -1)) * inner
|
| 100 |
-
dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, 1, -1)
|
| 101 |
-
if metric == "l1":
|
| 102 |
-
return dist.mean(dim=-1)
|
| 103 |
-
if metric == "l2":
|
| 104 |
-
return dist.square().mean(dim=-1).sqrt()
|
| 105 |
-
raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def metric_pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 109 |
-
"""Pairwise distance for either a single Lorentz space or a product space."""
|
| 110 |
-
if x.dim() == 3 or y.dim() == 3:
|
| 111 |
-
return product_pairwise_dist(x, y, kappa, metric=product_metric)
|
| 112 |
-
return pairwise_dist(x, y, kappa)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def paired_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1", eps: float = 1e-8) -> Tensor:
|
| 116 |
-
"""Row-wise distance for either a single Lorentz space or a product space."""
|
| 117 |
-
if x.dim() == 3 or y.dim() == 3:
|
| 118 |
-
if x.shape != y.shape:
|
| 119 |
-
raise ValueError("Product paired_dist expects matching tensor shapes")
|
| 120 |
-
kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
|
| 121 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 122 |
-
x = x.float()
|
| 123 |
-
y = y.float()
|
| 124 |
-
inner = -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
|
| 125 |
-
prod = (-kappa.view(1, -1)) * inner
|
| 126 |
-
dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, -1)
|
| 127 |
-
if product_metric == "l1":
|
| 128 |
-
return dist.mean(dim=-1)
|
| 129 |
-
if product_metric == "l2":
|
| 130 |
-
return dist.square().mean(dim=-1).sqrt()
|
| 131 |
-
raise ValueError(f"Unsupported product metric {product_metric!r}; expected 'l1' or 'l2'")
|
| 132 |
-
kappa = kappa.float()
|
| 133 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 134 |
-
prod = (-kappa) * lorentz_inner(x, y)
|
| 135 |
-
prod = prod.clamp_min(1.0 + dist_eps)
|
| 136 |
-
return torch.acosh(prod) / torch.sqrt(kappa)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def radial_distance(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 140 |
-
"""Geodesic distance from the origin.
|
| 141 |
-
|
| 142 |
-
For points on the hyperboloid, the time coordinate satisfies
|
| 143 |
-
``x0 = cosh(sqrt(kappa) * r) / sqrt(kappa)``, so we can recover the radial
|
| 144 |
-
distance via ``r = arcosh(sqrt(kappa) * x0) / sqrt(kappa)``.
|
| 145 |
-
"""
|
| 146 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 147 |
-
x = x.float()
|
| 148 |
-
kappa = kappa.to(dtype=torch.float32).flatten()
|
| 149 |
-
if x.dim() == 2:
|
| 150 |
-
if kappa.numel() != 1:
|
| 151 |
-
raise ValueError("radial_distance expects scalar kappa for non-product embeddings")
|
| 152 |
-
sqrt_k = torch.sqrt(kappa.reshape(()))
|
| 153 |
-
arg = (sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps)
|
| 154 |
-
return torch.acosh(arg) / sqrt_k
|
| 155 |
-
if x.dim() == 3:
|
| 156 |
-
if kappa.numel() == 1:
|
| 157 |
-
kappa = kappa.expand(x.shape[1])
|
| 158 |
-
if kappa.numel() != x.shape[1]:
|
| 159 |
-
raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
|
| 160 |
-
sqrt_k = torch.sqrt(kappa).view(1, -1)
|
| 161 |
-
arg = (sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps)
|
| 162 |
-
dist = torch.acosh(arg) / sqrt_k
|
| 163 |
-
return dist.mean(dim=-1)
|
| 164 |
-
raise ValueError("radial_distance expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def metric_similarity(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 168 |
-
"""Retrieval/classification similarity for single-space and PHyCLIP-style models."""
|
| 169 |
-
if x.dim() == 3 or y.dim() == 3:
|
| 170 |
-
return -product_pairwise_dist(x, y, kappa, metric=product_metric)
|
| 171 |
-
return pairwise_lorentz_inner(x, y)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def half_aperture(general: Tensor, kappa: Tensor, min_radius: float = 0.1, eps: float = 1e-8) -> Tensor:
|
| 175 |
-
"""Cone half-aperture for entailment cone centered at general concept."""
|
| 176 |
-
general = general.float()
|
| 177 |
-
kappa = kappa.float()
|
| 178 |
-
aperture_eps = max(eps, 16.0 * torch.finfo(general.dtype).eps)
|
| 179 |
-
general_norm = torch.linalg.norm(general[:, 1:], dim=-1)
|
| 180 |
-
ratio = (2.0 * min_radius) / (general_norm * torch.sqrt(kappa) + aperture_eps)
|
| 181 |
-
ratio = ratio.clamp(max=1.0 - aperture_eps)
|
| 182 |
-
return torch.asin(ratio)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 186 |
-
"""Exterior angle between specific point and entailment cone at general point."""
|
| 187 |
-
specific = specific.float()
|
| 188 |
-
general = general.float()
|
| 189 |
-
kappa = kappa.float()
|
| 190 |
-
angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
|
| 191 |
-
inner = lorentz_inner(specific, general)
|
| 192 |
-
numerator = specific[:, 0] + kappa * inner * general[:, 0]
|
| 193 |
-
general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
|
| 194 |
-
denom_term = (kappa * inner).pow(2) - 1.0
|
| 195 |
-
denom = general_norm * torch.sqrt(denom_term.clamp_min(angle_eps))
|
| 196 |
-
cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 197 |
-
return torch.acos(cosine)
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 201 |
-
"""All-pairs exterior angle between specific points and entailment cones at general points."""
|
| 202 |
-
specific = specific.float()
|
| 203 |
-
general = general.float()
|
| 204 |
-
kappa = kappa.to(dtype=torch.float32).flatten()
|
| 205 |
-
if kappa.numel() != 1:
|
| 206 |
-
raise ValueError("pairwise_oxy_angle expects scalar kappa for non-product embeddings")
|
| 207 |
-
kappa_scalar = kappa.reshape(())
|
| 208 |
-
angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
|
| 209 |
-
inner = -specific[:, None, 0] * general[None, :, 0] + torch.einsum("nd,md->nm", specific[:, 1:], general[:, 1:])
|
| 210 |
-
numerator = specific[:, None, 0] + kappa_scalar * inner * general[None, :, 0]
|
| 211 |
-
general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
|
| 212 |
-
denom_term = (kappa_scalar * inner).pow(2) - 1.0
|
| 213 |
-
denom = general_norm[None, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
|
| 214 |
-
cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 215 |
-
return torch.acos(cosine)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
def product_pairwise_oxy_angle(
|
| 219 |
-
specific: Tensor,
|
| 220 |
-
general: Tensor,
|
| 221 |
-
kappa: Tensor,
|
| 222 |
-
metric: str = "l1",
|
| 223 |
-
eps: float = 1e-8,
|
| 224 |
-
) -> Tensor:
|
| 225 |
-
"""All-pairs exterior angle in an l1/l2 product of Lorentz factors."""
|
| 226 |
-
if specific.dim() != 3 or general.dim() != 3:
|
| 227 |
-
raise ValueError("product_pairwise_oxy_angle expects [batch, factors, dim + 1] tensors")
|
| 228 |
-
if specific.shape[1] != general.shape[1] or specific.shape[2] != general.shape[2]:
|
| 229 |
-
raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
|
| 230 |
-
kappa = _product_kappa(kappa, specific.shape[1], specific.device).to(dtype=torch.float32)
|
| 231 |
-
angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
|
| 232 |
-
specific = specific.float()
|
| 233 |
-
general = general.float()
|
| 234 |
-
inner = -specific[:, None, :, 0] * general[None, :, :, 0] + torch.einsum(
|
| 235 |
-
"nkd,mkd->nmk",
|
| 236 |
-
specific[..., 1:],
|
| 237 |
-
general[..., 1:],
|
| 238 |
-
)
|
| 239 |
-
numerator = specific[:, None, :, 0] + (kappa.view(1, 1, -1) * inner) * general[None, :, :, 0]
|
| 240 |
-
general_norm = torch.linalg.norm(general[..., 1:], dim=-1).clamp_min(angle_eps)
|
| 241 |
-
denom_term = (kappa.view(1, 1, -1) * inner).pow(2) - 1.0
|
| 242 |
-
denom = general_norm[None, :, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
|
| 243 |
-
cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 244 |
-
angles = torch.acos(cosine)
|
| 245 |
-
if metric == "l1":
|
| 246 |
-
return angles.mean(dim=-1)
|
| 247 |
-
if metric == "l2":
|
| 248 |
-
return angles.square().mean(dim=-1).sqrt()
|
| 249 |
-
raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def metric_pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 253 |
-
"""All-pairs oxy-angle for either a single Lorentz space or a product space."""
|
| 254 |
-
if specific.dim() == 3 or general.dim() == 3:
|
| 255 |
-
return product_pairwise_oxy_angle(specific, general, kappa, metric=product_metric)
|
| 256 |
-
return pairwise_oxy_angle(specific, general, kappa)
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
def _product_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
|
| 260 |
-
kappa = kappa.to(device=device, dtype=torch.float32).flatten()
|
| 261 |
-
if kappa.numel() == 1:
|
| 262 |
-
return kappa.expand(num_factors)
|
| 263 |
-
if kappa.numel() != num_factors:
|
| 264 |
-
raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
|
| 265 |
-
return kappa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/losses.py
DELETED
|
@@ -1,1400 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
from hyper3_clip.models.lorentz import (
|
| 10 |
-
half_aperture,
|
| 11 |
-
metric_pairwise_dist,
|
| 12 |
-
metric_pairwise_oxy_angle,
|
| 13 |
-
oxy_angle,
|
| 14 |
-
paired_dist,
|
| 15 |
-
radial_distance,
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def contrastive_ce(logits: Tensor, targets: Tensor | None = None, weights: Tensor | None = None) -> Tensor:
|
| 20 |
-
if targets is None:
|
| 21 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 22 |
-
losses = F.cross_entropy(logits, targets, reduction="none")
|
| 23 |
-
return weighted_mean(losses, weights)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def contrastive_sigmoid(
|
| 27 |
-
logits: Tensor,
|
| 28 |
-
targets: Tensor | None = None,
|
| 29 |
-
weights: Tensor | None = None,
|
| 30 |
-
negative_weight: float = 1.0,
|
| 31 |
-
) -> Tensor:
|
| 32 |
-
if targets is None:
|
| 33 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 34 |
-
labels = torch.zeros_like(logits)
|
| 35 |
-
labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
|
| 36 |
-
losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
|
| 37 |
-
if negative_weight != 1.0:
|
| 38 |
-
element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
|
| 39 |
-
losses = losses * element_weights
|
| 40 |
-
losses = losses.mean(dim=1)
|
| 41 |
-
return weighted_mean(losses, weights)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def contrastive_siglip(
|
| 45 |
-
logits: Tensor,
|
| 46 |
-
targets: Tensor | None = None,
|
| 47 |
-
weights: Tensor | None = None,
|
| 48 |
-
negative_weight: float = 1.0,
|
| 49 |
-
) -> Tensor:
|
| 50 |
-
"""SigLIP pairwise sigmoid loss (Zhai et al., ICCV 2023).
|
| 51 |
-
|
| 52 |
-
Uses labels in {+1, -1} with a per-row sum (not mean) over pairs:
|
| 53 |
-
L_i = sum_j softplus(- y_ij * logit_ij)
|
| 54 |
-
"""
|
| 55 |
-
if logits.ndim != 2:
|
| 56 |
-
raise ValueError("contrastive_siglip expects a [batch, classes] logit matrix")
|
| 57 |
-
if targets is None:
|
| 58 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 59 |
-
labels = logits.new_full(logits.shape, -1.0)
|
| 60 |
-
labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
|
| 61 |
-
losses = F.softplus(-(labels * logits))
|
| 62 |
-
if negative_weight != 1.0:
|
| 63 |
-
element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
|
| 64 |
-
losses = losses * element_weights
|
| 65 |
-
row_losses = losses.sum(dim=1)
|
| 66 |
-
return weighted_mean(row_losses, weights)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def weighted_mean(values: Tensor, weights: Tensor | None = None) -> Tensor:
|
| 70 |
-
if weights is None:
|
| 71 |
-
return values.mean()
|
| 72 |
-
weights = weights.to(device=values.device, dtype=values.dtype)
|
| 73 |
-
while weights.dim() < values.dim():
|
| 74 |
-
weights = weights.unsqueeze(-1)
|
| 75 |
-
return (values * weights).sum() / weights.sum().clamp_min(torch.finfo(values.dtype).eps)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def gramian_volume_loss(vectors: Tensor, weights: Tensor | None = None, eps: float = 1e-4) -> Tensor:
|
| 79 |
-
"""GRAM-style volume loss for sets of vectors.
|
| 80 |
-
|
| 81 |
-
``vectors`` is expected to have shape ``[batch, k, dim]``. Each set of k
|
| 82 |
-
vectors is L2-normalized along ``dim``, then we compute the Gramian
|
| 83 |
-
``G = V V^T`` and return ``sqrt(det(G + eps I))`` averaged over the batch.
|
| 84 |
-
"""
|
| 85 |
-
if vectors.ndim != 3:
|
| 86 |
-
raise ValueError("gramian_volume_loss expects a [batch, k, dim] tensor")
|
| 87 |
-
if eps <= 0.0:
|
| 88 |
-
raise ValueError("gramian_volume_loss eps must be positive")
|
| 89 |
-
|
| 90 |
-
vectors = F.normalize(vectors.float(), dim=-1, eps=1e-8)
|
| 91 |
-
gram = vectors @ vectors.transpose(-1, -2)
|
| 92 |
-
k = gram.size(-1)
|
| 93 |
-
gram = gram + eps * torch.eye(k, device=gram.device, dtype=gram.dtype)
|
| 94 |
-
sign, logabsdet = torch.linalg.slogdet(gram)
|
| 95 |
-
volume = torch.exp(0.5 * logabsdet)
|
| 96 |
-
volume = torch.where(sign > 0, volume, volume.new_ones(volume.shape))
|
| 97 |
-
return weighted_mean(volume, weights)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def radius_order_hinge(
|
| 101 |
-
specific: Tensor,
|
| 102 |
-
general: Tensor,
|
| 103 |
-
kappa: Tensor,
|
| 104 |
-
margin: float,
|
| 105 |
-
weights: Tensor | None = None,
|
| 106 |
-
) -> Tensor:
|
| 107 |
-
if specific.shape[0] != general.shape[0]:
|
| 108 |
-
raise ValueError("radius_order_hinge expects matching batch dimensions")
|
| 109 |
-
if margin < 0.0:
|
| 110 |
-
raise ValueError("radius_order_hinge margin must be non-negative")
|
| 111 |
-
specific_radius = radial_distance(specific, kappa)
|
| 112 |
-
general_radius = radial_distance(general, kappa)
|
| 113 |
-
losses = F.relu(float(margin) + general_radius - specific_radius)
|
| 114 |
-
return weighted_mean(losses, weights)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def soft_contrastive_ce(logits: Tensor, target_weights: Tensor, weights: Tensor | None = None) -> Tensor:
|
| 118 |
-
if logits.ndim != 2 or target_weights.ndim != 2:
|
| 119 |
-
raise ValueError("soft_contrastive_ce expects [batch, classes] tensors")
|
| 120 |
-
if logits.shape != target_weights.shape:
|
| 121 |
-
raise ValueError("soft_contrastive_ce requires logits and target_weights to have matching shapes")
|
| 122 |
-
log_probs = F.log_softmax(logits, dim=1)
|
| 123 |
-
losses = -(target_weights.to(dtype=log_probs.dtype) * log_probs).sum(dim=1)
|
| 124 |
-
return weighted_mean(losses, weights)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def beta_cal_loss(
|
| 128 |
-
logits: Tensor,
|
| 129 |
-
*,
|
| 130 |
-
targets: Tensor,
|
| 131 |
-
group_ids: Tensor,
|
| 132 |
-
all_group_ids: Tensor,
|
| 133 |
-
beta: float,
|
| 134 |
-
variant: str,
|
| 135 |
-
weights: Tensor | None = None,
|
| 136 |
-
) -> Tensor:
|
| 137 |
-
if beta < 0.0:
|
| 138 |
-
raise ValueError("beta_cal_loss beta must be non-negative")
|
| 139 |
-
if variant not in {"ce", "bce"}:
|
| 140 |
-
raise ValueError("beta_cal_loss variant must be 'ce' or 'bce'")
|
| 141 |
-
if logits.ndim != 2:
|
| 142 |
-
raise ValueError("beta_cal_loss expects a [batch, classes] logit matrix")
|
| 143 |
-
if targets.shape != (logits.size(0),):
|
| 144 |
-
raise ValueError("beta_cal_loss targets must have shape [batch]")
|
| 145 |
-
if group_ids.shape != (logits.size(0),):
|
| 146 |
-
raise ValueError("beta_cal_loss group_ids must have shape [batch]")
|
| 147 |
-
if all_group_ids.shape != (logits.size(1),):
|
| 148 |
-
raise ValueError("beta_cal_loss all_group_ids must have shape [classes]")
|
| 149 |
-
|
| 150 |
-
same_group = group_ids[:, None] == all_group_ids[None, :]
|
| 151 |
-
same_pair = targets[:, None] == torch.arange(logits.size(1), device=logits.device)[None, :]
|
| 152 |
-
|
| 153 |
-
if variant == "ce":
|
| 154 |
-
target_weights = logits.new_zeros(logits.shape)
|
| 155 |
-
target_weights = torch.where(same_pair, logits.new_ones(()), target_weights)
|
| 156 |
-
target_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), target_weights)
|
| 157 |
-
target_weights = target_weights / target_weights.sum(dim=1, keepdim=True).clamp_min(
|
| 158 |
-
torch.finfo(target_weights.dtype).eps
|
| 159 |
-
)
|
| 160 |
-
return soft_contrastive_ce(logits, target_weights, weights)
|
| 161 |
-
|
| 162 |
-
labels = same_group.to(dtype=logits.dtype)
|
| 163 |
-
element_weights = logits.new_ones(logits.shape)
|
| 164 |
-
element_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), element_weights)
|
| 165 |
-
element_losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") * element_weights
|
| 166 |
-
row_losses = element_losses.mean(dim=1)
|
| 167 |
-
return weighted_mean(row_losses, weights)
|
| 168 |
-
|
| 169 |
-
def compositional_contrastive_loss(
|
| 170 |
-
image_feats: Tensor,
|
| 171 |
-
text_feats: Tensor,
|
| 172 |
-
box_image_feats: Tensor,
|
| 173 |
-
box_text_feats: Tensor,
|
| 174 |
-
kappa: Tensor,
|
| 175 |
-
logit_scale: Tensor,
|
| 176 |
-
all_image_feats: Tensor | None = None,
|
| 177 |
-
all_text_feats: Tensor | None = None,
|
| 178 |
-
targets: Tensor | None = None,
|
| 179 |
-
) -> Tensor:
|
| 180 |
-
scale = logit_scale.exp().clamp(max=100.0)
|
| 181 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 182 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 183 |
-
|
| 184 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
|
| 185 |
-
logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
|
| 186 |
-
logits_bi_t = -metric_pairwise_dist(box_image_feats, all_text_feats, kappa) * scale
|
| 187 |
-
logits_bt_i = -metric_pairwise_dist(box_text_feats, all_image_feats, kappa) * scale
|
| 188 |
-
|
| 189 |
-
return 0.25 * (
|
| 190 |
-
contrastive_ce(logits_i_t, targets)
|
| 191 |
-
+ contrastive_ce(logits_t_i, targets)
|
| 192 |
-
+ contrastive_ce(logits_bi_t, targets)
|
| 193 |
-
+ contrastive_ce(logits_bt_i, targets)
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def multi_part_contrastive_loss(
|
| 198 |
-
image_feats: Tensor,
|
| 199 |
-
text_feats: Tensor,
|
| 200 |
-
part_image_feats: Tensor,
|
| 201 |
-
part_text_feats: Tensor,
|
| 202 |
-
part_mask: Tensor,
|
| 203 |
-
kappa: Tensor,
|
| 204 |
-
logit_scale: Tensor,
|
| 205 |
-
all_image_feats: Tensor | None = None,
|
| 206 |
-
all_text_feats: Tensor | None = None,
|
| 207 |
-
targets: Tensor | None = None,
|
| 208 |
-
) -> Tensor:
|
| 209 |
-
scale = logit_scale.exp().clamp(max=100.0)
|
| 210 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 211 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 212 |
-
if targets is None:
|
| 213 |
-
targets = torch.arange(image_feats.size(0), device=image_feats.device)
|
| 214 |
-
|
| 215 |
-
part_image_flat, part_text_flat, part_targets = _flatten_valid_parts(part_image_feats, part_text_feats, part_mask, targets)
|
| 216 |
-
|
| 217 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
|
| 218 |
-
logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
|
| 219 |
-
logits_pi_t = -metric_pairwise_dist(part_image_flat, all_text_feats, kappa) * scale
|
| 220 |
-
logits_pt_i = -metric_pairwise_dist(part_text_flat, all_image_feats, kappa) * scale
|
| 221 |
-
|
| 222 |
-
return 0.25 * (
|
| 223 |
-
contrastive_ce(logits_i_t, targets)
|
| 224 |
-
+ contrastive_ce(logits_t_i, targets)
|
| 225 |
-
+ contrastive_ce(logits_pi_t, part_targets)
|
| 226 |
-
+ contrastive_ce(logits_pt_i, part_targets)
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def packed_part_contrastive_loss(
|
| 231 |
-
image_feats: Tensor,
|
| 232 |
-
text_feats: Tensor,
|
| 233 |
-
part_image_feats: Tensor,
|
| 234 |
-
part_text_feats: Tensor,
|
| 235 |
-
part_owner: Tensor,
|
| 236 |
-
kappa: Tensor,
|
| 237 |
-
logit_scale: Tensor,
|
| 238 |
-
all_image_feats: Tensor | None = None,
|
| 239 |
-
all_text_feats: Tensor | None = None,
|
| 240 |
-
targets: Tensor | None = None,
|
| 241 |
-
) -> Tensor:
|
| 242 |
-
scale = logit_scale.exp().clamp(max=100.0)
|
| 243 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 244 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 245 |
-
if targets is None:
|
| 246 |
-
targets = torch.arange(image_feats.size(0), device=image_feats.device)
|
| 247 |
-
|
| 248 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
|
| 249 |
-
logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
|
| 250 |
-
global_loss = 0.5 * (contrastive_ce(logits_i_t, targets) + contrastive_ce(logits_t_i, targets))
|
| 251 |
-
|
| 252 |
-
if part_image_feats.numel() == 0:
|
| 253 |
-
return global_loss
|
| 254 |
-
|
| 255 |
-
part_targets = targets[part_owner]
|
| 256 |
-
logits_pi_t = -metric_pairwise_dist(part_image_feats, all_text_feats, kappa) * scale
|
| 257 |
-
logits_pt_i = -metric_pairwise_dist(part_text_feats, all_image_feats, kappa) * scale
|
| 258 |
-
part_loss = 0.5 * (contrastive_ce(logits_pi_t, part_targets) + contrastive_ce(logits_pt_i, part_targets))
|
| 259 |
-
return 0.5 * (global_loss + part_loss)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
def factor_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor) -> Tensor:
|
| 263 |
-
if specific.dim() != 3:
|
| 264 |
-
return oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 265 |
-
batch_size, num_factors, feature_dim = specific.shape
|
| 266 |
-
kappa = _factor_kappa(kappa, num_factors, specific.device)
|
| 267 |
-
factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
|
| 268 |
-
return oxy_angle(
|
| 269 |
-
specific=specific.reshape(batch_size * num_factors, feature_dim),
|
| 270 |
-
general=general.reshape(batch_size * num_factors, feature_dim),
|
| 271 |
-
kappa=factor_kappa,
|
| 272 |
-
).reshape(batch_size, num_factors)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def factor_half_aperture(general: Tensor, kappa: Tensor) -> Tensor:
|
| 276 |
-
if general.dim() != 3:
|
| 277 |
-
return half_aperture(general=general, kappa=kappa)
|
| 278 |
-
batch_size, num_factors, feature_dim = general.shape
|
| 279 |
-
kappa = _factor_kappa(kappa, num_factors, general.device)
|
| 280 |
-
factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
|
| 281 |
-
return half_aperture(
|
| 282 |
-
general=general.reshape(batch_size * num_factors, feature_dim),
|
| 283 |
-
kappa=factor_kappa,
|
| 284 |
-
).reshape(batch_size, num_factors)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def _factor_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
|
| 288 |
-
kappa = kappa.to(device=device, dtype=torch.float32).flatten()
|
| 289 |
-
if kappa.numel() == 1:
|
| 290 |
-
return kappa.expand(num_factors)
|
| 291 |
-
if kappa.numel() != num_factors:
|
| 292 |
-
raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
|
| 293 |
-
return kappa
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
def entailment_residual(
|
| 297 |
-
specific: Tensor,
|
| 298 |
-
general: Tensor,
|
| 299 |
-
kappa: Tensor,
|
| 300 |
-
aperture_scale: float,
|
| 301 |
-
) -> Tensor:
|
| 302 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 303 |
-
apertures = factor_half_aperture(general=general, kappa=kappa)
|
| 304 |
-
return torch.clamp(angles - (aperture_scale * apertures), min=0.0).mean()
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
def weighted_entailment_residual(
|
| 308 |
-
specific: Tensor,
|
| 309 |
-
general: Tensor,
|
| 310 |
-
kappa: Tensor,
|
| 311 |
-
aperture_scale: float,
|
| 312 |
-
weights: Tensor | None = None,
|
| 313 |
-
) -> Tensor:
|
| 314 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 315 |
-
apertures = factor_half_aperture(general=general, kappa=kappa)
|
| 316 |
-
residuals = torch.clamp(angles - (aperture_scale * apertures), min=0.0)
|
| 317 |
-
if residuals.dim() == 2:
|
| 318 |
-
residuals = residuals.mean(dim=-1)
|
| 319 |
-
return weighted_mean(residuals, weights)
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
def compositional_entailment_loss(
|
| 323 |
-
image_feats: Tensor,
|
| 324 |
-
text_feats: Tensor,
|
| 325 |
-
box_image_feats: Tensor,
|
| 326 |
-
box_text_feats: Tensor,
|
| 327 |
-
kappa: Tensor,
|
| 328 |
-
inter_aperture_scale: float,
|
| 329 |
-
intra_aperture_scale: float,
|
| 330 |
-
) -> Tensor:
|
| 331 |
-
text_to_image = entailment_residual(
|
| 332 |
-
specific=image_feats,
|
| 333 |
-
general=text_feats,
|
| 334 |
-
kappa=kappa,
|
| 335 |
-
aperture_scale=inter_aperture_scale,
|
| 336 |
-
)
|
| 337 |
-
box_text_to_box_image = entailment_residual(
|
| 338 |
-
specific=box_image_feats,
|
| 339 |
-
general=box_text_feats,
|
| 340 |
-
kappa=kappa,
|
| 341 |
-
aperture_scale=inter_aperture_scale,
|
| 342 |
-
)
|
| 343 |
-
box_image_to_image = entailment_residual(
|
| 344 |
-
specific=image_feats,
|
| 345 |
-
general=box_image_feats,
|
| 346 |
-
kappa=kappa,
|
| 347 |
-
aperture_scale=intra_aperture_scale,
|
| 348 |
-
)
|
| 349 |
-
box_text_to_text = entailment_residual(
|
| 350 |
-
specific=text_feats,
|
| 351 |
-
general=box_text_feats,
|
| 352 |
-
kappa=kappa,
|
| 353 |
-
aperture_scale=intra_aperture_scale,
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
return 0.5 * (text_to_image + box_text_to_box_image + box_image_to_image + box_text_to_text)
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
def multi_part_entailment_loss(
|
| 360 |
-
image_feats: Tensor,
|
| 361 |
-
text_feats: Tensor,
|
| 362 |
-
part_image_feats: Tensor,
|
| 363 |
-
part_text_feats: Tensor,
|
| 364 |
-
part_mask: Tensor,
|
| 365 |
-
kappa: Tensor,
|
| 366 |
-
inter_aperture_scale: float,
|
| 367 |
-
intra_aperture_scale: float,
|
| 368 |
-
) -> Tensor:
|
| 369 |
-
part_image_flat = part_image_feats[part_mask]
|
| 370 |
-
part_text_flat = part_text_feats[part_mask]
|
| 371 |
-
image_for_parts = image_feats[:, None, :].expand_as(part_image_feats)[part_mask]
|
| 372 |
-
text_for_parts = text_feats[:, None, :].expand_as(part_text_feats)[part_mask]
|
| 373 |
-
|
| 374 |
-
text_to_image = entailment_residual(
|
| 375 |
-
specific=image_feats,
|
| 376 |
-
general=text_feats,
|
| 377 |
-
kappa=kappa,
|
| 378 |
-
aperture_scale=inter_aperture_scale,
|
| 379 |
-
)
|
| 380 |
-
part_text_to_part_image = entailment_residual(
|
| 381 |
-
specific=part_image_flat,
|
| 382 |
-
general=part_text_flat,
|
| 383 |
-
kappa=kappa,
|
| 384 |
-
aperture_scale=inter_aperture_scale,
|
| 385 |
-
)
|
| 386 |
-
part_image_to_image = entailment_residual(
|
| 387 |
-
specific=image_for_parts,
|
| 388 |
-
general=part_image_flat,
|
| 389 |
-
kappa=kappa,
|
| 390 |
-
aperture_scale=intra_aperture_scale,
|
| 391 |
-
)
|
| 392 |
-
part_text_to_text = entailment_residual(
|
| 393 |
-
specific=text_for_parts,
|
| 394 |
-
general=part_text_flat,
|
| 395 |
-
kappa=kappa,
|
| 396 |
-
aperture_scale=intra_aperture_scale,
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
def packed_part_entailment_loss(
|
| 403 |
-
image_feats: Tensor,
|
| 404 |
-
text_feats: Tensor,
|
| 405 |
-
part_image_feats: Tensor,
|
| 406 |
-
part_text_feats: Tensor,
|
| 407 |
-
part_owner: Tensor,
|
| 408 |
-
kappa: Tensor,
|
| 409 |
-
inter_aperture_scale: float,
|
| 410 |
-
intra_aperture_scale: float,
|
| 411 |
-
) -> Tensor:
|
| 412 |
-
text_to_image = entailment_residual(
|
| 413 |
-
specific=image_feats,
|
| 414 |
-
general=text_feats,
|
| 415 |
-
kappa=kappa,
|
| 416 |
-
aperture_scale=inter_aperture_scale,
|
| 417 |
-
)
|
| 418 |
-
if part_image_feats.numel() == 0:
|
| 419 |
-
return text_to_image
|
| 420 |
-
|
| 421 |
-
image_for_parts = image_feats[part_owner]
|
| 422 |
-
text_for_parts = text_feats[part_owner]
|
| 423 |
-
part_text_to_part_image = entailment_residual(
|
| 424 |
-
specific=part_image_feats,
|
| 425 |
-
general=part_text_feats,
|
| 426 |
-
kappa=kappa,
|
| 427 |
-
aperture_scale=inter_aperture_scale,
|
| 428 |
-
)
|
| 429 |
-
part_image_to_image = entailment_residual(
|
| 430 |
-
specific=image_for_parts,
|
| 431 |
-
general=part_image_feats,
|
| 432 |
-
kappa=kappa,
|
| 433 |
-
aperture_scale=intra_aperture_scale,
|
| 434 |
-
)
|
| 435 |
-
part_text_to_text = entailment_residual(
|
| 436 |
-
specific=text_for_parts,
|
| 437 |
-
general=part_text_feats,
|
| 438 |
-
kappa=kappa,
|
| 439 |
-
aperture_scale=intra_aperture_scale,
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
def uncha_contrastive_losses(
|
| 446 |
-
image_feats: Tensor,
|
| 447 |
-
text_feats: Tensor,
|
| 448 |
-
part_image_flat: Tensor,
|
| 449 |
-
part_text_flat: Tensor,
|
| 450 |
-
image_for_parts: Tensor,
|
| 451 |
-
text_for_parts: Tensor,
|
| 452 |
-
kappa: Tensor,
|
| 453 |
-
global_logit_scale: Tensor,
|
| 454 |
-
local_logit_scale: Tensor,
|
| 455 |
-
global_local_logit_scale: Tensor,
|
| 456 |
-
image_euc_feats: Tensor | None = None,
|
| 457 |
-
text_euc_feats: Tensor | None = None,
|
| 458 |
-
part_image_euc_flat: Tensor | None = None,
|
| 459 |
-
part_text_euc_flat: Tensor | None = None,
|
| 460 |
-
image_for_parts_euc: Tensor | None = None,
|
| 461 |
-
text_for_parts_euc: Tensor | None = None,
|
| 462 |
-
all_image_feats: Tensor | None = None,
|
| 463 |
-
all_text_feats: Tensor | None = None,
|
| 464 |
-
all_part_image_feats: Tensor | None = None,
|
| 465 |
-
all_part_text_feats: Tensor | None = None,
|
| 466 |
-
all_image_for_parts: Tensor | None = None,
|
| 467 |
-
all_text_for_parts: Tensor | None = None,
|
| 468 |
-
all_image_euc_feats: Tensor | None = None,
|
| 469 |
-
all_text_euc_feats: Tensor | None = None,
|
| 470 |
-
all_part_image_euc_feats: Tensor | None = None,
|
| 471 |
-
all_part_text_euc_feats: Tensor | None = None,
|
| 472 |
-
all_image_for_parts_euc: Tensor | None = None,
|
| 473 |
-
all_text_for_parts_euc: Tensor | None = None,
|
| 474 |
-
global_targets: Tensor | None = None,
|
| 475 |
-
part_targets: Tensor | None = None,
|
| 476 |
-
part_weights: Tensor | None = None,
|
| 477 |
-
product_metric: str = "l1",
|
| 478 |
-
loss_type: str = "ce",
|
| 479 |
-
contrastive_global_weight: float = 1.0,
|
| 480 |
-
contrastive_local_weight: float = 1.0,
|
| 481 |
-
contrastive_global_local_weight: float = 1.0,
|
| 482 |
-
beta_cal_beta: float = 0.0,
|
| 483 |
-
beta_cal_variant: str = "ce",
|
| 484 |
-
beta_cal_weight: float = 0.0,
|
| 485 |
-
part_group_ids: Tensor | None = None,
|
| 486 |
-
all_part_group_ids: Tensor | None = None,
|
| 487 |
-
global_logit_bias: Tensor | None = None,
|
| 488 |
-
local_logit_bias: Tensor | None = None,
|
| 489 |
-
global_local_logit_bias: Tensor | None = None,
|
| 490 |
-
sigmoid_negative_weight: float = 1.0,
|
| 491 |
-
global_local_mode: str = "repeat",
|
| 492 |
-
global_local_metric: str = "distance",
|
| 493 |
-
global_local_angle_aux_weight: float = 0.0,
|
| 494 |
-
global_local_angle_aux_mode: str = "contrastive",
|
| 495 |
-
global_local_angle_aux_scale: float = 5.5,
|
| 496 |
-
global_local_angle_aux_aperture_scale: float = 1.0,
|
| 497 |
-
) -> dict[str, Tensor]:
|
| 498 |
-
if loss_type not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
|
| 499 |
-
raise ValueError(
|
| 500 |
-
f"Unsupported contrastive loss {loss_type!r}; expected 'ce', 'sigmoid', 'siglip', or 'siglip_metric'"
|
| 501 |
-
)
|
| 502 |
-
if global_local_mode not in {"repeat", "inbatch"}:
|
| 503 |
-
raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
|
| 504 |
-
if global_local_metric not in {"distance", "angle"}:
|
| 505 |
-
raise ValueError("global_local_metric must be 'distance' or 'angle'")
|
| 506 |
-
if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
|
| 507 |
-
raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
|
| 508 |
-
if global_local_angle_aux_weight < 0.0:
|
| 509 |
-
raise ValueError("global_local_angle_aux_weight must be non-negative")
|
| 510 |
-
if global_local_angle_aux_scale <= 0.0:
|
| 511 |
-
raise ValueError("global_local_angle_aux_scale must be positive")
|
| 512 |
-
if global_local_angle_aux_aperture_scale <= 0.0:
|
| 513 |
-
raise ValueError("global_local_angle_aux_aperture_scale must be positive")
|
| 514 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 515 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 516 |
-
all_part_image_feats = part_image_flat if all_part_image_feats is None else all_part_image_feats
|
| 517 |
-
all_part_text_feats = part_text_flat if all_part_text_feats is None else all_part_text_feats
|
| 518 |
-
all_image_for_parts = image_for_parts if all_image_for_parts is None else all_image_for_parts
|
| 519 |
-
all_text_for_parts = text_for_parts if all_text_for_parts is None else all_text_for_parts
|
| 520 |
-
if global_targets is None:
|
| 521 |
-
global_targets = torch.arange(image_feats.size(0), device=image_feats.device)
|
| 522 |
-
if part_targets is None:
|
| 523 |
-
part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device)
|
| 524 |
-
|
| 525 |
-
global_scale = global_logit_scale.exp().clamp(max=100.0)
|
| 526 |
-
local_scale = local_logit_scale.exp().clamp(max=100.0)
|
| 527 |
-
global_local_scale = global_local_logit_scale.exp().clamp(max=100.0)
|
| 528 |
-
|
| 529 |
-
if loss_type == "siglip":
|
| 530 |
-
if image_euc_feats is None or text_euc_feats is None:
|
| 531 |
-
raise ValueError("siglip contrastive requires image_euc_feats and text_euc_feats")
|
| 532 |
-
if image_feats.dim() != 2 or text_feats.dim() != 2:
|
| 533 |
-
raise ValueError("siglip contrastive is only supported for non-product features")
|
| 534 |
-
all_image_euc_feats = image_euc_feats if all_image_euc_feats is None else all_image_euc_feats
|
| 535 |
-
all_text_euc_feats = text_euc_feats if all_text_euc_feats is None else all_text_euc_feats
|
| 536 |
-
zimg = F.normalize(image_euc_feats.float(), dim=-1)
|
| 537 |
-
ztxt = F.normalize(text_euc_feats.float(), dim=-1)
|
| 538 |
-
zimg_all = F.normalize(all_image_euc_feats.float(), dim=-1)
|
| 539 |
-
ztxt_all = F.normalize(all_text_euc_feats.float(), dim=-1)
|
| 540 |
-
image_logits = (zimg @ ztxt_all.T) * global_scale
|
| 541 |
-
text_logits = (ztxt @ zimg_all.T) * global_scale
|
| 542 |
-
else:
|
| 543 |
-
image_logits = -metric_pairwise_dist(image_feats, all_text_feats, kappa, product_metric=product_metric) * global_scale
|
| 544 |
-
text_logits = -metric_pairwise_dist(text_feats, all_image_feats, kappa, product_metric=product_metric) * global_scale
|
| 545 |
-
|
| 546 |
-
if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
|
| 547 |
-
bias = image_logits.new_zeros(()) if global_logit_bias is None else global_logit_bias.to(image_logits.device)
|
| 548 |
-
image_logits = image_logits + bias
|
| 549 |
-
text_logits = text_logits + bias
|
| 550 |
-
global_contrastive = 0.5 * (
|
| 551 |
-
_contrastive_loss(image_logits, global_targets, None, loss_type, sigmoid_negative_weight)
|
| 552 |
-
+ _contrastive_loss(text_logits, global_targets, None, loss_type, sigmoid_negative_weight)
|
| 553 |
-
)
|
| 554 |
-
|
| 555 |
-
if part_image_flat.numel() == 0:
|
| 556 |
-
zero = image_feats.new_zeros(())
|
| 557 |
-
contrastive = contrastive_global_weight * global_contrastive
|
| 558 |
-
return {
|
| 559 |
-
"contrastive_loss": contrastive,
|
| 560 |
-
"global_contrastive_loss": global_contrastive,
|
| 561 |
-
"local_contrastive_loss": zero,
|
| 562 |
-
"global_local_contrastive_loss": zero,
|
| 563 |
-
"global_local_angle_aux_loss": zero,
|
| 564 |
-
"beta_cal_loss": zero,
|
| 565 |
-
}
|
| 566 |
-
|
| 567 |
-
if loss_type == "siglip":
|
| 568 |
-
if part_image_euc_flat is None or part_text_euc_flat is None:
|
| 569 |
-
raise ValueError("siglip contrastive requires part_image_euc_flat and part_text_euc_flat when parts exist")
|
| 570 |
-
all_part_image_euc_feats = part_image_euc_flat if all_part_image_euc_feats is None else all_part_image_euc_feats
|
| 571 |
-
all_part_text_euc_feats = part_text_euc_flat if all_part_text_euc_feats is None else all_part_text_euc_feats
|
| 572 |
-
zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
|
| 573 |
-
zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
|
| 574 |
-
zpi_all = F.normalize(all_part_image_euc_feats.float(), dim=-1)
|
| 575 |
-
zpt_all = F.normalize(all_part_text_euc_feats.float(), dim=-1)
|
| 576 |
-
part_image_logits = (zpi @ zpt_all.T) * local_scale
|
| 577 |
-
part_text_logits = (zpt @ zpi_all.T) * local_scale
|
| 578 |
-
else:
|
| 579 |
-
part_image_logits = -metric_pairwise_dist(part_image_flat, all_part_text_feats, kappa, product_metric=product_metric) * local_scale
|
| 580 |
-
part_text_logits = -metric_pairwise_dist(part_text_flat, all_part_image_feats, kappa, product_metric=product_metric) * local_scale
|
| 581 |
-
|
| 582 |
-
if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
|
| 583 |
-
bias = part_image_logits.new_zeros(()) if local_logit_bias is None else local_logit_bias.to(part_image_logits.device)
|
| 584 |
-
part_image_logits = part_image_logits + bias
|
| 585 |
-
part_text_logits = part_text_logits + bias
|
| 586 |
-
local_contrastive = 0.5 * (
|
| 587 |
-
_contrastive_loss(part_image_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 588 |
-
+ _contrastive_loss(part_text_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 589 |
-
)
|
| 590 |
-
|
| 591 |
-
global_local_contrastive = image_feats.new_zeros(())
|
| 592 |
-
global_local_angle_aux = image_feats.new_zeros(())
|
| 593 |
-
if contrastive_global_local_weight != 0.0:
|
| 594 |
-
if global_local_mode == "inbatch":
|
| 595 |
-
if part_group_ids is None:
|
| 596 |
-
raise ValueError("inbatch global-local contrastive requires part_group_ids to be provided")
|
| 597 |
-
global_local_targets = part_group_ids
|
| 598 |
-
all_text_for_global_local = all_text_feats
|
| 599 |
-
all_image_for_global_local = all_image_feats
|
| 600 |
-
all_text_for_global_local_euc = all_text_euc_feats
|
| 601 |
-
all_image_for_global_local_euc = all_image_euc_feats
|
| 602 |
-
else:
|
| 603 |
-
global_local_targets = part_targets
|
| 604 |
-
all_text_for_global_local = all_text_for_parts
|
| 605 |
-
all_image_for_global_local = all_image_for_parts
|
| 606 |
-
all_text_for_global_local_euc = all_text_for_parts_euc
|
| 607 |
-
all_image_for_global_local_euc = all_image_for_parts_euc
|
| 608 |
-
|
| 609 |
-
image_uncertainty = embedding_uncertainty(part_image_flat).detach()
|
| 610 |
-
text_uncertainty = embedding_uncertainty(part_text_flat).detach()
|
| 611 |
-
image_temp = torch.exp(-0.5 * image_uncertainty).clamp(min=0.1, max=10.0)
|
| 612 |
-
text_temp = torch.exp(-0.5 * text_uncertainty).clamp(min=0.1, max=10.0)
|
| 613 |
-
|
| 614 |
-
if loss_type == "siglip":
|
| 615 |
-
if part_image_euc_flat is None or part_text_euc_flat is None:
|
| 616 |
-
raise ValueError("siglip global-local contrastive requires part_image_euc_flat/part_text_euc_flat")
|
| 617 |
-
if all_text_for_global_local_euc is None or all_image_for_global_local_euc is None:
|
| 618 |
-
raise ValueError("siglip global-local contrastive requires all_image_euc_feats/all_text_euc_feats")
|
| 619 |
-
zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
|
| 620 |
-
zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
|
| 621 |
-
zimg_all = F.normalize(all_image_for_global_local_euc.float(), dim=-1)
|
| 622 |
-
ztxt_all = F.normalize(all_text_for_global_local_euc.float(), dim=-1)
|
| 623 |
-
part_image_to_whole_text = (zpi @ ztxt_all.T) * image_temp[:, None] * global_local_scale
|
| 624 |
-
part_text_to_whole_image = (zpt @ zimg_all.T) * text_temp[:, None] * global_local_scale
|
| 625 |
-
else:
|
| 626 |
-
if global_local_metric == "angle":
|
| 627 |
-
part_image_to_whole_text = -metric_pairwise_oxy_angle(
|
| 628 |
-
part_image_flat,
|
| 629 |
-
all_text_for_global_local,
|
| 630 |
-
kappa,
|
| 631 |
-
product_metric=product_metric,
|
| 632 |
-
)
|
| 633 |
-
part_text_to_whole_image = -metric_pairwise_oxy_angle(
|
| 634 |
-
part_text_flat,
|
| 635 |
-
all_image_for_global_local,
|
| 636 |
-
kappa,
|
| 637 |
-
product_metric=product_metric,
|
| 638 |
-
)
|
| 639 |
-
else:
|
| 640 |
-
part_image_to_whole_text = -metric_pairwise_dist(
|
| 641 |
-
part_image_flat, all_text_for_global_local, kappa, product_metric=product_metric
|
| 642 |
-
)
|
| 643 |
-
part_text_to_whole_image = -metric_pairwise_dist(
|
| 644 |
-
part_text_flat, all_image_for_global_local, kappa, product_metric=product_metric
|
| 645 |
-
)
|
| 646 |
-
part_image_to_whole_text = part_image_to_whole_text * image_temp[:, None] * global_local_scale
|
| 647 |
-
part_text_to_whole_image = part_text_to_whole_image * text_temp[:, None] * global_local_scale
|
| 648 |
-
|
| 649 |
-
if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
|
| 650 |
-
bias = (
|
| 651 |
-
part_image_to_whole_text.new_zeros(())
|
| 652 |
-
if global_local_logit_bias is None
|
| 653 |
-
else global_local_logit_bias.to(part_image_to_whole_text.device)
|
| 654 |
-
)
|
| 655 |
-
part_image_to_whole_text = part_image_to_whole_text + bias
|
| 656 |
-
part_text_to_whole_image = part_text_to_whole_image + bias
|
| 657 |
-
|
| 658 |
-
global_local_contrastive = 0.5 * (
|
| 659 |
-
_contrastive_loss(part_image_to_whole_text, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 660 |
-
+ _contrastive_loss(part_text_to_whole_image, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 661 |
-
)
|
| 662 |
-
|
| 663 |
-
if global_local_angle_aux_weight > 0.0:
|
| 664 |
-
if global_local_angle_aux_mode == "positive_hinge":
|
| 665 |
-
positive_text = all_text_for_global_local.index_select(0, global_local_targets)
|
| 666 |
-
positive_image = all_image_for_global_local.index_select(0, global_local_targets)
|
| 667 |
-
global_local_angle_aux = 0.5 * (
|
| 668 |
-
weighted_entailment_residual(
|
| 669 |
-
specific=part_image_flat,
|
| 670 |
-
general=positive_text,
|
| 671 |
-
kappa=kappa,
|
| 672 |
-
aperture_scale=global_local_angle_aux_aperture_scale,
|
| 673 |
-
weights=part_weights,
|
| 674 |
-
)
|
| 675 |
-
+ weighted_entailment_residual(
|
| 676 |
-
specific=part_text_flat,
|
| 677 |
-
general=positive_image,
|
| 678 |
-
kappa=kappa,
|
| 679 |
-
aperture_scale=global_local_angle_aux_aperture_scale,
|
| 680 |
-
weights=part_weights,
|
| 681 |
-
)
|
| 682 |
-
)
|
| 683 |
-
elif loss_type != "siglip":
|
| 684 |
-
angle_scale = part_image_flat.new_tensor(float(global_local_angle_aux_scale))
|
| 685 |
-
part_image_to_whole_text_angle = -metric_pairwise_oxy_angle(
|
| 686 |
-
part_image_flat,
|
| 687 |
-
all_text_for_global_local,
|
| 688 |
-
kappa,
|
| 689 |
-
product_metric=product_metric,
|
| 690 |
-
) * image_temp[:, None] * angle_scale
|
| 691 |
-
part_text_to_whole_image_angle = -metric_pairwise_oxy_angle(
|
| 692 |
-
part_text_flat,
|
| 693 |
-
all_image_for_global_local,
|
| 694 |
-
kappa,
|
| 695 |
-
product_metric=product_metric,
|
| 696 |
-
) * text_temp[:, None] * angle_scale
|
| 697 |
-
if loss_type in {"sigmoid", "siglip_metric"}:
|
| 698 |
-
bias = (
|
| 699 |
-
part_image_to_whole_text_angle.new_zeros(())
|
| 700 |
-
if global_local_logit_bias is None
|
| 701 |
-
else global_local_logit_bias.to(part_image_to_whole_text_angle.device)
|
| 702 |
-
)
|
| 703 |
-
part_image_to_whole_text_angle = part_image_to_whole_text_angle + bias
|
| 704 |
-
part_text_to_whole_image_angle = part_text_to_whole_image_angle + bias
|
| 705 |
-
global_local_angle_aux = 0.5 * (
|
| 706 |
-
_contrastive_loss(
|
| 707 |
-
part_image_to_whole_text_angle,
|
| 708 |
-
global_local_targets,
|
| 709 |
-
part_weights,
|
| 710 |
-
loss_type,
|
| 711 |
-
sigmoid_negative_weight,
|
| 712 |
-
)
|
| 713 |
-
+ _contrastive_loss(
|
| 714 |
-
part_text_to_whole_image_angle,
|
| 715 |
-
global_local_targets,
|
| 716 |
-
part_weights,
|
| 717 |
-
loss_type,
|
| 718 |
-
sigmoid_negative_weight,
|
| 719 |
-
)
|
| 720 |
-
)
|
| 721 |
-
|
| 722 |
-
beta_cal = image_feats.new_zeros(())
|
| 723 |
-
if beta_cal_weight > 0.0 and beta_cal_beta > 0.0:
|
| 724 |
-
if part_group_ids is None or all_part_group_ids is None:
|
| 725 |
-
raise ValueError("beta_cal requires part_group_ids and all_part_group_ids to be provided")
|
| 726 |
-
beta_cal = 0.5 * (
|
| 727 |
-
beta_cal_loss(
|
| 728 |
-
part_image_logits,
|
| 729 |
-
targets=part_targets,
|
| 730 |
-
group_ids=part_group_ids,
|
| 731 |
-
all_group_ids=all_part_group_ids,
|
| 732 |
-
beta=beta_cal_beta,
|
| 733 |
-
variant=beta_cal_variant,
|
| 734 |
-
weights=part_weights,
|
| 735 |
-
)
|
| 736 |
-
+ beta_cal_loss(
|
| 737 |
-
part_text_logits,
|
| 738 |
-
targets=part_targets,
|
| 739 |
-
group_ids=part_group_ids,
|
| 740 |
-
all_group_ids=all_part_group_ids,
|
| 741 |
-
beta=beta_cal_beta,
|
| 742 |
-
variant=beta_cal_variant,
|
| 743 |
-
weights=part_weights,
|
| 744 |
-
)
|
| 745 |
-
)
|
| 746 |
-
|
| 747 |
-
contrastive = (
|
| 748 |
-
contrastive_global_weight * global_contrastive
|
| 749 |
-
+ contrastive_local_weight * local_contrastive
|
| 750 |
-
+ contrastive_global_local_weight * global_local_contrastive
|
| 751 |
-
+ global_local_angle_aux_weight * global_local_angle_aux
|
| 752 |
-
+ beta_cal_weight * beta_cal
|
| 753 |
-
)
|
| 754 |
-
return {
|
| 755 |
-
"contrastive_loss": contrastive,
|
| 756 |
-
"global_contrastive_loss": global_contrastive,
|
| 757 |
-
"local_contrastive_loss": local_contrastive,
|
| 758 |
-
"global_local_contrastive_loss": global_local_contrastive,
|
| 759 |
-
"global_local_angle_aux_loss": global_local_angle_aux,
|
| 760 |
-
"beta_cal_loss": beta_cal,
|
| 761 |
-
}
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
def _contrastive_loss(
|
| 765 |
-
logits: Tensor,
|
| 766 |
-
targets: Tensor,
|
| 767 |
-
weights: Tensor | None,
|
| 768 |
-
loss_type: str,
|
| 769 |
-
sigmoid_negative_weight: float,
|
| 770 |
-
) -> Tensor:
|
| 771 |
-
if loss_type == "ce":
|
| 772 |
-
return contrastive_ce(logits, targets, weights)
|
| 773 |
-
if loss_type == "sigmoid":
|
| 774 |
-
return contrastive_sigmoid(logits, targets, weights, negative_weight=sigmoid_negative_weight)
|
| 775 |
-
if loss_type in {"siglip", "siglip_metric"}:
|
| 776 |
-
return contrastive_siglip(logits, targets, weights, negative_weight=sigmoid_negative_weight)
|
| 777 |
-
raise ValueError(f"Unsupported contrastive loss {loss_type!r}")
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
def uncha_entailment_losses(
|
| 781 |
-
image_feats: Tensor,
|
| 782 |
-
text_feats: Tensor,
|
| 783 |
-
part_image_flat: Tensor,
|
| 784 |
-
part_text_flat: Tensor,
|
| 785 |
-
image_for_parts: Tensor,
|
| 786 |
-
text_for_parts: Tensor,
|
| 787 |
-
kappa: Tensor,
|
| 788 |
-
inter_aperture_scale: float,
|
| 789 |
-
intra_aperture_scale: float,
|
| 790 |
-
piecewise_factor: float = 0.1,
|
| 791 |
-
calibration_alpha: float = 10.0,
|
| 792 |
-
stop_grad_calibration: bool = True,
|
| 793 |
-
geometry: str = "lorentz",
|
| 794 |
-
part_weights: Tensor | None = None,
|
| 795 |
-
) -> dict[str, Tensor]:
|
| 796 |
-
text_image = piecewise_entailment_residual(
|
| 797 |
-
specific=image_feats,
|
| 798 |
-
general=text_feats,
|
| 799 |
-
kappa=kappa,
|
| 800 |
-
aperture_scale=inter_aperture_scale,
|
| 801 |
-
factor=piecewise_factor,
|
| 802 |
-
geometry=geometry,
|
| 803 |
-
)
|
| 804 |
-
text_image_entailment = 0.5 * text_image.mean()
|
| 805 |
-
|
| 806 |
-
if part_image_flat.numel() == 0:
|
| 807 |
-
zero = image_feats.new_zeros(())
|
| 808 |
-
return {
|
| 809 |
-
"entailment_loss": text_image_entailment,
|
| 810 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 811 |
-
"part_text_image_entailment_loss": zero,
|
| 812 |
-
"cross_image_entailment_loss": zero,
|
| 813 |
-
"cross_text_entailment_loss": zero,
|
| 814 |
-
"cross_image_calibration_loss": zero,
|
| 815 |
-
"cross_text_calibration_loss": zero,
|
| 816 |
-
}
|
| 817 |
-
|
| 818 |
-
part_text_image = piecewise_entailment_residual(
|
| 819 |
-
specific=part_image_flat,
|
| 820 |
-
general=part_text_flat,
|
| 821 |
-
kappa=kappa,
|
| 822 |
-
aperture_scale=inter_aperture_scale,
|
| 823 |
-
factor=piecewise_factor,
|
| 824 |
-
geometry=geometry,
|
| 825 |
-
)
|
| 826 |
-
cross_image = piecewise_entailment_residual(
|
| 827 |
-
specific=image_for_parts,
|
| 828 |
-
general=part_image_flat,
|
| 829 |
-
kappa=kappa,
|
| 830 |
-
aperture_scale=intra_aperture_scale,
|
| 831 |
-
factor=piecewise_factor,
|
| 832 |
-
geometry=geometry,
|
| 833 |
-
)
|
| 834 |
-
cross_text = piecewise_entailment_residual(
|
| 835 |
-
specific=text_for_parts,
|
| 836 |
-
general=part_text_flat,
|
| 837 |
-
kappa=kappa,
|
| 838 |
-
aperture_scale=intra_aperture_scale,
|
| 839 |
-
factor=piecewise_factor,
|
| 840 |
-
geometry=geometry,
|
| 841 |
-
)
|
| 842 |
-
|
| 843 |
-
part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
|
| 844 |
-
cross_image_entailment, cross_image_calibration = uncertainty_calibrated_entailment_loss(
|
| 845 |
-
cross_image,
|
| 846 |
-
embedding_uncertainty(part_image_flat),
|
| 847 |
-
alpha=calibration_alpha,
|
| 848 |
-
stop_grad=stop_grad_calibration,
|
| 849 |
-
weights=part_weights,
|
| 850 |
-
)
|
| 851 |
-
cross_text_entailment, cross_text_calibration = uncertainty_calibrated_entailment_loss(
|
| 852 |
-
cross_text,
|
| 853 |
-
embedding_uncertainty(part_text_flat),
|
| 854 |
-
alpha=calibration_alpha,
|
| 855 |
-
stop_grad=stop_grad_calibration,
|
| 856 |
-
weights=part_weights,
|
| 857 |
-
)
|
| 858 |
-
|
| 859 |
-
entailment = (
|
| 860 |
-
text_image_entailment
|
| 861 |
-
+ part_text_image_entailment
|
| 862 |
-
+ 0.5 * (cross_image_entailment + cross_text_entailment)
|
| 863 |
-
+ cross_image_calibration
|
| 864 |
-
+ cross_text_calibration
|
| 865 |
-
)
|
| 866 |
-
return {
|
| 867 |
-
"entailment_loss": entailment,
|
| 868 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 869 |
-
"part_text_image_entailment_loss": part_text_image_entailment,
|
| 870 |
-
"cross_image_entailment_loss": cross_image_entailment,
|
| 871 |
-
"cross_text_entailment_loss": cross_text_entailment,
|
| 872 |
-
"cross_image_calibration_loss": cross_image_calibration,
|
| 873 |
-
"cross_text_calibration_loss": cross_text_calibration,
|
| 874 |
-
}
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
def uncha_argent_entailment_losses(
|
| 878 |
-
image_feats: Tensor,
|
| 879 |
-
text_feats: Tensor,
|
| 880 |
-
part_image_flat: Tensor,
|
| 881 |
-
part_text_flat: Tensor,
|
| 882 |
-
image_for_parts: Tensor,
|
| 883 |
-
text_for_parts: Tensor,
|
| 884 |
-
kappa: Tensor,
|
| 885 |
-
beta: float = 1.0,
|
| 886 |
-
part_weights: Tensor | None = None,
|
| 887 |
-
product_metric: str = "l1",
|
| 888 |
-
aggregation: str = "uncha",
|
| 889 |
-
) -> dict[str, Tensor]:
|
| 890 |
-
if aggregation not in {"uncha", "equal"}:
|
| 891 |
-
raise ValueError("aggregation must be 'uncha' or 'equal'")
|
| 892 |
-
text_image = argent_adaptive_entailment_residual(
|
| 893 |
-
specific=image_feats,
|
| 894 |
-
general=text_feats,
|
| 895 |
-
kappa=kappa,
|
| 896 |
-
adaptive_weight=False,
|
| 897 |
-
beta=beta,
|
| 898 |
-
product_metric=product_metric,
|
| 899 |
-
)
|
| 900 |
-
text_image_entailment = 0.5 * text_image.mean()
|
| 901 |
-
|
| 902 |
-
if part_image_flat.numel() == 0:
|
| 903 |
-
zero = image_feats.new_zeros(())
|
| 904 |
-
norm_regularization = argent_norm_regularization_loss(image_feats, text_feats)
|
| 905 |
-
return {
|
| 906 |
-
"entailment_loss": text_image_entailment,
|
| 907 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 908 |
-
"part_text_image_entailment_loss": zero,
|
| 909 |
-
"cross_image_entailment_loss": zero,
|
| 910 |
-
"cross_text_entailment_loss": zero,
|
| 911 |
-
"cross_image_calibration_loss": zero,
|
| 912 |
-
"cross_text_calibration_loss": zero,
|
| 913 |
-
"norm_regularization_loss": norm_regularization,
|
| 914 |
-
}
|
| 915 |
-
|
| 916 |
-
part_text_image = argent_adaptive_entailment_residual(
|
| 917 |
-
specific=part_image_flat,
|
| 918 |
-
general=part_text_flat,
|
| 919 |
-
kappa=kappa,
|
| 920 |
-
adaptive_weight=False,
|
| 921 |
-
beta=beta,
|
| 922 |
-
product_metric=product_metric,
|
| 923 |
-
)
|
| 924 |
-
cross_image = argent_adaptive_entailment_residual(
|
| 925 |
-
specific=image_for_parts,
|
| 926 |
-
general=part_image_flat,
|
| 927 |
-
kappa=kappa,
|
| 928 |
-
adaptive_weight=True,
|
| 929 |
-
beta=beta,
|
| 930 |
-
product_metric=product_metric,
|
| 931 |
-
)
|
| 932 |
-
cross_text = argent_adaptive_entailment_residual(
|
| 933 |
-
specific=text_for_parts,
|
| 934 |
-
general=part_text_flat,
|
| 935 |
-
kappa=kappa,
|
| 936 |
-
adaptive_weight=True,
|
| 937 |
-
beta=beta,
|
| 938 |
-
product_metric=product_metric,
|
| 939 |
-
)
|
| 940 |
-
|
| 941 |
-
part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
|
| 942 |
-
cross_image_entailment = 0.5 * weighted_mean(cross_image, part_weights)
|
| 943 |
-
cross_text_entailment = 0.5 * weighted_mean(cross_text, part_weights)
|
| 944 |
-
norm_regularization = argent_norm_regularization_loss(image_feats, text_feats, part_image_flat, part_text_flat)
|
| 945 |
-
if aggregation == "equal":
|
| 946 |
-
entailment = text_image_entailment + part_text_image_entailment + cross_image_entailment + cross_text_entailment
|
| 947 |
-
else:
|
| 948 |
-
entailment = text_image_entailment + part_text_image_entailment + 0.5 * (
|
| 949 |
-
cross_image_entailment + cross_text_entailment
|
| 950 |
-
)
|
| 951 |
-
diagnostics = argent_entailment_diagnostics(
|
| 952 |
-
image_feats=image_feats,
|
| 953 |
-
text_feats=text_feats,
|
| 954 |
-
part_image_flat=part_image_flat,
|
| 955 |
-
part_text_flat=part_text_flat,
|
| 956 |
-
image_for_parts=image_for_parts,
|
| 957 |
-
text_for_parts=text_for_parts,
|
| 958 |
-
kappa=kappa,
|
| 959 |
-
product_metric=product_metric,
|
| 960 |
-
)
|
| 961 |
-
|
| 962 |
-
return {
|
| 963 |
-
"entailment_loss": entailment,
|
| 964 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 965 |
-
"part_text_image_entailment_loss": part_text_image_entailment,
|
| 966 |
-
"cross_image_entailment_loss": cross_image_entailment,
|
| 967 |
-
"cross_text_entailment_loss": cross_text_entailment,
|
| 968 |
-
"cross_image_calibration_loss": image_feats.new_zeros(()),
|
| 969 |
-
"cross_text_calibration_loss": image_feats.new_zeros(()),
|
| 970 |
-
"norm_regularization_loss": norm_regularization,
|
| 971 |
-
**diagnostics,
|
| 972 |
-
}
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
def hierarchical_beta_argent_entailment_losses(
|
| 976 |
-
image_feats: Tensor,
|
| 977 |
-
text_feats: Tensor,
|
| 978 |
-
part_image_flat: Tensor,
|
| 979 |
-
part_text_flat: Tensor,
|
| 980 |
-
image_for_parts: Tensor,
|
| 981 |
-
text_for_parts: Tensor,
|
| 982 |
-
beta_query_image_feats: Tensor,
|
| 983 |
-
beta_query_text_feats: Tensor,
|
| 984 |
-
beta_query_owner: Tensor,
|
| 985 |
-
beta_query_parent: Tensor,
|
| 986 |
-
beta_query_weight: Tensor,
|
| 987 |
-
kappa: Tensor,
|
| 988 |
-
beta_query_source_part: Tensor | None = None,
|
| 989 |
-
beta: float = 1.0,
|
| 990 |
-
part_weights: Tensor | None = None,
|
| 991 |
-
product_metric: str = "l1",
|
| 992 |
-
aggregation: str = "uncha",
|
| 993 |
-
) -> dict[str, Tensor]:
|
| 994 |
-
base = uncha_argent_entailment_losses(
|
| 995 |
-
image_feats=image_feats,
|
| 996 |
-
text_feats=text_feats,
|
| 997 |
-
part_image_flat=part_image_flat,
|
| 998 |
-
part_text_flat=part_text_flat,
|
| 999 |
-
image_for_parts=image_for_parts,
|
| 1000 |
-
text_for_parts=text_for_parts,
|
| 1001 |
-
kappa=kappa,
|
| 1002 |
-
beta=beta,
|
| 1003 |
-
part_weights=part_weights,
|
| 1004 |
-
product_metric=product_metric,
|
| 1005 |
-
aggregation=aggregation,
|
| 1006 |
-
)
|
| 1007 |
-
if beta_query_image_feats.numel() == 0:
|
| 1008 |
-
return {
|
| 1009 |
-
**base,
|
| 1010 |
-
"hier_beta_query_text_entailment_loss": image_feats.new_zeros(()),
|
| 1011 |
-
"hier_beta_visual_entailment_loss": image_feats.new_zeros(()),
|
| 1012 |
-
"hier_beta_text_entailment_loss": image_feats.new_zeros(()),
|
| 1013 |
-
"hier_beta_sourcepart_visual_entailment_loss": image_feats.new_zeros(()),
|
| 1014 |
-
"hier_beta_sourcepart_text_entailment_loss": image_feats.new_zeros(()),
|
| 1015 |
-
"hier_beta_query_count": beta_query_owner.new_tensor(0),
|
| 1016 |
-
"hier_beta_sourcepart_query_count": beta_query_owner.new_tensor(0),
|
| 1017 |
-
}
|
| 1018 |
-
|
| 1019 |
-
query_owner = beta_query_owner.to(device=image_feats.device, dtype=torch.long)
|
| 1020 |
-
query_weights = beta_query_weight.to(device=image_feats.device, dtype=torch.float32).clamp_min(0.0)
|
| 1021 |
-
if query_weights.numel() != beta_query_image_feats.size(0):
|
| 1022 |
-
raise ValueError("beta_query_weight must have one value per beta query")
|
| 1023 |
-
query_weights = query_weights / query_weights.mean().clamp_min(torch.finfo(query_weights.dtype).eps)
|
| 1024 |
-
|
| 1025 |
-
query_text = argent_adaptive_entailment_residual(
|
| 1026 |
-
specific=beta_query_image_feats,
|
| 1027 |
-
general=beta_query_text_feats,
|
| 1028 |
-
kappa=kappa,
|
| 1029 |
-
adaptive_weight=False,
|
| 1030 |
-
beta=beta,
|
| 1031 |
-
product_metric=product_metric,
|
| 1032 |
-
)
|
| 1033 |
-
visual_hierarchy = argent_adaptive_entailment_residual(
|
| 1034 |
-
specific=image_feats.index_select(0, query_owner),
|
| 1035 |
-
general=beta_query_image_feats,
|
| 1036 |
-
kappa=kappa,
|
| 1037 |
-
adaptive_weight=True,
|
| 1038 |
-
beta=beta,
|
| 1039 |
-
product_metric=product_metric,
|
| 1040 |
-
)
|
| 1041 |
-
query_text_entailment = 0.5 * weighted_mean(query_text, query_weights)
|
| 1042 |
-
visual_entailment = 0.5 * weighted_mean(visual_hierarchy, query_weights)
|
| 1043 |
-
|
| 1044 |
-
parent = beta_query_parent.to(device=image_feats.device, dtype=torch.long)
|
| 1045 |
-
parent_mask = (parent >= 0) & (parent < beta_query_text_feats.size(0)) & (query_weights > 0.0)
|
| 1046 |
-
if bool(parent_mask.any()):
|
| 1047 |
-
child_text = beta_query_text_feats[parent_mask]
|
| 1048 |
-
parent_text = beta_query_text_feats[parent[parent_mask]]
|
| 1049 |
-
text_hierarchy = argent_adaptive_entailment_residual(
|
| 1050 |
-
specific=parent_text,
|
| 1051 |
-
general=child_text,
|
| 1052 |
-
kappa=kappa,
|
| 1053 |
-
adaptive_weight=True,
|
| 1054 |
-
beta=beta,
|
| 1055 |
-
product_metric=product_metric,
|
| 1056 |
-
)
|
| 1057 |
-
text_entailment = 0.5 * weighted_mean(text_hierarchy, query_weights[parent_mask])
|
| 1058 |
-
else:
|
| 1059 |
-
text_entailment = image_feats.new_zeros(())
|
| 1060 |
-
|
| 1061 |
-
sourcepart_visual_entailment = image_feats.new_zeros(())
|
| 1062 |
-
sourcepart_text_entailment = image_feats.new_zeros(())
|
| 1063 |
-
sourcepart_query_count = beta_query_owner.new_tensor(0)
|
| 1064 |
-
if beta_query_source_part is not None and part_image_flat.numel() > 0:
|
| 1065 |
-
source_part = beta_query_source_part.to(device=image_feats.device, dtype=torch.long)
|
| 1066 |
-
if source_part.numel() != beta_query_image_feats.size(0):
|
| 1067 |
-
raise ValueError("beta_query_source_part must have one value per beta query")
|
| 1068 |
-
source_mask = (
|
| 1069 |
-
(source_part >= 0)
|
| 1070 |
-
& (source_part < part_image_flat.size(0))
|
| 1071 |
-
& (query_weights > 0.0)
|
| 1072 |
-
)
|
| 1073 |
-
if bool(source_mask.any()):
|
| 1074 |
-
source_indices = source_part[source_mask]
|
| 1075 |
-
sourcepart_visual = argent_adaptive_entailment_residual(
|
| 1076 |
-
specific=part_image_flat.index_select(0, source_indices),
|
| 1077 |
-
general=beta_query_image_feats[source_mask],
|
| 1078 |
-
kappa=kappa,
|
| 1079 |
-
adaptive_weight=True,
|
| 1080 |
-
beta=beta,
|
| 1081 |
-
product_metric=product_metric,
|
| 1082 |
-
)
|
| 1083 |
-
sourcepart_text = argent_adaptive_entailment_residual(
|
| 1084 |
-
specific=part_text_flat.index_select(0, source_indices),
|
| 1085 |
-
general=beta_query_text_feats[source_mask],
|
| 1086 |
-
kappa=kappa,
|
| 1087 |
-
adaptive_weight=True,
|
| 1088 |
-
beta=beta,
|
| 1089 |
-
product_metric=product_metric,
|
| 1090 |
-
)
|
| 1091 |
-
source_weights = query_weights[source_mask]
|
| 1092 |
-
sourcepart_visual_entailment = 0.5 * weighted_mean(sourcepart_visual, source_weights)
|
| 1093 |
-
sourcepart_text_entailment = 0.5 * weighted_mean(sourcepart_text, source_weights)
|
| 1094 |
-
sourcepart_query_count = beta_query_owner.new_tensor(int(source_mask.sum().item()))
|
| 1095 |
-
|
| 1096 |
-
norm_regularization = argent_norm_regularization_loss(
|
| 1097 |
-
image_feats,
|
| 1098 |
-
text_feats,
|
| 1099 |
-
part_image_flat,
|
| 1100 |
-
part_text_flat,
|
| 1101 |
-
beta_query_image_feats,
|
| 1102 |
-
beta_query_text_feats,
|
| 1103 |
-
)
|
| 1104 |
-
sourcepart_entailment = 0.5 * (sourcepart_visual_entailment + sourcepart_text_entailment)
|
| 1105 |
-
query_entailment = query_text_entailment + 0.5 * (visual_entailment + text_entailment) + sourcepart_entailment
|
| 1106 |
-
return {
|
| 1107 |
-
**base,
|
| 1108 |
-
"entailment_loss": base["entailment_loss"] + query_entailment,
|
| 1109 |
-
"norm_regularization_loss": norm_regularization,
|
| 1110 |
-
"hier_beta_query_text_entailment_loss": query_text_entailment,
|
| 1111 |
-
"hier_beta_visual_entailment_loss": visual_entailment,
|
| 1112 |
-
"hier_beta_text_entailment_loss": text_entailment,
|
| 1113 |
-
"hier_beta_sourcepart_visual_entailment_loss": sourcepart_visual_entailment,
|
| 1114 |
-
"hier_beta_sourcepart_text_entailment_loss": sourcepart_text_entailment,
|
| 1115 |
-
"hier_beta_query_count": beta_query_owner.new_tensor(beta_query_owner.numel()),
|
| 1116 |
-
"hier_beta_sourcepart_query_count": sourcepart_query_count,
|
| 1117 |
-
}
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
def argent_entailment_diagnostics(
|
| 1121 |
-
image_feats: Tensor,
|
| 1122 |
-
text_feats: Tensor,
|
| 1123 |
-
part_image_flat: Tensor,
|
| 1124 |
-
part_text_flat: Tensor,
|
| 1125 |
-
image_for_parts: Tensor,
|
| 1126 |
-
text_for_parts: Tensor,
|
| 1127 |
-
kappa: Tensor,
|
| 1128 |
-
product_metric: str = "l1",
|
| 1129 |
-
) -> dict[str, Tensor]:
|
| 1130 |
-
zero = image_feats.new_zeros(())
|
| 1131 |
-
|
| 1132 |
-
def angle_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1133 |
-
if specific.numel() == 0:
|
| 1134 |
-
return zero
|
| 1135 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1136 |
-
if angles.dim() == 2:
|
| 1137 |
-
angles = angles.mean(dim=-1)
|
| 1138 |
-
return angles.detach().mean()
|
| 1139 |
-
|
| 1140 |
-
def pent_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1141 |
-
if specific.numel() == 0:
|
| 1142 |
-
return zero
|
| 1143 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1144 |
-
if angles.dim() == 2:
|
| 1145 |
-
angles = angles.mean(dim=-1)
|
| 1146 |
-
scores = torch.clamp(1.0 - (2.0 * angles / math.pi), min=0.0, max=1.0)
|
| 1147 |
-
return scores.detach().mean()
|
| 1148 |
-
|
| 1149 |
-
def distance_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1150 |
-
if specific.numel() == 0:
|
| 1151 |
-
return zero
|
| 1152 |
-
return lorentz_dist(specific, general, kappa, product_metric=product_metric).detach().mean()
|
| 1153 |
-
|
| 1154 |
-
def adaptive_weight_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1155 |
-
if specific.numel() == 0:
|
| 1156 |
-
return zero
|
| 1157 |
-
weights = 1.0 - torch.exp(-lorentz_dist(specific, general, kappa, product_metric=product_metric))
|
| 1158 |
-
return weights.detach().mean()
|
| 1159 |
-
|
| 1160 |
-
def space_norm_mean(embedding: Tensor) -> Tensor:
|
| 1161 |
-
if embedding.numel() == 0:
|
| 1162 |
-
return zero
|
| 1163 |
-
return torch.linalg.norm(_space_components(embedding).float(), dim=-1).detach().mean()
|
| 1164 |
-
|
| 1165 |
-
return {
|
| 1166 |
-
"argent_text_image_angle_mean": angle_mean(image_feats, text_feats),
|
| 1167 |
-
"argent_text_image_pent_mean": pent_mean(image_feats, text_feats),
|
| 1168 |
-
"argent_part_text_image_angle_mean": angle_mean(part_image_flat, part_text_flat),
|
| 1169 |
-
"argent_part_text_image_pent_mean": pent_mean(part_image_flat, part_text_flat),
|
| 1170 |
-
"argent_cross_image_angle_mean": angle_mean(image_for_parts, part_image_flat),
|
| 1171 |
-
"argent_cross_image_pent_mean": pent_mean(image_for_parts, part_image_flat),
|
| 1172 |
-
"argent_cross_image_distance_mean": distance_mean(image_for_parts, part_image_flat),
|
| 1173 |
-
"argent_cross_image_adaptive_weight_mean": adaptive_weight_mean(image_for_parts, part_image_flat),
|
| 1174 |
-
"argent_cross_text_angle_mean": angle_mean(text_for_parts, part_text_flat),
|
| 1175 |
-
"argent_cross_text_pent_mean": pent_mean(text_for_parts, part_text_flat),
|
| 1176 |
-
"argent_cross_text_distance_mean": distance_mean(text_for_parts, part_text_flat),
|
| 1177 |
-
"argent_cross_text_adaptive_weight_mean": adaptive_weight_mean(text_for_parts, part_text_flat),
|
| 1178 |
-
"argent_image_space_norm_mean": space_norm_mean(image_feats),
|
| 1179 |
-
"argent_text_space_norm_mean": space_norm_mean(text_feats),
|
| 1180 |
-
"argent_part_image_space_norm_mean": space_norm_mean(part_image_flat),
|
| 1181 |
-
"argent_part_text_space_norm_mean": space_norm_mean(part_text_flat),
|
| 1182 |
-
}
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
def part_quality_weights(
|
| 1186 |
-
image_for_parts: Tensor,
|
| 1187 |
-
text_for_parts: Tensor,
|
| 1188 |
-
part_image_flat: Tensor,
|
| 1189 |
-
part_text_flat: Tensor,
|
| 1190 |
-
part_owner: Tensor,
|
| 1191 |
-
batch_size: int,
|
| 1192 |
-
kappa: Tensor,
|
| 1193 |
-
mode: str,
|
| 1194 |
-
topk: int = 5,
|
| 1195 |
-
temperature: float = 4.0,
|
| 1196 |
-
product_metric: str = "l1",
|
| 1197 |
-
) -> tuple[Tensor | None, Tensor, Tensor]:
|
| 1198 |
-
if mode not in {"none", "soft", "topk"}:
|
| 1199 |
-
raise ValueError(f"Unsupported part quality mode {mode!r}; expected 'none', 'soft', or 'topk'")
|
| 1200 |
-
if mode == "none" or part_image_flat.numel() == 0:
|
| 1201 |
-
empty = part_image_flat.new_zeros((part_image_flat.size(0),))
|
| 1202 |
-
return None, empty, empty
|
| 1203 |
-
|
| 1204 |
-
with torch.no_grad():
|
| 1205 |
-
image_parent = torch.exp(-lorentz_dist(part_image_flat, image_for_parts, kappa, product_metric=product_metric))
|
| 1206 |
-
text_parent = torch.exp(-lorentz_dist(part_text_flat, text_for_parts, kappa, product_metric=product_metric))
|
| 1207 |
-
image_text = torch.exp(-lorentz_dist(part_image_flat, part_text_flat, kappa, product_metric=product_metric))
|
| 1208 |
-
scores = torch.stack([image_parent, text_parent, image_text]).mean(dim=0).clamp_min(0.0)
|
| 1209 |
-
|
| 1210 |
-
if mode == "soft":
|
| 1211 |
-
weights = _owner_softmax_weights(scores, part_owner, batch_size, temperature)
|
| 1212 |
-
else:
|
| 1213 |
-
weights = _owner_topk_weights(scores, part_owner, batch_size, topk)
|
| 1214 |
-
weights = weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
|
| 1215 |
-
return weights, scores, (weights > 0.0).to(dtype=scores.dtype)
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
def _owner_softmax_weights(scores: Tensor, part_owner: Tensor, batch_size: int, temperature: float) -> Tensor:
|
| 1219 |
-
weights = torch.zeros_like(scores)
|
| 1220 |
-
for owner in range(batch_size):
|
| 1221 |
-
mask = part_owner == owner
|
| 1222 |
-
if not bool(mask.any()):
|
| 1223 |
-
continue
|
| 1224 |
-
owner_scores = scores[mask]
|
| 1225 |
-
owner_weights = torch.softmax(owner_scores * temperature, dim=0) * owner_scores.numel()
|
| 1226 |
-
weights[mask] = owner_weights
|
| 1227 |
-
return weights
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
def _owner_topk_weights(scores: Tensor, part_owner: Tensor, batch_size: int, topk: int) -> Tensor:
|
| 1231 |
-
if topk <= 0:
|
| 1232 |
-
raise ValueError("topk must be positive for top-k part quality weighting")
|
| 1233 |
-
weights = torch.zeros_like(scores)
|
| 1234 |
-
for owner in range(batch_size):
|
| 1235 |
-
indices = torch.nonzero(part_owner == owner, as_tuple=False).flatten()
|
| 1236 |
-
if indices.numel() == 0:
|
| 1237 |
-
continue
|
| 1238 |
-
keep = min(topk, indices.numel())
|
| 1239 |
-
selected = indices[scores[indices].topk(k=keep).indices]
|
| 1240 |
-
weights[selected] = 1.0
|
| 1241 |
-
return weights
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
def argent_adaptive_entailment_residual(
|
| 1245 |
-
specific: Tensor,
|
| 1246 |
-
general: Tensor,
|
| 1247 |
-
kappa: Tensor,
|
| 1248 |
-
adaptive_weight: bool,
|
| 1249 |
-
beta: float = 1.0,
|
| 1250 |
-
product_metric: str = "l1",
|
| 1251 |
-
) -> Tensor:
|
| 1252 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1253 |
-
if angles.dim() == 2:
|
| 1254 |
-
angles = angles.mean(dim=-1)
|
| 1255 |
-
if adaptive_weight:
|
| 1256 |
-
weights = 1.0 - torch.exp(
|
| 1257 |
-
-lorentz_dist(specific=specific, general=general, kappa=kappa, product_metric=product_metric)
|
| 1258 |
-
)
|
| 1259 |
-
angles = angles * weights
|
| 1260 |
-
return F.huber_loss(angles, torch.zeros_like(angles), delta=beta, reduction="none")
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
def lorentz_dist(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 1264 |
-
return paired_dist(specific, general, kappa, product_metric=product_metric)
|
| 1265 |
-
|
| 1266 |
-
|
| 1267 |
-
def argent_norm_regularization_loss(*embeddings: Tensor, eps: float = 1e-6) -> Tensor:
|
| 1268 |
-
losses = []
|
| 1269 |
-
for embedding in embeddings:
|
| 1270 |
-
if embedding.numel() == 0:
|
| 1271 |
-
continue
|
| 1272 |
-
space = _space_components(embedding)
|
| 1273 |
-
space_norm = torch.linalg.norm(space.float(), dim=-1).clamp_min(eps)
|
| 1274 |
-
losses.append((space_norm.square() - torch.log(space_norm)).mean())
|
| 1275 |
-
if not losses:
|
| 1276 |
-
raise ValueError("argent_norm_regularization_loss requires at least one non-empty embedding tensor")
|
| 1277 |
-
return torch.stack(losses).mean()
|
| 1278 |
-
|
| 1279 |
-
|
| 1280 |
-
def piecewise_entailment_residual(
|
| 1281 |
-
specific: Tensor,
|
| 1282 |
-
general: Tensor,
|
| 1283 |
-
kappa: Tensor,
|
| 1284 |
-
aperture_scale: float,
|
| 1285 |
-
factor: float = 0.1,
|
| 1286 |
-
geometry: str = "lorentz",
|
| 1287 |
-
) -> Tensor:
|
| 1288 |
-
if geometry == "lorentz":
|
| 1289 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1290 |
-
apertures = factor_half_aperture(general=general, kappa=kappa)
|
| 1291 |
-
elif geometry == "euclidean":
|
| 1292 |
-
angles = euclidean_angle(specific=specific, general=general)
|
| 1293 |
-
apertures = euclidean_half_aperture(general=general, aperture_scale=aperture_scale)
|
| 1294 |
-
aperture_scale = 1.0
|
| 1295 |
-
else:
|
| 1296 |
-
raise ValueError(f"Unsupported entailment geometry {geometry!r}; expected 'lorentz' or 'euclidean'")
|
| 1297 |
-
residual = angles - aperture_scale * apertures
|
| 1298 |
-
loss = torch.where(residual > 0.0, residual + factor * angles, factor * angles)
|
| 1299 |
-
return loss.mean(dim=-1) if loss.dim() == 2 else loss
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
def euclidean_angle(specific: Tensor, general: Tensor, eps: float = 1e-6) -> Tensor:
|
| 1303 |
-
specific_space = _space_components(specific).float()
|
| 1304 |
-
general_space = _space_components(general).float()
|
| 1305 |
-
numerator = (specific_space * general_space).sum(dim=-1)
|
| 1306 |
-
denominator = torch.linalg.norm(specific_space, dim=-1) * torch.linalg.norm(general_space, dim=-1)
|
| 1307 |
-
dtype_eps = torch.finfo(specific_space.dtype).eps
|
| 1308 |
-
angle_eps = max(eps, 16.0 * dtype_eps)
|
| 1309 |
-
cosine = (numerator / denominator.clamp_min(angle_eps)).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 1310 |
-
return torch.acos(cosine)
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
def euclidean_half_aperture(general: Tensor, aperture_scale: float, eps: float = 1e-8) -> Tensor:
|
| 1314 |
-
general_norm = torch.linalg.norm(_space_components(general).float(), dim=-1).clamp_min(eps)
|
| 1315 |
-
return torch.atan(torch.as_tensor(aperture_scale, device=general.device, dtype=general.dtype) / general_norm)
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
def aggregate_part_consistency_loss(
|
| 1319 |
-
image_feats: Tensor,
|
| 1320 |
-
text_feats: Tensor,
|
| 1321 |
-
part_image_flat: Tensor,
|
| 1322 |
-
part_text_flat: Tensor,
|
| 1323 |
-
part_owner: Tensor,
|
| 1324 |
-
part_weights: Tensor | None = None,
|
| 1325 |
-
) -> Tensor:
|
| 1326 |
-
if part_image_flat.numel() == 0:
|
| 1327 |
-
return image_feats.new_zeros(())
|
| 1328 |
-
|
| 1329 |
-
batch_size = image_feats.size(0)
|
| 1330 |
-
image_space = _space_components(image_feats).reshape(batch_size, -1).float()
|
| 1331 |
-
text_space = _space_components(text_feats).reshape(batch_size, -1).float()
|
| 1332 |
-
part_image_space = _space_components(part_image_flat).reshape(part_image_flat.size(0), -1).float()
|
| 1333 |
-
part_text_space = _space_components(part_text_flat).reshape(part_text_flat.size(0), -1).float()
|
| 1334 |
-
if part_weights is None:
|
| 1335 |
-
counts = torch.bincount(part_owner, minlength=batch_size).to(device=image_feats.device, dtype=image_space.dtype)
|
| 1336 |
-
denom = counts
|
| 1337 |
-
valid = counts > 0
|
| 1338 |
-
weights = part_image_space.new_ones((part_image_space.size(0),))
|
| 1339 |
-
else:
|
| 1340 |
-
weights = part_weights.to(device=image_feats.device, dtype=image_space.dtype).flatten()
|
| 1341 |
-
if weights.numel() != part_owner.numel():
|
| 1342 |
-
raise ValueError("part_weights must have the same number of elements as part_owner when provided")
|
| 1343 |
-
denom = torch.zeros(batch_size, device=image_feats.device, dtype=image_space.dtype)
|
| 1344 |
-
denom.index_add_(0, part_owner, weights)
|
| 1345 |
-
valid = denom > 0
|
| 1346 |
-
|
| 1347 |
-
image_agg = image_space.new_zeros(image_space.shape)
|
| 1348 |
-
text_agg = text_space.new_zeros(text_space.shape)
|
| 1349 |
-
image_agg.index_add_(0, part_owner, part_image_space * weights[:, None])
|
| 1350 |
-
text_agg.index_add_(0, part_owner, part_text_space * weights[:, None])
|
| 1351 |
-
image_agg = image_agg[valid] / denom[valid, None].clamp_min(1.0)
|
| 1352 |
-
text_agg = text_agg[valid] / denom[valid, None].clamp_min(1.0)
|
| 1353 |
-
|
| 1354 |
-
image_space = image_space[valid]
|
| 1355 |
-
text_space = text_space[valid]
|
| 1356 |
-
return 0.25 * (
|
| 1357 |
-
cosine_residual(image_agg, image_space)
|
| 1358 |
-
+ cosine_residual(text_agg, text_space)
|
| 1359 |
-
+ cosine_residual(image_agg, text_space)
|
| 1360 |
-
+ cosine_residual(text_agg, image_space)
|
| 1361 |
-
)
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
def cosine_residual(x: Tensor, y: Tensor) -> Tensor:
|
| 1365 |
-
return (1.0 - F.cosine_similarity(x, y, dim=-1)).mean()
|
| 1366 |
-
|
| 1367 |
-
|
| 1368 |
-
def uncertainty_calibrated_entailment_loss(
|
| 1369 |
-
entail_residual: Tensor,
|
| 1370 |
-
log_uncertainty: Tensor,
|
| 1371 |
-
alpha: float = 10.0,
|
| 1372 |
-
stop_grad: bool = True,
|
| 1373 |
-
weights: Tensor | None = None,
|
| 1374 |
-
) -> tuple[Tensor, Tensor]:
|
| 1375 |
-
mean_loss = 0.5 * entail_residual
|
| 1376 |
-
uncertainty = torch.exp(log_uncertainty).clamp(min=1e-6, max=1e6)
|
| 1377 |
-
residual = entail_residual.detach() if stop_grad else entail_residual
|
| 1378 |
-
scaled_entail = residual / (uncertainty + 1e-6)
|
| 1379 |
-
calibration_term = 0.5 * scaled_entail + 0.5 * log_uncertainty
|
| 1380 |
-
prob = torch.softmax(log_uncertainty.flatten(), dim=0)
|
| 1381 |
-
entropy = -(prob * torch.log(prob + 1e-8)).sum()
|
| 1382 |
-
calibration_loss = alpha * (calibration_term + entropy)
|
| 1383 |
-
return weighted_mean(mean_loss, weights), weighted_mean(calibration_loss, weights)
|
| 1384 |
-
|
| 1385 |
-
|
| 1386 |
-
def embedding_uncertainty(x: Tensor) -> Tensor:
|
| 1387 |
-
space = _space_components(x)
|
| 1388 |
-
norm = torch.linalg.norm(space.float(), dim=-1)
|
| 1389 |
-
if norm.dim() > 1:
|
| 1390 |
-
norm = norm.mean(dim=-1)
|
| 1391 |
-
return F.softplus(-norm)
|
| 1392 |
-
|
| 1393 |
-
|
| 1394 |
-
def _space_components(x: Tensor) -> Tensor:
|
| 1395 |
-
return x[..., 1:] if x.shape[-1] > 1 else x
|
| 1396 |
-
|
| 1397 |
-
|
| 1398 |
-
def _flatten_valid_parts(part_image_feats: Tensor, part_text_feats: Tensor, part_mask: Tensor, targets: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
| 1399 |
-
part_targets = targets[:, None].expand_as(part_mask)[part_mask]
|
| 1400 |
-
return part_image_feats[part_mask], part_text_feats[part_mask], part_targets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/objectives.py
DELETED
|
@@ -1,580 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from collections.abc import Mapping
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor, nn
|
| 7 |
-
|
| 8 |
-
from hyper3_clip.models.lorentz import log_map0, metric_pairwise_dist
|
| 9 |
-
from hyper3_clip.models.losses import (
|
| 10 |
-
aggregate_part_consistency_loss,
|
| 11 |
-
contrastive_ce,
|
| 12 |
-
gramian_volume_loss,
|
| 13 |
-
hierarchical_beta_argent_entailment_losses,
|
| 14 |
-
packed_part_contrastive_loss,
|
| 15 |
-
packed_part_entailment_loss,
|
| 16 |
-
part_quality_weights,
|
| 17 |
-
radius_order_hinge,
|
| 18 |
-
uncha_argent_entailment_losses,
|
| 19 |
-
uncha_contrastive_losses,
|
| 20 |
-
uncha_entailment_losses,
|
| 21 |
-
)
|
| 22 |
-
from hyper3_clip.training.distributed import gather_variable_many_with_grad, gather_variable_no_grad, get_rank
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class HyCoCLIPObjective(nn.Module):
|
| 26 |
-
def __init__(
|
| 27 |
-
self,
|
| 28 |
-
entail_weight: float,
|
| 29 |
-
inter_aperture_scale: float,
|
| 30 |
-
intra_aperture_scale: float,
|
| 31 |
-
product_metric: str = "l1",
|
| 32 |
-
) -> None:
|
| 33 |
-
super().__init__()
|
| 34 |
-
self.entail_weight = entail_weight
|
| 35 |
-
self.inter_aperture_scale = inter_aperture_scale
|
| 36 |
-
self.intra_aperture_scale = intra_aperture_scale
|
| 37 |
-
self.product_metric = product_metric
|
| 38 |
-
|
| 39 |
-
def forward(self, embeddings: Mapping[str, Tensor], logit_scale: Tensor) -> dict[str, Tensor]:
|
| 40 |
-
part_owner = embeddings["part_owner"].long()
|
| 41 |
-
part_count = part_owner.new_tensor(part_owner.numel())
|
| 42 |
-
contrastive = packed_part_contrastive_loss(
|
| 43 |
-
image_feats=embeddings["image_feats"],
|
| 44 |
-
text_feats=embeddings["text_feats"],
|
| 45 |
-
part_image_feats=embeddings["part_image_feats"],
|
| 46 |
-
part_text_feats=embeddings["part_text_feats"],
|
| 47 |
-
part_owner=part_owner,
|
| 48 |
-
kappa=embeddings["kappa"],
|
| 49 |
-
logit_scale=logit_scale,
|
| 50 |
-
all_image_feats=embeddings.get("all_image_feats"),
|
| 51 |
-
all_text_feats=embeddings.get("all_text_feats"),
|
| 52 |
-
targets=embeddings.get("targets"),
|
| 53 |
-
)
|
| 54 |
-
entailment = packed_part_entailment_loss(
|
| 55 |
-
image_feats=embeddings["image_feats"],
|
| 56 |
-
text_feats=embeddings["text_feats"],
|
| 57 |
-
part_image_feats=embeddings["part_image_feats"],
|
| 58 |
-
part_text_feats=embeddings["part_text_feats"],
|
| 59 |
-
part_owner=part_owner,
|
| 60 |
-
kappa=embeddings["kappa"],
|
| 61 |
-
inter_aperture_scale=self.inter_aperture_scale,
|
| 62 |
-
intra_aperture_scale=self.intra_aperture_scale,
|
| 63 |
-
)
|
| 64 |
-
total = contrastive + self.entail_weight * entailment
|
| 65 |
-
return {
|
| 66 |
-
"loss": total,
|
| 67 |
-
"contrastive_loss": contrastive,
|
| 68 |
-
"entailment_loss": entailment,
|
| 69 |
-
"part_count": part_count,
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class UNCHAObjective(nn.Module):
|
| 74 |
-
def __init__(
|
| 75 |
-
self,
|
| 76 |
-
entail_weight: float,
|
| 77 |
-
inter_aperture_scale: float,
|
| 78 |
-
intra_aperture_scale: float,
|
| 79 |
-
piecewise_factor: float = 0.1,
|
| 80 |
-
calibration_alpha: float = 10.0,
|
| 81 |
-
stop_grad_calibration: bool = True,
|
| 82 |
-
entailment_geometry: str = "lorentz",
|
| 83 |
-
aggregate_weight: float = 0.0,
|
| 84 |
-
entailment_loss: str = "piecewise",
|
| 85 |
-
argent_beta: float = 1.0,
|
| 86 |
-
argent_norm_weight: float = 0.0,
|
| 87 |
-
argent_aux_weight: float = 0.5,
|
| 88 |
-
argent_aggregation: str = "uncha",
|
| 89 |
-
part_weight_power: float = 0.0,
|
| 90 |
-
product_metric: str = "l1",
|
| 91 |
-
contrastive_loss: str = "ce",
|
| 92 |
-
sigmoid_negative_weight: float = 1.0,
|
| 93 |
-
part_quality_mode: str = "none",
|
| 94 |
-
part_quality_topk: int = 5,
|
| 95 |
-
part_quality_temperature: float = 4.0,
|
| 96 |
-
contrastive_global_weight: float = 1.0,
|
| 97 |
-
contrastive_local_weight: float = 1.0,
|
| 98 |
-
contrastive_global_local_weight: float = 1.0,
|
| 99 |
-
beta_cal_beta: float = 0.0,
|
| 100 |
-
beta_cal_variant: str = "ce",
|
| 101 |
-
beta_cal_weight: float = 0.0,
|
| 102 |
-
himo_component_weight: float = 0.0,
|
| 103 |
-
global_local_mode: str = "repeat",
|
| 104 |
-
global_local_metric: str = "distance",
|
| 105 |
-
global_local_angle_aux_weight: float = 0.0,
|
| 106 |
-
global_local_angle_aux_mode: str = "contrastive",
|
| 107 |
-
global_local_angle_aux_scale: float = 5.5,
|
| 108 |
-
global_local_angle_aux_aperture_scale: float = 1.0,
|
| 109 |
-
radius_order_weight: float = 0.0,
|
| 110 |
-
radius_order_margin: float = 0.0,
|
| 111 |
-
gramian_align_weight: float = 0.0,
|
| 112 |
-
) -> None:
|
| 113 |
-
super().__init__()
|
| 114 |
-
if entailment_loss not in {
|
| 115 |
-
"piecewise",
|
| 116 |
-
"argent",
|
| 117 |
-
"piecewise_argent",
|
| 118 |
-
"hier_beta_argent",
|
| 119 |
-
"hier_beta_sourcepart_argent",
|
| 120 |
-
}:
|
| 121 |
-
raise ValueError(
|
| 122 |
-
f"Unsupported UNCHA entailment loss {entailment_loss!r}; "
|
| 123 |
-
"expected 'piecewise', 'argent', 'piecewise_argent', 'hier_beta_argent', "
|
| 124 |
-
"or 'hier_beta_sourcepart_argent'"
|
| 125 |
-
)
|
| 126 |
-
if contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
|
| 127 |
-
raise ValueError("contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
|
| 128 |
-
if beta_cal_variant not in {"ce", "bce"}:
|
| 129 |
-
raise ValueError("beta_cal_variant must be 'ce' or 'bce'")
|
| 130 |
-
if argent_aggregation not in {"uncha", "equal"}:
|
| 131 |
-
raise ValueError("argent_aggregation must be 'uncha' or 'equal'")
|
| 132 |
-
if part_quality_mode not in {"none", "soft", "topk"}:
|
| 133 |
-
raise ValueError("part_quality_mode must be 'none', 'soft', or 'topk'")
|
| 134 |
-
if global_local_mode not in {"repeat", "inbatch"}:
|
| 135 |
-
raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
|
| 136 |
-
if global_local_metric not in {"distance", "angle"}:
|
| 137 |
-
raise ValueError("global_local_metric must be 'distance' or 'angle'")
|
| 138 |
-
if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
|
| 139 |
-
raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
|
| 140 |
-
if global_local_angle_aux_weight < 0.0:
|
| 141 |
-
raise ValueError("global_local_angle_aux_weight must be non-negative")
|
| 142 |
-
if global_local_angle_aux_scale <= 0.0:
|
| 143 |
-
raise ValueError("global_local_angle_aux_scale must be positive")
|
| 144 |
-
if global_local_angle_aux_aperture_scale <= 0.0:
|
| 145 |
-
raise ValueError("global_local_angle_aux_aperture_scale must be positive")
|
| 146 |
-
if part_quality_topk <= 0:
|
| 147 |
-
raise ValueError("part_quality_topk must be positive")
|
| 148 |
-
self.entail_weight = entail_weight
|
| 149 |
-
self.inter_aperture_scale = inter_aperture_scale
|
| 150 |
-
self.intra_aperture_scale = intra_aperture_scale
|
| 151 |
-
self.piecewise_factor = piecewise_factor
|
| 152 |
-
self.calibration_alpha = calibration_alpha
|
| 153 |
-
self.stop_grad_calibration = stop_grad_calibration
|
| 154 |
-
self.entailment_geometry = entailment_geometry
|
| 155 |
-
self.aggregate_weight = aggregate_weight
|
| 156 |
-
self.entailment_loss = entailment_loss
|
| 157 |
-
self.argent_beta = argent_beta
|
| 158 |
-
self.argent_norm_weight = argent_norm_weight
|
| 159 |
-
self.argent_aux_weight = argent_aux_weight
|
| 160 |
-
self.argent_aggregation = argent_aggregation
|
| 161 |
-
self.part_weight_power = part_weight_power
|
| 162 |
-
self.product_metric = product_metric
|
| 163 |
-
self.contrastive_loss = contrastive_loss
|
| 164 |
-
self.sigmoid_negative_weight = sigmoid_negative_weight
|
| 165 |
-
self.part_quality_mode = part_quality_mode
|
| 166 |
-
self.part_quality_topk = part_quality_topk
|
| 167 |
-
self.part_quality_temperature = part_quality_temperature
|
| 168 |
-
self.contrastive_global_weight = float(contrastive_global_weight)
|
| 169 |
-
self.contrastive_local_weight = float(contrastive_local_weight)
|
| 170 |
-
self.contrastive_global_local_weight = float(contrastive_global_local_weight)
|
| 171 |
-
self.beta_cal_beta = float(beta_cal_beta)
|
| 172 |
-
self.beta_cal_variant = beta_cal_variant
|
| 173 |
-
self.beta_cal_weight = float(beta_cal_weight)
|
| 174 |
-
self.himo_component_weight = float(himo_component_weight)
|
| 175 |
-
self.global_local_mode = global_local_mode
|
| 176 |
-
self.global_local_metric = global_local_metric
|
| 177 |
-
self.global_local_angle_aux_weight = float(global_local_angle_aux_weight)
|
| 178 |
-
self.global_local_angle_aux_mode = global_local_angle_aux_mode
|
| 179 |
-
self.global_local_angle_aux_scale = float(global_local_angle_aux_scale)
|
| 180 |
-
self.global_local_angle_aux_aperture_scale = float(global_local_angle_aux_aperture_scale)
|
| 181 |
-
self.radius_order_weight = float(radius_order_weight)
|
| 182 |
-
self.radius_order_margin = float(radius_order_margin)
|
| 183 |
-
self.gramian_align_weight = float(gramian_align_weight)
|
| 184 |
-
|
| 185 |
-
def forward(self, embeddings: Mapping[str, Tensor], logit_scales: Mapping[str, Tensor]) -> dict[str, Tensor]:
|
| 186 |
-
part_owner = embeddings["part_owner"].long()
|
| 187 |
-
part_count = part_owner.new_tensor(part_owner.numel())
|
| 188 |
-
part_image_flat = embeddings["part_image_feats"]
|
| 189 |
-
part_text_flat = embeddings["part_text_feats"]
|
| 190 |
-
image_feats = embeddings["image_feats"]
|
| 191 |
-
text_feats = embeddings["text_feats"]
|
| 192 |
-
|
| 193 |
-
if part_owner.numel() == 0:
|
| 194 |
-
image_for_parts = image_feats.new_zeros((0, image_feats.size(-1)))
|
| 195 |
-
text_for_parts = text_feats.new_zeros((0, text_feats.size(-1)))
|
| 196 |
-
else:
|
| 197 |
-
image_for_parts = image_feats[part_owner]
|
| 198 |
-
text_for_parts = text_feats[part_owner]
|
| 199 |
-
count_part_weights = _part_weights(part_owner, image_feats.size(0), self.part_weight_power)
|
| 200 |
-
quality_part_weights, quality_scores, quality_keep = part_quality_weights(
|
| 201 |
-
image_for_parts=image_for_parts,
|
| 202 |
-
text_for_parts=text_for_parts,
|
| 203 |
-
part_image_flat=part_image_flat,
|
| 204 |
-
part_text_flat=part_text_flat,
|
| 205 |
-
part_owner=part_owner,
|
| 206 |
-
batch_size=image_feats.size(0),
|
| 207 |
-
kappa=embeddings["kappa"],
|
| 208 |
-
mode=self.part_quality_mode,
|
| 209 |
-
topk=self.part_quality_topk,
|
| 210 |
-
temperature=self.part_quality_temperature,
|
| 211 |
-
product_metric=self.product_metric,
|
| 212 |
-
)
|
| 213 |
-
part_weights = _combine_part_weights(count_part_weights, quality_part_weights)
|
| 214 |
-
|
| 215 |
-
needs_repeated_global_local = self.global_local_mode == "repeat" and self.contrastive_global_local_weight != 0.0
|
| 216 |
-
part_feature_tensors = [part_image_flat, part_text_flat]
|
| 217 |
-
if needs_repeated_global_local:
|
| 218 |
-
part_feature_tensors.extend([image_for_parts, text_for_parts])
|
| 219 |
-
gathered_part_features, part_counts = gather_variable_many_with_grad(part_feature_tensors)
|
| 220 |
-
all_part_image_feats = gathered_part_features[0]
|
| 221 |
-
all_part_text_feats = gathered_part_features[1]
|
| 222 |
-
all_image_for_parts = gathered_part_features[2] if needs_repeated_global_local else None
|
| 223 |
-
all_text_for_parts = gathered_part_features[3] if needs_repeated_global_local else None
|
| 224 |
-
image_euc_feats = embeddings.get("image_euc_feats")
|
| 225 |
-
text_euc_feats = embeddings.get("text_euc_feats")
|
| 226 |
-
part_image_euc_flat = embeddings.get("part_image_euc_feats")
|
| 227 |
-
part_text_euc_flat = embeddings.get("part_text_euc_feats")
|
| 228 |
-
image_for_parts_euc = None
|
| 229 |
-
text_for_parts_euc = None
|
| 230 |
-
all_part_image_euc_feats = None
|
| 231 |
-
all_part_text_euc_feats = None
|
| 232 |
-
all_image_for_parts_euc = None
|
| 233 |
-
all_text_for_parts_euc = None
|
| 234 |
-
if (
|
| 235 |
-
image_euc_feats is not None
|
| 236 |
-
and text_euc_feats is not None
|
| 237 |
-
and part_owner.numel() > 0
|
| 238 |
-
and needs_repeated_global_local
|
| 239 |
-
):
|
| 240 |
-
image_for_parts_euc = image_euc_feats[part_owner]
|
| 241 |
-
text_for_parts_euc = text_euc_feats[part_owner]
|
| 242 |
-
if part_image_euc_flat is not None and part_text_euc_flat is not None:
|
| 243 |
-
euc_feature_tensors = [part_image_euc_flat, part_text_euc_flat]
|
| 244 |
-
if image_for_parts_euc is not None and text_for_parts_euc is not None:
|
| 245 |
-
euc_feature_tensors.extend([image_for_parts_euc, text_for_parts_euc])
|
| 246 |
-
gathered_euc_features, _ = gather_variable_many_with_grad(euc_feature_tensors)
|
| 247 |
-
all_part_image_euc_feats = gathered_euc_features[0]
|
| 248 |
-
all_part_text_euc_feats = gathered_euc_features[1]
|
| 249 |
-
if image_for_parts_euc is not None and text_for_parts_euc is not None:
|
| 250 |
-
all_image_for_parts_euc = gathered_euc_features[2]
|
| 251 |
-
all_text_for_parts_euc = gathered_euc_features[3]
|
| 252 |
-
if "targets" not in embeddings:
|
| 253 |
-
raise ValueError("UNCHAObjective requires 'targets' to compute group-aware losses")
|
| 254 |
-
global_targets = embeddings["targets"]
|
| 255 |
-
part_group_ids = global_targets[part_owner] if part_owner.numel() > 0 else part_owner.new_zeros((0,))
|
| 256 |
-
all_part_group_ids = None
|
| 257 |
-
if self.beta_cal_weight > 0.0 and self.beta_cal_beta > 0.0:
|
| 258 |
-
all_part_group_ids, _ = gather_variable_no_grad(part_group_ids)
|
| 259 |
-
part_offset = part_counts[: get_rank()].sum() if part_counts.numel() > 1 else part_counts.new_zeros(())
|
| 260 |
-
part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device) + part_offset
|
| 261 |
-
|
| 262 |
-
contrastive = uncha_contrastive_losses(
|
| 263 |
-
image_feats=image_feats,
|
| 264 |
-
text_feats=text_feats,
|
| 265 |
-
part_image_flat=part_image_flat,
|
| 266 |
-
part_text_flat=part_text_flat,
|
| 267 |
-
image_for_parts=image_for_parts,
|
| 268 |
-
text_for_parts=text_for_parts,
|
| 269 |
-
image_euc_feats=image_euc_feats,
|
| 270 |
-
text_euc_feats=text_euc_feats,
|
| 271 |
-
part_image_euc_flat=part_image_euc_flat,
|
| 272 |
-
part_text_euc_flat=part_text_euc_flat,
|
| 273 |
-
image_for_parts_euc=image_for_parts_euc,
|
| 274 |
-
text_for_parts_euc=text_for_parts_euc,
|
| 275 |
-
kappa=embeddings["kappa"],
|
| 276 |
-
global_logit_scale=logit_scales["global"],
|
| 277 |
-
local_logit_scale=logit_scales["local"],
|
| 278 |
-
global_local_logit_scale=logit_scales["global_local"],
|
| 279 |
-
all_image_feats=embeddings.get("all_image_feats"),
|
| 280 |
-
all_text_feats=embeddings.get("all_text_feats"),
|
| 281 |
-
all_part_image_feats=all_part_image_feats,
|
| 282 |
-
all_part_text_feats=all_part_text_feats,
|
| 283 |
-
all_image_for_parts=all_image_for_parts,
|
| 284 |
-
all_text_for_parts=all_text_for_parts,
|
| 285 |
-
all_image_euc_feats=embeddings.get("all_image_euc_feats"),
|
| 286 |
-
all_text_euc_feats=embeddings.get("all_text_euc_feats"),
|
| 287 |
-
all_part_image_euc_feats=all_part_image_euc_feats,
|
| 288 |
-
all_part_text_euc_feats=all_part_text_euc_feats,
|
| 289 |
-
all_image_for_parts_euc=all_image_for_parts_euc,
|
| 290 |
-
all_text_for_parts_euc=all_text_for_parts_euc,
|
| 291 |
-
global_targets=global_targets,
|
| 292 |
-
part_targets=part_targets,
|
| 293 |
-
part_weights=part_weights,
|
| 294 |
-
product_metric=self.product_metric,
|
| 295 |
-
loss_type=self.contrastive_loss,
|
| 296 |
-
contrastive_global_weight=self.contrastive_global_weight,
|
| 297 |
-
contrastive_local_weight=self.contrastive_local_weight,
|
| 298 |
-
contrastive_global_local_weight=self.contrastive_global_local_weight,
|
| 299 |
-
beta_cal_beta=self.beta_cal_beta,
|
| 300 |
-
beta_cal_variant=self.beta_cal_variant,
|
| 301 |
-
beta_cal_weight=self.beta_cal_weight,
|
| 302 |
-
part_group_ids=part_group_ids,
|
| 303 |
-
all_part_group_ids=all_part_group_ids,
|
| 304 |
-
global_logit_bias=logit_scales.get("global_bias"),
|
| 305 |
-
local_logit_bias=logit_scales.get("local_bias"),
|
| 306 |
-
global_local_logit_bias=logit_scales.get("global_local_bias"),
|
| 307 |
-
sigmoid_negative_weight=self.sigmoid_negative_weight,
|
| 308 |
-
global_local_mode=self.global_local_mode,
|
| 309 |
-
global_local_metric=self.global_local_metric,
|
| 310 |
-
global_local_angle_aux_weight=self.global_local_angle_aux_weight,
|
| 311 |
-
global_local_angle_aux_mode=self.global_local_angle_aux_mode,
|
| 312 |
-
global_local_angle_aux_scale=self.global_local_angle_aux_scale,
|
| 313 |
-
global_local_angle_aux_aperture_scale=self.global_local_angle_aux_aperture_scale,
|
| 314 |
-
)
|
| 315 |
-
himo_component_loss = image_feats.new_zeros(())
|
| 316 |
-
if self.himo_component_weight > 0.0 and embeddings.get("himo_text_feats") is not None:
|
| 317 |
-
himo_text_feats = embeddings["himo_text_feats"]
|
| 318 |
-
all_himo_text_feats = embeddings.get("all_himo_text_feats")
|
| 319 |
-
if all_himo_text_feats is None:
|
| 320 |
-
raise ValueError("himo_text_feats requires all_himo_text_feats for distributed contrastive loss")
|
| 321 |
-
scale = logit_scales["global"].exp().clamp(max=100.0)
|
| 322 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_himo_text_feats, embeddings["kappa"], product_metric=self.product_metric) * scale
|
| 323 |
-
logits_t_i = -metric_pairwise_dist(himo_text_feats, embeddings["all_image_feats"], embeddings["kappa"], product_metric=self.product_metric) * scale
|
| 324 |
-
himo_component_loss = 0.5 * (contrastive_ce(logits_i_t, global_targets) + contrastive_ce(logits_t_i, global_targets))
|
| 325 |
-
if self.entailment_loss == "argent":
|
| 326 |
-
entailment = uncha_argent_entailment_losses(
|
| 327 |
-
image_feats=image_feats,
|
| 328 |
-
text_feats=text_feats,
|
| 329 |
-
part_image_flat=part_image_flat,
|
| 330 |
-
part_text_flat=part_text_flat,
|
| 331 |
-
image_for_parts=image_for_parts,
|
| 332 |
-
text_for_parts=text_for_parts,
|
| 333 |
-
kappa=embeddings["kappa"],
|
| 334 |
-
beta=self.argent_beta,
|
| 335 |
-
part_weights=part_weights,
|
| 336 |
-
product_metric=self.product_metric,
|
| 337 |
-
aggregation=self.argent_aggregation,
|
| 338 |
-
)
|
| 339 |
-
elif self.entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}:
|
| 340 |
-
required = (
|
| 341 |
-
"beta_query_image_feats",
|
| 342 |
-
"beta_query_text_feats",
|
| 343 |
-
"beta_query_owner",
|
| 344 |
-
"beta_query_parent",
|
| 345 |
-
"beta_query_weight",
|
| 346 |
-
)
|
| 347 |
-
if self.entailment_loss == "hier_beta_sourcepart_argent":
|
| 348 |
-
required = (*required, "beta_query_source_part")
|
| 349 |
-
missing = [key for key in required if embeddings.get(key) is None]
|
| 350 |
-
if missing:
|
| 351 |
-
raise ValueError(f"{self.entailment_loss} requires beta query embeddings: missing {missing}")
|
| 352 |
-
entailment = hierarchical_beta_argent_entailment_losses(
|
| 353 |
-
image_feats=image_feats,
|
| 354 |
-
text_feats=text_feats,
|
| 355 |
-
part_image_flat=part_image_flat,
|
| 356 |
-
part_text_flat=part_text_flat,
|
| 357 |
-
image_for_parts=image_for_parts,
|
| 358 |
-
text_for_parts=text_for_parts,
|
| 359 |
-
beta_query_image_feats=embeddings["beta_query_image_feats"],
|
| 360 |
-
beta_query_text_feats=embeddings["beta_query_text_feats"],
|
| 361 |
-
beta_query_owner=embeddings["beta_query_owner"],
|
| 362 |
-
beta_query_parent=embeddings["beta_query_parent"],
|
| 363 |
-
beta_query_weight=embeddings["beta_query_weight"],
|
| 364 |
-
beta_query_source_part=embeddings.get("beta_query_source_part")
|
| 365 |
-
if self.entailment_loss == "hier_beta_sourcepart_argent"
|
| 366 |
-
else None,
|
| 367 |
-
kappa=embeddings["kappa"],
|
| 368 |
-
beta=self.argent_beta,
|
| 369 |
-
part_weights=part_weights,
|
| 370 |
-
product_metric=self.product_metric,
|
| 371 |
-
aggregation=self.argent_aggregation,
|
| 372 |
-
)
|
| 373 |
-
else:
|
| 374 |
-
piecewise_entailment = uncha_entailment_losses(
|
| 375 |
-
image_feats=image_feats,
|
| 376 |
-
text_feats=text_feats,
|
| 377 |
-
part_image_flat=part_image_flat,
|
| 378 |
-
part_text_flat=part_text_flat,
|
| 379 |
-
image_for_parts=image_for_parts,
|
| 380 |
-
text_for_parts=text_for_parts,
|
| 381 |
-
kappa=embeddings["kappa"],
|
| 382 |
-
inter_aperture_scale=self.inter_aperture_scale,
|
| 383 |
-
intra_aperture_scale=self.intra_aperture_scale,
|
| 384 |
-
piecewise_factor=self.piecewise_factor,
|
| 385 |
-
calibration_alpha=self.calibration_alpha,
|
| 386 |
-
stop_grad_calibration=self.stop_grad_calibration,
|
| 387 |
-
geometry=self.entailment_geometry,
|
| 388 |
-
part_weights=part_weights,
|
| 389 |
-
)
|
| 390 |
-
if self.entailment_loss == "piecewise_argent":
|
| 391 |
-
argent_entailment = uncha_argent_entailment_losses(
|
| 392 |
-
image_feats=image_feats,
|
| 393 |
-
text_feats=text_feats,
|
| 394 |
-
part_image_flat=part_image_flat,
|
| 395 |
-
part_text_flat=part_text_flat,
|
| 396 |
-
image_for_parts=image_for_parts,
|
| 397 |
-
text_for_parts=text_for_parts,
|
| 398 |
-
kappa=embeddings["kappa"],
|
| 399 |
-
beta=self.argent_beta,
|
| 400 |
-
part_weights=part_weights,
|
| 401 |
-
product_metric=self.product_metric,
|
| 402 |
-
aggregation=self.argent_aggregation,
|
| 403 |
-
)
|
| 404 |
-
entailment = {
|
| 405 |
-
**piecewise_entailment,
|
| 406 |
-
"entailment_loss": piecewise_entailment["entailment_loss"]
|
| 407 |
-
+ self.argent_aux_weight * argent_entailment["entailment_loss"],
|
| 408 |
-
"piecewise_entailment_loss": piecewise_entailment["entailment_loss"],
|
| 409 |
-
"argent_entailment_loss": argent_entailment["entailment_loss"],
|
| 410 |
-
"norm_regularization_loss": argent_entailment["norm_regularization_loss"],
|
| 411 |
-
}
|
| 412 |
-
else:
|
| 413 |
-
entailment = piecewise_entailment
|
| 414 |
-
aggregate = aggregate_part_consistency_loss(
|
| 415 |
-
image_feats=image_feats,
|
| 416 |
-
text_feats=text_feats,
|
| 417 |
-
part_image_flat=part_image_flat,
|
| 418 |
-
part_text_flat=part_text_flat,
|
| 419 |
-
part_owner=part_owner,
|
| 420 |
-
part_weights=part_weights,
|
| 421 |
-
)
|
| 422 |
-
radius_order = image_feats.new_zeros(())
|
| 423 |
-
if self.radius_order_weight > 0.0:
|
| 424 |
-
radius_order = (
|
| 425 |
-
radius_order_hinge(image_feats, text_feats, embeddings["kappa"], self.radius_order_margin)
|
| 426 |
-
+ radius_order_hinge(part_image_flat, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
|
| 427 |
-
+ radius_order_hinge(image_for_parts, part_image_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
|
| 428 |
-
+ radius_order_hinge(text_for_parts, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
|
| 429 |
-
)
|
| 430 |
-
gramian_align = image_feats.new_zeros(())
|
| 431 |
-
if self.gramian_align_weight > 0.0 and part_owner.numel() > 0:
|
| 432 |
-
def _tangent_flat(x: Tensor) -> Tensor:
|
| 433 |
-
tangent = log_map0(x, embeddings["kappa"])
|
| 434 |
-
return tangent.reshape(tangent.size(0), -1) if tangent.dim() == 3 else tangent
|
| 435 |
-
|
| 436 |
-
gramian_vectors = torch.stack(
|
| 437 |
-
[
|
| 438 |
-
_tangent_flat(image_for_parts),
|
| 439 |
-
_tangent_flat(text_for_parts),
|
| 440 |
-
_tangent_flat(part_image_flat),
|
| 441 |
-
_tangent_flat(part_text_flat),
|
| 442 |
-
],
|
| 443 |
-
dim=1,
|
| 444 |
-
)
|
| 445 |
-
gramian_align = gramian_volume_loss(gramian_vectors, part_weights)
|
| 446 |
-
entail_weight_scale = embeddings.get("entail_weight_scale", image_feats.new_ones(()))
|
| 447 |
-
total = (
|
| 448 |
-
contrastive["contrastive_loss"]
|
| 449 |
-
+ self.himo_component_weight * himo_component_loss
|
| 450 |
-
+ self.entail_weight * entail_weight_scale * entailment["entailment_loss"]
|
| 451 |
-
+ self.aggregate_weight * aggregate
|
| 452 |
-
+ self.radius_order_weight * radius_order
|
| 453 |
-
+ self.gramian_align_weight * gramian_align
|
| 454 |
-
+ self.argent_norm_weight * entailment.get(
|
| 455 |
-
"norm_regularization_loss",
|
| 456 |
-
image_feats.new_zeros(()),
|
| 457 |
-
)
|
| 458 |
-
)
|
| 459 |
-
return {
|
| 460 |
-
"loss": total,
|
| 461 |
-
**contrastive,
|
| 462 |
-
"himo_component_contrastive_loss": himo_component_loss,
|
| 463 |
-
**entailment,
|
| 464 |
-
"aggregate_consistency_loss": aggregate,
|
| 465 |
-
"radius_order_loss": radius_order,
|
| 466 |
-
"gramian_align_loss": gramian_align,
|
| 467 |
-
"part_count": part_count,
|
| 468 |
-
"entail_weight_scale": entail_weight_scale.detach(),
|
| 469 |
-
"part_quality_mean": (
|
| 470 |
-
image_feats.new_zeros(()) if quality_scores.numel() == 0 else quality_scores.mean().detach()
|
| 471 |
-
),
|
| 472 |
-
"part_quality_keep_fraction": (
|
| 473 |
-
image_feats.new_zeros(()) if quality_keep.numel() == 0 else quality_keep.mean().detach()
|
| 474 |
-
),
|
| 475 |
-
}
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
def build_objective(
|
| 479 |
-
objective: str,
|
| 480 |
-
entail_weight: float,
|
| 481 |
-
inter_aperture_scale: float,
|
| 482 |
-
intra_aperture_scale: float,
|
| 483 |
-
uncha_piecewise_factor: float = 0.1,
|
| 484 |
-
uncha_calibration_alpha: float = 10.0,
|
| 485 |
-
uncha_stop_grad_calibration: bool = True,
|
| 486 |
-
uncha_entailment_geometry: str = "lorentz",
|
| 487 |
-
uncha_aggregate_weight: float = 0.0,
|
| 488 |
-
uncha_entailment_loss: str = "piecewise",
|
| 489 |
-
uncha_argent_beta: float = 1.0,
|
| 490 |
-
uncha_argent_norm_weight: float = 0.0,
|
| 491 |
-
uncha_argent_aux_weight: float = 0.5,
|
| 492 |
-
uncha_argent_aggregation: str = "uncha",
|
| 493 |
-
uncha_part_weight_power: float = 0.0,
|
| 494 |
-
uncha_contrastive_loss: str = "ce",
|
| 495 |
-
uncha_sigmoid_negative_weight: float = 1.0,
|
| 496 |
-
uncha_part_quality_mode: str = "none",
|
| 497 |
-
uncha_part_quality_topk: int = 5,
|
| 498 |
-
uncha_part_quality_temperature: float = 4.0,
|
| 499 |
-
uncha_contrastive_global_weight: float = 1.0,
|
| 500 |
-
uncha_contrastive_local_weight: float = 1.0,
|
| 501 |
-
uncha_contrastive_global_local_weight: float = 1.0,
|
| 502 |
-
uncha_beta_cal_beta: float = 0.0,
|
| 503 |
-
uncha_beta_cal_variant: str = "ce",
|
| 504 |
-
uncha_beta_cal_weight: float = 0.0,
|
| 505 |
-
uncha_himo_component_weight: float = 0.0,
|
| 506 |
-
uncha_global_local_mode: str = "repeat",
|
| 507 |
-
uncha_global_local_metric: str = "distance",
|
| 508 |
-
uncha_global_local_angle_aux_weight: float = 0.0,
|
| 509 |
-
uncha_global_local_angle_aux_mode: str = "contrastive",
|
| 510 |
-
uncha_global_local_angle_aux_scale: float = 5.5,
|
| 511 |
-
uncha_global_local_angle_aux_aperture_scale: float = 1.0,
|
| 512 |
-
uncha_radius_order_weight: float = 0.0,
|
| 513 |
-
uncha_radius_order_margin: float = 0.0,
|
| 514 |
-
uncha_gramian_align_weight: float = 0.0,
|
| 515 |
-
product_metric: str = "l1",
|
| 516 |
-
) -> nn.Module:
|
| 517 |
-
if objective == "hycoclip":
|
| 518 |
-
return HyCoCLIPObjective(
|
| 519 |
-
entail_weight=entail_weight,
|
| 520 |
-
inter_aperture_scale=inter_aperture_scale,
|
| 521 |
-
intra_aperture_scale=intra_aperture_scale,
|
| 522 |
-
product_metric=product_metric,
|
| 523 |
-
)
|
| 524 |
-
if objective == "uncha":
|
| 525 |
-
return UNCHAObjective(
|
| 526 |
-
entail_weight=entail_weight,
|
| 527 |
-
inter_aperture_scale=inter_aperture_scale,
|
| 528 |
-
intra_aperture_scale=intra_aperture_scale,
|
| 529 |
-
piecewise_factor=uncha_piecewise_factor,
|
| 530 |
-
calibration_alpha=uncha_calibration_alpha,
|
| 531 |
-
stop_grad_calibration=uncha_stop_grad_calibration,
|
| 532 |
-
entailment_geometry=uncha_entailment_geometry,
|
| 533 |
-
aggregate_weight=uncha_aggregate_weight,
|
| 534 |
-
entailment_loss=uncha_entailment_loss,
|
| 535 |
-
argent_beta=uncha_argent_beta,
|
| 536 |
-
argent_norm_weight=uncha_argent_norm_weight,
|
| 537 |
-
argent_aux_weight=uncha_argent_aux_weight,
|
| 538 |
-
argent_aggregation=uncha_argent_aggregation,
|
| 539 |
-
part_weight_power=uncha_part_weight_power,
|
| 540 |
-
product_metric=product_metric,
|
| 541 |
-
contrastive_loss=uncha_contrastive_loss,
|
| 542 |
-
sigmoid_negative_weight=uncha_sigmoid_negative_weight,
|
| 543 |
-
part_quality_mode=uncha_part_quality_mode,
|
| 544 |
-
part_quality_topk=uncha_part_quality_topk,
|
| 545 |
-
part_quality_temperature=uncha_part_quality_temperature,
|
| 546 |
-
contrastive_global_weight=uncha_contrastive_global_weight,
|
| 547 |
-
contrastive_local_weight=uncha_contrastive_local_weight,
|
| 548 |
-
contrastive_global_local_weight=uncha_contrastive_global_local_weight,
|
| 549 |
-
beta_cal_beta=uncha_beta_cal_beta,
|
| 550 |
-
beta_cal_variant=uncha_beta_cal_variant,
|
| 551 |
-
beta_cal_weight=uncha_beta_cal_weight,
|
| 552 |
-
himo_component_weight=uncha_himo_component_weight,
|
| 553 |
-
global_local_mode=uncha_global_local_mode,
|
| 554 |
-
global_local_metric=uncha_global_local_metric,
|
| 555 |
-
global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
|
| 556 |
-
global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
|
| 557 |
-
global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
|
| 558 |
-
global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
|
| 559 |
-
radius_order_weight=uncha_radius_order_weight,
|
| 560 |
-
radius_order_margin=uncha_radius_order_margin,
|
| 561 |
-
gramian_align_weight=uncha_gramian_align_weight,
|
| 562 |
-
)
|
| 563 |
-
raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip' or 'uncha'")
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
def _part_weights(part_owner: Tensor, batch_size: int, power: float) -> Tensor | None:
|
| 567 |
-
if power <= 0.0 or part_owner.numel() == 0:
|
| 568 |
-
return None
|
| 569 |
-
counts = torch.bincount(part_owner, minlength=batch_size).to(dtype=torch.float32, device=part_owner.device)
|
| 570 |
-
weights = counts[part_owner].clamp_min(1.0).pow(-power)
|
| 571 |
-
return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
def _combine_part_weights(count_weights: Tensor | None, quality_weights: Tensor | None) -> Tensor | None:
|
| 575 |
-
if count_weights is None:
|
| 576 |
-
return quality_weights
|
| 577 |
-
if quality_weights is None:
|
| 578 |
-
return count_weights
|
| 579 |
-
weights = count_weights * quality_weights
|
| 580 |
-
return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/tren.py
DELETED
|
@@ -1,255 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from torch import Tensor, nn
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class FourierPositionEncoding2D(nn.Module):
|
| 11 |
-
def __init__(self, dim: int, scale: float = 1.0) -> None:
|
| 12 |
-
super().__init__()
|
| 13 |
-
if dim <= 0 or dim % 2 != 0:
|
| 14 |
-
raise ValueError("FourierPositionEncoding2D dim must be a positive even integer")
|
| 15 |
-
if scale <= 0.0:
|
| 16 |
-
raise ValueError("FourierPositionEncoding2D scale must be positive")
|
| 17 |
-
generator = torch.Generator()
|
| 18 |
-
generator.manual_seed(42)
|
| 19 |
-
self.register_buffer("gaussian_matrix", scale * torch.randn((2, dim // 2), generator=generator))
|
| 20 |
-
|
| 21 |
-
def forward(self, coords: Tensor) -> Tensor:
|
| 22 |
-
projected = (2.0 * coords.float() - 1.0) @ self.gaussian_matrix
|
| 23 |
-
projected = 2.0 * math.pi * projected
|
| 24 |
-
return torch.cat([torch.sin(projected), torch.cos(projected)], dim=-1)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class _MLPBlock(nn.Module):
|
| 28 |
-
def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
|
| 29 |
-
super().__init__()
|
| 30 |
-
self.net = nn.Sequential(
|
| 31 |
-
nn.Linear(dim, hidden_dim),
|
| 32 |
-
nn.GELU(),
|
| 33 |
-
nn.Dropout(dropout),
|
| 34 |
-
nn.Linear(hidden_dim, dim),
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 38 |
-
return self.net(x)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class _AttentionLayer(nn.Module):
|
| 42 |
-
def __init__(
|
| 43 |
-
self,
|
| 44 |
-
q_dim: int,
|
| 45 |
-
kv_dim: int,
|
| 46 |
-
hidden_dim: int,
|
| 47 |
-
*,
|
| 48 |
-
num_heads: int,
|
| 49 |
-
dropout: float,
|
| 50 |
-
use_bias: bool = False,
|
| 51 |
-
use_v_proj: bool = True,
|
| 52 |
-
use_out_proj: bool = True,
|
| 53 |
-
) -> None:
|
| 54 |
-
super().__init__()
|
| 55 |
-
if hidden_dim % num_heads != 0:
|
| 56 |
-
raise ValueError("hidden_dim must be divisible by num_heads")
|
| 57 |
-
if not use_v_proj and kv_dim != hidden_dim:
|
| 58 |
-
raise ValueError("kv_dim must equal hidden_dim when value projection is disabled")
|
| 59 |
-
self.hidden_dim = hidden_dim
|
| 60 |
-
self.num_heads = num_heads
|
| 61 |
-
self.head_dim = hidden_dim // num_heads
|
| 62 |
-
self.q_proj = nn.Linear(q_dim, hidden_dim, bias=use_bias)
|
| 63 |
-
self.k_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias)
|
| 64 |
-
self.v_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias) if use_v_proj else nn.Identity()
|
| 65 |
-
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias) if use_out_proj else nn.Identity()
|
| 66 |
-
self.q_norm = nn.LayerNorm(self.head_dim)
|
| 67 |
-
self.k_norm = nn.LayerNorm(self.head_dim)
|
| 68 |
-
self.dropout = nn.Dropout(dropout)
|
| 69 |
-
self.scale = self.head_dim**-0.5
|
| 70 |
-
|
| 71 |
-
nn.init.kaiming_normal_(self.q_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 72 |
-
nn.init.kaiming_normal_(self.k_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 73 |
-
if isinstance(self.v_proj, nn.Linear):
|
| 74 |
-
nn.init.kaiming_normal_(self.v_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 75 |
-
if isinstance(self.out_proj, nn.Linear):
|
| 76 |
-
nn.init.kaiming_normal_(self.out_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 77 |
-
|
| 78 |
-
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 79 |
-
batch_size, q_len, _ = q.shape
|
| 80 |
-
_, kv_len, _ = k.shape
|
| 81 |
-
query = self.q_proj(q).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 82 |
-
key = self.k_proj(k).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
-
value = self.v_proj(v).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 84 |
-
|
| 85 |
-
query = self.q_norm(query)
|
| 86 |
-
key = self.k_norm(key)
|
| 87 |
-
attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
|
| 88 |
-
attn_weights = self.dropout(F.softmax(attn_scores, dim=-1))
|
| 89 |
-
out = torch.matmul(attn_weights, value)
|
| 90 |
-
out = out.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_dim)
|
| 91 |
-
return self.out_proj(out), attn_weights
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class _CrossAttentionBlock(nn.Module):
|
| 95 |
-
def __init__(self, dim: int, *, num_heads: int, dropout: float) -> None:
|
| 96 |
-
super().__init__()
|
| 97 |
-
self.query_norm = nn.LayerNorm(dim)
|
| 98 |
-
self.cross_attn = _AttentionLayer(dim, dim, dim, num_heads=num_heads, dropout=dropout)
|
| 99 |
-
self.dropout = nn.Dropout(dropout)
|
| 100 |
-
self.mlp_norm = nn.LayerNorm(dim)
|
| 101 |
-
self.mlp = _MLPBlock(dim, 2 * dim, dropout)
|
| 102 |
-
self.out_norm = nn.LayerNorm(dim)
|
| 103 |
-
|
| 104 |
-
def forward(self, query: Tensor, context: Tensor) -> Tensor:
|
| 105 |
-
x, _ = self.cross_attn(self.query_norm(query), context, context)
|
| 106 |
-
x = query + self.dropout(x)
|
| 107 |
-
return self.out_norm(x + self.mlp(self.mlp_norm(x)))
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class TRENRegionEncoder(nn.Module):
|
| 111 |
-
"""T-REN-style point-prompted region token encoder.
|
| 112 |
-
|
| 113 |
-
The module follows the public T-REN architecture: learned k-per-prompt
|
| 114 |
-
query tokens, Fourier 2D prompt/patch position encodings, alternating
|
| 115 |
-
cross-attention and per-prompt self-attention, then final single-head
|
| 116 |
-
attention that pools unprojected patch tokens into region tokens.
|
| 117 |
-
"""
|
| 118 |
-
|
| 119 |
-
def __init__(
|
| 120 |
-
self,
|
| 121 |
-
vision_dim: int,
|
| 122 |
-
text_dim: int,
|
| 123 |
-
*,
|
| 124 |
-
hidden_dim: int | None = None,
|
| 125 |
-
num_region_tokens: int = 3,
|
| 126 |
-
num_decoder_layers: int = 2,
|
| 127 |
-
num_attention_heads: int = 8,
|
| 128 |
-
prompt_grid_size: int = 7,
|
| 129 |
-
dropout: float = 0.1,
|
| 130 |
-
) -> None:
|
| 131 |
-
super().__init__()
|
| 132 |
-
if num_region_tokens <= 0:
|
| 133 |
-
raise ValueError("num_region_tokens must be positive")
|
| 134 |
-
if num_decoder_layers <= 0:
|
| 135 |
-
raise ValueError("num_decoder_layers must be positive")
|
| 136 |
-
if prompt_grid_size <= 0:
|
| 137 |
-
raise ValueError("prompt_grid_size must be positive")
|
| 138 |
-
hidden_dim = int(hidden_dim or vision_dim)
|
| 139 |
-
if hidden_dim != vision_dim:
|
| 140 |
-
raise ValueError("TRENRegionEncoder currently requires hidden_dim == vision_dim")
|
| 141 |
-
if hidden_dim % 2 != 0:
|
| 142 |
-
raise ValueError("TRENRegionEncoder hidden_dim must be even for Fourier features")
|
| 143 |
-
if hidden_dim % num_attention_heads != 0:
|
| 144 |
-
raise ValueError("TRENRegionEncoder hidden_dim must be divisible by num_attention_heads")
|
| 145 |
-
|
| 146 |
-
self.vision_dim = vision_dim
|
| 147 |
-
self.text_dim = text_dim
|
| 148 |
-
self.hidden_dim = hidden_dim
|
| 149 |
-
self.num_region_tokens = num_region_tokens
|
| 150 |
-
self.prompt_grid_size = prompt_grid_size
|
| 151 |
-
self.position_encoder = FourierPositionEncoding2D(hidden_dim)
|
| 152 |
-
self.region_token_embeddings = nn.Embedding(num_region_tokens, hidden_dim)
|
| 153 |
-
nn.init.normal_(self.region_token_embeddings.weight, std=0.02)
|
| 154 |
-
self.region_attention_layers = nn.ModuleList(
|
| 155 |
-
[_CrossAttentionBlock(hidden_dim, num_heads=num_attention_heads, dropout=dropout) for _ in range(num_decoder_layers)]
|
| 156 |
-
)
|
| 157 |
-
self.region_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
|
| 158 |
-
self.prompt_attention_layers = nn.ModuleList(
|
| 159 |
-
[
|
| 160 |
-
_AttentionLayer(
|
| 161 |
-
hidden_dim,
|
| 162 |
-
hidden_dim,
|
| 163 |
-
hidden_dim,
|
| 164 |
-
num_heads=num_attention_heads,
|
| 165 |
-
dropout=dropout,
|
| 166 |
-
)
|
| 167 |
-
for _ in range(num_decoder_layers)
|
| 168 |
-
]
|
| 169 |
-
)
|
| 170 |
-
self.prompt_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
|
| 171 |
-
self.token_prediction_head = _AttentionLayer(
|
| 172 |
-
hidden_dim,
|
| 173 |
-
hidden_dim,
|
| 174 |
-
hidden_dim,
|
| 175 |
-
num_heads=1,
|
| 176 |
-
dropout=0.0,
|
| 177 |
-
use_v_proj=False,
|
| 178 |
-
use_out_proj=False,
|
| 179 |
-
)
|
| 180 |
-
self.text_alignment_block = nn.Sequential(
|
| 181 |
-
nn.Linear(hidden_dim, 2 * hidden_dim),
|
| 182 |
-
nn.GELU(),
|
| 183 |
-
nn.Dropout(dropout),
|
| 184 |
-
nn.Linear(2 * hidden_dim, text_dim),
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
def forward(self, image_tokens: Tensor) -> dict[str, Tensor]:
|
| 188 |
-
patch_tokens, patch_grid = _patch_tokens_and_grid(image_tokens)
|
| 189 |
-
batch_size, patch_count, _ = patch_tokens.shape
|
| 190 |
-
patch_coords = _grid_coords(patch_grid, patch_grid, patch_tokens.device)
|
| 191 |
-
prompt_coords = _grid_coords(self.prompt_grid_size, self.prompt_grid_size, patch_tokens.device)
|
| 192 |
-
prompt_count = prompt_coords.size(0)
|
| 193 |
-
|
| 194 |
-
feature_pos = self.position_encoder(patch_coords).to(dtype=patch_tokens.dtype)
|
| 195 |
-
prompt_pos = self.position_encoder(prompt_coords).to(dtype=patch_tokens.dtype)
|
| 196 |
-
kv = patch_tokens + feature_pos.unsqueeze(0)
|
| 197 |
-
prompt_pos = prompt_pos.view(1, prompt_count, 1, self.hidden_dim)
|
| 198 |
-
|
| 199 |
-
q = self.region_token_embeddings.weight.to(dtype=patch_tokens.dtype)
|
| 200 |
-
q = q.view(1, 1, self.num_region_tokens, self.hidden_dim).expand(
|
| 201 |
-
batch_size,
|
| 202 |
-
prompt_count,
|
| 203 |
-
self.num_region_tokens,
|
| 204 |
-
self.hidden_dim,
|
| 205 |
-
)
|
| 206 |
-
for region_layer, region_norm, prompt_layer, prompt_norm in zip(
|
| 207 |
-
self.region_attention_layers,
|
| 208 |
-
self.region_attention_norms,
|
| 209 |
-
self.prompt_attention_layers,
|
| 210 |
-
self.prompt_attention_norms,
|
| 211 |
-
strict=True,
|
| 212 |
-
):
|
| 213 |
-
q = q + prompt_pos
|
| 214 |
-
q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
|
| 215 |
-
q = region_layer(q, kv)
|
| 216 |
-
q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 217 |
-
q = region_norm(q)
|
| 218 |
-
q = q.reshape(batch_size * prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 219 |
-
q, _ = prompt_layer(q, q, q)
|
| 220 |
-
q = prompt_norm(q)
|
| 221 |
-
q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 222 |
-
|
| 223 |
-
flat_q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
|
| 224 |
-
visual_tokens, attn_weights = self.token_prediction_head(flat_q, kv, patch_tokens)
|
| 225 |
-
visual_tokens = visual_tokens.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 226 |
-
attn_weights = attn_weights.squeeze(1).reshape(batch_size, prompt_count, self.num_region_tokens, patch_count)
|
| 227 |
-
region_masks = attn_weights / attn_weights.amax(dim=-1, keepdim=True).clamp_min(torch.finfo(attn_weights.dtype).eps)
|
| 228 |
-
region_masks = region_masks.reshape(batch_size, prompt_count, self.num_region_tokens, patch_grid, patch_grid)
|
| 229 |
-
text_aligned_tokens = self.text_alignment_block(visual_tokens)
|
| 230 |
-
return {
|
| 231 |
-
"visual_tokens": visual_tokens,
|
| 232 |
-
"text_aligned_tokens": text_aligned_tokens,
|
| 233 |
-
"region_masks": region_masks,
|
| 234 |
-
"prompt_coords": prompt_coords,
|
| 235 |
-
}
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
def _patch_tokens_and_grid(tokens: Tensor) -> tuple[Tensor, int]:
|
| 239 |
-
if tokens.ndim != 3:
|
| 240 |
-
raise ValueError("TRENRegionEncoder expects image tokens with shape [batch, tokens, dim]")
|
| 241 |
-
token_count = tokens.size(1)
|
| 242 |
-
grid = int(math.isqrt(token_count))
|
| 243 |
-
if grid * grid == token_count:
|
| 244 |
-
return tokens, grid
|
| 245 |
-
grid = int(math.isqrt(token_count - 1))
|
| 246 |
-
if grid * grid == token_count - 1:
|
| 247 |
-
return tokens[:, 1:, :], grid
|
| 248 |
-
raise ValueError(f"Cannot infer a square patch grid from {token_count} image tokens")
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
def _grid_coords(height: int, width: int, device: torch.device) -> Tensor:
|
| 252 |
-
y = torch.linspace(0.5 / height, 1.0 - 0.5 / height, height, device=device)
|
| 253 |
-
x = torch.linspace(0.5 / width, 1.0 - 0.5 / width, width, device=device)
|
| 254 |
-
yy, xx = torch.meshgrid(y, x, indexing="ij")
|
| 255 |
-
return torch.stack([xx, yy], dim=-1).reshape(-1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/training/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
__all__: list[str] = []
|
|
|
|
|
|
hyper3_clip/training/checkpointing.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import random
|
| 5 |
-
from typing import Any
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
from torch import nn
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def save_checkpoint(
|
| 13 |
-
path: str | Path,
|
| 14 |
-
step: int,
|
| 15 |
-
model: nn.Module,
|
| 16 |
-
optimizer: torch.optim.Optimizer,
|
| 17 |
-
scheduler: Any,
|
| 18 |
-
scaler: Any,
|
| 19 |
-
config: dict,
|
| 20 |
-
) -> None:
|
| 21 |
-
checkpoint_path = Path(path)
|
| 22 |
-
tmp_path = checkpoint_path.with_name(f"{checkpoint_path.name}.tmp")
|
| 23 |
-
checkpoint = {
|
| 24 |
-
"step": step,
|
| 25 |
-
"model": model.state_dict(),
|
| 26 |
-
"optimizer": optimizer.state_dict(),
|
| 27 |
-
"scheduler": scheduler.state_dict(),
|
| 28 |
-
"scaler": scaler.state_dict(),
|
| 29 |
-
"config": config,
|
| 30 |
-
"rng": _rng_state(),
|
| 31 |
-
}
|
| 32 |
-
torch.save(checkpoint, tmp_path)
|
| 33 |
-
tmp_path.replace(checkpoint_path)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def load_checkpoint(
|
| 37 |
-
path: str | Path,
|
| 38 |
-
model: nn.Module,
|
| 39 |
-
optimizer: torch.optim.Optimizer,
|
| 40 |
-
scheduler: Any,
|
| 41 |
-
scaler: Any,
|
| 42 |
-
device: torch.device,
|
| 43 |
-
*,
|
| 44 |
-
model_only: bool = False,
|
| 45 |
-
strict_model: bool = True,
|
| 46 |
-
) -> int:
|
| 47 |
-
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 48 |
-
model.load_state_dict(checkpoint["model"], strict=strict_model)
|
| 49 |
-
if model_only:
|
| 50 |
-
return int(checkpoint["step"])
|
| 51 |
-
optimizer.load_state_dict(checkpoint["optimizer"])
|
| 52 |
-
scheduler.load_state_dict(checkpoint["scheduler"])
|
| 53 |
-
scaler.load_state_dict(checkpoint["scaler"])
|
| 54 |
-
_set_rng_state(checkpoint["rng"])
|
| 55 |
-
return int(checkpoint["step"])
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def latest_checkpoint(output_dir: str | Path) -> Path | None:
|
| 59 |
-
paths = sorted(Path(output_dir).glob("checkpoint_step_*.pt"))
|
| 60 |
-
if not paths:
|
| 61 |
-
return None
|
| 62 |
-
return max(paths, key=_checkpoint_step)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def _checkpoint_step(path: Path) -> int:
|
| 66 |
-
return int(path.stem.rsplit("_", 1)[1])
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def _rng_state() -> dict[str, Any]:
|
| 70 |
-
state: dict[str, Any] = {
|
| 71 |
-
"python": random.getstate(),
|
| 72 |
-
"numpy": np.random.get_state(),
|
| 73 |
-
"torch": torch.get_rng_state(),
|
| 74 |
-
}
|
| 75 |
-
if torch.cuda.is_available():
|
| 76 |
-
state["cuda"] = torch.cuda.get_rng_state_all()
|
| 77 |
-
return state
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def _set_rng_state(state: dict[str, Any]) -> None:
|
| 81 |
-
random.setstate(state["python"])
|
| 82 |
-
np.random.set_state(state["numpy"])
|
| 83 |
-
torch.set_rng_state(_cpu_byte_tensor(state["torch"]))
|
| 84 |
-
if torch.cuda.is_available() and "cuda" in state:
|
| 85 |
-
torch.cuda.set_rng_state_all([_cpu_byte_tensor(cuda_state) for cuda_state in state["cuda"]])
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def _cpu_byte_tensor(value: Any) -> torch.ByteTensor:
|
| 89 |
-
if isinstance(value, torch.Tensor):
|
| 90 |
-
return value.detach().to(device="cpu", dtype=torch.uint8)
|
| 91 |
-
return torch.as_tensor(value, dtype=torch.uint8, device="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/training/distributed.py
DELETED
|
@@ -1,149 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from collections.abc import Sequence
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.distributed as dist
|
| 8 |
-
from torch.distributed.nn import all_gather as differentiable_all_gather
|
| 9 |
-
from torch import Tensor
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def init_distributed() -> None:
|
| 13 |
-
if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not dist.is_initialized():
|
| 14 |
-
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 15 |
-
if torch.cuda.is_available():
|
| 16 |
-
torch.cuda.set_device(get_local_rank())
|
| 17 |
-
dist.init_process_group(backend=backend)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def is_distributed() -> bool:
|
| 21 |
-
return dist.is_available() and dist.is_initialized()
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def barrier() -> None:
|
| 25 |
-
if is_distributed():
|
| 26 |
-
dist.barrier()
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def destroy_distributed() -> None:
|
| 30 |
-
if is_distributed():
|
| 31 |
-
dist.destroy_process_group()
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_rank() -> int:
|
| 35 |
-
return dist.get_rank() if is_distributed() else 0
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def get_world_size() -> int:
|
| 39 |
-
return dist.get_world_size() if is_distributed() else 1
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_local_rank() -> int:
|
| 43 |
-
return int(os.environ.get("LOCAL_RANK", "0"))
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def is_main_process() -> bool:
|
| 47 |
-
return get_rank() == 0
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def gather_with_grad(tensor: Tensor) -> Tensor:
|
| 51 |
-
world_size = get_world_size()
|
| 52 |
-
if world_size == 1:
|
| 53 |
-
return tensor
|
| 54 |
-
return torch.cat(list(differentiable_all_gather(tensor.contiguous())), dim=0)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def gather_variable_with_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
|
| 58 |
-
"""Gather tensors with variable first-dimension lengths across ranks."""
|
| 59 |
-
count_tensor, max_count, keep = _variable_gather_metadata(tensor)
|
| 60 |
-
if get_world_size() == 1:
|
| 61 |
-
return tensor, count_tensor
|
| 62 |
-
return _gather_variable_from_metadata(tensor, max_count, keep), count_tensor
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def gather_variable_many_with_grad(tensors: Sequence[Tensor]) -> tuple[list[Tensor], Tensor]:
|
| 66 |
-
"""Gather same-length variable tensors while sharing count metadata.
|
| 67 |
-
|
| 68 |
-
Tensors with matching dtype/rank/trailing shape are packed along the last
|
| 69 |
-
dimension so a single differentiable all-gather can serve several feature
|
| 70 |
-
tensors with the same variable first dimension.
|
| 71 |
-
"""
|
| 72 |
-
if not tensors:
|
| 73 |
-
raise ValueError("gather_variable_many_with_grad requires at least one tensor")
|
| 74 |
-
first = tensors[0]
|
| 75 |
-
for tensor in tensors:
|
| 76 |
-
if tensor.device != first.device:
|
| 77 |
-
raise ValueError("all tensors must be on the same device")
|
| 78 |
-
if tensor.shape[0] != first.shape[0]:
|
| 79 |
-
raise ValueError("all tensors must have the same first dimension")
|
| 80 |
-
count_tensor, max_count, keep = _variable_gather_metadata(first)
|
| 81 |
-
if get_world_size() == 1:
|
| 82 |
-
return list(tensors), count_tensor
|
| 83 |
-
|
| 84 |
-
gathered: list[Tensor | None] = [None] * len(tensors)
|
| 85 |
-
groups: dict[tuple[torch.dtype, torch.Size, int], list[int]] = {}
|
| 86 |
-
for index, tensor in enumerate(tensors):
|
| 87 |
-
if tensor.dim() == 0:
|
| 88 |
-
raise ValueError("variable gather tensors must have at least one dimension")
|
| 89 |
-
key = (tensor.dtype, tensor.shape[1:-1], tensor.dim()) if tensor.dim() > 1 else (tensor.dtype, torch.Size(), 1)
|
| 90 |
-
groups.setdefault(key, []).append(index)
|
| 91 |
-
|
| 92 |
-
for indices in groups.values():
|
| 93 |
-
group_tensors = [tensors[index] for index in indices]
|
| 94 |
-
if len(group_tensors) == 1 or group_tensors[0].dim() == 1:
|
| 95 |
-
for index, tensor in zip(indices, group_tensors, strict=True):
|
| 96 |
-
gathered[index] = _gather_variable_from_metadata(tensor, max_count, keep)
|
| 97 |
-
continue
|
| 98 |
-
widths = [tensor.shape[-1] for tensor in group_tensors]
|
| 99 |
-
packed = torch.cat(group_tensors, dim=-1)
|
| 100 |
-
gathered_packed = _gather_variable_from_metadata(packed, max_count, keep)
|
| 101 |
-
for index, chunk in zip(indices, gathered_packed.split(widths, dim=-1), strict=True):
|
| 102 |
-
gathered[index] = chunk
|
| 103 |
-
|
| 104 |
-
if any(tensor is None for tensor in gathered):
|
| 105 |
-
raise RuntimeError("internal error while gathering variable tensors")
|
| 106 |
-
return [tensor for tensor in gathered if tensor is not None], count_tensor
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def gather_variable_no_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
|
| 110 |
-
"""Gather variable-length tensors that do not require autograd."""
|
| 111 |
-
count_tensor, max_count, keep = _variable_gather_metadata(tensor)
|
| 112 |
-
if get_world_size() == 1:
|
| 113 |
-
return tensor, count_tensor
|
| 114 |
-
padded = tensor.new_zeros((max_count, *tensor.shape[1:]))
|
| 115 |
-
padded[: tensor.shape[0]] = tensor
|
| 116 |
-
gathered = [torch.zeros_like(padded) for _ in range(get_world_size())]
|
| 117 |
-
dist.all_gather(gathered, padded.contiguous())
|
| 118 |
-
return torch.cat(gathered, dim=0)[keep], count_tensor
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def _variable_gather_metadata(tensor: Tensor) -> tuple[Tensor, int, Tensor]:
|
| 122 |
-
world_size = get_world_size()
|
| 123 |
-
local_count = torch.tensor([tensor.shape[0]], device=tensor.device, dtype=torch.long)
|
| 124 |
-
if world_size == 1:
|
| 125 |
-
keep = torch.ones(tensor.shape[0], device=tensor.device, dtype=torch.bool)
|
| 126 |
-
return local_count, tensor.shape[0], keep
|
| 127 |
-
|
| 128 |
-
counts = [torch.zeros_like(local_count) for _ in range(world_size)]
|
| 129 |
-
dist.all_gather(counts, local_count)
|
| 130 |
-
count_tensor = torch.cat(counts)
|
| 131 |
-
max_count = int(count_tensor.max().item())
|
| 132 |
-
keep = torch.zeros(world_size * max_count, device=tensor.device, dtype=torch.bool)
|
| 133 |
-
for rank, count in enumerate(count_tensor.tolist()):
|
| 134 |
-
start = rank * max_count
|
| 135 |
-
keep[start : start + count] = True
|
| 136 |
-
return count_tensor, max_count, keep
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def _gather_variable_from_metadata(tensor: Tensor, max_count: int, keep: Tensor) -> Tensor:
|
| 140 |
-
padded_shape = (max_count, *tensor.shape[1:])
|
| 141 |
-
padded = tensor.new_zeros(padded_shape)
|
| 142 |
-
padded[: tensor.shape[0]] = tensor
|
| 143 |
-
|
| 144 |
-
gathered = torch.cat(list(differentiable_all_gather(padded.contiguous())), dim=0)
|
| 145 |
-
return gathered[keep]
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def local_target_indices(batch_size: int, device: torch.device) -> Tensor:
|
| 149 |
-
return torch.arange(batch_size, device=device) + batch_size * get_rank()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/training/engine.py
DELETED
|
@@ -1,442 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from datetime import datetime, timezone
|
| 4 |
-
import json
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
import time
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torch import nn
|
| 11 |
-
from torch.optim import AdamW, Optimizer
|
| 12 |
-
from torch.nn.parallel import DistributedDataParallel
|
| 13 |
-
from torch.utils.data import DataLoader, DistributedSampler, IterableDataset
|
| 14 |
-
from torch.amp import GradScaler
|
| 15 |
-
|
| 16 |
-
from hyper3_clip.data import (
|
| 17 |
-
GroundedManifestDataset,
|
| 18 |
-
MixedGroundedIterableDataset,
|
| 19 |
-
ProcessedGritDataset,
|
| 20 |
-
collate_grounded,
|
| 21 |
-
)
|
| 22 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 23 |
-
from hyper3_clip.training.checkpointing import latest_checkpoint, load_checkpoint, save_checkpoint
|
| 24 |
-
from hyper3_clip.training.distributed import (
|
| 25 |
-
barrier,
|
| 26 |
-
destroy_distributed,
|
| 27 |
-
get_local_rank,
|
| 28 |
-
get_rank,
|
| 29 |
-
get_world_size,
|
| 30 |
-
init_distributed,
|
| 31 |
-
is_main_process,
|
| 32 |
-
)
|
| 33 |
-
from hyper3_clip.training.logging import JsonlLogger
|
| 34 |
-
from hyper3_clip.utils.io import ensure_dir, save_yaml, set_seed
|
| 35 |
-
|
| 36 |
-
try:
|
| 37 |
-
from hypercluster.hooks import RunControl
|
| 38 |
-
except ImportError: # pragma: no cover - hypercluster is only present in cluster allocations.
|
| 39 |
-
RunControl = None
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class CosineWithWarmup:
|
| 43 |
-
def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, base_lr: float) -> None:
|
| 44 |
-
self.optimizer = optimizer
|
| 45 |
-
self.warmup_steps = warmup_steps
|
| 46 |
-
self.total_steps = total_steps
|
| 47 |
-
self.base_lr = base_lr
|
| 48 |
-
|
| 49 |
-
def step(self, step_idx: int) -> None:
|
| 50 |
-
if step_idx < self.warmup_steps:
|
| 51 |
-
lr = self.base_lr * float(step_idx + 1) / float(max(1, self.warmup_steps))
|
| 52 |
-
else:
|
| 53 |
-
progress = float(step_idx - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
|
| 54 |
-
lr = self.base_lr * 0.5 * (1.0 + torch.cos(torch.tensor(progress * torch.pi)).item())
|
| 55 |
-
for group in self.optimizer.param_groups:
|
| 56 |
-
group["lr"] = lr
|
| 57 |
-
|
| 58 |
-
def state_dict(self) -> dict[str, int | float]:
|
| 59 |
-
return {"warmup_steps": self.warmup_steps, "total_steps": self.total_steps, "base_lr": self.base_lr}
|
| 60 |
-
|
| 61 |
-
def load_state_dict(self, state: dict[str, int | float]) -> None:
|
| 62 |
-
self.warmup_steps = int(state["warmup_steps"])
|
| 63 |
-
self.total_steps = int(state["total_steps"])
|
| 64 |
-
self.base_lr = float(state["base_lr"])
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def _build_optimizer(model: nn.Module, cfg: dict) -> AdamW:
|
| 68 |
-
no_decay_names = set(cfg.get("optimizer", {}).get("no_decay_params", []))
|
| 69 |
-
decay_params = []
|
| 70 |
-
no_decay_params = []
|
| 71 |
-
|
| 72 |
-
for name, param in model.named_parameters():
|
| 73 |
-
if not param.requires_grad:
|
| 74 |
-
continue
|
| 75 |
-
leaf_name = name.split(".")[-1]
|
| 76 |
-
if param.ndim < 2 or leaf_name in no_decay_names or leaf_name == "bias" or "norm" in name.lower():
|
| 77 |
-
no_decay_params.append(param)
|
| 78 |
-
else:
|
| 79 |
-
decay_params.append(param)
|
| 80 |
-
|
| 81 |
-
return AdamW(
|
| 82 |
-
[
|
| 83 |
-
{"params": decay_params, "weight_decay": cfg["training"]["weight_decay"]},
|
| 84 |
-
{"params": no_decay_params, "weight_decay": 0.0},
|
| 85 |
-
],
|
| 86 |
-
lr=cfg["training"]["lr"],
|
| 87 |
-
betas=tuple(cfg["training"]["betas"]),
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def run_training(config: dict) -> None:
|
| 92 |
-
init_distributed()
|
| 93 |
-
set_seed(config["seed"] + get_rank())
|
| 94 |
-
ensure_dir(config["output_dir"])
|
| 95 |
-
started_at = utc_timestamp()
|
| 96 |
-
if is_main_process():
|
| 97 |
-
save_yaml(Path(config["output_dir"]) / "config.yaml", config)
|
| 98 |
-
write_metadata(config, status="running", started_at=started_at)
|
| 99 |
-
|
| 100 |
-
if torch.cuda.is_available():
|
| 101 |
-
if "LOCAL_RANK" in os.environ:
|
| 102 |
-
device = torch.device(f"cuda:{get_local_rank()}")
|
| 103 |
-
torch.cuda.set_device(device)
|
| 104 |
-
else:
|
| 105 |
-
device = torch.device("cuda")
|
| 106 |
-
else:
|
| 107 |
-
device = torch.device("cpu")
|
| 108 |
-
if device.type == "cuda":
|
| 109 |
-
torch.cuda.reset_peak_memory_stats()
|
| 110 |
-
torch.backends.cudnn.benchmark = bool(config["training"].get("cudnn_benchmark", False))
|
| 111 |
-
|
| 112 |
-
raw_model = Hyper3CLIP(**config["model"]).to(device)
|
| 113 |
-
channels_last = str(config["training"].get("memory_format", "")).lower() == "channels_last"
|
| 114 |
-
if channels_last:
|
| 115 |
-
raw_model = raw_model.to(memory_format=torch.channels_last)
|
| 116 |
-
model: nn.Module = raw_model
|
| 117 |
-
if get_world_size() > 1:
|
| 118 |
-
device_ids = [get_local_rank()] if device.type == "cuda" else None
|
| 119 |
-
model = DistributedDataParallel(
|
| 120 |
-
raw_model,
|
| 121 |
-
device_ids=device_ids,
|
| 122 |
-
broadcast_buffers=False,
|
| 123 |
-
find_unused_parameters=bool(config["training"].get("find_unused_parameters", False)),
|
| 124 |
-
)
|
| 125 |
-
dataset = _build_dataset(config["data"], config["seed"])
|
| 126 |
-
sampler = _build_sampler(dataset)
|
| 127 |
-
local_batch_size = _local_batch_size(config["training"])
|
| 128 |
-
num_workers = config["data"].get("num_workers", config["training"].get("num_workers", 4))
|
| 129 |
-
dataloader_kwargs = {}
|
| 130 |
-
if num_workers > 0:
|
| 131 |
-
dataloader_kwargs["persistent_workers"] = bool(
|
| 132 |
-
config["data"].get("persistent_workers", config["training"].get("persistent_workers", False))
|
| 133 |
-
)
|
| 134 |
-
prefetch_factor = config["data"].get("prefetch_factor", config["training"].get("prefetch_factor"))
|
| 135 |
-
if prefetch_factor is not None:
|
| 136 |
-
dataloader_kwargs["prefetch_factor"] = int(prefetch_factor)
|
| 137 |
-
beta_clip_data_config = config["data"].get("beta_clip", {})
|
| 138 |
-
dataloader = DataLoader(
|
| 139 |
-
dataset,
|
| 140 |
-
batch_size=local_batch_size,
|
| 141 |
-
sampler=sampler,
|
| 142 |
-
shuffle=sampler is None and not isinstance(dataset, IterableDataset),
|
| 143 |
-
num_workers=num_workers,
|
| 144 |
-
pin_memory=bool(config["data"].get("pin_memory", True)),
|
| 145 |
-
drop_last=True,
|
| 146 |
-
collate_fn=lambda x: collate_grounded(
|
| 147 |
-
x,
|
| 148 |
-
tokenizer=raw_model.text_encoder.tokenizer,
|
| 149 |
-
max_text_length=config["data"]["max_text_length"],
|
| 150 |
-
beta_clip_queries=bool(beta_clip_data_config.get("enabled", False)),
|
| 151 |
-
beta_clip_max_sentences=int(beta_clip_data_config.get("max_sentences", 5)),
|
| 152 |
-
beta_clip_max_phrases=int(beta_clip_data_config.get("max_phrases", 30)),
|
| 153 |
-
beta_clip_max_queries_per_image=beta_clip_data_config.get("max_queries_per_image"),
|
| 154 |
-
beta_clip_use_part_texts=bool(beta_clip_data_config.get("use_part_texts", True)),
|
| 155 |
-
),
|
| 156 |
-
**dataloader_kwargs,
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
optimizer = _build_optimizer(model=raw_model, cfg=config)
|
| 160 |
-
scheduler = CosineWithWarmup(
|
| 161 |
-
optimizer=optimizer,
|
| 162 |
-
warmup_steps=config["training"]["warmup_steps"],
|
| 163 |
-
total_steps=config["training"]["total_steps"],
|
| 164 |
-
base_lr=config["training"]["lr"],
|
| 165 |
-
)
|
| 166 |
-
scaler = GradScaler(device.type, enabled=config["training"]["amp"])
|
| 167 |
-
start_step = _resume_step(config, raw_model, optimizer, scheduler, scaler, device)
|
| 168 |
-
run_control = RunControl.from_env() if RunControl is not None else None
|
| 169 |
-
|
| 170 |
-
logger = JsonlLogger(Path(config["output_dir"]) / "train_log.jsonl")
|
| 171 |
-
|
| 172 |
-
model.train()
|
| 173 |
-
step = start_step
|
| 174 |
-
micro_step = 0
|
| 175 |
-
grad_accum_steps = max(1, int(config["training"].get("grad_accum_steps", 1)))
|
| 176 |
-
non_blocking_transfer = bool(config["training"].get("non_blocking_transfer", True))
|
| 177 |
-
micro_batch_global_size = local_batch_size * get_world_size()
|
| 178 |
-
effective_global_batch_size = micro_batch_global_size * grad_accum_steps
|
| 179 |
-
last_step_time = time.perf_counter()
|
| 180 |
-
optimizer.zero_grad(set_to_none=True)
|
| 181 |
-
while step < config["training"]["total_steps"]:
|
| 182 |
-
if sampler is not None:
|
| 183 |
-
sampler.set_epoch(step)
|
| 184 |
-
for batch in dataloader:
|
| 185 |
-
if step >= config["training"]["total_steps"]:
|
| 186 |
-
break
|
| 187 |
-
|
| 188 |
-
if micro_step % grad_accum_steps == 0:
|
| 189 |
-
optimizer.zero_grad(set_to_none=True)
|
| 190 |
-
scheduler.step(step)
|
| 191 |
-
|
| 192 |
-
batch = {k: v.to(device, non_blocking=non_blocking_transfer) for k, v in batch.items()}
|
| 193 |
-
if channels_last:
|
| 194 |
-
batch["image"] = batch["image"].contiguous(memory_format=torch.channels_last)
|
| 195 |
-
batch["part_images"] = batch["part_images"].contiguous(memory_format=torch.channels_last)
|
| 196 |
-
|
| 197 |
-
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=config["training"]["amp"]):
|
| 198 |
-
out = model(**batch, step=step)
|
| 199 |
-
loss = out["loss"] / grad_accum_steps
|
| 200 |
-
|
| 201 |
-
scaler.scale(loss).backward()
|
| 202 |
-
micro_step += 1
|
| 203 |
-
if micro_step % grad_accum_steps != 0:
|
| 204 |
-
continue
|
| 205 |
-
|
| 206 |
-
if config["training"]["max_grad_norm"] > 0:
|
| 207 |
-
scaler.unscale_(optimizer)
|
| 208 |
-
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["training"]["max_grad_norm"])
|
| 209 |
-
else:
|
| 210 |
-
grad_norm = None
|
| 211 |
-
scaler.step(optimizer)
|
| 212 |
-
scaler.update()
|
| 213 |
-
|
| 214 |
-
completed_steps = step + 1
|
| 215 |
-
now = time.perf_counter()
|
| 216 |
-
step_time_seconds = now - last_step_time
|
| 217 |
-
last_step_time = now
|
| 218 |
-
|
| 219 |
-
if completed_steps == 1 or completed_steps % config["training"]["log_interval"] == 0:
|
| 220 |
-
remaining_steps = config["training"]["total_steps"] - completed_steps
|
| 221 |
-
row = {
|
| 222 |
-
"timestamp": datetime.now(timezone.utc).replace(microsecond=0).isoformat(),
|
| 223 |
-
"step": completed_steps,
|
| 224 |
-
"loss": float(out["loss"].detach().cpu().item()),
|
| 225 |
-
"contrastive_loss": float(out["contrastive_loss"].detach().cpu().item()),
|
| 226 |
-
"entailment_loss": float(out["entailment_loss"].detach().cpu().item()),
|
| 227 |
-
"part_count": int(out["part_count"].detach().cpu().item()),
|
| 228 |
-
"kappa": float(out["kappa"].detach().cpu().item()),
|
| 229 |
-
"lr": optimizer.param_groups[0]["lr"],
|
| 230 |
-
"grad_norm": None if grad_norm is None else float(grad_norm.detach().cpu().item()),
|
| 231 |
-
"step_time_seconds": step_time_seconds,
|
| 232 |
-
"steps_per_second": 1.0 / max(step_time_seconds, 1e-12),
|
| 233 |
-
"samples_per_second": effective_global_batch_size / max(step_time_seconds, 1e-12),
|
| 234 |
-
"samples_seen": completed_steps * effective_global_batch_size,
|
| 235 |
-
"progress": completed_steps / config["training"]["total_steps"],
|
| 236 |
-
"eta_seconds": remaining_steps * step_time_seconds,
|
| 237 |
-
"rank": get_rank(),
|
| 238 |
-
"world_size": get_world_size(),
|
| 239 |
-
"local_batch_size": local_batch_size,
|
| 240 |
-
"micro_batch_global_size": micro_batch_global_size,
|
| 241 |
-
"global_batch_size": effective_global_batch_size,
|
| 242 |
-
"grad_accum_steps": grad_accum_steps,
|
| 243 |
-
}
|
| 244 |
-
if device.type == "cuda":
|
| 245 |
-
row["cuda_max_memory_allocated_mb"] = torch.cuda.max_memory_allocated() / (1024**2)
|
| 246 |
-
for key, value in out.items():
|
| 247 |
-
if key in row or key == "loss":
|
| 248 |
-
continue
|
| 249 |
-
if torch.is_tensor(value) and value.numel() == 1:
|
| 250 |
-
row[key] = _scalar_log_value(value)
|
| 251 |
-
if is_main_process():
|
| 252 |
-
logger.write(row)
|
| 253 |
-
print(_format_log_row(row), flush=True)
|
| 254 |
-
|
| 255 |
-
if is_main_process() and completed_steps > 0 and completed_steps % config["training"]["ckpt_interval"] == 0:
|
| 256 |
-
ckpt_path = str(Path(config["output_dir"]) / f"checkpoint_step_{completed_steps}.pt")
|
| 257 |
-
save_checkpoint(ckpt_path, completed_steps, raw_model, optimizer, scheduler, scaler, config)
|
| 258 |
-
|
| 259 |
-
step = completed_steps
|
| 260 |
-
if run_control is not None and run_control.should_pause():
|
| 261 |
-
ckpt_path = str(Path(config["output_dir"]) / f"checkpoint_step_{completed_steps}.pt")
|
| 262 |
-
if is_main_process():
|
| 263 |
-
save_checkpoint(ckpt_path, completed_steps, raw_model, optimizer, scheduler, scaler, config)
|
| 264 |
-
(Path(config["output_dir"]) / "latest_checkpoint.txt").write_text(f"{ckpt_path}\n", encoding="utf-8")
|
| 265 |
-
run_control.report_checkpoint(ckpt_path)
|
| 266 |
-
write_metadata(config, status="paused", started_at=started_at, ended_at=utc_timestamp(), final_step=completed_steps)
|
| 267 |
-
barrier()
|
| 268 |
-
destroy_distributed()
|
| 269 |
-
raise SystemExit(run_control.PAUSED_EXIT_CODE)
|
| 270 |
-
|
| 271 |
-
barrier()
|
| 272 |
-
if is_main_process():
|
| 273 |
-
final_ckpt = str(Path(config["output_dir"]) / "checkpoint_final.pt")
|
| 274 |
-
save_checkpoint(final_ckpt, step, raw_model, optimizer, scheduler, scaler, config)
|
| 275 |
-
write_metadata(config, status="completed", started_at=started_at, ended_at=utc_timestamp(), final_step=step)
|
| 276 |
-
barrier()
|
| 277 |
-
destroy_distributed()
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
def utc_timestamp() -> str:
|
| 281 |
-
return datetime.now(timezone.utc).replace(microsecond=0).isoformat()
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
def write_metadata(
|
| 285 |
-
config: dict,
|
| 286 |
-
*,
|
| 287 |
-
status: str,
|
| 288 |
-
started_at: str,
|
| 289 |
-
ended_at: str | None = None,
|
| 290 |
-
final_step: int | None = None,
|
| 291 |
-
) -> None:
|
| 292 |
-
metadata = {
|
| 293 |
-
"run_id": config["project"]["experiment"],
|
| 294 |
-
"experiment_name": config["project"]["name"],
|
| 295 |
-
"status": status,
|
| 296 |
-
"start_time": started_at,
|
| 297 |
-
"end_time": ended_at,
|
| 298 |
-
"final_step": final_step,
|
| 299 |
-
"tags": {
|
| 300 |
-
"data": config.get("data", {}).get("type", "unknown"),
|
| 301 |
-
"model": config.get("model", {}).get("vision_backbone", "unknown"),
|
| 302 |
-
"objective": config.get("model", {}).get("objective", "hycoclip"),
|
| 303 |
-
},
|
| 304 |
-
"job": {
|
| 305 |
-
"job_id": os.environ.get("JOB_ID") or os.environ.get("SCHEDULER_JOB_ID") or os.environ.get("SLURM_JOB_ID"),
|
| 306 |
-
"partition": os.environ.get("JOB_PARTITION")
|
| 307 |
-
or os.environ.get("SCHEDULER_PARTITION")
|
| 308 |
-
or os.environ.get("SLURM_JOB_PARTITION"),
|
| 309 |
-
"num_nodes": os.environ.get("NUM_NODES") or os.environ.get("SLURM_JOB_NUM_NODES"),
|
| 310 |
-
"node_list": os.environ.get("NODE_LIST") or os.environ.get("SLURM_JOB_NODELIST"),
|
| 311 |
-
"gpus": os.environ.get("GPU_DEVICES") or os.environ.get("SLURM_JOB_GPUS") or os.environ.get("SLURM_GPUS"),
|
| 312 |
-
},
|
| 313 |
-
"env": {
|
| 314 |
-
"hostname": os.environ.get("HOSTNAME"),
|
| 315 |
-
"world_size": str(get_world_size()),
|
| 316 |
-
"rank": str(get_rank()),
|
| 317 |
-
},
|
| 318 |
-
}
|
| 319 |
-
path = Path(config["output_dir"]) / "metadata.json"
|
| 320 |
-
path.write_text(json.dumps(metadata, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
def _build_dataset(data_config: dict, seed: int) -> GroundedManifestDataset | ProcessedGritDataset | MixedGroundedIterableDataset:
|
| 324 |
-
data_type = data_config.get("type")
|
| 325 |
-
if data_type is None:
|
| 326 |
-
data_type = "processed_grit" if data_config.get("tarfiles") else "manifest"
|
| 327 |
-
if data_type == "manifest":
|
| 328 |
-
manifests = data_config.get("manifests") or data_config.get("manifest")
|
| 329 |
-
if manifests is None:
|
| 330 |
-
raise ValueError("Manifest training requires data.manifests or data.manifest")
|
| 331 |
-
return GroundedManifestDataset(
|
| 332 |
-
manifests=manifests,
|
| 333 |
-
image_size=data_config["image_size"],
|
| 334 |
-
seed=seed,
|
| 335 |
-
manifest_weights=data_config.get("manifest_weights"),
|
| 336 |
-
part_sampling=data_config.get("part_sampling", "random_one"),
|
| 337 |
-
max_parts=data_config.get("max_parts"),
|
| 338 |
-
train_transform=data_config.get("train_transform", "wide_random_crop"),
|
| 339 |
-
image_normalization=data_config.get("image_normalization", "imagenet"),
|
| 340 |
-
)
|
| 341 |
-
if data_type == "processed_grit":
|
| 342 |
-
return ProcessedGritDataset(
|
| 343 |
-
tarfiles=data_config["tarfiles"],
|
| 344 |
-
image_size=data_config["image_size"],
|
| 345 |
-
seed=seed,
|
| 346 |
-
shuffle_buffer=data_config.get("shuffle_buffer", 4000),
|
| 347 |
-
part_sampling=data_config.get("part_sampling", "random_one"),
|
| 348 |
-
max_parts=data_config.get("max_parts"),
|
| 349 |
-
train_transform=data_config.get("train_transform", "wide_random_crop"),
|
| 350 |
-
image_normalization=data_config.get("image_normalization", "imagenet"),
|
| 351 |
-
deterministic_transforms=data_config.get("deterministic_transforms", False),
|
| 352 |
-
)
|
| 353 |
-
if data_type == "mixed_processed_grit_manifest":
|
| 354 |
-
manifest_config = data_config.get("manifest_data", {})
|
| 355 |
-
manifests = manifest_config.get("manifests") or manifest_config.get("manifest") or data_config.get("manifests")
|
| 356 |
-
if manifests is None:
|
| 357 |
-
raise ValueError("Mixed GRIT+manifest training requires data.manifest_data.manifests")
|
| 358 |
-
primary = ProcessedGritDataset(
|
| 359 |
-
tarfiles=data_config["tarfiles"],
|
| 360 |
-
image_size=data_config["image_size"],
|
| 361 |
-
seed=seed,
|
| 362 |
-
shuffle_buffer=data_config.get("shuffle_buffer", 4000),
|
| 363 |
-
part_sampling=data_config.get("part_sampling", "random_one"),
|
| 364 |
-
max_parts=data_config.get("max_parts"),
|
| 365 |
-
train_transform=data_config.get("train_transform", "wide_random_crop"),
|
| 366 |
-
image_normalization=data_config.get("image_normalization", "imagenet"),
|
| 367 |
-
deterministic_transforms=data_config.get("deterministic_transforms", False),
|
| 368 |
-
)
|
| 369 |
-
auxiliary = GroundedManifestDataset(
|
| 370 |
-
manifests=manifests,
|
| 371 |
-
image_size=manifest_config.get("image_size", data_config["image_size"]),
|
| 372 |
-
seed=seed + 47,
|
| 373 |
-
manifest_weights=manifest_config.get("manifest_weights"),
|
| 374 |
-
part_sampling=manifest_config.get("part_sampling", data_config.get("manifest_part_sampling", "all")),
|
| 375 |
-
max_parts=manifest_config.get("max_parts", data_config.get("manifest_max_parts")),
|
| 376 |
-
train_transform=manifest_config.get("train_transform", data_config.get("train_transform", "wide_random_crop")),
|
| 377 |
-
image_normalization=manifest_config.get("image_normalization", data_config.get("image_normalization", "imagenet")),
|
| 378 |
-
)
|
| 379 |
-
return MixedGroundedIterableDataset(
|
| 380 |
-
primary=primary,
|
| 381 |
-
auxiliary=auxiliary,
|
| 382 |
-
auxiliary_probability=float(data_config.get("manifest_probability", 0.15)),
|
| 383 |
-
seed=seed,
|
| 384 |
-
)
|
| 385 |
-
raise ValueError(f"Unsupported data.type {data_type!r}")
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
def _build_sampler(dataset: GroundedManifestDataset | ProcessedGritDataset | MixedGroundedIterableDataset) -> DistributedSampler | None:
|
| 389 |
-
if get_world_size() == 1 or isinstance(dataset, IterableDataset):
|
| 390 |
-
return None
|
| 391 |
-
return DistributedSampler(dataset, num_replicas=get_world_size(), rank=get_rank(), shuffle=True, drop_last=True)
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
def _local_batch_size(training_config: dict) -> int:
|
| 395 |
-
if "batch_size" in training_config:
|
| 396 |
-
return int(training_config["batch_size"])
|
| 397 |
-
global_batch_size = int(training_config["global_batch_size"])
|
| 398 |
-
if global_batch_size % get_world_size() != 0:
|
| 399 |
-
raise ValueError("training.global_batch_size must be divisible by world size")
|
| 400 |
-
return global_batch_size // get_world_size()
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
def _resume_step(
|
| 404 |
-
config: dict,
|
| 405 |
-
model: nn.Module,
|
| 406 |
-
optimizer: Optimizer,
|
| 407 |
-
scheduler: CosineWithWarmup,
|
| 408 |
-
scaler: GradScaler,
|
| 409 |
-
device: torch.device,
|
| 410 |
-
) -> int:
|
| 411 |
-
training_config = config["training"]
|
| 412 |
-
resume_env = training_config.get("resume_from_env", "RESUME_FROM_CHECKPOINT")
|
| 413 |
-
resume_path = os.environ.get(str(resume_env)) if resume_env else None
|
| 414 |
-
if resume_path is None:
|
| 415 |
-
resume_path = training_config.get("resume_from")
|
| 416 |
-
if resume_path is None and training_config.get("resume", False):
|
| 417 |
-
resume_path = latest_checkpoint(config["output_dir"])
|
| 418 |
-
if resume_path is None:
|
| 419 |
-
return 0
|
| 420 |
-
return load_checkpoint(
|
| 421 |
-
resume_path,
|
| 422 |
-
model,
|
| 423 |
-
optimizer,
|
| 424 |
-
scheduler,
|
| 425 |
-
scaler,
|
| 426 |
-
device,
|
| 427 |
-
model_only=bool(training_config.get("resume_model_only", False)),
|
| 428 |
-
strict_model=bool(training_config.get("resume_strict_model", True)),
|
| 429 |
-
)
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
def _format_log_row(row: dict) -> str:
|
| 433 |
-
return " ".join(f"{key}={value}" for key, value in row.items())
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
def _scalar_log_value(value: torch.Tensor) -> float | int:
|
| 437 |
-
detached = value.detach().cpu()
|
| 438 |
-
if detached.dtype == torch.bool:
|
| 439 |
-
return int(detached.item())
|
| 440 |
-
if detached.dtype in (torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8):
|
| 441 |
-
return int(detached.item())
|
| 442 |
-
return float(detached.item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/training/logging.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import json
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Any
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class JsonlLogger:
|
| 9 |
-
def __init__(self, path: str | Path) -> None:
|
| 10 |
-
self.path = Path(path)
|
| 11 |
-
self.path.parent.mkdir(parents=True, exist_ok=True)
|
| 12 |
-
|
| 13 |
-
def write(self, row: dict[str, Any]) -> None:
|
| 14 |
-
with self.path.open("a", encoding="utf-8") as handle:
|
| 15 |
-
handle.write(json.dumps(row) + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/utils/io.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import random
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import yaml
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def load_yaml(path: str) -> dict:
|
| 12 |
-
with Path(path).open("r", encoding="utf-8") as f:
|
| 13 |
-
return yaml.safe_load(f)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def save_yaml(path: str | Path, payload: dict) -> None:
|
| 17 |
-
with Path(path).open("w", encoding="utf-8") as f:
|
| 18 |
-
yaml.safe_dump(payload, f, sort_keys=False)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def ensure_dir(path: str) -> None:
|
| 22 |
-
Path(path).mkdir(parents=True, exist_ok=True)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def set_seed(seed: int) -> None:
|
| 26 |
-
random.seed(seed)
|
| 27 |
-
np.random.seed(seed)
|
| 28 |
-
torch.manual_seed(seed)
|
| 29 |
-
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip_provider.py
DELETED
|
@@ -1,133 +0,0 @@
|
|
| 1 |
-
"""HyperView embedding provider for the Hyper3-CLIP v0.5 HF checkpoint."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch
|
| 11 |
-
import yaml
|
| 12 |
-
from huggingface_hub import snapshot_download
|
| 13 |
-
from lancedb.embeddings import EmbeddingFunction
|
| 14 |
-
from pydantic import PrivateAttr
|
| 15 |
-
from safetensors.torch import load_file
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class Hyper3ClipEmbeddings(EmbeddingFunction):
|
| 19 |
-
"""Image embeddings from Hyper3-CLIP v0.5 in Lorentz/hyperboloid space."""
|
| 20 |
-
|
| 21 |
-
name: str = "hyper3labs/hyper3-clip-v0.5"
|
| 22 |
-
batch_size: int = 8
|
| 23 |
-
device: str = "cpu"
|
| 24 |
-
|
| 25 |
-
_model: Any = PrivateAttr(default=None)
|
| 26 |
-
_transform: Any = PrivateAttr(default=None)
|
| 27 |
-
|
| 28 |
-
@property
|
| 29 |
-
def geometry(self) -> str:
|
| 30 |
-
return "hyperboloid"
|
| 31 |
-
|
| 32 |
-
@property
|
| 33 |
-
def curvature(self) -> float:
|
| 34 |
-
self._ensure_model()
|
| 35 |
-
return float(self._model._kappa().detach().cpu().reshape(-1)[0].item())
|
| 36 |
-
|
| 37 |
-
def ndims(self) -> int:
|
| 38 |
-
return 513
|
| 39 |
-
|
| 40 |
-
def _ensure_model(self) -> None:
|
| 41 |
-
if self._model is not None:
|
| 42 |
-
return
|
| 43 |
-
|
| 44 |
-
from hyper3_clip import Hyper3CLIP
|
| 45 |
-
from torchvision import transforms
|
| 46 |
-
|
| 47 |
-
token = os.environ.get("HF_TOKEN")
|
| 48 |
-
local_dir = snapshot_download(
|
| 49 |
-
self.name,
|
| 50 |
-
allow_patterns=["config.yaml", "model.safetensors"],
|
| 51 |
-
token=token,
|
| 52 |
-
)
|
| 53 |
-
root = Path(local_dir)
|
| 54 |
-
config = yaml.safe_load((root / "config.yaml").read_text(encoding="utf-8"))
|
| 55 |
-
|
| 56 |
-
model = Hyper3CLIP(**config["model"])
|
| 57 |
-
state = load_file(root / "model.safetensors", device="cpu")
|
| 58 |
-
state = _normalize_checkpoint_keys(state, model)
|
| 59 |
-
model.load_state_dict(state)
|
| 60 |
-
model.to(torch.device(self.device))
|
| 61 |
-
model.eval()
|
| 62 |
-
|
| 63 |
-
self._model = model
|
| 64 |
-
image_size = int(config.get("data", {}).get("image_size", 224))
|
| 65 |
-
self._transform = transforms.Compose(
|
| 66 |
-
[
|
| 67 |
-
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 68 |
-
transforms.CenterCrop(image_size),
|
| 69 |
-
transforms.ToTensor(),
|
| 70 |
-
transforms.Normalize(
|
| 71 |
-
mean=(0.485, 0.456, 0.406),
|
| 72 |
-
std=(0.229, 0.224, 0.225),
|
| 73 |
-
),
|
| 74 |
-
]
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
def compute_source_embeddings(
|
| 78 |
-
self,
|
| 79 |
-
inputs: Any,
|
| 80 |
-
*args: Any,
|
| 81 |
-
**kwargs: Any,
|
| 82 |
-
) -> list[np.ndarray | None]:
|
| 83 |
-
from PIL import Image
|
| 84 |
-
from hyperview.core.sample import Sample
|
| 85 |
-
|
| 86 |
-
self._ensure_model()
|
| 87 |
-
device = torch.device(self.device)
|
| 88 |
-
images = []
|
| 89 |
-
for item in self.sanitize_input(inputs):
|
| 90 |
-
if isinstance(item, Sample):
|
| 91 |
-
with item.load_image() as img:
|
| 92 |
-
images.append(img.convert("RGB"))
|
| 93 |
-
elif isinstance(item, str):
|
| 94 |
-
with Image.open(item) as img:
|
| 95 |
-
images.append(img.convert("RGB"))
|
| 96 |
-
elif isinstance(item, Image.Image):
|
| 97 |
-
images.append(item.convert("RGB"))
|
| 98 |
-
else:
|
| 99 |
-
raise TypeError(f"Unsupported input type: {type(item)}")
|
| 100 |
-
|
| 101 |
-
outputs: list[np.ndarray | None] = []
|
| 102 |
-
with torch.inference_mode():
|
| 103 |
-
for start in range(0, len(images), self.batch_size):
|
| 104 |
-
batch = images[start:start + self.batch_size]
|
| 105 |
-
tensor = torch.stack([self._transform(image) for image in batch]).to(device)
|
| 106 |
-
encoded = self._model.encode_image(tensor).detach().cpu().numpy().astype(np.float32)
|
| 107 |
-
outputs.extend(encoded)
|
| 108 |
-
return outputs
|
| 109 |
-
|
| 110 |
-
def compute_query_embeddings(
|
| 111 |
-
self,
|
| 112 |
-
query: Any,
|
| 113 |
-
*args: Any,
|
| 114 |
-
**kwargs: Any,
|
| 115 |
-
) -> list[np.ndarray | None]:
|
| 116 |
-
return self.compute_source_embeddings([query], *args, **kwargs)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def _normalize_checkpoint_keys(state: dict[str, torch.Tensor], model: torch.nn.Module) -> dict[str, torch.Tensor]:
|
| 120 |
-
"""Handle CLIPTextModel wrapper key drift between training and Space runtime."""
|
| 121 |
-
model_keys = set(model.state_dict())
|
| 122 |
-
old_prefix = "text_encoder.backbone.text_model."
|
| 123 |
-
new_prefix = "text_encoder.backbone."
|
| 124 |
-
if not any(key.startswith(old_prefix) for key in state):
|
| 125 |
-
return state
|
| 126 |
-
if any(key.startswith(old_prefix) for key in model_keys):
|
| 127 |
-
return state
|
| 128 |
-
|
| 129 |
-
normalized: dict[str, torch.Tensor] = {}
|
| 130 |
-
for key, value in state.items():
|
| 131 |
-
candidate = new_prefix + key[len(old_prefix):] if key.startswith(old_prefix) else key
|
| 132 |
-
normalized[candidate if candidate in model_keys else key] = value
|
| 133 |
-
return normalized
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|