github-actions[bot] commited on
Commit ·
ccf4b11
1
Parent(s): 1779c3f
Deploy hyper3labs/HyperView-ABO-Catalog from Hyper3Labs/hyperview-spaces@fd3578c
Browse files- Dockerfile +4 -8
- README.md +9 -8
- demo.py +2 -14
- hyper3_clip/__init__.py +0 -3
- hyper3_clip/models/__init__.py +0 -3
- hyper3_clip/models/encoders.py +0 -173
- hyper3_clip/models/experimental.py +0 -587
- hyper3_clip/models/himo.py +0 -55
- hyper3_clip/models/hyper3_clip.py +0 -958
- hyper3_clip/models/lorentz.py +0 -265
- hyper3_clip/models/losses.py +0 -1400
- hyper3_clip/models/objectives.py +0 -580
- hyper3_clip/models/tren.py +0 -255
- hyper3_clip/training/__init__.py +0 -1
- hyper3_clip/training/distributed.py +0 -149
- hyper3_clip_provider.py +0 -115
Dockerfile
CHANGED
|
@@ -20,7 +20,8 @@ WORKDIR $HOME/app
|
|
| 20 |
|
| 21 |
RUN pip install --upgrade pip
|
| 22 |
|
| 23 |
-
ARG HYPERVIEW_VERSION=0.6.
|
|
|
|
| 24 |
|
| 25 |
# Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
|
| 26 |
RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
|
@@ -33,14 +34,9 @@ import hyperview as hv
|
|
| 33 |
print("hyperview", hv.__version__, inspect.signature(hv.launch))
|
| 34 |
PY
|
| 35 |
RUN pip install \
|
|
|
|
| 36 |
"datasets>=4.5.0" \
|
| 37 |
-
"Pillow>=12.0.0"
|
| 38 |
-
"timm>=1.0.0" \
|
| 39 |
-
"transformers==4.49.0" \
|
| 40 |
-
"safetensors>=0.4.0" \
|
| 41 |
-
"pyyaml>=6.0.0" \
|
| 42 |
-
"sentencepiece>=0.2.0" \
|
| 43 |
-
"protobuf>=4.25.0"
|
| 44 |
|
| 45 |
COPY --chown=user . .
|
| 46 |
|
|
|
|
| 20 |
|
| 21 |
RUN pip install --upgrade pip
|
| 22 |
|
| 23 |
+
ARG HYPERVIEW_VERSION=0.6.1
|
| 24 |
+
ARG HYPER_MODELS_VERSION=0.3.0
|
| 25 |
|
| 26 |
# Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
|
| 27 |
RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
|
|
|
| 34 |
print("hyperview", hv.__version__, inspect.signature(hv.launch))
|
| 35 |
PY
|
| 36 |
RUN pip install \
|
| 37 |
+
"hyper-models[ml]==${HYPER_MODELS_VERSION}" \
|
| 38 |
"datasets>=4.5.0" \
|
| 39 |
+
"Pillow>=12.0.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
COPY --chown=user . .
|
| 42 |
|
README.md
CHANGED
|
@@ -14,7 +14,7 @@ This demo builds a small Amazon Berkeley Objects product-catalog subset and open
|
|
| 14 |
HyperView with two pinned scatter panels plus a comparison readout:
|
| 15 |
|
| 16 |
- CLIP ViT-B/32 in a Euclidean 2D layout
|
| 17 |
-
- Hyper3-CLIP `
|
| 18 |
|
| 19 |
The right-side panel uses fixed product examples to compare nearest-neighbor
|
| 20 |
behavior for the same query under each model.
|
|
@@ -45,7 +45,7 @@ variables or edit the second entry in `MODEL_SPECS`:
|
|
| 45 |
|
| 46 |
```bash
|
| 47 |
ABO_CANDIDATE_DISPLAY_NAME="New Model" \
|
| 48 |
-
ABO_CANDIDATE_PROVIDER="
|
| 49 |
ABO_CANDIDATE_MODEL="new-model-id" \
|
| 50 |
ABO_CANDIDATE_LAYOUT="poincare:2d" \
|
| 51 |
ABO_CANDIDATE_GEOMETRY="poincare" \
|
|
@@ -61,10 +61,11 @@ JavaScript.
|
|
| 61 |
This folder is intended to deploy to `hyper3labs/HyperView-ABO-Catalog` from
|
| 62 |
the `hyperview-spaces` deployment repository.
|
| 63 |
|
| 64 |
-
The Dockerfile installs `hyperview==0.6.
|
| 65 |
-
wheel includes the built frontend assets, so this
|
| 66 |
-
`static/` bundle or copy frontend files into the
|
|
|
|
| 67 |
|
| 68 |
-
Hyper3-CLIP weights are loaded
|
| 69 |
-
`hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs
|
| 70 |
-
`HF_TOKEN` secret with access to that model.
|
|
|
|
| 14 |
HyperView with two pinned scatter panels plus a comparison readout:
|
| 15 |
|
| 16 |
- CLIP ViT-B/32 in a Euclidean 2D layout
|
| 17 |
+
- Hyper3-CLIP `hyper3-clip-v0.5` from `hyper-models` in a Poincare 2D layout
|
| 18 |
|
| 19 |
The right-side panel uses fixed product examples to compare nearest-neighbor
|
| 20 |
behavior for the same query under each model.
|
|
|
|
| 45 |
|
| 46 |
```bash
|
| 47 |
ABO_CANDIDATE_DISPLAY_NAME="New Model" \
|
| 48 |
+
ABO_CANDIDATE_PROVIDER="hyper-models" \
|
| 49 |
ABO_CANDIDATE_MODEL="new-model-id" \
|
| 50 |
ABO_CANDIDATE_LAYOUT="poincare:2d" \
|
| 51 |
ABO_CANDIDATE_GEOMETRY="poincare" \
|
|
|
|
| 61 |
This folder is intended to deploy to `hyper3labs/HyperView-ABO-Catalog` from
|
| 62 |
the `hyperview-spaces` deployment repository.
|
| 63 |
|
| 64 |
+
The Dockerfile installs `hyperview==0.6.1` and `hyper-models[ml]==0.3.0` from
|
| 65 |
+
PyPI. The released HyperView wheel includes the built frontend assets, so this
|
| 66 |
+
Space does not carry a local `static/` bundle or copy frontend files into the
|
| 67 |
+
installed package.
|
| 68 |
|
| 69 |
+
Hyper3-CLIP weights are loaded through the `hyper-models` catalog entry for the
|
| 70 |
+
gated `hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs
|
| 71 |
+
an `HF_TOKEN` secret with access to that model.
|
demo.py
CHANGED
|
@@ -64,8 +64,8 @@ MODEL_SPECS = [
|
|
| 64 |
"key": "candidate",
|
| 65 |
"display_name": os.environ.get("ABO_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
|
| 66 |
"button_label": os.environ.get("ABO_CANDIDATE_BUTTON_LABEL", "Hyper3-CLIP query"),
|
| 67 |
-
"provider": os.environ.get("ABO_CANDIDATE_PROVIDER", "
|
| 68 |
-
"model": os.environ.get("ABO_CANDIDATE_MODEL", "
|
| 69 |
"layout": os.environ.get("ABO_CANDIDATE_LAYOUT", "poincare:2d"),
|
| 70 |
"geometry": os.environ.get("ABO_CANDIDATE_GEOMETRY", "poincare"),
|
| 71 |
"layout_dimension": int(os.environ.get("ABO_CANDIDATE_LAYOUT_DIMENSION", "2")),
|
|
@@ -341,17 +341,6 @@ def supported_kwargs(func: Any, kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
| 341 |
return {key: value for key, value in kwargs.items() if key in params}
|
| 342 |
|
| 343 |
|
| 344 |
-
def register_hyper3_clip_provider() -> None:
|
| 345 |
-
from hyperview.runtime import ProviderRegistry
|
| 346 |
-
|
| 347 |
-
ProviderRegistry().register_python(
|
| 348 |
-
"hyper3-clip",
|
| 349 |
-
"hyper3_clip_provider:Hyper3ClipEmbeddings",
|
| 350 |
-
description="Hyper3-CLIP v0.5 image embeddings from hyper3labs/hyper3-clip-v0.5",
|
| 351 |
-
overwrite=True,
|
| 352 |
-
)
|
| 353 |
-
|
| 354 |
-
|
| 355 |
def api_base_url() -> str:
|
| 356 |
host = "127.0.0.1" if SPACE_HOST == "0.0.0.0" else SPACE_HOST
|
| 357 |
return f"http://{host}:{SPACE_PORT}"
|
|
@@ -500,7 +489,6 @@ def launch_demo(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.Session:
|
|
| 500 |
|
| 501 |
|
| 502 |
def main() -> None:
|
| 503 |
-
register_hyper3_clip_provider()
|
| 504 |
dataset, layouts = build_dataset()
|
| 505 |
print("Layouts:", flush=True)
|
| 506 |
for spec in MODEL_SPECS:
|
|
|
|
| 64 |
"key": "candidate",
|
| 65 |
"display_name": os.environ.get("ABO_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
|
| 66 |
"button_label": os.environ.get("ABO_CANDIDATE_BUTTON_LABEL", "Hyper3-CLIP query"),
|
| 67 |
+
"provider": os.environ.get("ABO_CANDIDATE_PROVIDER", "hyper-models"),
|
| 68 |
+
"model": os.environ.get("ABO_CANDIDATE_MODEL", "hyper3-clip-v0.5"),
|
| 69 |
"layout": os.environ.get("ABO_CANDIDATE_LAYOUT", "poincare:2d"),
|
| 70 |
"geometry": os.environ.get("ABO_CANDIDATE_GEOMETRY", "poincare"),
|
| 71 |
"layout_dimension": int(os.environ.get("ABO_CANDIDATE_LAYOUT_DIMENSION", "2")),
|
|
|
|
| 341 |
return {key: value for key, value in kwargs.items() if key in params}
|
| 342 |
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
def api_base_url() -> str:
|
| 345 |
host = "127.0.0.1" if SPACE_HOST == "0.0.0.0" else SPACE_HOST
|
| 346 |
return f"http://{host}:{SPACE_PORT}"
|
|
|
|
| 489 |
|
| 490 |
|
| 491 |
def main() -> None:
|
|
|
|
| 492 |
dataset, layouts = build_dataset()
|
| 493 |
print("Layouts:", flush=True)
|
| 494 |
for spec in MODEL_SPECS:
|
hyper3_clip/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 2 |
-
|
| 3 |
-
__all__ = ["Hyper3CLIP"]
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from hyper3_clip.models.hyper3_clip import Hyper3CLIP
|
| 2 |
-
|
| 3 |
-
__all__ = ["Hyper3CLIP"]
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/encoders.py
DELETED
|
@@ -1,173 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import timm
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from transformers import (
|
| 7 |
-
AutoConfig,
|
| 8 |
-
AutoModel,
|
| 9 |
-
AutoTokenizer,
|
| 10 |
-
CLIPTextConfig,
|
| 11 |
-
CLIPTextModel,
|
| 12 |
-
CLIPTextModelWithProjection,
|
| 13 |
-
CLIPVisionConfig,
|
| 14 |
-
CLIPVisionModel,
|
| 15 |
-
CLIPVisionModelWithProjection,
|
| 16 |
-
SiglipTextConfig,
|
| 17 |
-
SiglipTextModel,
|
| 18 |
-
SiglipVisionConfig,
|
| 19 |
-
SiglipVisionModel,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class VisionEncoder(nn.Module):
|
| 24 |
-
def __init__(self, backbone_name: str, pretrained: bool = True) -> None:
|
| 25 |
-
super().__init__()
|
| 26 |
-
self.kind = "timm"
|
| 27 |
-
if backbone_name.startswith("hf_clip_projected:"):
|
| 28 |
-
self.kind = "hf_clip_projected"
|
| 29 |
-
model_name = backbone_name.removeprefix("hf_clip_projected:")
|
| 30 |
-
self.backbone = (
|
| 31 |
-
CLIPVisionModelWithProjection.from_pretrained(model_name)
|
| 32 |
-
if pretrained
|
| 33 |
-
else CLIPVisionModelWithProjection(CLIPVisionConfig.from_pretrained(model_name))
|
| 34 |
-
)
|
| 35 |
-
self.output_dim = self.backbone.config.projection_dim
|
| 36 |
-
elif backbone_name.startswith("hf_clip:"):
|
| 37 |
-
self.kind = "hf_vision"
|
| 38 |
-
model_name = backbone_name.removeprefix("hf_clip:")
|
| 39 |
-
self.backbone = (
|
| 40 |
-
CLIPVisionModel.from_pretrained(model_name)
|
| 41 |
-
if pretrained
|
| 42 |
-
else CLIPVisionModel(CLIPVisionConfig.from_pretrained(model_name))
|
| 43 |
-
)
|
| 44 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 45 |
-
elif backbone_name.startswith("hf_siglip:"):
|
| 46 |
-
self.kind = "hf_vision"
|
| 47 |
-
model_name = backbone_name.removeprefix("hf_siglip:")
|
| 48 |
-
self.backbone = (
|
| 49 |
-
SiglipVisionModel.from_pretrained(model_name)
|
| 50 |
-
if pretrained
|
| 51 |
-
else SiglipVisionModel(SiglipVisionConfig.from_pretrained(model_name))
|
| 52 |
-
)
|
| 53 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 54 |
-
else:
|
| 55 |
-
self.backbone = timm.create_model(
|
| 56 |
-
backbone_name,
|
| 57 |
-
pretrained=pretrained,
|
| 58 |
-
num_classes=0,
|
| 59 |
-
global_pool="avg",
|
| 60 |
-
)
|
| 61 |
-
self.output_dim = self.backbone.num_features
|
| 62 |
-
|
| 63 |
-
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 64 |
-
if self.kind == "hf_clip_projected":
|
| 65 |
-
return self.backbone(pixel_values=image).image_embeds
|
| 66 |
-
if self.kind == "hf_vision":
|
| 67 |
-
out = self.backbone(pixel_values=image)
|
| 68 |
-
if hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 69 |
-
return out.pooler_output
|
| 70 |
-
return out.last_hidden_state[:, 0]
|
| 71 |
-
return self.backbone(image)
|
| 72 |
-
|
| 73 |
-
def forward_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 74 |
-
if self.kind == "hf_clip_projected":
|
| 75 |
-
out = self.backbone(pixel_values=image)
|
| 76 |
-
tokens = getattr(out, "last_hidden_state", None)
|
| 77 |
-
if tokens is None and hasattr(out, "vision_model_output"):
|
| 78 |
-
tokens = out.vision_model_output.last_hidden_state
|
| 79 |
-
if tokens is None:
|
| 80 |
-
raise RuntimeError("Projected CLIP vision output did not include patch tokens")
|
| 81 |
-
return out.image_embeds, tokens
|
| 82 |
-
if self.kind == "hf_vision":
|
| 83 |
-
out = self.backbone(pixel_values=image)
|
| 84 |
-
if hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 85 |
-
pooled = out.pooler_output
|
| 86 |
-
else:
|
| 87 |
-
pooled = out.last_hidden_state[:, 0]
|
| 88 |
-
return pooled, out.last_hidden_state
|
| 89 |
-
|
| 90 |
-
if not hasattr(self.backbone, "forward_features"):
|
| 91 |
-
pooled = self.backbone(image)
|
| 92 |
-
return pooled, pooled[:, None, :]
|
| 93 |
-
features = self.backbone.forward_features(image)
|
| 94 |
-
if hasattr(self.backbone, "forward_head"):
|
| 95 |
-
pooled = self.backbone.forward_head(features, pre_logits=False)
|
| 96 |
-
else:
|
| 97 |
-
pooled = self.backbone(image)
|
| 98 |
-
return pooled, _tokens_from_features(features)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
class TextEncoder(nn.Module):
|
| 102 |
-
def __init__(self, model_name: str, pretrained: bool = True, pooling: str = "auto") -> None:
|
| 103 |
-
super().__init__()
|
| 104 |
-
if pooling not in {"auto", "pooler", "cls", "mean"}:
|
| 105 |
-
raise ValueError(f"Unsupported text pooling {pooling!r}; expected auto, pooler, cls, or mean")
|
| 106 |
-
self.kind = "hf_text"
|
| 107 |
-
self.pooling = pooling
|
| 108 |
-
tokenizer_name = model_name.removeprefix("hf_clip_projected:")
|
| 109 |
-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 110 |
-
model_name_lower = model_name.lower()
|
| 111 |
-
if model_name.startswith("hf_clip_projected:"):
|
| 112 |
-
self.kind = "hf_clip_projected"
|
| 113 |
-
projected_model_name = model_name.removeprefix("hf_clip_projected:")
|
| 114 |
-
if pretrained:
|
| 115 |
-
self.backbone = CLIPTextModelWithProjection.from_pretrained(projected_model_name)
|
| 116 |
-
else:
|
| 117 |
-
self.backbone = CLIPTextModelWithProjection(CLIPTextConfig.from_pretrained(projected_model_name))
|
| 118 |
-
self.output_dim = self.backbone.config.projection_dim
|
| 119 |
-
elif "siglip" in model_name_lower:
|
| 120 |
-
if pretrained:
|
| 121 |
-
self.backbone = SiglipTextModel.from_pretrained(model_name)
|
| 122 |
-
else:
|
| 123 |
-
self.backbone = SiglipTextModel(SiglipTextConfig.from_pretrained(model_name))
|
| 124 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 125 |
-
elif "clip" in model_name_lower:
|
| 126 |
-
if pretrained:
|
| 127 |
-
self.backbone = CLIPTextModel.from_pretrained(model_name)
|
| 128 |
-
else:
|
| 129 |
-
self.backbone = CLIPTextModel(CLIPTextConfig.from_pretrained(model_name))
|
| 130 |
-
self.output_dim = self.backbone.config.hidden_size
|
| 131 |
-
else:
|
| 132 |
-
if pretrained:
|
| 133 |
-
self.backbone = AutoModel.from_pretrained(model_name)
|
| 134 |
-
else:
|
| 135 |
-
self.backbone = AutoModel.from_config(AutoConfig.from_pretrained(model_name))
|
| 136 |
-
hidden_size = getattr(self.backbone.config, "hidden_size", None)
|
| 137 |
-
if hidden_size is None:
|
| 138 |
-
raise ValueError(f"Unsupported text model config for {model_name}")
|
| 139 |
-
self.output_dim = hidden_size
|
| 140 |
-
|
| 141 |
-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 142 |
-
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
|
| 143 |
-
if self.kind == "hf_clip_projected":
|
| 144 |
-
return out.text_embeds
|
| 145 |
-
if self.pooling == "mean":
|
| 146 |
-
mask = attention_mask.to(dtype=out.last_hidden_state.dtype).unsqueeze(-1)
|
| 147 |
-
summed = (out.last_hidden_state * mask).sum(dim=1)
|
| 148 |
-
denom = mask.sum(dim=1).clamp_min(1.0)
|
| 149 |
-
return summed / denom
|
| 150 |
-
if self.pooling in {"auto", "pooler"} and hasattr(out, "pooler_output") and out.pooler_output is not None:
|
| 151 |
-
return out.pooler_output
|
| 152 |
-
return out.last_hidden_state[:, 0]
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def _tokens_from_features(features: torch.Tensor | dict | tuple | list) -> torch.Tensor:
|
| 156 |
-
if isinstance(features, dict):
|
| 157 |
-
for key in ("x", "last_hidden_state", "features"):
|
| 158 |
-
if key in features:
|
| 159 |
-
features = features[key]
|
| 160 |
-
break
|
| 161 |
-
else:
|
| 162 |
-
features = next(iter(features.values()))
|
| 163 |
-
if isinstance(features, tuple | list):
|
| 164 |
-
features = features[0]
|
| 165 |
-
if not torch.is_tensor(features):
|
| 166 |
-
raise TypeError(f"Expected tensor features, got {type(features)!r}")
|
| 167 |
-
if features.ndim == 4:
|
| 168 |
-
return features.flatten(2).transpose(1, 2)
|
| 169 |
-
if features.ndim == 3:
|
| 170 |
-
return features
|
| 171 |
-
if features.ndim == 2:
|
| 172 |
-
return features[:, None, :]
|
| 173 |
-
raise ValueError(f"Unsupported feature tensor shape {tuple(features.shape)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/experimental.py
DELETED
|
@@ -1,587 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from collections.abc import Callable
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from torch import Tensor, nn
|
| 8 |
-
|
| 9 |
-
from hyper3_clip.models.lorentz import exp_map0, metric_pairwise_dist
|
| 10 |
-
from hyper3_clip.models.losses import beta_cal_loss
|
| 11 |
-
from hyper3_clip.models.tren import TRENRegionEncoder
|
| 12 |
-
from hyper3_clip.training.distributed import gather_variable_with_grad, gather_with_grad, get_rank
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
ProjectionHeadFactory = Callable[[int, int, int | None], nn.Module]
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class ExperimentalObjectiveMixin:
|
| 19 |
-
@staticmethod
|
| 20 |
-
def _validate_experimental_options(
|
| 21 |
-
*,
|
| 22 |
-
proclip_geometry: str,
|
| 23 |
-
proclip_projection_hidden_dim: int | None,
|
| 24 |
-
proclip_component_dim: int | None,
|
| 25 |
-
beta_clip_weight: float,
|
| 26 |
-
beta_clip_global_weight: float,
|
| 27 |
-
beta_clip_beta: float,
|
| 28 |
-
beta_clip_variant: str,
|
| 29 |
-
beta_clip_similarity: str,
|
| 30 |
-
beta_clip_num_heads: int,
|
| 31 |
-
beta_clip_mlp_ratio: float,
|
| 32 |
-
tren_weight: float,
|
| 33 |
-
tren_visual_distill_weight: float,
|
| 34 |
-
tren_text_distill_weight: float,
|
| 35 |
-
tren_region_text_weight: float,
|
| 36 |
-
tren_num_region_tokens: int,
|
| 37 |
-
tren_num_decoder_layers: int,
|
| 38 |
-
tren_num_attention_heads: int,
|
| 39 |
-
tren_prompt_grid_size: int,
|
| 40 |
-
tren_dropout: float,
|
| 41 |
-
) -> None:
|
| 42 |
-
if proclip_geometry not in {"product", "hyperbolic", "euclidean", "spherical", "clip"}:
|
| 43 |
-
raise ValueError("proclip_geometry must be 'product', 'hyperbolic', 'euclidean', 'spherical', or 'clip'")
|
| 44 |
-
if proclip_projection_hidden_dim is not None and proclip_projection_hidden_dim <= 0:
|
| 45 |
-
raise ValueError("proclip_projection_hidden_dim must be positive when set")
|
| 46 |
-
if proclip_component_dim is not None and proclip_component_dim <= 0:
|
| 47 |
-
raise ValueError("proclip_component_dim must be positive when set")
|
| 48 |
-
if beta_clip_variant not in {"ce", "bce"}:
|
| 49 |
-
raise ValueError("beta_clip_variant must be 'ce' or 'bce'")
|
| 50 |
-
if beta_clip_similarity not in {"metric", "dot"}:
|
| 51 |
-
raise ValueError("beta_clip_similarity must be 'metric' or 'dot'")
|
| 52 |
-
if beta_clip_weight < 0.0:
|
| 53 |
-
raise ValueError("beta_clip_weight must be non-negative")
|
| 54 |
-
if beta_clip_global_weight < 0.0:
|
| 55 |
-
raise ValueError("beta_clip_global_weight must be non-negative")
|
| 56 |
-
if beta_clip_beta < 0.0:
|
| 57 |
-
raise ValueError("beta_clip_beta must be non-negative")
|
| 58 |
-
if beta_clip_num_heads <= 0:
|
| 59 |
-
raise ValueError("beta_clip_num_heads must be positive")
|
| 60 |
-
if beta_clip_mlp_ratio <= 0.0:
|
| 61 |
-
raise ValueError("beta_clip_mlp_ratio must be positive")
|
| 62 |
-
if tren_weight < 0.0:
|
| 63 |
-
raise ValueError("tren_weight must be non-negative")
|
| 64 |
-
if tren_visual_distill_weight < 0.0 or tren_text_distill_weight < 0.0 or tren_region_text_weight < 0.0:
|
| 65 |
-
raise ValueError("T-REN loss weights must be non-negative")
|
| 66 |
-
if tren_num_region_tokens <= 0:
|
| 67 |
-
raise ValueError("tren_num_region_tokens must be positive")
|
| 68 |
-
if tren_num_decoder_layers <= 0:
|
| 69 |
-
raise ValueError("tren_num_decoder_layers must be positive")
|
| 70 |
-
if tren_num_attention_heads <= 0:
|
| 71 |
-
raise ValueError("tren_num_attention_heads must be positive")
|
| 72 |
-
if tren_prompt_grid_size <= 0:
|
| 73 |
-
raise ValueError("tren_prompt_grid_size must be positive")
|
| 74 |
-
if tren_dropout < 0.0:
|
| 75 |
-
raise ValueError("tren_dropout must be non-negative")
|
| 76 |
-
|
| 77 |
-
def _init_experimental_modules(
|
| 78 |
-
self,
|
| 79 |
-
*,
|
| 80 |
-
beta_clip_num_heads: int,
|
| 81 |
-
beta_clip_mlp_ratio: float,
|
| 82 |
-
tren_num_region_tokens: int,
|
| 83 |
-
tren_num_decoder_layers: int,
|
| 84 |
-
tren_num_attention_heads: int,
|
| 85 |
-
tren_prompt_grid_size: int,
|
| 86 |
-
tren_dropout: float,
|
| 87 |
-
projection_hidden_dim: int | None,
|
| 88 |
-
proclip_projection_hidden_dim: int | None,
|
| 89 |
-
projection_head: ProjectionHeadFactory,
|
| 90 |
-
) -> None:
|
| 91 |
-
if self.beta_query_pooling_enabled:
|
| 92 |
-
if self.vision_encoder.output_dim % beta_clip_num_heads != 0:
|
| 93 |
-
raise ValueError("vision encoder output_dim must be divisible by beta_clip_num_heads")
|
| 94 |
-
beta_clip_hidden_dim = max(1, int(round(self.vision_encoder.output_dim * beta_clip_mlp_ratio)))
|
| 95 |
-
self.beta_clip_text_query_proj = nn.Linear(self.text_encoder.output_dim, self.vision_encoder.output_dim)
|
| 96 |
-
self.beta_clip_cross_attention = nn.MultiheadAttention(
|
| 97 |
-
self.vision_encoder.output_dim,
|
| 98 |
-
beta_clip_num_heads,
|
| 99 |
-
batch_first=True,
|
| 100 |
-
)
|
| 101 |
-
self.beta_clip_mlp_norm = nn.LayerNorm(self.vision_encoder.output_dim)
|
| 102 |
-
self.beta_clip_pool_mlp = nn.Sequential(
|
| 103 |
-
nn.Linear(self.vision_encoder.output_dim, beta_clip_hidden_dim),
|
| 104 |
-
nn.GELU(),
|
| 105 |
-
nn.Linear(beta_clip_hidden_dim, self.vision_encoder.output_dim),
|
| 106 |
-
)
|
| 107 |
-
if self.beta_clip_enabled:
|
| 108 |
-
self.beta_clip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 109 |
-
if self.tren_enabled:
|
| 110 |
-
self.tren_region_encoder = TRENRegionEncoder(
|
| 111 |
-
vision_dim=self.vision_encoder.output_dim,
|
| 112 |
-
text_dim=self.text_encoder.output_dim,
|
| 113 |
-
num_region_tokens=tren_num_region_tokens,
|
| 114 |
-
num_decoder_layers=tren_num_decoder_layers,
|
| 115 |
-
num_attention_heads=tren_num_attention_heads,
|
| 116 |
-
prompt_grid_size=tren_prompt_grid_size,
|
| 117 |
-
dropout=tren_dropout,
|
| 118 |
-
)
|
| 119 |
-
self.tren_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 120 |
-
if self.proclip_enabled:
|
| 121 |
-
component_dim = self._proclip_component_dim
|
| 122 |
-
spherical_dim = self._proclip_spherical_ambient_dim
|
| 123 |
-
proclip_hidden_dim = proclip_projection_hidden_dim
|
| 124 |
-
if proclip_hidden_dim is None:
|
| 125 |
-
proclip_hidden_dim = projection_hidden_dim
|
| 126 |
-
if self.proclip_dedicated_hyperbolic:
|
| 127 |
-
self.proclip_image_hyperbolic_proj = projection_head(
|
| 128 |
-
self.vision_encoder.output_dim, self.embed_dim, proclip_hidden_dim
|
| 129 |
-
)
|
| 130 |
-
self.proclip_text_hyperbolic_proj = projection_head(
|
| 131 |
-
self.text_encoder.output_dim, self.embed_dim, proclip_hidden_dim
|
| 132 |
-
)
|
| 133 |
-
self.proclip_image_euclidean_proj = projection_head(
|
| 134 |
-
self.vision_encoder.output_dim, component_dim, proclip_hidden_dim
|
| 135 |
-
)
|
| 136 |
-
self.proclip_text_euclidean_proj = projection_head(
|
| 137 |
-
self.text_encoder.output_dim, component_dim, proclip_hidden_dim
|
| 138 |
-
)
|
| 139 |
-
self.proclip_image_spherical_proj = projection_head(
|
| 140 |
-
self.vision_encoder.output_dim, spherical_dim, proclip_hidden_dim
|
| 141 |
-
)
|
| 142 |
-
self.proclip_text_spherical_proj = projection_head(
|
| 143 |
-
self.text_encoder.output_dim, spherical_dim, proclip_hidden_dim
|
| 144 |
-
)
|
| 145 |
-
self.proclip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 146 |
-
self.proclip_log_weights = nn.Parameter(torch.zeros(3))
|
| 147 |
-
|
| 148 |
-
@property
|
| 149 |
-
def proclip_enabled(self) -> bool:
|
| 150 |
-
return (
|
| 151 |
-
self.objective_name == "proclip"
|
| 152 |
-
or self.proclip_component_dim is not None
|
| 153 |
-
or self.proclip_weight > 0.0
|
| 154 |
-
or self.proclip_retrieval
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
@property
|
| 158 |
-
def beta_clip_enabled(self) -> bool:
|
| 159 |
-
return self.beta_clip_weight > 0.0
|
| 160 |
-
|
| 161 |
-
@property
|
| 162 |
-
def beta_query_pooling_enabled(self) -> bool:
|
| 163 |
-
return self.beta_clip_enabled or (
|
| 164 |
-
self.objective_name == "uncha"
|
| 165 |
-
and self.uncha_entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
@property
|
| 169 |
-
def tren_enabled(self) -> bool:
|
| 170 |
-
return self.tren_weight > 0.0
|
| 171 |
-
|
| 172 |
-
@property
|
| 173 |
-
def _proclip_component_dim(self) -> int:
|
| 174 |
-
return int(self.proclip_component_dim or self.embed_dim)
|
| 175 |
-
|
| 176 |
-
@property
|
| 177 |
-
def _proclip_spherical_ambient_dim(self) -> int:
|
| 178 |
-
return self._proclip_component_dim + 1
|
| 179 |
-
|
| 180 |
-
def _clamp_experimental_logit_scales(self) -> None:
|
| 181 |
-
if self.proclip_enabled:
|
| 182 |
-
self.proclip_logit_scale.clamp_(max=4.6052)
|
| 183 |
-
if self.beta_clip_enabled:
|
| 184 |
-
self.beta_clip_logit_scale.clamp_(max=4.6052)
|
| 185 |
-
if self.tren_enabled:
|
| 186 |
-
self.tren_logit_scale.clamp_(max=4.6052)
|
| 187 |
-
|
| 188 |
-
def _detached_experimental_logit_scales(self) -> dict[str, torch.Tensor]:
|
| 189 |
-
logs = {}
|
| 190 |
-
if self.proclip_enabled:
|
| 191 |
-
logs.update(self._detached_proclip_logs())
|
| 192 |
-
if self.beta_clip_enabled:
|
| 193 |
-
logs["beta_clip_logit_scale"] = self.beta_clip_logit_scale.exp().detach()
|
| 194 |
-
if self.tren_enabled:
|
| 195 |
-
logs["tren_logit_scale"] = self.tren_logit_scale.exp().detach()
|
| 196 |
-
return logs
|
| 197 |
-
|
| 198 |
-
def _beta_clip_global_contrastive_loss(
|
| 199 |
-
self,
|
| 200 |
-
*,
|
| 201 |
-
image_euc: torch.Tensor,
|
| 202 |
-
text_euc: torch.Tensor,
|
| 203 |
-
targets: torch.Tensor,
|
| 204 |
-
) -> torch.Tensor:
|
| 205 |
-
image_feats = F.normalize(image_euc.float(), dim=-1)
|
| 206 |
-
text_feats = F.normalize(text_euc.float(), dim=-1)
|
| 207 |
-
all_image_feats = gather_with_grad(image_feats)
|
| 208 |
-
all_text_feats = gather_with_grad(text_feats)
|
| 209 |
-
if self.objective_name == "hycoclip":
|
| 210 |
-
scale = self.logit_scale.exp().clamp(max=100.0)
|
| 211 |
-
elif self.objective_name == "proclip":
|
| 212 |
-
scale = self.proclip_logit_scale.exp().clamp(max=100.0)
|
| 213 |
-
else:
|
| 214 |
-
scale = self.global_logit_scale.exp().clamp(max=100.0)
|
| 215 |
-
logits_i_t = image_feats @ all_text_feats.T * scale
|
| 216 |
-
logits_t_i = text_feats @ all_image_feats.T * scale
|
| 217 |
-
return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
|
| 218 |
-
|
| 219 |
-
def _beta_query_entailment_embeddings(
|
| 220 |
-
self,
|
| 221 |
-
*,
|
| 222 |
-
image_tokens: torch.Tensor,
|
| 223 |
-
beta_query_input_ids: torch.Tensor | None,
|
| 224 |
-
beta_query_attention_mask: torch.Tensor | None,
|
| 225 |
-
beta_query_owner: torch.Tensor | None,
|
| 226 |
-
beta_query_parent: torch.Tensor | None,
|
| 227 |
-
beta_query_weight: torch.Tensor | None,
|
| 228 |
-
beta_query_source_part: torch.Tensor | None,
|
| 229 |
-
kappa: torch.Tensor,
|
| 230 |
-
query_base: torch.Tensor | None = None,
|
| 231 |
-
) -> dict[str, torch.Tensor]:
|
| 232 |
-
if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
|
| 233 |
-
raise ValueError(f"{self.uncha_entailment_loss} requires beta query tensors from the collator")
|
| 234 |
-
if beta_query_parent is None or beta_query_weight is None:
|
| 235 |
-
raise ValueError(f"{self.uncha_entailment_loss} requires beta query hierarchy metadata from the collator")
|
| 236 |
-
if self.uncha_entailment_loss == "hier_beta_sourcepart_argent" and beta_query_source_part is None:
|
| 237 |
-
raise ValueError("hier_beta_sourcepart_argent requires beta_query_source_part from the collator")
|
| 238 |
-
if beta_query_input_ids.shape[0] == 0:
|
| 239 |
-
source_part = (
|
| 240 |
-
beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)
|
| 241 |
-
if beta_query_source_part is not None
|
| 242 |
-
else beta_query_owner.new_zeros((0,), device=image_tokens.device, dtype=torch.long)
|
| 243 |
-
)
|
| 244 |
-
return {
|
| 245 |
-
"beta_query_image_feats": image_tokens.new_zeros((0, self.embed_dim)),
|
| 246 |
-
"beta_query_text_feats": image_tokens.new_zeros((0, self.embed_dim)),
|
| 247 |
-
"beta_query_owner": beta_query_owner.to(device=image_tokens.device, dtype=torch.long),
|
| 248 |
-
"beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
|
| 249 |
-
"beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
|
| 250 |
-
"beta_query_source_part": source_part,
|
| 251 |
-
}
|
| 252 |
-
|
| 253 |
-
query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
|
| 254 |
-
if query_base is None:
|
| 255 |
-
query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
|
| 256 |
-
conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, query_owner)
|
| 257 |
-
query_image_euc = self.image_proj(conditioned_image_base)
|
| 258 |
-
query_text_euc = self.text_proj(query_base)
|
| 259 |
-
return {
|
| 260 |
-
"beta_query_image_feats": self.project_image_features(query_image_euc),
|
| 261 |
-
"beta_query_text_feats": self.project_text_features(query_text_euc),
|
| 262 |
-
"beta_query_owner": query_owner,
|
| 263 |
-
"beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
|
| 264 |
-
"beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
|
| 265 |
-
**(
|
| 266 |
-
{"beta_query_source_part": beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)}
|
| 267 |
-
if beta_query_source_part is not None
|
| 268 |
-
else {}
|
| 269 |
-
),
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
def _beta_clip_auxiliary_loss(
|
| 273 |
-
self,
|
| 274 |
-
*,
|
| 275 |
-
image_tokens: torch.Tensor,
|
| 276 |
-
beta_query_input_ids: torch.Tensor | None,
|
| 277 |
-
beta_query_attention_mask: torch.Tensor | None,
|
| 278 |
-
beta_query_owner: torch.Tensor | None,
|
| 279 |
-
global_targets: torch.Tensor,
|
| 280 |
-
kappa: torch.Tensor,
|
| 281 |
-
) -> torch.Tensor:
|
| 282 |
-
if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
|
| 283 |
-
raise ValueError("beta-CLIP auxiliary requires beta query tensors from the collator")
|
| 284 |
-
if beta_query_input_ids.shape[0] == 0:
|
| 285 |
-
return image_tokens.new_zeros(())
|
| 286 |
-
|
| 287 |
-
beta_query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
|
| 288 |
-
query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
|
| 289 |
-
conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, beta_query_owner)
|
| 290 |
-
query_image_euc = self.image_proj(conditioned_image_base)
|
| 291 |
-
query_text_euc = self.text_proj(query_base)
|
| 292 |
-
|
| 293 |
-
if self.beta_clip_similarity == "dot":
|
| 294 |
-
query_image_feats = F.normalize(query_image_euc.float(), dim=-1)
|
| 295 |
-
query_text_feats = F.normalize(query_text_euc.float(), dim=-1)
|
| 296 |
-
else:
|
| 297 |
-
query_image_feats = self.project_image_features(query_image_euc)
|
| 298 |
-
query_text_feats = self.project_text_features(query_text_euc)
|
| 299 |
-
|
| 300 |
-
all_query_image_feats, query_counts = gather_variable_with_grad(query_image_feats)
|
| 301 |
-
all_query_text_feats, _ = gather_variable_with_grad(query_text_feats)
|
| 302 |
-
query_offset = query_counts[: get_rank()].sum() if query_counts.numel() > 1 else query_counts.new_zeros(())
|
| 303 |
-
query_targets = torch.arange(query_image_feats.size(0), device=query_image_feats.device) + query_offset
|
| 304 |
-
query_group_ids = global_targets.index_select(0, beta_query_owner)
|
| 305 |
-
all_query_group_ids, _ = gather_variable_with_grad(query_group_ids)
|
| 306 |
-
|
| 307 |
-
scale = self.beta_clip_logit_scale.exp().clamp(max=100.0)
|
| 308 |
-
if self.beta_clip_similarity == "dot":
|
| 309 |
-
logits_i_t = query_image_feats @ all_query_text_feats.T * scale
|
| 310 |
-
logits_t_i = query_text_feats @ all_query_image_feats.T * scale
|
| 311 |
-
else:
|
| 312 |
-
logits_i_t = -metric_pairwise_dist(
|
| 313 |
-
query_image_feats,
|
| 314 |
-
all_query_text_feats,
|
| 315 |
-
kappa,
|
| 316 |
-
product_metric=self.phyclip_product_metric,
|
| 317 |
-
) * scale
|
| 318 |
-
logits_t_i = -metric_pairwise_dist(
|
| 319 |
-
query_text_feats,
|
| 320 |
-
all_query_image_feats,
|
| 321 |
-
kappa,
|
| 322 |
-
product_metric=self.phyclip_product_metric,
|
| 323 |
-
) * scale
|
| 324 |
-
return 0.5 * (
|
| 325 |
-
beta_cal_loss(
|
| 326 |
-
logits_i_t,
|
| 327 |
-
targets=query_targets,
|
| 328 |
-
group_ids=query_group_ids,
|
| 329 |
-
all_group_ids=all_query_group_ids,
|
| 330 |
-
beta=self.beta_clip_beta,
|
| 331 |
-
variant=self.beta_clip_variant,
|
| 332 |
-
)
|
| 333 |
-
+ beta_cal_loss(
|
| 334 |
-
logits_t_i,
|
| 335 |
-
targets=query_targets,
|
| 336 |
-
group_ids=query_group_ids,
|
| 337 |
-
all_group_ids=all_query_group_ids,
|
| 338 |
-
beta=self.beta_clip_beta,
|
| 339 |
-
variant=self.beta_clip_variant,
|
| 340 |
-
)
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
def _beta_clip_text_conditioned_pool(
|
| 344 |
-
self,
|
| 345 |
-
image_tokens: torch.Tensor,
|
| 346 |
-
query_base: torch.Tensor,
|
| 347 |
-
query_owner: torch.Tensor,
|
| 348 |
-
) -> torch.Tensor:
|
| 349 |
-
if image_tokens.ndim != 3:
|
| 350 |
-
raise ValueError("beta-CLIP image tokens must have shape [batch, tokens, dim]")
|
| 351 |
-
if getattr(self, "group_beta_query_pooling", False):
|
| 352 |
-
return self._beta_clip_text_conditioned_pool_grouped(image_tokens, query_base, query_owner)
|
| 353 |
-
if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1:
|
| 354 |
-
image_tokens = image_tokens[:, 1:, :]
|
| 355 |
-
selected_tokens = image_tokens.index_select(0, query_owner).to(dtype=query_base.dtype)
|
| 356 |
-
query = self.beta_clip_text_query_proj(query_base).unsqueeze(1)
|
| 357 |
-
attended, _ = self.beta_clip_cross_attention(query, selected_tokens, selected_tokens, need_weights=False)
|
| 358 |
-
pooled = attended.squeeze(1)
|
| 359 |
-
return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
|
| 360 |
-
|
| 361 |
-
def _beta_clip_text_conditioned_pool_grouped(
|
| 362 |
-
self,
|
| 363 |
-
image_tokens: torch.Tensor,
|
| 364 |
-
query_base: torch.Tensor,
|
| 365 |
-
query_owner: torch.Tensor,
|
| 366 |
-
) -> torch.Tensor:
|
| 367 |
-
if query_owner.numel() == 0:
|
| 368 |
-
return query_base.new_zeros((0, self.vision_encoder.output_dim))
|
| 369 |
-
if query_owner.min().item() < 0 or query_owner.max().item() >= image_tokens.size(0):
|
| 370 |
-
raise IndexError("beta_query_owner contains an out-of-range image index")
|
| 371 |
-
|
| 372 |
-
tokens = image_tokens[:, 1:, :] if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1 else image_tokens
|
| 373 |
-
tokens = tokens.to(dtype=query_base.dtype)
|
| 374 |
-
query_projected = self.beta_clip_text_query_proj(query_base)
|
| 375 |
-
counts = torch.bincount(query_owner, minlength=image_tokens.size(0))
|
| 376 |
-
max_queries = int(counts.max().item())
|
| 377 |
-
|
| 378 |
-
order = torch.argsort(query_owner)
|
| 379 |
-
sorted_owner = query_owner.index_select(0, order)
|
| 380 |
-
owner_offsets = torch.zeros_like(counts)
|
| 381 |
-
owner_offsets[1:] = counts.cumsum(0)[:-1]
|
| 382 |
-
sorted_positions = torch.arange(query_owner.numel(), device=query_owner.device) - owner_offsets.index_select(
|
| 383 |
-
0, sorted_owner
|
| 384 |
-
)
|
| 385 |
-
positions = torch.empty_like(sorted_positions)
|
| 386 |
-
positions[order] = sorted_positions
|
| 387 |
-
|
| 388 |
-
packed_query = query_projected.new_zeros((image_tokens.size(0), max_queries, query_projected.size(-1)))
|
| 389 |
-
packed_query[query_owner, positions] = query_projected
|
| 390 |
-
attended, _ = self.beta_clip_cross_attention(packed_query, tokens, tokens, need_weights=False)
|
| 391 |
-
pooled = attended[query_owner, positions]
|
| 392 |
-
return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
|
| 393 |
-
|
| 394 |
-
def _tren_auxiliary_losses(
|
| 395 |
-
self,
|
| 396 |
-
*,
|
| 397 |
-
image_tokens: torch.Tensor,
|
| 398 |
-
part_owner: torch.Tensor,
|
| 399 |
-
part_image_base: torch.Tensor,
|
| 400 |
-
part_text_base: torch.Tensor,
|
| 401 |
-
) -> dict[str, torch.Tensor]:
|
| 402 |
-
zero = image_tokens.new_zeros(())
|
| 403 |
-
if part_owner.numel() == 0:
|
| 404 |
-
return {
|
| 405 |
-
"tren_loss": zero,
|
| 406 |
-
"tren_visual_distill_loss": zero,
|
| 407 |
-
"tren_text_distill_loss": zero,
|
| 408 |
-
"tren_region_text_contrastive_loss": zero,
|
| 409 |
-
"tren_assignment_count": part_owner.new_tensor(0),
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
tren_outputs = self.tren_region_encoder(image_tokens)
|
| 413 |
-
visual_tokens = tren_outputs["visual_tokens"].flatten(1, 2)
|
| 414 |
-
text_tokens = tren_outputs["text_aligned_tokens"].flatten(1, 2)
|
| 415 |
-
|
| 416 |
-
matched_visual: list[torch.Tensor] = []
|
| 417 |
-
matched_text: list[torch.Tensor] = []
|
| 418 |
-
target_visual: list[torch.Tensor] = []
|
| 419 |
-
target_text: list[torch.Tensor] = []
|
| 420 |
-
for owner in range(image_tokens.size(0)):
|
| 421 |
-
region_mask = part_owner == owner
|
| 422 |
-
if not bool(region_mask.any()):
|
| 423 |
-
continue
|
| 424 |
-
owner_target_visual = part_image_base[region_mask].detach()
|
| 425 |
-
owner_target_text = part_text_base[region_mask].detach()
|
| 426 |
-
owner_visual_tokens = visual_tokens[owner]
|
| 427 |
-
owner_text_tokens = text_tokens[owner]
|
| 428 |
-
pred_indices, target_indices = _greedy_region_assignment(owner_visual_tokens, owner_target_visual)
|
| 429 |
-
if pred_indices.numel() == 0:
|
| 430 |
-
continue
|
| 431 |
-
matched_visual.append(owner_visual_tokens.index_select(0, pred_indices))
|
| 432 |
-
matched_text.append(owner_text_tokens.index_select(0, pred_indices))
|
| 433 |
-
target_visual.append(owner_target_visual.index_select(0, target_indices))
|
| 434 |
-
target_text.append(owner_target_text.index_select(0, target_indices))
|
| 435 |
-
|
| 436 |
-
if not matched_visual:
|
| 437 |
-
return {
|
| 438 |
-
"tren_loss": zero,
|
| 439 |
-
"tren_visual_distill_loss": zero,
|
| 440 |
-
"tren_text_distill_loss": zero,
|
| 441 |
-
"tren_region_text_contrastive_loss": zero,
|
| 442 |
-
"tren_assignment_count": part_owner.new_tensor(0),
|
| 443 |
-
}
|
| 444 |
-
|
| 445 |
-
matched_visual_tensor = torch.cat(matched_visual, dim=0)
|
| 446 |
-
matched_text_tensor = torch.cat(matched_text, dim=0)
|
| 447 |
-
target_visual_tensor = torch.cat(target_visual, dim=0)
|
| 448 |
-
target_text_tensor = torch.cat(target_text, dim=0)
|
| 449 |
-
visual_distill = 1.0 - F.cosine_similarity(matched_visual_tensor, target_visual_tensor, dim=-1).mean()
|
| 450 |
-
text_distill = 1.0 - F.cosine_similarity(matched_text_tensor, target_text_tensor, dim=-1).mean()
|
| 451 |
-
region_text = _symmetric_dot_contrastive(
|
| 452 |
-
matched_text_tensor,
|
| 453 |
-
target_text_tensor,
|
| 454 |
-
scale=self.tren_logit_scale.exp().clamp(max=100.0),
|
| 455 |
-
)
|
| 456 |
-
total = (
|
| 457 |
-
self.tren_visual_distill_weight * visual_distill
|
| 458 |
-
+ self.tren_text_distill_weight * text_distill
|
| 459 |
-
+ self.tren_region_text_weight * region_text
|
| 460 |
-
)
|
| 461 |
-
return {
|
| 462 |
-
"tren_loss": total,
|
| 463 |
-
"tren_visual_distill_loss": visual_distill,
|
| 464 |
-
"tren_text_distill_loss": text_distill,
|
| 465 |
-
"tren_region_text_contrastive_loss": region_text,
|
| 466 |
-
"tren_assignment_count": part_owner.new_tensor(matched_visual_tensor.size(0)),
|
| 467 |
-
}
|
| 468 |
-
|
| 469 |
-
def _project_proclip_image_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
|
| 470 |
-
if self.proclip_geometry == "clip":
|
| 471 |
-
return F.normalize(base_feats.float(), dim=-1)
|
| 472 |
-
if self.proclip_dedicated_hyperbolic:
|
| 473 |
-
hyperbolic = exp_map0(self.proclip_image_hyperbolic_proj(base_feats.float()), self._kappa().float())
|
| 474 |
-
return self._pack_proclip_features(
|
| 475 |
-
hyperbolic=hyperbolic,
|
| 476 |
-
euclidean=self.proclip_image_euclidean_proj(base_feats.float()),
|
| 477 |
-
spherical=self.proclip_image_spherical_proj(base_feats.float()),
|
| 478 |
-
)
|
| 479 |
-
|
| 480 |
-
def _project_proclip_text_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
|
| 481 |
-
if self.proclip_geometry == "clip":
|
| 482 |
-
return F.normalize(base_feats.float(), dim=-1)
|
| 483 |
-
if self.proclip_dedicated_hyperbolic:
|
| 484 |
-
hyperbolic = exp_map0(self.proclip_text_hyperbolic_proj(base_feats.float()), self._kappa().float())
|
| 485 |
-
return self._pack_proclip_features(
|
| 486 |
-
hyperbolic=hyperbolic,
|
| 487 |
-
euclidean=self.proclip_text_euclidean_proj(base_feats.float()),
|
| 488 |
-
spherical=self.proclip_text_spherical_proj(base_feats.float()),
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
def _pack_proclip_features(self, hyperbolic: torch.Tensor, euclidean: torch.Tensor, spherical: torch.Tensor) -> torch.Tensor:
|
| 492 |
-
spherical = F.normalize(spherical.float(), dim=-1)
|
| 493 |
-
if self.proclip_geometry == "hyperbolic":
|
| 494 |
-
return hyperbolic.float()
|
| 495 |
-
if self.proclip_geometry == "euclidean":
|
| 496 |
-
return euclidean.float()
|
| 497 |
-
if self.proclip_geometry == "spherical":
|
| 498 |
-
return spherical
|
| 499 |
-
return torch.cat([hyperbolic.float(), euclidean.float(), spherical], dim=-1)
|
| 500 |
-
|
| 501 |
-
def _split_proclip_features(self, feats: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 502 |
-
hyperbolic_dim = self.embed_dim + 1
|
| 503 |
-
component_dim = self._proclip_component_dim
|
| 504 |
-
spherical_dim = self._proclip_spherical_ambient_dim
|
| 505 |
-
hyperbolic = feats[:, :hyperbolic_dim]
|
| 506 |
-
euclidean = feats[:, hyperbolic_dim : hyperbolic_dim + component_dim]
|
| 507 |
-
spherical = feats[:, hyperbolic_dim + component_dim : hyperbolic_dim + component_dim + spherical_dim]
|
| 508 |
-
return hyperbolic, euclidean, spherical
|
| 509 |
-
|
| 510 |
-
def _proclip_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
|
| 511 |
-
if self.proclip_geometry == "clip":
|
| 512 |
-
return image_feats.float() @ text_feats.float().T
|
| 513 |
-
if self.proclip_geometry == "hyperbolic":
|
| 514 |
-
return -metric_pairwise_dist(image_feats, text_feats, self._kappa()).square()
|
| 515 |
-
if self.proclip_geometry == "euclidean":
|
| 516 |
-
return -torch.cdist(image_feats.float(), text_feats.float(), p=2).square()
|
| 517 |
-
if self.proclip_geometry == "spherical":
|
| 518 |
-
dot = (image_feats.float() @ text_feats.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
|
| 519 |
-
return -torch.acos(dot).square()
|
| 520 |
-
image_hyp, image_euc, image_sph = self._split_proclip_features(image_feats)
|
| 521 |
-
text_hyp, text_euc, text_sph = self._split_proclip_features(text_feats)
|
| 522 |
-
weights = self.proclip_log_weights.exp().to(device=image_feats.device, dtype=torch.float32)
|
| 523 |
-
hyperbolic_dist2 = metric_pairwise_dist(image_hyp, text_hyp, self._kappa()).square()
|
| 524 |
-
euclidean_dist2 = torch.cdist(image_euc.float(), text_euc.float(), p=2).square()
|
| 525 |
-
spherical_dot = (image_sph.float() @ text_sph.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
|
| 526 |
-
spherical_dist2 = torch.acos(spherical_dot).square()
|
| 527 |
-
return -(weights[0] * hyperbolic_dist2 + weights[1] * euclidean_dist2 + weights[2] * spherical_dist2)
|
| 528 |
-
|
| 529 |
-
def _proclip_contrastive_loss(
|
| 530 |
-
self,
|
| 531 |
-
image_feats: torch.Tensor,
|
| 532 |
-
text_feats: torch.Tensor,
|
| 533 |
-
all_image_feats: torch.Tensor,
|
| 534 |
-
all_text_feats: torch.Tensor,
|
| 535 |
-
targets: torch.Tensor,
|
| 536 |
-
) -> torch.Tensor:
|
| 537 |
-
scale = self.proclip_logit_scale.exp().clamp(max=100.0)
|
| 538 |
-
logits_i_t = self._proclip_similarity_scores(image_feats, all_text_feats) * scale
|
| 539 |
-
logits_t_i = self._proclip_similarity_scores(text_feats, all_image_feats) * scale
|
| 540 |
-
return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
|
| 541 |
-
|
| 542 |
-
def _detached_proclip_logs(self) -> dict[str, torch.Tensor]:
|
| 543 |
-
weights = self.proclip_log_weights.exp().detach()
|
| 544 |
-
return {
|
| 545 |
-
"proclip_logit_scale": self.proclip_logit_scale.exp().detach(),
|
| 546 |
-
"proclip_hyperbolic_weight": weights[0],
|
| 547 |
-
"proclip_euclidean_weight": weights[1],
|
| 548 |
-
"proclip_spherical_weight": weights[2],
|
| 549 |
-
}
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
def _greedy_region_assignment(pred_tokens: torch.Tensor, target_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 553 |
-
if pred_tokens.numel() == 0 or target_tokens.numel() == 0:
|
| 554 |
-
empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
|
| 555 |
-
return empty, empty
|
| 556 |
-
similarities = F.normalize(pred_tokens.float(), dim=-1) @ F.normalize(target_tokens.float(), dim=-1).T
|
| 557 |
-
pair_scores = similarities.flatten()
|
| 558 |
-
order = torch.argsort(pair_scores, descending=True)
|
| 559 |
-
used_pred = torch.zeros(pred_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
|
| 560 |
-
used_target = torch.zeros(target_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
|
| 561 |
-
pred_indices: list[torch.Tensor] = []
|
| 562 |
-
target_indices: list[torch.Tensor] = []
|
| 563 |
-
for flat_index in order:
|
| 564 |
-
pred_index = torch.div(flat_index, target_tokens.size(0), rounding_mode="floor")
|
| 565 |
-
target_index = flat_index % target_tokens.size(0)
|
| 566 |
-
if used_pred[pred_index] or used_target[target_index]:
|
| 567 |
-
continue
|
| 568 |
-
used_pred[pred_index] = True
|
| 569 |
-
used_target[target_index] = True
|
| 570 |
-
pred_indices.append(pred_index)
|
| 571 |
-
target_indices.append(target_index)
|
| 572 |
-
if len(target_indices) == target_tokens.size(0):
|
| 573 |
-
break
|
| 574 |
-
if not pred_indices:
|
| 575 |
-
empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
|
| 576 |
-
return empty, empty
|
| 577 |
-
return torch.stack(pred_indices), torch.stack(target_indices)
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
def _symmetric_dot_contrastive(region_tokens: torch.Tensor, text_tokens: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 581 |
-
if region_tokens.size(0) == 1:
|
| 582 |
-
return region_tokens.new_zeros(())
|
| 583 |
-
region_tokens = F.normalize(region_tokens.float(), dim=-1)
|
| 584 |
-
text_tokens = F.normalize(text_tokens.float(), dim=-1)
|
| 585 |
-
logits = region_tokens @ text_tokens.T * scale
|
| 586 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 587 |
-
return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/himo.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from torch import Tensor
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def hide_reconstruct_embeddings(
|
| 8 |
-
embeddings: Tensor,
|
| 9 |
-
*,
|
| 10 |
-
variance_threshold: float = 0.9,
|
| 11 |
-
detach_pca: bool = True,
|
| 12 |
-
eps: float = 1e-8,
|
| 13 |
-
) -> Tensor:
|
| 14 |
-
"""HiMo-CLIP HiDe: PCA-reconstruct embeddings using top principal components.
|
| 15 |
-
|
| 16 |
-
Given a batch of embeddings ``U ∈ R^{B×D}``, compute mean-centered embeddings,
|
| 17 |
-
perform SVD/PCA, choose the smallest number of components whose cumulative
|
| 18 |
-
explained variance exceeds ``variance_threshold``, and reconstruct each
|
| 19 |
-
embedding from this principal subspace:
|
| 20 |
-
|
| 21 |
-
u'_i = P^T (P (u_i - ū)) + ū
|
| 22 |
-
|
| 23 |
-
where P stacks the selected principal components as rows.
|
| 24 |
-
"""
|
| 25 |
-
if embeddings.ndim != 2:
|
| 26 |
-
raise ValueError("hide_reconstruct_embeddings expects a [batch, dim] tensor")
|
| 27 |
-
if not (0.0 < variance_threshold <= 1.0):
|
| 28 |
-
raise ValueError("variance_threshold must be in (0, 1]")
|
| 29 |
-
if embeddings.size(0) < 2:
|
| 30 |
-
return embeddings
|
| 31 |
-
|
| 32 |
-
u = embeddings.to(dtype=torch.float32)
|
| 33 |
-
mean = u.mean(dim=0, keepdim=True)
|
| 34 |
-
centered = u - mean
|
| 35 |
-
if detach_pca:
|
| 36 |
-
centered_for_pca = centered.detach()
|
| 37 |
-
else:
|
| 38 |
-
centered_for_pca = centered
|
| 39 |
-
|
| 40 |
-
# SVD: centered = U S Vh, principal components are rows of Vh.
|
| 41 |
-
_, s, vh = torch.linalg.svd(centered_for_pca, full_matrices=False)
|
| 42 |
-
if s.numel() == 0 or float((s.square().sum()).item()) <= eps:
|
| 43 |
-
return embeddings
|
| 44 |
-
|
| 45 |
-
explained = s.square()
|
| 46 |
-
cumulative = explained.cumsum(dim=0) / explained.sum().clamp_min(eps)
|
| 47 |
-
m = int((cumulative >= variance_threshold).to(dtype=torch.int64).argmax().item()) + 1
|
| 48 |
-
m = max(1, min(m, vh.size(0)))
|
| 49 |
-
p = vh[:m]
|
| 50 |
-
if detach_pca:
|
| 51 |
-
p = p.detach()
|
| 52 |
-
|
| 53 |
-
recon = (centered @ p.T) @ p + mean
|
| 54 |
-
return recon.to(dtype=embeddings.dtype)
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/hyper3_clip.py
DELETED
|
@@ -1,958 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn.functional as F
|
| 5 |
-
from torch import nn
|
| 6 |
-
|
| 7 |
-
from hyper3_clip.models.encoders import TextEncoder, VisionEncoder
|
| 8 |
-
from hyper3_clip.models.experimental import ExperimentalObjectiveMixin
|
| 9 |
-
from hyper3_clip.models.himo import hide_reconstruct_embeddings
|
| 10 |
-
from hyper3_clip.models.lorentz import exp_map0, metric_similarity
|
| 11 |
-
from hyper3_clip.models.objectives import build_objective
|
| 12 |
-
from hyper3_clip.training.distributed import (
|
| 13 |
-
gather_with_grad,
|
| 14 |
-
get_rank,
|
| 15 |
-
get_world_size,
|
| 16 |
-
local_target_indices,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class Hyper3CLIP(ExperimentalObjectiveMixin, nn.Module):
|
| 21 |
-
def __init__(
|
| 22 |
-
self,
|
| 23 |
-
vision_backbone: str,
|
| 24 |
-
text_model_name: str,
|
| 25 |
-
embed_dim: int,
|
| 26 |
-
curv_init: float,
|
| 27 |
-
learn_curv: bool,
|
| 28 |
-
entail_weight: float,
|
| 29 |
-
inter_aperture_scale: float,
|
| 30 |
-
intra_aperture_scale: float,
|
| 31 |
-
objective: str = "hycoclip",
|
| 32 |
-
uncha_piecewise_factor: float = 0.1,
|
| 33 |
-
uncha_calibration_alpha: float = 10.0,
|
| 34 |
-
uncha_stop_grad_calibration: bool = True,
|
| 35 |
-
vision_pretrained: bool = True,
|
| 36 |
-
text_pretrained: bool = True,
|
| 37 |
-
text_pooling: str = "auto",
|
| 38 |
-
freeze_vision_encoder: bool = False,
|
| 39 |
-
freeze_text_encoder: bool = False,
|
| 40 |
-
normalize_encoder_features: bool = False,
|
| 41 |
-
projection_hidden_dim: int | None = None,
|
| 42 |
-
uncha_entailment_geometry: str = "lorentz",
|
| 43 |
-
uncha_aggregate_weight: float = 0.0,
|
| 44 |
-
uncha_entailment_loss: str = "piecewise",
|
| 45 |
-
uncha_argent_beta: float = 1.0,
|
| 46 |
-
uncha_argent_norm_weight: float = 0.0,
|
| 47 |
-
uncha_argent_aux_weight: float = 0.5,
|
| 48 |
-
uncha_argent_aggregation: str = "uncha",
|
| 49 |
-
uncha_part_weight_power: float = 0.0,
|
| 50 |
-
uncha_contrastive_loss: str = "ce",
|
| 51 |
-
uncha_sigmoid_bias_init: float = -10.0,
|
| 52 |
-
uncha_sigmoid_negative_weight: float = 1.0,
|
| 53 |
-
uncha_part_quality_mode: str = "none",
|
| 54 |
-
uncha_part_quality_topk: int = 5,
|
| 55 |
-
uncha_part_quality_temperature: float = 4.0,
|
| 56 |
-
uncha_entailment_warmup_steps: int = 0,
|
| 57 |
-
uncha_contrastive_global_weight: float = 1.0,
|
| 58 |
-
uncha_contrastive_local_weight: float = 1.0,
|
| 59 |
-
uncha_contrastive_global_local_weight: float = 1.0,
|
| 60 |
-
uncha_global_local_mode: str = "repeat",
|
| 61 |
-
uncha_global_local_metric: str = "distance",
|
| 62 |
-
uncha_global_local_angle_aux_weight: float = 0.0,
|
| 63 |
-
uncha_global_local_angle_aux_mode: str = "contrastive",
|
| 64 |
-
uncha_global_local_angle_aux_scale: float = 5.5,
|
| 65 |
-
uncha_global_local_angle_aux_aperture_scale: float = 1.0,
|
| 66 |
-
uncha_beta_cal_beta: float = 0.0,
|
| 67 |
-
uncha_beta_cal_variant: str = "ce",
|
| 68 |
-
uncha_beta_cal_weight: float = 0.0,
|
| 69 |
-
uncha_himo_component_weight: float = 0.0,
|
| 70 |
-
uncha_himo_variance_threshold: float = 0.9,
|
| 71 |
-
uncha_himo_detach_pca: bool = True,
|
| 72 |
-
uncha_radius_order_weight: float = 0.0,
|
| 73 |
-
uncha_radius_order_margin: float = 0.0,
|
| 74 |
-
uncha_gramian_align_weight: float = 0.0,
|
| 75 |
-
phyclip_subspace_dim: int | None = None,
|
| 76 |
-
phyclip_product_metric: str = "l1",
|
| 77 |
-
proclip_weight: float = 0.0,
|
| 78 |
-
proclip_component_dim: int | None = None,
|
| 79 |
-
proclip_retrieval: bool = False,
|
| 80 |
-
proclip_geometry: str = "product",
|
| 81 |
-
proclip_dedicated_hyperbolic: bool = False,
|
| 82 |
-
proclip_projection_hidden_dim: int | None = None,
|
| 83 |
-
beta_clip_weight: float = 0.0,
|
| 84 |
-
beta_clip_global_weight: float = 0.0,
|
| 85 |
-
beta_clip_beta: float = 0.5,
|
| 86 |
-
beta_clip_variant: str = "ce",
|
| 87 |
-
beta_clip_similarity: str = "metric",
|
| 88 |
-
beta_clip_num_heads: int = 8,
|
| 89 |
-
beta_clip_mlp_ratio: float = 4.0,
|
| 90 |
-
beta_clip_drop_cls_token: bool = True,
|
| 91 |
-
tren_weight: float = 0.0,
|
| 92 |
-
tren_visual_distill_weight: float = 1.0,
|
| 93 |
-
tren_text_distill_weight: float = 1.0,
|
| 94 |
-
tren_region_text_weight: float = 1.0,
|
| 95 |
-
tren_num_region_tokens: int = 3,
|
| 96 |
-
tren_num_decoder_layers: int = 2,
|
| 97 |
-
tren_num_attention_heads: int = 8,
|
| 98 |
-
tren_prompt_grid_size: int = 7,
|
| 99 |
-
tren_dropout: float = 0.1,
|
| 100 |
-
fuse_whole_part_encoder_forwards: bool = False,
|
| 101 |
-
fuse_beta_query_encoder_forwards: bool = False,
|
| 102 |
-
group_beta_query_pooling: bool = False,
|
| 103 |
-
objective_autocast_dtype: str = "float32",
|
| 104 |
-
) -> None:
|
| 105 |
-
super().__init__()
|
| 106 |
-
if objective not in {"hycoclip", "uncha", "proclip"}:
|
| 107 |
-
raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip', 'uncha', or 'proclip'")
|
| 108 |
-
if phyclip_product_metric not in {"l1", "l2"}:
|
| 109 |
-
raise ValueError("phyclip_product_metric must be 'l1' or 'l2'")
|
| 110 |
-
self._validate_experimental_options(
|
| 111 |
-
proclip_geometry=proclip_geometry,
|
| 112 |
-
proclip_projection_hidden_dim=proclip_projection_hidden_dim,
|
| 113 |
-
proclip_component_dim=proclip_component_dim,
|
| 114 |
-
beta_clip_weight=beta_clip_weight,
|
| 115 |
-
beta_clip_global_weight=beta_clip_global_weight,
|
| 116 |
-
beta_clip_beta=beta_clip_beta,
|
| 117 |
-
beta_clip_variant=beta_clip_variant,
|
| 118 |
-
beta_clip_similarity=beta_clip_similarity,
|
| 119 |
-
beta_clip_num_heads=beta_clip_num_heads,
|
| 120 |
-
beta_clip_mlp_ratio=beta_clip_mlp_ratio,
|
| 121 |
-
tren_weight=tren_weight,
|
| 122 |
-
tren_visual_distill_weight=tren_visual_distill_weight,
|
| 123 |
-
tren_text_distill_weight=tren_text_distill_weight,
|
| 124 |
-
tren_region_text_weight=tren_region_text_weight,
|
| 125 |
-
tren_num_region_tokens=tren_num_region_tokens,
|
| 126 |
-
tren_num_decoder_layers=tren_num_decoder_layers,
|
| 127 |
-
tren_num_attention_heads=tren_num_attention_heads,
|
| 128 |
-
tren_prompt_grid_size=tren_prompt_grid_size,
|
| 129 |
-
tren_dropout=tren_dropout,
|
| 130 |
-
)
|
| 131 |
-
if objective_autocast_dtype not in {"float32", "fp32", "float16", "fp16", "bfloat16", "bf16"}:
|
| 132 |
-
raise ValueError("objective_autocast_dtype must be one of 'float32', 'float16', or 'bfloat16'")
|
| 133 |
-
if uncha_contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
|
| 134 |
-
raise ValueError("uncha_contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
|
| 135 |
-
if uncha_global_local_metric not in {"distance", "angle"}:
|
| 136 |
-
raise ValueError("uncha_global_local_metric must be 'distance' or 'angle'")
|
| 137 |
-
if uncha_global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
|
| 138 |
-
raise ValueError("uncha_global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
|
| 139 |
-
if uncha_global_local_angle_aux_weight < 0.0:
|
| 140 |
-
raise ValueError("uncha_global_local_angle_aux_weight must be non-negative")
|
| 141 |
-
if uncha_global_local_angle_aux_scale <= 0.0:
|
| 142 |
-
raise ValueError("uncha_global_local_angle_aux_scale must be positive")
|
| 143 |
-
if uncha_global_local_angle_aux_aperture_scale <= 0.0:
|
| 144 |
-
raise ValueError("uncha_global_local_angle_aux_aperture_scale must be positive")
|
| 145 |
-
if uncha_entailment_warmup_steps < 0:
|
| 146 |
-
raise ValueError("uncha_entailment_warmup_steps must be non-negative")
|
| 147 |
-
self.objective_name = objective
|
| 148 |
-
self.uncha_contrastive_loss = uncha_contrastive_loss
|
| 149 |
-
self.uncha_entailment_loss = uncha_entailment_loss
|
| 150 |
-
self.uncha_entailment_warmup_steps = uncha_entailment_warmup_steps
|
| 151 |
-
self.uncha_himo_component_weight = float(uncha_himo_component_weight)
|
| 152 |
-
self.uncha_himo_variance_threshold = float(uncha_himo_variance_threshold)
|
| 153 |
-
self.uncha_himo_detach_pca = bool(uncha_himo_detach_pca)
|
| 154 |
-
self.proclip_weight = float(proclip_weight)
|
| 155 |
-
self.proclip_retrieval = bool(proclip_retrieval)
|
| 156 |
-
self.proclip_geometry = proclip_geometry
|
| 157 |
-
self.proclip_dedicated_hyperbolic = bool(proclip_dedicated_hyperbolic)
|
| 158 |
-
self.beta_clip_weight = float(beta_clip_weight)
|
| 159 |
-
self.beta_clip_global_weight = float(beta_clip_global_weight)
|
| 160 |
-
self.beta_clip_beta = float(beta_clip_beta)
|
| 161 |
-
self.beta_clip_variant = beta_clip_variant
|
| 162 |
-
self.beta_clip_similarity = beta_clip_similarity
|
| 163 |
-
self.beta_clip_drop_cls_token = bool(beta_clip_drop_cls_token)
|
| 164 |
-
self.tren_weight = float(tren_weight)
|
| 165 |
-
self.tren_visual_distill_weight = float(tren_visual_distill_weight)
|
| 166 |
-
self.tren_text_distill_weight = float(tren_text_distill_weight)
|
| 167 |
-
self.tren_region_text_weight = float(tren_region_text_weight)
|
| 168 |
-
self.fuse_whole_part_encoder_forwards = bool(fuse_whole_part_encoder_forwards)
|
| 169 |
-
self.fuse_beta_query_encoder_forwards = bool(fuse_beta_query_encoder_forwards)
|
| 170 |
-
self.group_beta_query_pooling = bool(group_beta_query_pooling)
|
| 171 |
-
self.objective_autocast_dtype = objective_autocast_dtype
|
| 172 |
-
self.freeze_vision_encoder = bool(freeze_vision_encoder)
|
| 173 |
-
self.freeze_text_encoder = bool(freeze_text_encoder)
|
| 174 |
-
self.normalize_encoder_features = bool(normalize_encoder_features)
|
| 175 |
-
self.phyclip_subspace_dim = phyclip_subspace_dim
|
| 176 |
-
self.phyclip_product_metric = phyclip_product_metric
|
| 177 |
-
self.proclip_component_dim = proclip_component_dim
|
| 178 |
-
if projection_hidden_dim is not None and projection_hidden_dim <= 0:
|
| 179 |
-
raise ValueError("projection_hidden_dim must be positive when set")
|
| 180 |
-
if self.proclip_enabled and phyclip_subspace_dim is not None:
|
| 181 |
-
raise ValueError("ProCLIP mixed-curvature proxy cannot be combined with PHyCLIP Lorentz factors")
|
| 182 |
-
if phyclip_subspace_dim is not None:
|
| 183 |
-
if phyclip_subspace_dim <= 0:
|
| 184 |
-
raise ValueError("phyclip_subspace_dim must be positive when set")
|
| 185 |
-
if embed_dim % phyclip_subspace_dim != 0:
|
| 186 |
-
raise ValueError("embed_dim must be divisible by phyclip_subspace_dim")
|
| 187 |
-
self.phyclip_num_factors = embed_dim // phyclip_subspace_dim
|
| 188 |
-
else:
|
| 189 |
-
self.phyclip_num_factors = 0
|
| 190 |
-
self.vision_encoder = VisionEncoder(vision_backbone, pretrained=vision_pretrained)
|
| 191 |
-
self.text_encoder = TextEncoder(text_model_name, pretrained=text_pretrained, pooling=text_pooling)
|
| 192 |
-
self.tokenizer = self.text_encoder.tokenizer
|
| 193 |
-
self.embed_dim = embed_dim
|
| 194 |
-
if self.freeze_vision_encoder:
|
| 195 |
-
self.vision_encoder.requires_grad_(False)
|
| 196 |
-
self.vision_encoder.eval()
|
| 197 |
-
if self.freeze_text_encoder:
|
| 198 |
-
self.text_encoder.requires_grad_(False)
|
| 199 |
-
self.text_encoder.eval()
|
| 200 |
-
|
| 201 |
-
self.image_proj = _projection_head(self.vision_encoder.output_dim, embed_dim, projection_hidden_dim)
|
| 202 |
-
self.text_proj = _projection_head(self.text_encoder.output_dim, embed_dim, projection_hidden_dim)
|
| 203 |
-
self._init_experimental_modules(
|
| 204 |
-
beta_clip_num_heads=beta_clip_num_heads,
|
| 205 |
-
beta_clip_mlp_ratio=beta_clip_mlp_ratio,
|
| 206 |
-
tren_num_region_tokens=tren_num_region_tokens,
|
| 207 |
-
tren_num_decoder_layers=tren_num_decoder_layers,
|
| 208 |
-
tren_num_attention_heads=tren_num_attention_heads,
|
| 209 |
-
tren_prompt_grid_size=tren_prompt_grid_size,
|
| 210 |
-
tren_dropout=tren_dropout,
|
| 211 |
-
projection_hidden_dim=projection_hidden_dim,
|
| 212 |
-
proclip_projection_hidden_dim=proclip_projection_hidden_dim,
|
| 213 |
-
projection_head=_projection_head,
|
| 214 |
-
)
|
| 215 |
-
|
| 216 |
-
if objective == "hycoclip":
|
| 217 |
-
self.logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 218 |
-
elif objective == "uncha":
|
| 219 |
-
self.global_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
|
| 220 |
-
self.local_logit_scale = nn.Parameter(torch.tensor(1 / 0.05).log())
|
| 221 |
-
self.global_local_logit_scale = nn.Parameter(torch.tensor(1 / 0.06).log())
|
| 222 |
-
if uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
|
| 223 |
-
self.global_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
|
| 224 |
-
self.local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
|
| 225 |
-
self.global_local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
|
| 226 |
-
alpha_dim = phyclip_subspace_dim or embed_dim
|
| 227 |
-
alpha_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
|
| 228 |
-
self.visual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
|
| 229 |
-
self.textual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
|
| 230 |
-
|
| 231 |
-
curv_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
|
| 232 |
-
log_curv = torch.full(curv_shape, curv_init).log()
|
| 233 |
-
self.log_curv = nn.Parameter(log_curv, requires_grad=learn_curv)
|
| 234 |
-
self.curv_min = curv_init / 10.0
|
| 235 |
-
self.curv_max = curv_init * 10.0
|
| 236 |
-
self.objective = None
|
| 237 |
-
if objective != "proclip":
|
| 238 |
-
self.objective = build_objective(
|
| 239 |
-
objective=objective,
|
| 240 |
-
entail_weight=entail_weight,
|
| 241 |
-
inter_aperture_scale=inter_aperture_scale,
|
| 242 |
-
intra_aperture_scale=intra_aperture_scale,
|
| 243 |
-
uncha_piecewise_factor=uncha_piecewise_factor,
|
| 244 |
-
uncha_calibration_alpha=uncha_calibration_alpha,
|
| 245 |
-
uncha_stop_grad_calibration=uncha_stop_grad_calibration,
|
| 246 |
-
uncha_entailment_geometry=uncha_entailment_geometry,
|
| 247 |
-
uncha_aggregate_weight=uncha_aggregate_weight,
|
| 248 |
-
uncha_entailment_loss=uncha_entailment_loss,
|
| 249 |
-
uncha_argent_beta=uncha_argent_beta,
|
| 250 |
-
uncha_argent_norm_weight=uncha_argent_norm_weight,
|
| 251 |
-
uncha_argent_aux_weight=uncha_argent_aux_weight,
|
| 252 |
-
uncha_argent_aggregation=uncha_argent_aggregation,
|
| 253 |
-
uncha_part_weight_power=uncha_part_weight_power,
|
| 254 |
-
uncha_contrastive_loss=uncha_contrastive_loss,
|
| 255 |
-
uncha_sigmoid_negative_weight=uncha_sigmoid_negative_weight,
|
| 256 |
-
uncha_part_quality_mode=uncha_part_quality_mode,
|
| 257 |
-
uncha_part_quality_topk=uncha_part_quality_topk,
|
| 258 |
-
uncha_part_quality_temperature=uncha_part_quality_temperature,
|
| 259 |
-
uncha_contrastive_global_weight=uncha_contrastive_global_weight,
|
| 260 |
-
uncha_contrastive_local_weight=uncha_contrastive_local_weight,
|
| 261 |
-
uncha_contrastive_global_local_weight=uncha_contrastive_global_local_weight,
|
| 262 |
-
uncha_global_local_mode=uncha_global_local_mode,
|
| 263 |
-
uncha_global_local_metric=uncha_global_local_metric,
|
| 264 |
-
uncha_global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
|
| 265 |
-
uncha_global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
|
| 266 |
-
uncha_global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
|
| 267 |
-
uncha_global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
|
| 268 |
-
uncha_beta_cal_beta=uncha_beta_cal_beta,
|
| 269 |
-
uncha_beta_cal_variant=uncha_beta_cal_variant,
|
| 270 |
-
uncha_beta_cal_weight=uncha_beta_cal_weight,
|
| 271 |
-
uncha_himo_component_weight=uncha_himo_component_weight,
|
| 272 |
-
uncha_radius_order_weight=uncha_radius_order_weight,
|
| 273 |
-
uncha_radius_order_margin=uncha_radius_order_margin,
|
| 274 |
-
uncha_gramian_align_weight=uncha_gramian_align_weight,
|
| 275 |
-
product_metric=phyclip_product_metric,
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
def train(self, mode: bool = True) -> Hyper3CLIP:
|
| 279 |
-
super().train(mode)
|
| 280 |
-
if self.freeze_vision_encoder:
|
| 281 |
-
self.vision_encoder.eval()
|
| 282 |
-
if self.freeze_text_encoder:
|
| 283 |
-
self.text_encoder.eval()
|
| 284 |
-
return self
|
| 285 |
-
|
| 286 |
-
@property
|
| 287 |
-
def phyclip_enabled(self) -> bool:
|
| 288 |
-
return self.phyclip_subspace_dim is not None
|
| 289 |
-
|
| 290 |
-
def _kappa(self) -> torch.Tensor:
|
| 291 |
-
return self.log_curv.exp().clamp(min=self.curv_min, max=self.curv_max)
|
| 292 |
-
|
| 293 |
-
def encode_image(self, image: torch.Tensor, project: bool = True) -> torch.Tensor:
|
| 294 |
-
feats = self.image_proj(self.encode_image_base(image))
|
| 295 |
-
if not project:
|
| 296 |
-
return feats
|
| 297 |
-
return self.project_image_features(feats)
|
| 298 |
-
|
| 299 |
-
def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, project: bool = True) -> torch.Tensor:
|
| 300 |
-
feats = self.text_proj(self.encode_text_base(input_ids, attention_mask))
|
| 301 |
-
if not project:
|
| 302 |
-
return feats
|
| 303 |
-
return self.project_text_features(feats)
|
| 304 |
-
|
| 305 |
-
def encode_image_base(self, image: torch.Tensor) -> torch.Tensor:
|
| 306 |
-
with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
|
| 307 |
-
feats = self.vision_encoder(image)
|
| 308 |
-
feats = feats.detach() if self.freeze_vision_encoder else feats
|
| 309 |
-
return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
|
| 310 |
-
|
| 311 |
-
def encode_image_base_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 312 |
-
with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
|
| 313 |
-
feats, tokens = self.vision_encoder.forward_with_tokens(image)
|
| 314 |
-
if self.freeze_vision_encoder:
|
| 315 |
-
feats = feats.detach()
|
| 316 |
-
tokens = tokens.detach()
|
| 317 |
-
if self.normalize_encoder_features:
|
| 318 |
-
feats = F.normalize(feats.float(), dim=-1)
|
| 319 |
-
return feats, tokens
|
| 320 |
-
|
| 321 |
-
def encode_text_base(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 322 |
-
with torch.set_grad_enabled(self.training and not self.freeze_text_encoder):
|
| 323 |
-
feats = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 324 |
-
feats = feats.detach() if self.freeze_text_encoder else feats
|
| 325 |
-
return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
|
| 326 |
-
|
| 327 |
-
def project_image_features(self, feats: torch.Tensor) -> torch.Tensor:
|
| 328 |
-
if self.phyclip_enabled:
|
| 329 |
-
return self._project_product_features(feats, self.visual_alpha)
|
| 330 |
-
return exp_map0(feats.float() * self.visual_alpha.exp().float(), self._kappa().float())
|
| 331 |
-
|
| 332 |
-
def project_text_features(self, feats: torch.Tensor) -> torch.Tensor:
|
| 333 |
-
if self.phyclip_enabled:
|
| 334 |
-
return self._project_product_features(feats, self.textual_alpha)
|
| 335 |
-
return exp_map0(feats.float() * self.textual_alpha.exp().float(), self._kappa().float())
|
| 336 |
-
|
| 337 |
-
def similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
|
| 338 |
-
return metric_similarity(image_feats, text_feats, self._kappa(), product_metric=self.phyclip_product_metric)
|
| 339 |
-
|
| 340 |
-
def encode_retrieval_image(self, image: torch.Tensor) -> torch.Tensor:
|
| 341 |
-
base = self.encode_image_base(image)
|
| 342 |
-
tangent = self.image_proj(base)
|
| 343 |
-
if self.proclip_retrieval:
|
| 344 |
-
return self._project_proclip_image_base(base, self.project_image_features(tangent))
|
| 345 |
-
return self.project_image_features(tangent)
|
| 346 |
-
|
| 347 |
-
def encode_retrieval_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 348 |
-
base = self.encode_text_base(input_ids, attention_mask)
|
| 349 |
-
tangent = self.text_proj(base)
|
| 350 |
-
if self.proclip_retrieval:
|
| 351 |
-
return self._project_proclip_text_base(base, self.project_text_features(tangent))
|
| 352 |
-
return self.project_text_features(tangent)
|
| 353 |
-
|
| 354 |
-
def retrieval_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
|
| 355 |
-
if self.proclip_retrieval:
|
| 356 |
-
return self._proclip_similarity_scores(image_feats, text_feats)
|
| 357 |
-
return self.similarity_scores(image_feats, text_feats)
|
| 358 |
-
|
| 359 |
-
@property
|
| 360 |
-
def retrieval_requires_chunking(self) -> bool:
|
| 361 |
-
return self.phyclip_enabled or self.proclip_retrieval
|
| 362 |
-
|
| 363 |
-
def _objective_autocast(self, device_type: str):
|
| 364 |
-
dtype = {
|
| 365 |
-
"float32": torch.float32,
|
| 366 |
-
"fp32": torch.float32,
|
| 367 |
-
"float16": torch.float16,
|
| 368 |
-
"fp16": torch.float16,
|
| 369 |
-
"bfloat16": torch.bfloat16,
|
| 370 |
-
"bf16": torch.bfloat16,
|
| 371 |
-
}[self.objective_autocast_dtype]
|
| 372 |
-
enabled = device_type != "cpu" and dtype is not torch.float32
|
| 373 |
-
return torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled)
|
| 374 |
-
|
| 375 |
-
def forward(
|
| 376 |
-
self,
|
| 377 |
-
image: torch.Tensor,
|
| 378 |
-
part_images: torch.Tensor,
|
| 379 |
-
text_input_ids: torch.Tensor,
|
| 380 |
-
text_attention_mask: torch.Tensor,
|
| 381 |
-
part_text_input_ids: torch.Tensor,
|
| 382 |
-
part_text_attention_mask: torch.Tensor,
|
| 383 |
-
part_owner: torch.Tensor,
|
| 384 |
-
step: int | None = None,
|
| 385 |
-
beta_query_input_ids: torch.Tensor | None = None,
|
| 386 |
-
beta_query_attention_mask: torch.Tensor | None = None,
|
| 387 |
-
beta_query_owner: torch.Tensor | None = None,
|
| 388 |
-
beta_query_type: torch.Tensor | None = None,
|
| 389 |
-
beta_query_parent: torch.Tensor | None = None,
|
| 390 |
-
beta_query_weight: torch.Tensor | None = None,
|
| 391 |
-
beta_query_source_part: torch.Tensor | None = None,
|
| 392 |
-
) -> dict[str, torch.Tensor]:
|
| 393 |
-
with torch.no_grad():
|
| 394 |
-
self._clamp_logit_scales()
|
| 395 |
-
self.visual_alpha.clamp_(max=0.0)
|
| 396 |
-
self.textual_alpha.clamp_(max=0.0)
|
| 397 |
-
kappa = self._kappa()
|
| 398 |
-
|
| 399 |
-
feature_dim = self.embed_dim
|
| 400 |
-
beta_image_tokens = None
|
| 401 |
-
beta_query_base = None
|
| 402 |
-
part_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
|
| 403 |
-
part_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
|
| 404 |
-
hier_beta_enabled = self.objective_name == "uncha" and self.uncha_entailment_loss in {
|
| 405 |
-
"hier_beta_argent",
|
| 406 |
-
"hier_beta_sourcepart_argent",
|
| 407 |
-
}
|
| 408 |
-
if (
|
| 409 |
-
hier_beta_enabled
|
| 410 |
-
and self.fuse_beta_query_encoder_forwards
|
| 411 |
-
and not self.tren_enabled
|
| 412 |
-
and beta_query_input_ids is not None
|
| 413 |
-
and beta_query_attention_mask is not None
|
| 414 |
-
and part_images.shape[0] > 0
|
| 415 |
-
):
|
| 416 |
-
(
|
| 417 |
-
image_base,
|
| 418 |
-
text_base,
|
| 419 |
-
image_euc,
|
| 420 |
-
text_euc,
|
| 421 |
-
image_feats,
|
| 422 |
-
text_feats,
|
| 423 |
-
part_image_feats,
|
| 424 |
-
part_text_feats,
|
| 425 |
-
part_image_euc,
|
| 426 |
-
part_text_euc,
|
| 427 |
-
part_image_base,
|
| 428 |
-
part_text_base,
|
| 429 |
-
beta_image_tokens,
|
| 430 |
-
beta_query_base,
|
| 431 |
-
) = self._encode_hier_beta_whole_parts_and_queries(
|
| 432 |
-
image=image,
|
| 433 |
-
part_images=part_images,
|
| 434 |
-
text_input_ids=text_input_ids,
|
| 435 |
-
text_attention_mask=text_attention_mask,
|
| 436 |
-
part_text_input_ids=part_text_input_ids,
|
| 437 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 438 |
-
beta_query_input_ids=beta_query_input_ids,
|
| 439 |
-
beta_query_attention_mask=beta_query_attention_mask,
|
| 440 |
-
)
|
| 441 |
-
elif self.beta_query_pooling_enabled or self.tren_enabled:
|
| 442 |
-
image_base, beta_image_tokens = self.encode_image_base_with_tokens(image)
|
| 443 |
-
text_base = self.encode_text_base(text_input_ids, text_attention_mask)
|
| 444 |
-
image_euc = self.image_proj(image_base)
|
| 445 |
-
text_euc = self.text_proj(text_base)
|
| 446 |
-
image_feats = self.project_image_features(image_euc)
|
| 447 |
-
text_feats = self.project_text_features(text_euc)
|
| 448 |
-
(
|
| 449 |
-
part_image_feats,
|
| 450 |
-
part_text_feats,
|
| 451 |
-
part_image_euc,
|
| 452 |
-
part_text_euc,
|
| 453 |
-
part_image_base,
|
| 454 |
-
part_text_base,
|
| 455 |
-
) = self._encode_parts_with_base(
|
| 456 |
-
part_images=part_images,
|
| 457 |
-
part_text_input_ids=part_text_input_ids,
|
| 458 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 459 |
-
feature_dim=feature_dim,
|
| 460 |
-
)
|
| 461 |
-
elif self.fuse_whole_part_encoder_forwards and self.objective_name != "proclip" and part_images.shape[0] > 0:
|
| 462 |
-
(
|
| 463 |
-
image_base,
|
| 464 |
-
text_base,
|
| 465 |
-
image_euc,
|
| 466 |
-
text_euc,
|
| 467 |
-
image_feats,
|
| 468 |
-
text_feats,
|
| 469 |
-
part_image_feats,
|
| 470 |
-
part_text_feats,
|
| 471 |
-
part_image_euc,
|
| 472 |
-
part_text_euc,
|
| 473 |
-
part_image_base,
|
| 474 |
-
part_text_base,
|
| 475 |
-
) = self._encode_whole_and_parts(
|
| 476 |
-
image=image,
|
| 477 |
-
part_images=part_images,
|
| 478 |
-
text_input_ids=text_input_ids,
|
| 479 |
-
text_attention_mask=text_attention_mask,
|
| 480 |
-
part_text_input_ids=part_text_input_ids,
|
| 481 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 482 |
-
)
|
| 483 |
-
else:
|
| 484 |
-
image_base = self.encode_image_base(image)
|
| 485 |
-
text_base = self.encode_text_base(text_input_ids, text_attention_mask)
|
| 486 |
-
image_euc = self.image_proj(image_base)
|
| 487 |
-
text_euc = self.text_proj(text_base)
|
| 488 |
-
image_feats = self.project_image_features(image_euc)
|
| 489 |
-
text_feats = self.project_text_features(text_euc)
|
| 490 |
-
(
|
| 491 |
-
part_image_feats,
|
| 492 |
-
part_text_feats,
|
| 493 |
-
part_image_euc,
|
| 494 |
-
part_text_euc,
|
| 495 |
-
part_image_base,
|
| 496 |
-
part_text_base,
|
| 497 |
-
) = self._encode_parts_with_base(
|
| 498 |
-
part_images=part_images,
|
| 499 |
-
part_text_input_ids=part_text_input_ids,
|
| 500 |
-
part_text_attention_mask=part_text_attention_mask,
|
| 501 |
-
feature_dim=feature_dim,
|
| 502 |
-
)
|
| 503 |
-
targets = local_target_indices(image_feats.size(0), image_feats.device)
|
| 504 |
-
|
| 505 |
-
if self.objective_name == "proclip":
|
| 506 |
-
proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
|
| 507 |
-
proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
|
| 508 |
-
proclip_loss = self._proclip_contrastive_loss(
|
| 509 |
-
image_feats=proclip_image_feats,
|
| 510 |
-
text_feats=proclip_text_feats,
|
| 511 |
-
all_image_feats=gather_with_grad(proclip_image_feats),
|
| 512 |
-
all_text_feats=gather_with_grad(proclip_text_feats),
|
| 513 |
-
targets=targets,
|
| 514 |
-
)
|
| 515 |
-
zero = proclip_loss.new_zeros(())
|
| 516 |
-
return {
|
| 517 |
-
"loss": proclip_loss,
|
| 518 |
-
"contrastive_loss": proclip_loss,
|
| 519 |
-
"entailment_loss": zero,
|
| 520 |
-
"part_count": part_owner.new_tensor(0),
|
| 521 |
-
"proclip_contrastive_loss": proclip_loss,
|
| 522 |
-
**self._detached_kappa_logs(kappa),
|
| 523 |
-
**self._detached_logit_scales(),
|
| 524 |
-
}
|
| 525 |
-
|
| 526 |
-
himo_text_feats = None
|
| 527 |
-
all_himo_text_feats = None
|
| 528 |
-
if self.objective_name == "uncha" and self.uncha_himo_component_weight > 0.0:
|
| 529 |
-
all_text_euc = gather_with_grad(text_euc)
|
| 530 |
-
all_component_euc = hide_reconstruct_embeddings(
|
| 531 |
-
all_text_euc,
|
| 532 |
-
variance_threshold=self.uncha_himo_variance_threshold,
|
| 533 |
-
detach_pca=self.uncha_himo_detach_pca,
|
| 534 |
-
)
|
| 535 |
-
if get_world_size() > 1:
|
| 536 |
-
start = text_euc.size(0) * get_rank()
|
| 537 |
-
end = start + text_euc.size(0)
|
| 538 |
-
component_euc = all_component_euc[start:end]
|
| 539 |
-
else:
|
| 540 |
-
component_euc = all_component_euc
|
| 541 |
-
himo_text_feats = self.project_text_features(component_euc)
|
| 542 |
-
all_himo_text_feats = gather_with_grad(himo_text_feats)
|
| 543 |
-
all_image_feats = gather_with_grad(image_feats)
|
| 544 |
-
all_text_feats = gather_with_grad(text_feats)
|
| 545 |
-
all_image_euc = None
|
| 546 |
-
all_text_euc = None
|
| 547 |
-
if self.objective_name == "uncha" and self.uncha_contrastive_loss == "siglip":
|
| 548 |
-
all_image_euc = gather_with_grad(image_euc)
|
| 549 |
-
all_text_euc = gather_with_grad(text_euc)
|
| 550 |
-
part_owner = part_owner.to(device=image_feats.device, dtype=torch.long)
|
| 551 |
-
beta_query_embeddings = {}
|
| 552 |
-
if self.objective_name == "uncha" and self.uncha_entailment_loss in {
|
| 553 |
-
"hier_beta_argent",
|
| 554 |
-
"hier_beta_sourcepart_argent",
|
| 555 |
-
}:
|
| 556 |
-
if beta_image_tokens is None:
|
| 557 |
-
raise RuntimeError(f"{self.uncha_entailment_loss} requires image patch tokens")
|
| 558 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 559 |
-
beta_query_embeddings = self._beta_query_entailment_embeddings(
|
| 560 |
-
image_tokens=beta_image_tokens.float(),
|
| 561 |
-
beta_query_input_ids=beta_query_input_ids,
|
| 562 |
-
beta_query_attention_mask=beta_query_attention_mask,
|
| 563 |
-
beta_query_owner=beta_query_owner,
|
| 564 |
-
beta_query_parent=beta_query_parent,
|
| 565 |
-
beta_query_weight=beta_query_weight,
|
| 566 |
-
beta_query_source_part=beta_query_source_part,
|
| 567 |
-
kappa=kappa.float(),
|
| 568 |
-
query_base=beta_query_base,
|
| 569 |
-
)
|
| 570 |
-
|
| 571 |
-
with self._objective_autocast(image.device.type):
|
| 572 |
-
if self.objective is None:
|
| 573 |
-
raise RuntimeError("Non-ProCLIP forward requires an objective module")
|
| 574 |
-
losses = self.objective(
|
| 575 |
-
{
|
| 576 |
-
"image_feats": image_feats,
|
| 577 |
-
"text_feats": text_feats,
|
| 578 |
-
"part_image_feats": part_image_feats,
|
| 579 |
-
"part_text_feats": part_text_feats,
|
| 580 |
-
"part_owner": part_owner,
|
| 581 |
-
"all_image_feats": all_image_feats,
|
| 582 |
-
"all_text_feats": all_text_feats,
|
| 583 |
-
**(
|
| 584 |
-
{
|
| 585 |
-
"image_euc_feats": image_euc,
|
| 586 |
-
"text_euc_feats": text_euc,
|
| 587 |
-
"part_image_euc_feats": part_image_euc,
|
| 588 |
-
"part_text_euc_feats": part_text_euc,
|
| 589 |
-
"all_image_euc_feats": all_image_euc,
|
| 590 |
-
"all_text_euc_feats": all_text_euc,
|
| 591 |
-
}
|
| 592 |
-
if all_image_euc is not None and all_text_euc is not None
|
| 593 |
-
else {}
|
| 594 |
-
),
|
| 595 |
-
"targets": targets,
|
| 596 |
-
"kappa": kappa,
|
| 597 |
-
"entail_weight_scale": self._entail_weight_scale(step, image_feats.device),
|
| 598 |
-
**beta_query_embeddings,
|
| 599 |
-
**(
|
| 600 |
-
{
|
| 601 |
-
"himo_text_feats": himo_text_feats,
|
| 602 |
-
"all_himo_text_feats": all_himo_text_feats,
|
| 603 |
-
}
|
| 604 |
-
if himo_text_feats is not None
|
| 605 |
-
else {}
|
| 606 |
-
),
|
| 607 |
-
},
|
| 608 |
-
self._objective_logit_scales(),
|
| 609 |
-
)
|
| 610 |
-
|
| 611 |
-
if self.beta_clip_global_weight > 0.0:
|
| 612 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 613 |
-
beta_clip_global_loss = self._beta_clip_global_contrastive_loss(
|
| 614 |
-
image_euc=image_euc,
|
| 615 |
-
text_euc=text_euc,
|
| 616 |
-
targets=targets,
|
| 617 |
-
)
|
| 618 |
-
losses = {
|
| 619 |
-
**losses,
|
| 620 |
-
"loss": losses["loss"] + self.beta_clip_global_weight * beta_clip_global_loss,
|
| 621 |
-
"beta_clip_global_loss": beta_clip_global_loss,
|
| 622 |
-
}
|
| 623 |
-
|
| 624 |
-
if self.beta_clip_enabled:
|
| 625 |
-
if beta_image_tokens is None:
|
| 626 |
-
raise RuntimeError("beta-CLIP auxiliary requires image patch tokens")
|
| 627 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 628 |
-
beta_clip_loss = self._beta_clip_auxiliary_loss(
|
| 629 |
-
image_tokens=beta_image_tokens.float(),
|
| 630 |
-
beta_query_input_ids=beta_query_input_ids,
|
| 631 |
-
beta_query_attention_mask=beta_query_attention_mask,
|
| 632 |
-
beta_query_owner=beta_query_owner,
|
| 633 |
-
global_targets=targets,
|
| 634 |
-
kappa=kappa.float(),
|
| 635 |
-
)
|
| 636 |
-
losses = {
|
| 637 |
-
**losses,
|
| 638 |
-
"loss": losses["loss"] + self.beta_clip_weight * beta_clip_loss,
|
| 639 |
-
"beta_clip_loss": beta_clip_loss,
|
| 640 |
-
}
|
| 641 |
-
|
| 642 |
-
if self.tren_enabled:
|
| 643 |
-
if beta_image_tokens is None:
|
| 644 |
-
raise RuntimeError("T-REN auxiliary requires image patch tokens")
|
| 645 |
-
with torch.autocast(device_type=image.device.type, enabled=False):
|
| 646 |
-
tren_losses = self._tren_auxiliary_losses(
|
| 647 |
-
image_tokens=beta_image_tokens.float(),
|
| 648 |
-
part_owner=part_owner,
|
| 649 |
-
part_image_base=part_image_base.float(),
|
| 650 |
-
part_text_base=part_text_base.float(),
|
| 651 |
-
)
|
| 652 |
-
losses = {
|
| 653 |
-
**losses,
|
| 654 |
-
"loss": losses["loss"] + self.tren_weight * tren_losses["tren_loss"],
|
| 655 |
-
**tren_losses,
|
| 656 |
-
}
|
| 657 |
-
|
| 658 |
-
if self.proclip_enabled and self.proclip_weight > 0.0:
|
| 659 |
-
proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
|
| 660 |
-
proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
|
| 661 |
-
proclip_loss = self._proclip_contrastive_loss(
|
| 662 |
-
image_feats=proclip_image_feats,
|
| 663 |
-
text_feats=proclip_text_feats,
|
| 664 |
-
all_image_feats=gather_with_grad(proclip_image_feats),
|
| 665 |
-
all_text_feats=gather_with_grad(proclip_text_feats),
|
| 666 |
-
targets=targets,
|
| 667 |
-
)
|
| 668 |
-
losses = {
|
| 669 |
-
**losses,
|
| 670 |
-
"loss": losses["loss"] + self.proclip_weight * proclip_loss,
|
| 671 |
-
"proclip_contrastive_loss": proclip_loss,
|
| 672 |
-
}
|
| 673 |
-
|
| 674 |
-
return {**losses, **self._detached_kappa_logs(kappa), **self._detached_logit_scales()}
|
| 675 |
-
|
| 676 |
-
def _encode_parts(
|
| 677 |
-
self,
|
| 678 |
-
part_images: torch.Tensor,
|
| 679 |
-
part_text_input_ids: torch.Tensor,
|
| 680 |
-
part_text_attention_mask: torch.Tensor,
|
| 681 |
-
feature_dim: int,
|
| 682 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 683 |
-
if part_images.shape[0] == 0:
|
| 684 |
-
empty = part_images.new_zeros((0, feature_dim))
|
| 685 |
-
return empty, empty, empty, empty
|
| 686 |
-
|
| 687 |
-
part_image_euc = self.image_proj(self.encode_image_base(part_images))
|
| 688 |
-
part_text_euc = self.text_proj(self.encode_text_base(part_text_input_ids, part_text_attention_mask))
|
| 689 |
-
part_image_feats = self.project_image_features(part_image_euc)
|
| 690 |
-
part_text_feats = self.project_text_features(part_text_euc)
|
| 691 |
-
return part_image_feats, part_text_feats, part_image_euc, part_text_euc
|
| 692 |
-
|
| 693 |
-
def _encode_parts_with_base(
|
| 694 |
-
self,
|
| 695 |
-
part_images: torch.Tensor,
|
| 696 |
-
part_text_input_ids: torch.Tensor,
|
| 697 |
-
part_text_attention_mask: torch.Tensor,
|
| 698 |
-
feature_dim: int,
|
| 699 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 700 |
-
if part_images.shape[0] == 0:
|
| 701 |
-
empty = part_images.new_zeros((0, feature_dim))
|
| 702 |
-
empty_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
|
| 703 |
-
empty_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
|
| 704 |
-
return empty, empty, empty, empty, empty_image_base, empty_text_base
|
| 705 |
-
|
| 706 |
-
part_image_base = self.encode_image_base(part_images)
|
| 707 |
-
part_text_base = self.encode_text_base(part_text_input_ids, part_text_attention_mask)
|
| 708 |
-
part_image_euc = self.image_proj(part_image_base)
|
| 709 |
-
part_text_euc = self.text_proj(part_text_base)
|
| 710 |
-
part_image_feats = self.project_image_features(part_image_euc)
|
| 711 |
-
part_text_feats = self.project_text_features(part_text_euc)
|
| 712 |
-
return part_image_feats, part_text_feats, part_image_euc, part_text_euc, part_image_base, part_text_base
|
| 713 |
-
|
| 714 |
-
def _encode_whole_and_parts(
|
| 715 |
-
self,
|
| 716 |
-
image: torch.Tensor,
|
| 717 |
-
part_images: torch.Tensor,
|
| 718 |
-
text_input_ids: torch.Tensor,
|
| 719 |
-
text_attention_mask: torch.Tensor,
|
| 720 |
-
part_text_input_ids: torch.Tensor,
|
| 721 |
-
part_text_attention_mask: torch.Tensor,
|
| 722 |
-
) -> tuple[
|
| 723 |
-
torch.Tensor,
|
| 724 |
-
torch.Tensor,
|
| 725 |
-
torch.Tensor,
|
| 726 |
-
torch.Tensor,
|
| 727 |
-
torch.Tensor,
|
| 728 |
-
torch.Tensor,
|
| 729 |
-
torch.Tensor,
|
| 730 |
-
torch.Tensor,
|
| 731 |
-
torch.Tensor,
|
| 732 |
-
torch.Tensor,
|
| 733 |
-
torch.Tensor,
|
| 734 |
-
torch.Tensor,
|
| 735 |
-
]:
|
| 736 |
-
batch_size = image.shape[0]
|
| 737 |
-
part_count = part_images.shape[0]
|
| 738 |
-
image_base_all = self.encode_image_base(torch.cat([image, part_images], dim=0))
|
| 739 |
-
image_euc_all = self.image_proj(image_base_all)
|
| 740 |
-
image_feats_all = self.project_image_features(image_euc_all)
|
| 741 |
-
|
| 742 |
-
text_ids, text_mask = self._concat_text_batches(
|
| 743 |
-
text_input_ids,
|
| 744 |
-
text_attention_mask,
|
| 745 |
-
part_text_input_ids,
|
| 746 |
-
part_text_attention_mask,
|
| 747 |
-
)
|
| 748 |
-
text_base_all = self.encode_text_base(text_ids, text_mask)
|
| 749 |
-
text_euc_all = self.text_proj(text_base_all)
|
| 750 |
-
text_feats_all = self.project_text_features(text_euc_all)
|
| 751 |
-
|
| 752 |
-
image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
|
| 753 |
-
text_base, part_text_base = text_base_all.split([batch_size, part_count], dim=0)
|
| 754 |
-
image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
|
| 755 |
-
text_euc, part_text_euc = text_euc_all.split([batch_size, part_count], dim=0)
|
| 756 |
-
image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
|
| 757 |
-
text_feats, part_text_feats = text_feats_all.split([batch_size, part_count], dim=0)
|
| 758 |
-
return (
|
| 759 |
-
image_base,
|
| 760 |
-
text_base,
|
| 761 |
-
image_euc,
|
| 762 |
-
text_euc,
|
| 763 |
-
image_feats,
|
| 764 |
-
text_feats,
|
| 765 |
-
part_image_feats,
|
| 766 |
-
part_text_feats,
|
| 767 |
-
part_image_euc,
|
| 768 |
-
part_text_euc,
|
| 769 |
-
part_image_base,
|
| 770 |
-
part_text_base,
|
| 771 |
-
)
|
| 772 |
-
|
| 773 |
-
def _encode_hier_beta_whole_parts_and_queries(
|
| 774 |
-
self,
|
| 775 |
-
image: torch.Tensor,
|
| 776 |
-
part_images: torch.Tensor,
|
| 777 |
-
text_input_ids: torch.Tensor,
|
| 778 |
-
text_attention_mask: torch.Tensor,
|
| 779 |
-
part_text_input_ids: torch.Tensor,
|
| 780 |
-
part_text_attention_mask: torch.Tensor,
|
| 781 |
-
beta_query_input_ids: torch.Tensor,
|
| 782 |
-
beta_query_attention_mask: torch.Tensor,
|
| 783 |
-
) -> tuple[
|
| 784 |
-
torch.Tensor,
|
| 785 |
-
torch.Tensor,
|
| 786 |
-
torch.Tensor,
|
| 787 |
-
torch.Tensor,
|
| 788 |
-
torch.Tensor,
|
| 789 |
-
torch.Tensor,
|
| 790 |
-
torch.Tensor,
|
| 791 |
-
torch.Tensor,
|
| 792 |
-
torch.Tensor,
|
| 793 |
-
torch.Tensor,
|
| 794 |
-
torch.Tensor,
|
| 795 |
-
torch.Tensor,
|
| 796 |
-
torch.Tensor,
|
| 797 |
-
torch.Tensor,
|
| 798 |
-
]:
|
| 799 |
-
batch_size = image.shape[0]
|
| 800 |
-
part_count = part_images.shape[0]
|
| 801 |
-
query_count = beta_query_input_ids.shape[0]
|
| 802 |
-
|
| 803 |
-
image_base_all, image_tokens_all = self.encode_image_base_with_tokens(torch.cat([image, part_images], dim=0))
|
| 804 |
-
image_euc_all = self.image_proj(image_base_all)
|
| 805 |
-
image_feats_all = self.project_image_features(image_euc_all)
|
| 806 |
-
image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
|
| 807 |
-
image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
|
| 808 |
-
image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
|
| 809 |
-
beta_image_tokens = image_tokens_all[:batch_size]
|
| 810 |
-
|
| 811 |
-
text_ids, text_mask = self._concat_text_batch_list(
|
| 812 |
-
(text_input_ids, text_attention_mask),
|
| 813 |
-
(part_text_input_ids, part_text_attention_mask),
|
| 814 |
-
(beta_query_input_ids, beta_query_attention_mask),
|
| 815 |
-
)
|
| 816 |
-
text_base_all = self.encode_text_base(text_ids, text_mask)
|
| 817 |
-
text_euc_all = self.text_proj(text_base_all)
|
| 818 |
-
text_feats_all = self.project_text_features(text_euc_all)
|
| 819 |
-
text_base, part_text_base, beta_query_base = text_base_all.split([batch_size, part_count, query_count], dim=0)
|
| 820 |
-
text_euc, part_text_euc, _ = text_euc_all.split([batch_size, part_count, query_count], dim=0)
|
| 821 |
-
text_feats, part_text_feats, _ = text_feats_all.split([batch_size, part_count, query_count], dim=0)
|
| 822 |
-
|
| 823 |
-
return (
|
| 824 |
-
image_base,
|
| 825 |
-
text_base,
|
| 826 |
-
image_euc,
|
| 827 |
-
text_euc,
|
| 828 |
-
image_feats,
|
| 829 |
-
text_feats,
|
| 830 |
-
part_image_feats,
|
| 831 |
-
part_text_feats,
|
| 832 |
-
part_image_euc,
|
| 833 |
-
part_text_euc,
|
| 834 |
-
part_image_base,
|
| 835 |
-
part_text_base,
|
| 836 |
-
beta_image_tokens,
|
| 837 |
-
beta_query_base,
|
| 838 |
-
)
|
| 839 |
-
|
| 840 |
-
def _concat_text_batches(
|
| 841 |
-
self,
|
| 842 |
-
text_input_ids: torch.Tensor,
|
| 843 |
-
text_attention_mask: torch.Tensor,
|
| 844 |
-
part_text_input_ids: torch.Tensor,
|
| 845 |
-
part_text_attention_mask: torch.Tensor,
|
| 846 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 847 |
-
return self._concat_text_batch_list(
|
| 848 |
-
(text_input_ids, text_attention_mask),
|
| 849 |
-
(part_text_input_ids, part_text_attention_mask),
|
| 850 |
-
)
|
| 851 |
-
|
| 852 |
-
def _concat_text_batch_list(
|
| 853 |
-
self,
|
| 854 |
-
*batches: tuple[torch.Tensor, torch.Tensor],
|
| 855 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 856 |
-
target_length = max(input_ids.shape[1] for input_ids, _ in batches)
|
| 857 |
-
pad_token_id = self.text_encoder.tokenizer.pad_token_id
|
| 858 |
-
if pad_token_id is None:
|
| 859 |
-
pad_token_id = 0
|
| 860 |
-
return (
|
| 861 |
-
torch.cat([_pad_sequence_dim(input_ids, target_length, pad_token_id) for input_ids, _ in batches], dim=0),
|
| 862 |
-
torch.cat([_pad_sequence_dim(attention_mask, target_length, 0) for _, attention_mask in batches], dim=0),
|
| 863 |
-
)
|
| 864 |
-
|
| 865 |
-
def _clamp_logit_scales(self) -> None:
|
| 866 |
-
if self.objective_name == "proclip":
|
| 867 |
-
self.proclip_logit_scale.clamp_(max=4.6052)
|
| 868 |
-
self._clamp_experimental_logit_scales()
|
| 869 |
-
return
|
| 870 |
-
if self.objective_name == "hycoclip":
|
| 871 |
-
self.logit_scale.clamp_(max=4.6052)
|
| 872 |
-
self._clamp_experimental_logit_scales()
|
| 873 |
-
return
|
| 874 |
-
self.global_logit_scale.clamp_(max=4.6052)
|
| 875 |
-
self.local_logit_scale.clamp_(max=4.6052)
|
| 876 |
-
self.global_local_logit_scale.clamp_(max=4.6052)
|
| 877 |
-
self._clamp_experimental_logit_scales()
|
| 878 |
-
|
| 879 |
-
def _objective_logit_scales(self) -> torch.Tensor | dict[str, torch.Tensor]:
|
| 880 |
-
if self.objective_name == "hycoclip":
|
| 881 |
-
return self.logit_scale
|
| 882 |
-
if self.objective_name == "proclip":
|
| 883 |
-
return self.proclip_logit_scale
|
| 884 |
-
return {
|
| 885 |
-
"global": self.global_logit_scale,
|
| 886 |
-
"local": self.local_logit_scale,
|
| 887 |
-
"global_local": self.global_local_logit_scale,
|
| 888 |
-
**(
|
| 889 |
-
{
|
| 890 |
-
"global_bias": self.global_logit_bias,
|
| 891 |
-
"local_bias": self.local_logit_bias,
|
| 892 |
-
"global_local_bias": self.global_local_logit_bias,
|
| 893 |
-
}
|
| 894 |
-
if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}
|
| 895 |
-
else {}
|
| 896 |
-
),
|
| 897 |
-
}
|
| 898 |
-
|
| 899 |
-
def _detached_logit_scales(self) -> dict[str, torch.Tensor]:
|
| 900 |
-
if self.objective_name == "proclip":
|
| 901 |
-
return self._detached_experimental_logit_scales()
|
| 902 |
-
if self.objective_name == "hycoclip":
|
| 903 |
-
logs = {"logit_scale": self.logit_scale.exp().detach()}
|
| 904 |
-
logs.update(self._detached_experimental_logit_scales())
|
| 905 |
-
return logs
|
| 906 |
-
logs = {
|
| 907 |
-
"global_logit_scale": self.global_logit_scale.exp().detach(),
|
| 908 |
-
"local_logit_scale": self.local_logit_scale.exp().detach(),
|
| 909 |
-
"global_local_logit_scale": self.global_local_logit_scale.exp().detach(),
|
| 910 |
-
}
|
| 911 |
-
if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
|
| 912 |
-
logs.update(
|
| 913 |
-
{
|
| 914 |
-
"global_logit_bias": self.global_logit_bias.detach(),
|
| 915 |
-
"local_logit_bias": self.local_logit_bias.detach(),
|
| 916 |
-
"global_local_logit_bias": self.global_local_logit_bias.detach(),
|
| 917 |
-
}
|
| 918 |
-
)
|
| 919 |
-
logs.update(self._detached_experimental_logit_scales())
|
| 920 |
-
return logs
|
| 921 |
-
|
| 922 |
-
def _project_product_features(self, feats: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
|
| 923 |
-
product_feats = feats.float().reshape(feats.size(0), self.phyclip_num_factors, self.phyclip_subspace_dim)
|
| 924 |
-
product_feats = product_feats * alpha.exp().float().view(1, -1, 1)
|
| 925 |
-
return exp_map0(product_feats, self._kappa().float().view(1, -1, 1))
|
| 926 |
-
|
| 927 |
-
def _detached_kappa_logs(self, kappa: torch.Tensor) -> dict[str, torch.Tensor]:
|
| 928 |
-
detached = kappa.detach()
|
| 929 |
-
if detached.numel() == 1:
|
| 930 |
-
return {"kappa": detached.reshape(())}
|
| 931 |
-
return {
|
| 932 |
-
"kappa": detached.mean(),
|
| 933 |
-
"kappa_min": detached.min(),
|
| 934 |
-
"kappa_max": detached.max(),
|
| 935 |
-
}
|
| 936 |
-
|
| 937 |
-
def _entail_weight_scale(self, step: int | None, device: torch.device) -> torch.Tensor:
|
| 938 |
-
if self.uncha_entailment_warmup_steps <= 0 or step is None:
|
| 939 |
-
return torch.ones((), device=device)
|
| 940 |
-
scale = min(1.0, float(step + 1) / float(self.uncha_entailment_warmup_steps))
|
| 941 |
-
return torch.tensor(scale, device=device)
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
def _projection_head(input_dim: int, output_dim: int, hidden_dim: int | None) -> nn.Module:
|
| 945 |
-
if hidden_dim is None:
|
| 946 |
-
return nn.Linear(input_dim, output_dim)
|
| 947 |
-
return nn.Sequential(
|
| 948 |
-
nn.Linear(input_dim, hidden_dim),
|
| 949 |
-
nn.ReLU(),
|
| 950 |
-
nn.Linear(hidden_dim, output_dim),
|
| 951 |
-
)
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
def _pad_sequence_dim(tensor: torch.Tensor, target_length: int, value: int) -> torch.Tensor:
|
| 955 |
-
pad = target_length - tensor.shape[1]
|
| 956 |
-
if pad <= 0:
|
| 957 |
-
return tensor
|
| 958 |
-
return F.pad(tensor, (0, pad), value=value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/lorentz.py
DELETED
|
@@ -1,265 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
|
| 10 |
-
"""Compute batched Lorentzian inner product for matching rows."""
|
| 11 |
-
x = x.float()
|
| 12 |
-
y = y.float()
|
| 13 |
-
return -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def pairwise_lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
|
| 17 |
-
"""Compute all-pairs Lorentzian inner products."""
|
| 18 |
-
x = x.float()
|
| 19 |
-
y = y.float()
|
| 20 |
-
time = -x[:, :1] @ y[:, :1].T
|
| 21 |
-
space = x[:, 1:] @ y[:, 1:].T
|
| 22 |
-
return time + space
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def exp_map0(u: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 26 |
-
"""Exponential map at the origin from tangent space to hyperboloid."""
|
| 27 |
-
u = u.float()
|
| 28 |
-
kappa = kappa.float()
|
| 29 |
-
sqrt_k = torch.sqrt(kappa)
|
| 30 |
-
norm_u = torch.linalg.norm(u, dim=-1, keepdim=True).clamp_min(eps)
|
| 31 |
-
scaled = sqrt_k * norm_u
|
| 32 |
-
clipped_scaled = scaled.clamp_max(math.asinh(2**15))
|
| 33 |
-
time = torch.cosh(clipped_scaled) / sqrt_k
|
| 34 |
-
space = torch.sinh(clipped_scaled) * u / scaled.clamp_min(eps)
|
| 35 |
-
return torch.cat([time, space], dim=-1)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def log_map0(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 39 |
-
"""Logarithmic map at the origin from hyperboloid to tangent space.
|
| 40 |
-
|
| 41 |
-
Inverts ``exp_map0`` for points on the Lorentz model hyperboloid. Returns
|
| 42 |
-
vectors in the Euclidean tangent space at the origin (no time coordinate).
|
| 43 |
-
"""
|
| 44 |
-
x = x.float()
|
| 45 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 46 |
-
kappa = kappa.to(dtype=torch.float32).flatten()
|
| 47 |
-
|
| 48 |
-
if x.dim() == 2:
|
| 49 |
-
if kappa.numel() != 1:
|
| 50 |
-
raise ValueError("log_map0 expects scalar kappa for non-product embeddings")
|
| 51 |
-
sqrt_k = torch.sqrt(kappa.reshape(()))
|
| 52 |
-
alpha = torch.acosh((sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps))
|
| 53 |
-
coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
|
| 54 |
-
return x[:, 1:] * coef.unsqueeze(-1)
|
| 55 |
-
|
| 56 |
-
if x.dim() == 3:
|
| 57 |
-
if kappa.numel() == 1:
|
| 58 |
-
kappa = kappa.expand(x.shape[1])
|
| 59 |
-
if kappa.numel() != x.shape[1]:
|
| 60 |
-
raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
|
| 61 |
-
sqrt_k = torch.sqrt(kappa).view(1, -1)
|
| 62 |
-
alpha = torch.acosh((sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps))
|
| 63 |
-
coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
|
| 64 |
-
return x[..., 1:] * coef.unsqueeze(-1)
|
| 65 |
-
|
| 66 |
-
raise ValueError("log_map0 expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 70 |
-
"""Pairwise geodesic distance on the Lorentz model."""
|
| 71 |
-
kappa = kappa.float()
|
| 72 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 73 |
-
prod = (-kappa) * pairwise_lorentz_inner(x, y)
|
| 74 |
-
prod = prod.clamp_min(1.0 + dist_eps)
|
| 75 |
-
return torch.acosh(prod) / torch.sqrt(kappa)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def product_pairwise_dist(
|
| 79 |
-
x: Tensor,
|
| 80 |
-
y: Tensor,
|
| 81 |
-
kappa: Tensor,
|
| 82 |
-
metric: str = "l1",
|
| 83 |
-
eps: float = 1e-8,
|
| 84 |
-
) -> Tensor:
|
| 85 |
-
"""Pairwise distance in an l1/l2 product of Lorentz factors.
|
| 86 |
-
|
| 87 |
-
Inputs have shape ``[batch, factors, dim + 1]``. For ``metric="l1"``, this
|
| 88 |
-
matches the official PHyCLIP implementation's mean distance over factors.
|
| 89 |
-
"""
|
| 90 |
-
if x.dim() != 3 or y.dim() != 3:
|
| 91 |
-
raise ValueError("product_pairwise_dist expects [batch, factors, dim + 1] tensors")
|
| 92 |
-
if x.shape[1] != y.shape[1] or x.shape[2] != y.shape[2]:
|
| 93 |
-
raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
|
| 94 |
-
kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
|
| 95 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 96 |
-
x = x.float()
|
| 97 |
-
y = y.float()
|
| 98 |
-
inner = -x[:, None, :, 0] * y[None, :, :, 0] + torch.einsum("bkd,nkd->bnk", x[..., 1:], y[..., 1:])
|
| 99 |
-
prod = (-kappa.view(1, 1, -1)) * inner
|
| 100 |
-
dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, 1, -1)
|
| 101 |
-
if metric == "l1":
|
| 102 |
-
return dist.mean(dim=-1)
|
| 103 |
-
if metric == "l2":
|
| 104 |
-
return dist.square().mean(dim=-1).sqrt()
|
| 105 |
-
raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def metric_pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 109 |
-
"""Pairwise distance for either a single Lorentz space or a product space."""
|
| 110 |
-
if x.dim() == 3 or y.dim() == 3:
|
| 111 |
-
return product_pairwise_dist(x, y, kappa, metric=product_metric)
|
| 112 |
-
return pairwise_dist(x, y, kappa)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def paired_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1", eps: float = 1e-8) -> Tensor:
|
| 116 |
-
"""Row-wise distance for either a single Lorentz space or a product space."""
|
| 117 |
-
if x.dim() == 3 or y.dim() == 3:
|
| 118 |
-
if x.shape != y.shape:
|
| 119 |
-
raise ValueError("Product paired_dist expects matching tensor shapes")
|
| 120 |
-
kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
|
| 121 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 122 |
-
x = x.float()
|
| 123 |
-
y = y.float()
|
| 124 |
-
inner = -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
|
| 125 |
-
prod = (-kappa.view(1, -1)) * inner
|
| 126 |
-
dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, -1)
|
| 127 |
-
if product_metric == "l1":
|
| 128 |
-
return dist.mean(dim=-1)
|
| 129 |
-
if product_metric == "l2":
|
| 130 |
-
return dist.square().mean(dim=-1).sqrt()
|
| 131 |
-
raise ValueError(f"Unsupported product metric {product_metric!r}; expected 'l1' or 'l2'")
|
| 132 |
-
kappa = kappa.float()
|
| 133 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 134 |
-
prod = (-kappa) * lorentz_inner(x, y)
|
| 135 |
-
prod = prod.clamp_min(1.0 + dist_eps)
|
| 136 |
-
return torch.acosh(prod) / torch.sqrt(kappa)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def radial_distance(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 140 |
-
"""Geodesic distance from the origin.
|
| 141 |
-
|
| 142 |
-
For points on the hyperboloid, the time coordinate satisfies
|
| 143 |
-
``x0 = cosh(sqrt(kappa) * r) / sqrt(kappa)``, so we can recover the radial
|
| 144 |
-
distance via ``r = arcosh(sqrt(kappa) * x0) / sqrt(kappa)``.
|
| 145 |
-
"""
|
| 146 |
-
dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
|
| 147 |
-
x = x.float()
|
| 148 |
-
kappa = kappa.to(dtype=torch.float32).flatten()
|
| 149 |
-
if x.dim() == 2:
|
| 150 |
-
if kappa.numel() != 1:
|
| 151 |
-
raise ValueError("radial_distance expects scalar kappa for non-product embeddings")
|
| 152 |
-
sqrt_k = torch.sqrt(kappa.reshape(()))
|
| 153 |
-
arg = (sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps)
|
| 154 |
-
return torch.acosh(arg) / sqrt_k
|
| 155 |
-
if x.dim() == 3:
|
| 156 |
-
if kappa.numel() == 1:
|
| 157 |
-
kappa = kappa.expand(x.shape[1])
|
| 158 |
-
if kappa.numel() != x.shape[1]:
|
| 159 |
-
raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
|
| 160 |
-
sqrt_k = torch.sqrt(kappa).view(1, -1)
|
| 161 |
-
arg = (sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps)
|
| 162 |
-
dist = torch.acosh(arg) / sqrt_k
|
| 163 |
-
return dist.mean(dim=-1)
|
| 164 |
-
raise ValueError("radial_distance expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def metric_similarity(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 168 |
-
"""Retrieval/classification similarity for single-space and PHyCLIP-style models."""
|
| 169 |
-
if x.dim() == 3 or y.dim() == 3:
|
| 170 |
-
return -product_pairwise_dist(x, y, kappa, metric=product_metric)
|
| 171 |
-
return pairwise_lorentz_inner(x, y)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def half_aperture(general: Tensor, kappa: Tensor, min_radius: float = 0.1, eps: float = 1e-8) -> Tensor:
|
| 175 |
-
"""Cone half-aperture for entailment cone centered at general concept."""
|
| 176 |
-
general = general.float()
|
| 177 |
-
kappa = kappa.float()
|
| 178 |
-
aperture_eps = max(eps, 16.0 * torch.finfo(general.dtype).eps)
|
| 179 |
-
general_norm = torch.linalg.norm(general[:, 1:], dim=-1)
|
| 180 |
-
ratio = (2.0 * min_radius) / (general_norm * torch.sqrt(kappa) + aperture_eps)
|
| 181 |
-
ratio = ratio.clamp(max=1.0 - aperture_eps)
|
| 182 |
-
return torch.asin(ratio)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 186 |
-
"""Exterior angle between specific point and entailment cone at general point."""
|
| 187 |
-
specific = specific.float()
|
| 188 |
-
general = general.float()
|
| 189 |
-
kappa = kappa.float()
|
| 190 |
-
angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
|
| 191 |
-
inner = lorentz_inner(specific, general)
|
| 192 |
-
numerator = specific[:, 0] + kappa * inner * general[:, 0]
|
| 193 |
-
general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
|
| 194 |
-
denom_term = (kappa * inner).pow(2) - 1.0
|
| 195 |
-
denom = general_norm * torch.sqrt(denom_term.clamp_min(angle_eps))
|
| 196 |
-
cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 197 |
-
return torch.acos(cosine)
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
|
| 201 |
-
"""All-pairs exterior angle between specific points and entailment cones at general points."""
|
| 202 |
-
specific = specific.float()
|
| 203 |
-
general = general.float()
|
| 204 |
-
kappa = kappa.to(dtype=torch.float32).flatten()
|
| 205 |
-
if kappa.numel() != 1:
|
| 206 |
-
raise ValueError("pairwise_oxy_angle expects scalar kappa for non-product embeddings")
|
| 207 |
-
kappa_scalar = kappa.reshape(())
|
| 208 |
-
angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
|
| 209 |
-
inner = -specific[:, None, 0] * general[None, :, 0] + torch.einsum("nd,md->nm", specific[:, 1:], general[:, 1:])
|
| 210 |
-
numerator = specific[:, None, 0] + kappa_scalar * inner * general[None, :, 0]
|
| 211 |
-
general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
|
| 212 |
-
denom_term = (kappa_scalar * inner).pow(2) - 1.0
|
| 213 |
-
denom = general_norm[None, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
|
| 214 |
-
cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 215 |
-
return torch.acos(cosine)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
def product_pairwise_oxy_angle(
|
| 219 |
-
specific: Tensor,
|
| 220 |
-
general: Tensor,
|
| 221 |
-
kappa: Tensor,
|
| 222 |
-
metric: str = "l1",
|
| 223 |
-
eps: float = 1e-8,
|
| 224 |
-
) -> Tensor:
|
| 225 |
-
"""All-pairs exterior angle in an l1/l2 product of Lorentz factors."""
|
| 226 |
-
if specific.dim() != 3 or general.dim() != 3:
|
| 227 |
-
raise ValueError("product_pairwise_oxy_angle expects [batch, factors, dim + 1] tensors")
|
| 228 |
-
if specific.shape[1] != general.shape[1] or specific.shape[2] != general.shape[2]:
|
| 229 |
-
raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
|
| 230 |
-
kappa = _product_kappa(kappa, specific.shape[1], specific.device).to(dtype=torch.float32)
|
| 231 |
-
angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
|
| 232 |
-
specific = specific.float()
|
| 233 |
-
general = general.float()
|
| 234 |
-
inner = -specific[:, None, :, 0] * general[None, :, :, 0] + torch.einsum(
|
| 235 |
-
"nkd,mkd->nmk",
|
| 236 |
-
specific[..., 1:],
|
| 237 |
-
general[..., 1:],
|
| 238 |
-
)
|
| 239 |
-
numerator = specific[:, None, :, 0] + (kappa.view(1, 1, -1) * inner) * general[None, :, :, 0]
|
| 240 |
-
general_norm = torch.linalg.norm(general[..., 1:], dim=-1).clamp_min(angle_eps)
|
| 241 |
-
denom_term = (kappa.view(1, 1, -1) * inner).pow(2) - 1.0
|
| 242 |
-
denom = general_norm[None, :, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
|
| 243 |
-
cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 244 |
-
angles = torch.acos(cosine)
|
| 245 |
-
if metric == "l1":
|
| 246 |
-
return angles.mean(dim=-1)
|
| 247 |
-
if metric == "l2":
|
| 248 |
-
return angles.square().mean(dim=-1).sqrt()
|
| 249 |
-
raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def metric_pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 253 |
-
"""All-pairs oxy-angle for either a single Lorentz space or a product space."""
|
| 254 |
-
if specific.dim() == 3 or general.dim() == 3:
|
| 255 |
-
return product_pairwise_oxy_angle(specific, general, kappa, metric=product_metric)
|
| 256 |
-
return pairwise_oxy_angle(specific, general, kappa)
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
def _product_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
|
| 260 |
-
kappa = kappa.to(device=device, dtype=torch.float32).flatten()
|
| 261 |
-
if kappa.numel() == 1:
|
| 262 |
-
return kappa.expand(num_factors)
|
| 263 |
-
if kappa.numel() != num_factors:
|
| 264 |
-
raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
|
| 265 |
-
return kappa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/losses.py
DELETED
|
@@ -1,1400 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
from hyper3_clip.models.lorentz import (
|
| 10 |
-
half_aperture,
|
| 11 |
-
metric_pairwise_dist,
|
| 12 |
-
metric_pairwise_oxy_angle,
|
| 13 |
-
oxy_angle,
|
| 14 |
-
paired_dist,
|
| 15 |
-
radial_distance,
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def contrastive_ce(logits: Tensor, targets: Tensor | None = None, weights: Tensor | None = None) -> Tensor:
|
| 20 |
-
if targets is None:
|
| 21 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 22 |
-
losses = F.cross_entropy(logits, targets, reduction="none")
|
| 23 |
-
return weighted_mean(losses, weights)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def contrastive_sigmoid(
|
| 27 |
-
logits: Tensor,
|
| 28 |
-
targets: Tensor | None = None,
|
| 29 |
-
weights: Tensor | None = None,
|
| 30 |
-
negative_weight: float = 1.0,
|
| 31 |
-
) -> Tensor:
|
| 32 |
-
if targets is None:
|
| 33 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 34 |
-
labels = torch.zeros_like(logits)
|
| 35 |
-
labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
|
| 36 |
-
losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
|
| 37 |
-
if negative_weight != 1.0:
|
| 38 |
-
element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
|
| 39 |
-
losses = losses * element_weights
|
| 40 |
-
losses = losses.mean(dim=1)
|
| 41 |
-
return weighted_mean(losses, weights)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def contrastive_siglip(
|
| 45 |
-
logits: Tensor,
|
| 46 |
-
targets: Tensor | None = None,
|
| 47 |
-
weights: Tensor | None = None,
|
| 48 |
-
negative_weight: float = 1.0,
|
| 49 |
-
) -> Tensor:
|
| 50 |
-
"""SigLIP pairwise sigmoid loss (Zhai et al., ICCV 2023).
|
| 51 |
-
|
| 52 |
-
Uses labels in {+1, -1} with a per-row sum (not mean) over pairs:
|
| 53 |
-
L_i = sum_j softplus(- y_ij * logit_ij)
|
| 54 |
-
"""
|
| 55 |
-
if logits.ndim != 2:
|
| 56 |
-
raise ValueError("contrastive_siglip expects a [batch, classes] logit matrix")
|
| 57 |
-
if targets is None:
|
| 58 |
-
targets = torch.arange(logits.size(0), device=logits.device)
|
| 59 |
-
labels = logits.new_full(logits.shape, -1.0)
|
| 60 |
-
labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
|
| 61 |
-
losses = F.softplus(-(labels * logits))
|
| 62 |
-
if negative_weight != 1.0:
|
| 63 |
-
element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
|
| 64 |
-
losses = losses * element_weights
|
| 65 |
-
row_losses = losses.sum(dim=1)
|
| 66 |
-
return weighted_mean(row_losses, weights)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def weighted_mean(values: Tensor, weights: Tensor | None = None) -> Tensor:
|
| 70 |
-
if weights is None:
|
| 71 |
-
return values.mean()
|
| 72 |
-
weights = weights.to(device=values.device, dtype=values.dtype)
|
| 73 |
-
while weights.dim() < values.dim():
|
| 74 |
-
weights = weights.unsqueeze(-1)
|
| 75 |
-
return (values * weights).sum() / weights.sum().clamp_min(torch.finfo(values.dtype).eps)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def gramian_volume_loss(vectors: Tensor, weights: Tensor | None = None, eps: float = 1e-4) -> Tensor:
|
| 79 |
-
"""GRAM-style volume loss for sets of vectors.
|
| 80 |
-
|
| 81 |
-
``vectors`` is expected to have shape ``[batch, k, dim]``. Each set of k
|
| 82 |
-
vectors is L2-normalized along ``dim``, then we compute the Gramian
|
| 83 |
-
``G = V V^T`` and return ``sqrt(det(G + eps I))`` averaged over the batch.
|
| 84 |
-
"""
|
| 85 |
-
if vectors.ndim != 3:
|
| 86 |
-
raise ValueError("gramian_volume_loss expects a [batch, k, dim] tensor")
|
| 87 |
-
if eps <= 0.0:
|
| 88 |
-
raise ValueError("gramian_volume_loss eps must be positive")
|
| 89 |
-
|
| 90 |
-
vectors = F.normalize(vectors.float(), dim=-1, eps=1e-8)
|
| 91 |
-
gram = vectors @ vectors.transpose(-1, -2)
|
| 92 |
-
k = gram.size(-1)
|
| 93 |
-
gram = gram + eps * torch.eye(k, device=gram.device, dtype=gram.dtype)
|
| 94 |
-
sign, logabsdet = torch.linalg.slogdet(gram)
|
| 95 |
-
volume = torch.exp(0.5 * logabsdet)
|
| 96 |
-
volume = torch.where(sign > 0, volume, volume.new_ones(volume.shape))
|
| 97 |
-
return weighted_mean(volume, weights)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def radius_order_hinge(
|
| 101 |
-
specific: Tensor,
|
| 102 |
-
general: Tensor,
|
| 103 |
-
kappa: Tensor,
|
| 104 |
-
margin: float,
|
| 105 |
-
weights: Tensor | None = None,
|
| 106 |
-
) -> Tensor:
|
| 107 |
-
if specific.shape[0] != general.shape[0]:
|
| 108 |
-
raise ValueError("radius_order_hinge expects matching batch dimensions")
|
| 109 |
-
if margin < 0.0:
|
| 110 |
-
raise ValueError("radius_order_hinge margin must be non-negative")
|
| 111 |
-
specific_radius = radial_distance(specific, kappa)
|
| 112 |
-
general_radius = radial_distance(general, kappa)
|
| 113 |
-
losses = F.relu(float(margin) + general_radius - specific_radius)
|
| 114 |
-
return weighted_mean(losses, weights)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def soft_contrastive_ce(logits: Tensor, target_weights: Tensor, weights: Tensor | None = None) -> Tensor:
|
| 118 |
-
if logits.ndim != 2 or target_weights.ndim != 2:
|
| 119 |
-
raise ValueError("soft_contrastive_ce expects [batch, classes] tensors")
|
| 120 |
-
if logits.shape != target_weights.shape:
|
| 121 |
-
raise ValueError("soft_contrastive_ce requires logits and target_weights to have matching shapes")
|
| 122 |
-
log_probs = F.log_softmax(logits, dim=1)
|
| 123 |
-
losses = -(target_weights.to(dtype=log_probs.dtype) * log_probs).sum(dim=1)
|
| 124 |
-
return weighted_mean(losses, weights)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def beta_cal_loss(
|
| 128 |
-
logits: Tensor,
|
| 129 |
-
*,
|
| 130 |
-
targets: Tensor,
|
| 131 |
-
group_ids: Tensor,
|
| 132 |
-
all_group_ids: Tensor,
|
| 133 |
-
beta: float,
|
| 134 |
-
variant: str,
|
| 135 |
-
weights: Tensor | None = None,
|
| 136 |
-
) -> Tensor:
|
| 137 |
-
if beta < 0.0:
|
| 138 |
-
raise ValueError("beta_cal_loss beta must be non-negative")
|
| 139 |
-
if variant not in {"ce", "bce"}:
|
| 140 |
-
raise ValueError("beta_cal_loss variant must be 'ce' or 'bce'")
|
| 141 |
-
if logits.ndim != 2:
|
| 142 |
-
raise ValueError("beta_cal_loss expects a [batch, classes] logit matrix")
|
| 143 |
-
if targets.shape != (logits.size(0),):
|
| 144 |
-
raise ValueError("beta_cal_loss targets must have shape [batch]")
|
| 145 |
-
if group_ids.shape != (logits.size(0),):
|
| 146 |
-
raise ValueError("beta_cal_loss group_ids must have shape [batch]")
|
| 147 |
-
if all_group_ids.shape != (logits.size(1),):
|
| 148 |
-
raise ValueError("beta_cal_loss all_group_ids must have shape [classes]")
|
| 149 |
-
|
| 150 |
-
same_group = group_ids[:, None] == all_group_ids[None, :]
|
| 151 |
-
same_pair = targets[:, None] == torch.arange(logits.size(1), device=logits.device)[None, :]
|
| 152 |
-
|
| 153 |
-
if variant == "ce":
|
| 154 |
-
target_weights = logits.new_zeros(logits.shape)
|
| 155 |
-
target_weights = torch.where(same_pair, logits.new_ones(()), target_weights)
|
| 156 |
-
target_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), target_weights)
|
| 157 |
-
target_weights = target_weights / target_weights.sum(dim=1, keepdim=True).clamp_min(
|
| 158 |
-
torch.finfo(target_weights.dtype).eps
|
| 159 |
-
)
|
| 160 |
-
return soft_contrastive_ce(logits, target_weights, weights)
|
| 161 |
-
|
| 162 |
-
labels = same_group.to(dtype=logits.dtype)
|
| 163 |
-
element_weights = logits.new_ones(logits.shape)
|
| 164 |
-
element_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), element_weights)
|
| 165 |
-
element_losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") * element_weights
|
| 166 |
-
row_losses = element_losses.mean(dim=1)
|
| 167 |
-
return weighted_mean(row_losses, weights)
|
| 168 |
-
|
| 169 |
-
def compositional_contrastive_loss(
|
| 170 |
-
image_feats: Tensor,
|
| 171 |
-
text_feats: Tensor,
|
| 172 |
-
box_image_feats: Tensor,
|
| 173 |
-
box_text_feats: Tensor,
|
| 174 |
-
kappa: Tensor,
|
| 175 |
-
logit_scale: Tensor,
|
| 176 |
-
all_image_feats: Tensor | None = None,
|
| 177 |
-
all_text_feats: Tensor | None = None,
|
| 178 |
-
targets: Tensor | None = None,
|
| 179 |
-
) -> Tensor:
|
| 180 |
-
scale = logit_scale.exp().clamp(max=100.0)
|
| 181 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 182 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 183 |
-
|
| 184 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
|
| 185 |
-
logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
|
| 186 |
-
logits_bi_t = -metric_pairwise_dist(box_image_feats, all_text_feats, kappa) * scale
|
| 187 |
-
logits_bt_i = -metric_pairwise_dist(box_text_feats, all_image_feats, kappa) * scale
|
| 188 |
-
|
| 189 |
-
return 0.25 * (
|
| 190 |
-
contrastive_ce(logits_i_t, targets)
|
| 191 |
-
+ contrastive_ce(logits_t_i, targets)
|
| 192 |
-
+ contrastive_ce(logits_bi_t, targets)
|
| 193 |
-
+ contrastive_ce(logits_bt_i, targets)
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def multi_part_contrastive_loss(
|
| 198 |
-
image_feats: Tensor,
|
| 199 |
-
text_feats: Tensor,
|
| 200 |
-
part_image_feats: Tensor,
|
| 201 |
-
part_text_feats: Tensor,
|
| 202 |
-
part_mask: Tensor,
|
| 203 |
-
kappa: Tensor,
|
| 204 |
-
logit_scale: Tensor,
|
| 205 |
-
all_image_feats: Tensor | None = None,
|
| 206 |
-
all_text_feats: Tensor | None = None,
|
| 207 |
-
targets: Tensor | None = None,
|
| 208 |
-
) -> Tensor:
|
| 209 |
-
scale = logit_scale.exp().clamp(max=100.0)
|
| 210 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 211 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 212 |
-
if targets is None:
|
| 213 |
-
targets = torch.arange(image_feats.size(0), device=image_feats.device)
|
| 214 |
-
|
| 215 |
-
part_image_flat, part_text_flat, part_targets = _flatten_valid_parts(part_image_feats, part_text_feats, part_mask, targets)
|
| 216 |
-
|
| 217 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
|
| 218 |
-
logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
|
| 219 |
-
logits_pi_t = -metric_pairwise_dist(part_image_flat, all_text_feats, kappa) * scale
|
| 220 |
-
logits_pt_i = -metric_pairwise_dist(part_text_flat, all_image_feats, kappa) * scale
|
| 221 |
-
|
| 222 |
-
return 0.25 * (
|
| 223 |
-
contrastive_ce(logits_i_t, targets)
|
| 224 |
-
+ contrastive_ce(logits_t_i, targets)
|
| 225 |
-
+ contrastive_ce(logits_pi_t, part_targets)
|
| 226 |
-
+ contrastive_ce(logits_pt_i, part_targets)
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def packed_part_contrastive_loss(
|
| 231 |
-
image_feats: Tensor,
|
| 232 |
-
text_feats: Tensor,
|
| 233 |
-
part_image_feats: Tensor,
|
| 234 |
-
part_text_feats: Tensor,
|
| 235 |
-
part_owner: Tensor,
|
| 236 |
-
kappa: Tensor,
|
| 237 |
-
logit_scale: Tensor,
|
| 238 |
-
all_image_feats: Tensor | None = None,
|
| 239 |
-
all_text_feats: Tensor | None = None,
|
| 240 |
-
targets: Tensor | None = None,
|
| 241 |
-
) -> Tensor:
|
| 242 |
-
scale = logit_scale.exp().clamp(max=100.0)
|
| 243 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 244 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 245 |
-
if targets is None:
|
| 246 |
-
targets = torch.arange(image_feats.size(0), device=image_feats.device)
|
| 247 |
-
|
| 248 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
|
| 249 |
-
logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
|
| 250 |
-
global_loss = 0.5 * (contrastive_ce(logits_i_t, targets) + contrastive_ce(logits_t_i, targets))
|
| 251 |
-
|
| 252 |
-
if part_image_feats.numel() == 0:
|
| 253 |
-
return global_loss
|
| 254 |
-
|
| 255 |
-
part_targets = targets[part_owner]
|
| 256 |
-
logits_pi_t = -metric_pairwise_dist(part_image_feats, all_text_feats, kappa) * scale
|
| 257 |
-
logits_pt_i = -metric_pairwise_dist(part_text_feats, all_image_feats, kappa) * scale
|
| 258 |
-
part_loss = 0.5 * (contrastive_ce(logits_pi_t, part_targets) + contrastive_ce(logits_pt_i, part_targets))
|
| 259 |
-
return 0.5 * (global_loss + part_loss)
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
def factor_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor) -> Tensor:
|
| 263 |
-
if specific.dim() != 3:
|
| 264 |
-
return oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 265 |
-
batch_size, num_factors, feature_dim = specific.shape
|
| 266 |
-
kappa = _factor_kappa(kappa, num_factors, specific.device)
|
| 267 |
-
factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
|
| 268 |
-
return oxy_angle(
|
| 269 |
-
specific=specific.reshape(batch_size * num_factors, feature_dim),
|
| 270 |
-
general=general.reshape(batch_size * num_factors, feature_dim),
|
| 271 |
-
kappa=factor_kappa,
|
| 272 |
-
).reshape(batch_size, num_factors)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def factor_half_aperture(general: Tensor, kappa: Tensor) -> Tensor:
|
| 276 |
-
if general.dim() != 3:
|
| 277 |
-
return half_aperture(general=general, kappa=kappa)
|
| 278 |
-
batch_size, num_factors, feature_dim = general.shape
|
| 279 |
-
kappa = _factor_kappa(kappa, num_factors, general.device)
|
| 280 |
-
factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
|
| 281 |
-
return half_aperture(
|
| 282 |
-
general=general.reshape(batch_size * num_factors, feature_dim),
|
| 283 |
-
kappa=factor_kappa,
|
| 284 |
-
).reshape(batch_size, num_factors)
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def _factor_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
|
| 288 |
-
kappa = kappa.to(device=device, dtype=torch.float32).flatten()
|
| 289 |
-
if kappa.numel() == 1:
|
| 290 |
-
return kappa.expand(num_factors)
|
| 291 |
-
if kappa.numel() != num_factors:
|
| 292 |
-
raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
|
| 293 |
-
return kappa
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
def entailment_residual(
|
| 297 |
-
specific: Tensor,
|
| 298 |
-
general: Tensor,
|
| 299 |
-
kappa: Tensor,
|
| 300 |
-
aperture_scale: float,
|
| 301 |
-
) -> Tensor:
|
| 302 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 303 |
-
apertures = factor_half_aperture(general=general, kappa=kappa)
|
| 304 |
-
return torch.clamp(angles - (aperture_scale * apertures), min=0.0).mean()
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
def weighted_entailment_residual(
|
| 308 |
-
specific: Tensor,
|
| 309 |
-
general: Tensor,
|
| 310 |
-
kappa: Tensor,
|
| 311 |
-
aperture_scale: float,
|
| 312 |
-
weights: Tensor | None = None,
|
| 313 |
-
) -> Tensor:
|
| 314 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 315 |
-
apertures = factor_half_aperture(general=general, kappa=kappa)
|
| 316 |
-
residuals = torch.clamp(angles - (aperture_scale * apertures), min=0.0)
|
| 317 |
-
if residuals.dim() == 2:
|
| 318 |
-
residuals = residuals.mean(dim=-1)
|
| 319 |
-
return weighted_mean(residuals, weights)
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
def compositional_entailment_loss(
|
| 323 |
-
image_feats: Tensor,
|
| 324 |
-
text_feats: Tensor,
|
| 325 |
-
box_image_feats: Tensor,
|
| 326 |
-
box_text_feats: Tensor,
|
| 327 |
-
kappa: Tensor,
|
| 328 |
-
inter_aperture_scale: float,
|
| 329 |
-
intra_aperture_scale: float,
|
| 330 |
-
) -> Tensor:
|
| 331 |
-
text_to_image = entailment_residual(
|
| 332 |
-
specific=image_feats,
|
| 333 |
-
general=text_feats,
|
| 334 |
-
kappa=kappa,
|
| 335 |
-
aperture_scale=inter_aperture_scale,
|
| 336 |
-
)
|
| 337 |
-
box_text_to_box_image = entailment_residual(
|
| 338 |
-
specific=box_image_feats,
|
| 339 |
-
general=box_text_feats,
|
| 340 |
-
kappa=kappa,
|
| 341 |
-
aperture_scale=inter_aperture_scale,
|
| 342 |
-
)
|
| 343 |
-
box_image_to_image = entailment_residual(
|
| 344 |
-
specific=image_feats,
|
| 345 |
-
general=box_image_feats,
|
| 346 |
-
kappa=kappa,
|
| 347 |
-
aperture_scale=intra_aperture_scale,
|
| 348 |
-
)
|
| 349 |
-
box_text_to_text = entailment_residual(
|
| 350 |
-
specific=text_feats,
|
| 351 |
-
general=box_text_feats,
|
| 352 |
-
kappa=kappa,
|
| 353 |
-
aperture_scale=intra_aperture_scale,
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
return 0.5 * (text_to_image + box_text_to_box_image + box_image_to_image + box_text_to_text)
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
def multi_part_entailment_loss(
|
| 360 |
-
image_feats: Tensor,
|
| 361 |
-
text_feats: Tensor,
|
| 362 |
-
part_image_feats: Tensor,
|
| 363 |
-
part_text_feats: Tensor,
|
| 364 |
-
part_mask: Tensor,
|
| 365 |
-
kappa: Tensor,
|
| 366 |
-
inter_aperture_scale: float,
|
| 367 |
-
intra_aperture_scale: float,
|
| 368 |
-
) -> Tensor:
|
| 369 |
-
part_image_flat = part_image_feats[part_mask]
|
| 370 |
-
part_text_flat = part_text_feats[part_mask]
|
| 371 |
-
image_for_parts = image_feats[:, None, :].expand_as(part_image_feats)[part_mask]
|
| 372 |
-
text_for_parts = text_feats[:, None, :].expand_as(part_text_feats)[part_mask]
|
| 373 |
-
|
| 374 |
-
text_to_image = entailment_residual(
|
| 375 |
-
specific=image_feats,
|
| 376 |
-
general=text_feats,
|
| 377 |
-
kappa=kappa,
|
| 378 |
-
aperture_scale=inter_aperture_scale,
|
| 379 |
-
)
|
| 380 |
-
part_text_to_part_image = entailment_residual(
|
| 381 |
-
specific=part_image_flat,
|
| 382 |
-
general=part_text_flat,
|
| 383 |
-
kappa=kappa,
|
| 384 |
-
aperture_scale=inter_aperture_scale,
|
| 385 |
-
)
|
| 386 |
-
part_image_to_image = entailment_residual(
|
| 387 |
-
specific=image_for_parts,
|
| 388 |
-
general=part_image_flat,
|
| 389 |
-
kappa=kappa,
|
| 390 |
-
aperture_scale=intra_aperture_scale,
|
| 391 |
-
)
|
| 392 |
-
part_text_to_text = entailment_residual(
|
| 393 |
-
specific=text_for_parts,
|
| 394 |
-
general=part_text_flat,
|
| 395 |
-
kappa=kappa,
|
| 396 |
-
aperture_scale=intra_aperture_scale,
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
def packed_part_entailment_loss(
|
| 403 |
-
image_feats: Tensor,
|
| 404 |
-
text_feats: Tensor,
|
| 405 |
-
part_image_feats: Tensor,
|
| 406 |
-
part_text_feats: Tensor,
|
| 407 |
-
part_owner: Tensor,
|
| 408 |
-
kappa: Tensor,
|
| 409 |
-
inter_aperture_scale: float,
|
| 410 |
-
intra_aperture_scale: float,
|
| 411 |
-
) -> Tensor:
|
| 412 |
-
text_to_image = entailment_residual(
|
| 413 |
-
specific=image_feats,
|
| 414 |
-
general=text_feats,
|
| 415 |
-
kappa=kappa,
|
| 416 |
-
aperture_scale=inter_aperture_scale,
|
| 417 |
-
)
|
| 418 |
-
if part_image_feats.numel() == 0:
|
| 419 |
-
return text_to_image
|
| 420 |
-
|
| 421 |
-
image_for_parts = image_feats[part_owner]
|
| 422 |
-
text_for_parts = text_feats[part_owner]
|
| 423 |
-
part_text_to_part_image = entailment_residual(
|
| 424 |
-
specific=part_image_feats,
|
| 425 |
-
general=part_text_feats,
|
| 426 |
-
kappa=kappa,
|
| 427 |
-
aperture_scale=inter_aperture_scale,
|
| 428 |
-
)
|
| 429 |
-
part_image_to_image = entailment_residual(
|
| 430 |
-
specific=image_for_parts,
|
| 431 |
-
general=part_image_feats,
|
| 432 |
-
kappa=kappa,
|
| 433 |
-
aperture_scale=intra_aperture_scale,
|
| 434 |
-
)
|
| 435 |
-
part_text_to_text = entailment_residual(
|
| 436 |
-
specific=text_for_parts,
|
| 437 |
-
general=part_text_feats,
|
| 438 |
-
kappa=kappa,
|
| 439 |
-
aperture_scale=intra_aperture_scale,
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
def uncha_contrastive_losses(
|
| 446 |
-
image_feats: Tensor,
|
| 447 |
-
text_feats: Tensor,
|
| 448 |
-
part_image_flat: Tensor,
|
| 449 |
-
part_text_flat: Tensor,
|
| 450 |
-
image_for_parts: Tensor,
|
| 451 |
-
text_for_parts: Tensor,
|
| 452 |
-
kappa: Tensor,
|
| 453 |
-
global_logit_scale: Tensor,
|
| 454 |
-
local_logit_scale: Tensor,
|
| 455 |
-
global_local_logit_scale: Tensor,
|
| 456 |
-
image_euc_feats: Tensor | None = None,
|
| 457 |
-
text_euc_feats: Tensor | None = None,
|
| 458 |
-
part_image_euc_flat: Tensor | None = None,
|
| 459 |
-
part_text_euc_flat: Tensor | None = None,
|
| 460 |
-
image_for_parts_euc: Tensor | None = None,
|
| 461 |
-
text_for_parts_euc: Tensor | None = None,
|
| 462 |
-
all_image_feats: Tensor | None = None,
|
| 463 |
-
all_text_feats: Tensor | None = None,
|
| 464 |
-
all_part_image_feats: Tensor | None = None,
|
| 465 |
-
all_part_text_feats: Tensor | None = None,
|
| 466 |
-
all_image_for_parts: Tensor | None = None,
|
| 467 |
-
all_text_for_parts: Tensor | None = None,
|
| 468 |
-
all_image_euc_feats: Tensor | None = None,
|
| 469 |
-
all_text_euc_feats: Tensor | None = None,
|
| 470 |
-
all_part_image_euc_feats: Tensor | None = None,
|
| 471 |
-
all_part_text_euc_feats: Tensor | None = None,
|
| 472 |
-
all_image_for_parts_euc: Tensor | None = None,
|
| 473 |
-
all_text_for_parts_euc: Tensor | None = None,
|
| 474 |
-
global_targets: Tensor | None = None,
|
| 475 |
-
part_targets: Tensor | None = None,
|
| 476 |
-
part_weights: Tensor | None = None,
|
| 477 |
-
product_metric: str = "l1",
|
| 478 |
-
loss_type: str = "ce",
|
| 479 |
-
contrastive_global_weight: float = 1.0,
|
| 480 |
-
contrastive_local_weight: float = 1.0,
|
| 481 |
-
contrastive_global_local_weight: float = 1.0,
|
| 482 |
-
beta_cal_beta: float = 0.0,
|
| 483 |
-
beta_cal_variant: str = "ce",
|
| 484 |
-
beta_cal_weight: float = 0.0,
|
| 485 |
-
part_group_ids: Tensor | None = None,
|
| 486 |
-
all_part_group_ids: Tensor | None = None,
|
| 487 |
-
global_logit_bias: Tensor | None = None,
|
| 488 |
-
local_logit_bias: Tensor | None = None,
|
| 489 |
-
global_local_logit_bias: Tensor | None = None,
|
| 490 |
-
sigmoid_negative_weight: float = 1.0,
|
| 491 |
-
global_local_mode: str = "repeat",
|
| 492 |
-
global_local_metric: str = "distance",
|
| 493 |
-
global_local_angle_aux_weight: float = 0.0,
|
| 494 |
-
global_local_angle_aux_mode: str = "contrastive",
|
| 495 |
-
global_local_angle_aux_scale: float = 5.5,
|
| 496 |
-
global_local_angle_aux_aperture_scale: float = 1.0,
|
| 497 |
-
) -> dict[str, Tensor]:
|
| 498 |
-
if loss_type not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
|
| 499 |
-
raise ValueError(
|
| 500 |
-
f"Unsupported contrastive loss {loss_type!r}; expected 'ce', 'sigmoid', 'siglip', or 'siglip_metric'"
|
| 501 |
-
)
|
| 502 |
-
if global_local_mode not in {"repeat", "inbatch"}:
|
| 503 |
-
raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
|
| 504 |
-
if global_local_metric not in {"distance", "angle"}:
|
| 505 |
-
raise ValueError("global_local_metric must be 'distance' or 'angle'")
|
| 506 |
-
if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
|
| 507 |
-
raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
|
| 508 |
-
if global_local_angle_aux_weight < 0.0:
|
| 509 |
-
raise ValueError("global_local_angle_aux_weight must be non-negative")
|
| 510 |
-
if global_local_angle_aux_scale <= 0.0:
|
| 511 |
-
raise ValueError("global_local_angle_aux_scale must be positive")
|
| 512 |
-
if global_local_angle_aux_aperture_scale <= 0.0:
|
| 513 |
-
raise ValueError("global_local_angle_aux_aperture_scale must be positive")
|
| 514 |
-
all_image_feats = image_feats if all_image_feats is None else all_image_feats
|
| 515 |
-
all_text_feats = text_feats if all_text_feats is None else all_text_feats
|
| 516 |
-
all_part_image_feats = part_image_flat if all_part_image_feats is None else all_part_image_feats
|
| 517 |
-
all_part_text_feats = part_text_flat if all_part_text_feats is None else all_part_text_feats
|
| 518 |
-
all_image_for_parts = image_for_parts if all_image_for_parts is None else all_image_for_parts
|
| 519 |
-
all_text_for_parts = text_for_parts if all_text_for_parts is None else all_text_for_parts
|
| 520 |
-
if global_targets is None:
|
| 521 |
-
global_targets = torch.arange(image_feats.size(0), device=image_feats.device)
|
| 522 |
-
if part_targets is None:
|
| 523 |
-
part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device)
|
| 524 |
-
|
| 525 |
-
global_scale = global_logit_scale.exp().clamp(max=100.0)
|
| 526 |
-
local_scale = local_logit_scale.exp().clamp(max=100.0)
|
| 527 |
-
global_local_scale = global_local_logit_scale.exp().clamp(max=100.0)
|
| 528 |
-
|
| 529 |
-
if loss_type == "siglip":
|
| 530 |
-
if image_euc_feats is None or text_euc_feats is None:
|
| 531 |
-
raise ValueError("siglip contrastive requires image_euc_feats and text_euc_feats")
|
| 532 |
-
if image_feats.dim() != 2 or text_feats.dim() != 2:
|
| 533 |
-
raise ValueError("siglip contrastive is only supported for non-product features")
|
| 534 |
-
all_image_euc_feats = image_euc_feats if all_image_euc_feats is None else all_image_euc_feats
|
| 535 |
-
all_text_euc_feats = text_euc_feats if all_text_euc_feats is None else all_text_euc_feats
|
| 536 |
-
zimg = F.normalize(image_euc_feats.float(), dim=-1)
|
| 537 |
-
ztxt = F.normalize(text_euc_feats.float(), dim=-1)
|
| 538 |
-
zimg_all = F.normalize(all_image_euc_feats.float(), dim=-1)
|
| 539 |
-
ztxt_all = F.normalize(all_text_euc_feats.float(), dim=-1)
|
| 540 |
-
image_logits = (zimg @ ztxt_all.T) * global_scale
|
| 541 |
-
text_logits = (ztxt @ zimg_all.T) * global_scale
|
| 542 |
-
else:
|
| 543 |
-
image_logits = -metric_pairwise_dist(image_feats, all_text_feats, kappa, product_metric=product_metric) * global_scale
|
| 544 |
-
text_logits = -metric_pairwise_dist(text_feats, all_image_feats, kappa, product_metric=product_metric) * global_scale
|
| 545 |
-
|
| 546 |
-
if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
|
| 547 |
-
bias = image_logits.new_zeros(()) if global_logit_bias is None else global_logit_bias.to(image_logits.device)
|
| 548 |
-
image_logits = image_logits + bias
|
| 549 |
-
text_logits = text_logits + bias
|
| 550 |
-
global_contrastive = 0.5 * (
|
| 551 |
-
_contrastive_loss(image_logits, global_targets, None, loss_type, sigmoid_negative_weight)
|
| 552 |
-
+ _contrastive_loss(text_logits, global_targets, None, loss_type, sigmoid_negative_weight)
|
| 553 |
-
)
|
| 554 |
-
|
| 555 |
-
if part_image_flat.numel() == 0:
|
| 556 |
-
zero = image_feats.new_zeros(())
|
| 557 |
-
contrastive = contrastive_global_weight * global_contrastive
|
| 558 |
-
return {
|
| 559 |
-
"contrastive_loss": contrastive,
|
| 560 |
-
"global_contrastive_loss": global_contrastive,
|
| 561 |
-
"local_contrastive_loss": zero,
|
| 562 |
-
"global_local_contrastive_loss": zero,
|
| 563 |
-
"global_local_angle_aux_loss": zero,
|
| 564 |
-
"beta_cal_loss": zero,
|
| 565 |
-
}
|
| 566 |
-
|
| 567 |
-
if loss_type == "siglip":
|
| 568 |
-
if part_image_euc_flat is None or part_text_euc_flat is None:
|
| 569 |
-
raise ValueError("siglip contrastive requires part_image_euc_flat and part_text_euc_flat when parts exist")
|
| 570 |
-
all_part_image_euc_feats = part_image_euc_flat if all_part_image_euc_feats is None else all_part_image_euc_feats
|
| 571 |
-
all_part_text_euc_feats = part_text_euc_flat if all_part_text_euc_feats is None else all_part_text_euc_feats
|
| 572 |
-
zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
|
| 573 |
-
zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
|
| 574 |
-
zpi_all = F.normalize(all_part_image_euc_feats.float(), dim=-1)
|
| 575 |
-
zpt_all = F.normalize(all_part_text_euc_feats.float(), dim=-1)
|
| 576 |
-
part_image_logits = (zpi @ zpt_all.T) * local_scale
|
| 577 |
-
part_text_logits = (zpt @ zpi_all.T) * local_scale
|
| 578 |
-
else:
|
| 579 |
-
part_image_logits = -metric_pairwise_dist(part_image_flat, all_part_text_feats, kappa, product_metric=product_metric) * local_scale
|
| 580 |
-
part_text_logits = -metric_pairwise_dist(part_text_flat, all_part_image_feats, kappa, product_metric=product_metric) * local_scale
|
| 581 |
-
|
| 582 |
-
if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
|
| 583 |
-
bias = part_image_logits.new_zeros(()) if local_logit_bias is None else local_logit_bias.to(part_image_logits.device)
|
| 584 |
-
part_image_logits = part_image_logits + bias
|
| 585 |
-
part_text_logits = part_text_logits + bias
|
| 586 |
-
local_contrastive = 0.5 * (
|
| 587 |
-
_contrastive_loss(part_image_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 588 |
-
+ _contrastive_loss(part_text_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 589 |
-
)
|
| 590 |
-
|
| 591 |
-
global_local_contrastive = image_feats.new_zeros(())
|
| 592 |
-
global_local_angle_aux = image_feats.new_zeros(())
|
| 593 |
-
if contrastive_global_local_weight != 0.0:
|
| 594 |
-
if global_local_mode == "inbatch":
|
| 595 |
-
if part_group_ids is None:
|
| 596 |
-
raise ValueError("inbatch global-local contrastive requires part_group_ids to be provided")
|
| 597 |
-
global_local_targets = part_group_ids
|
| 598 |
-
all_text_for_global_local = all_text_feats
|
| 599 |
-
all_image_for_global_local = all_image_feats
|
| 600 |
-
all_text_for_global_local_euc = all_text_euc_feats
|
| 601 |
-
all_image_for_global_local_euc = all_image_euc_feats
|
| 602 |
-
else:
|
| 603 |
-
global_local_targets = part_targets
|
| 604 |
-
all_text_for_global_local = all_text_for_parts
|
| 605 |
-
all_image_for_global_local = all_image_for_parts
|
| 606 |
-
all_text_for_global_local_euc = all_text_for_parts_euc
|
| 607 |
-
all_image_for_global_local_euc = all_image_for_parts_euc
|
| 608 |
-
|
| 609 |
-
image_uncertainty = embedding_uncertainty(part_image_flat).detach()
|
| 610 |
-
text_uncertainty = embedding_uncertainty(part_text_flat).detach()
|
| 611 |
-
image_temp = torch.exp(-0.5 * image_uncertainty).clamp(min=0.1, max=10.0)
|
| 612 |
-
text_temp = torch.exp(-0.5 * text_uncertainty).clamp(min=0.1, max=10.0)
|
| 613 |
-
|
| 614 |
-
if loss_type == "siglip":
|
| 615 |
-
if part_image_euc_flat is None or part_text_euc_flat is None:
|
| 616 |
-
raise ValueError("siglip global-local contrastive requires part_image_euc_flat/part_text_euc_flat")
|
| 617 |
-
if all_text_for_global_local_euc is None or all_image_for_global_local_euc is None:
|
| 618 |
-
raise ValueError("siglip global-local contrastive requires all_image_euc_feats/all_text_euc_feats")
|
| 619 |
-
zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
|
| 620 |
-
zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
|
| 621 |
-
zimg_all = F.normalize(all_image_for_global_local_euc.float(), dim=-1)
|
| 622 |
-
ztxt_all = F.normalize(all_text_for_global_local_euc.float(), dim=-1)
|
| 623 |
-
part_image_to_whole_text = (zpi @ ztxt_all.T) * image_temp[:, None] * global_local_scale
|
| 624 |
-
part_text_to_whole_image = (zpt @ zimg_all.T) * text_temp[:, None] * global_local_scale
|
| 625 |
-
else:
|
| 626 |
-
if global_local_metric == "angle":
|
| 627 |
-
part_image_to_whole_text = -metric_pairwise_oxy_angle(
|
| 628 |
-
part_image_flat,
|
| 629 |
-
all_text_for_global_local,
|
| 630 |
-
kappa,
|
| 631 |
-
product_metric=product_metric,
|
| 632 |
-
)
|
| 633 |
-
part_text_to_whole_image = -metric_pairwise_oxy_angle(
|
| 634 |
-
part_text_flat,
|
| 635 |
-
all_image_for_global_local,
|
| 636 |
-
kappa,
|
| 637 |
-
product_metric=product_metric,
|
| 638 |
-
)
|
| 639 |
-
else:
|
| 640 |
-
part_image_to_whole_text = -metric_pairwise_dist(
|
| 641 |
-
part_image_flat, all_text_for_global_local, kappa, product_metric=product_metric
|
| 642 |
-
)
|
| 643 |
-
part_text_to_whole_image = -metric_pairwise_dist(
|
| 644 |
-
part_text_flat, all_image_for_global_local, kappa, product_metric=product_metric
|
| 645 |
-
)
|
| 646 |
-
part_image_to_whole_text = part_image_to_whole_text * image_temp[:, None] * global_local_scale
|
| 647 |
-
part_text_to_whole_image = part_text_to_whole_image * text_temp[:, None] * global_local_scale
|
| 648 |
-
|
| 649 |
-
if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
|
| 650 |
-
bias = (
|
| 651 |
-
part_image_to_whole_text.new_zeros(())
|
| 652 |
-
if global_local_logit_bias is None
|
| 653 |
-
else global_local_logit_bias.to(part_image_to_whole_text.device)
|
| 654 |
-
)
|
| 655 |
-
part_image_to_whole_text = part_image_to_whole_text + bias
|
| 656 |
-
part_text_to_whole_image = part_text_to_whole_image + bias
|
| 657 |
-
|
| 658 |
-
global_local_contrastive = 0.5 * (
|
| 659 |
-
_contrastive_loss(part_image_to_whole_text, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 660 |
-
+ _contrastive_loss(part_text_to_whole_image, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
|
| 661 |
-
)
|
| 662 |
-
|
| 663 |
-
if global_local_angle_aux_weight > 0.0:
|
| 664 |
-
if global_local_angle_aux_mode == "positive_hinge":
|
| 665 |
-
positive_text = all_text_for_global_local.index_select(0, global_local_targets)
|
| 666 |
-
positive_image = all_image_for_global_local.index_select(0, global_local_targets)
|
| 667 |
-
global_local_angle_aux = 0.5 * (
|
| 668 |
-
weighted_entailment_residual(
|
| 669 |
-
specific=part_image_flat,
|
| 670 |
-
general=positive_text,
|
| 671 |
-
kappa=kappa,
|
| 672 |
-
aperture_scale=global_local_angle_aux_aperture_scale,
|
| 673 |
-
weights=part_weights,
|
| 674 |
-
)
|
| 675 |
-
+ weighted_entailment_residual(
|
| 676 |
-
specific=part_text_flat,
|
| 677 |
-
general=positive_image,
|
| 678 |
-
kappa=kappa,
|
| 679 |
-
aperture_scale=global_local_angle_aux_aperture_scale,
|
| 680 |
-
weights=part_weights,
|
| 681 |
-
)
|
| 682 |
-
)
|
| 683 |
-
elif loss_type != "siglip":
|
| 684 |
-
angle_scale = part_image_flat.new_tensor(float(global_local_angle_aux_scale))
|
| 685 |
-
part_image_to_whole_text_angle = -metric_pairwise_oxy_angle(
|
| 686 |
-
part_image_flat,
|
| 687 |
-
all_text_for_global_local,
|
| 688 |
-
kappa,
|
| 689 |
-
product_metric=product_metric,
|
| 690 |
-
) * image_temp[:, None] * angle_scale
|
| 691 |
-
part_text_to_whole_image_angle = -metric_pairwise_oxy_angle(
|
| 692 |
-
part_text_flat,
|
| 693 |
-
all_image_for_global_local,
|
| 694 |
-
kappa,
|
| 695 |
-
product_metric=product_metric,
|
| 696 |
-
) * text_temp[:, None] * angle_scale
|
| 697 |
-
if loss_type in {"sigmoid", "siglip_metric"}:
|
| 698 |
-
bias = (
|
| 699 |
-
part_image_to_whole_text_angle.new_zeros(())
|
| 700 |
-
if global_local_logit_bias is None
|
| 701 |
-
else global_local_logit_bias.to(part_image_to_whole_text_angle.device)
|
| 702 |
-
)
|
| 703 |
-
part_image_to_whole_text_angle = part_image_to_whole_text_angle + bias
|
| 704 |
-
part_text_to_whole_image_angle = part_text_to_whole_image_angle + bias
|
| 705 |
-
global_local_angle_aux = 0.5 * (
|
| 706 |
-
_contrastive_loss(
|
| 707 |
-
part_image_to_whole_text_angle,
|
| 708 |
-
global_local_targets,
|
| 709 |
-
part_weights,
|
| 710 |
-
loss_type,
|
| 711 |
-
sigmoid_negative_weight,
|
| 712 |
-
)
|
| 713 |
-
+ _contrastive_loss(
|
| 714 |
-
part_text_to_whole_image_angle,
|
| 715 |
-
global_local_targets,
|
| 716 |
-
part_weights,
|
| 717 |
-
loss_type,
|
| 718 |
-
sigmoid_negative_weight,
|
| 719 |
-
)
|
| 720 |
-
)
|
| 721 |
-
|
| 722 |
-
beta_cal = image_feats.new_zeros(())
|
| 723 |
-
if beta_cal_weight > 0.0 and beta_cal_beta > 0.0:
|
| 724 |
-
if part_group_ids is None or all_part_group_ids is None:
|
| 725 |
-
raise ValueError("beta_cal requires part_group_ids and all_part_group_ids to be provided")
|
| 726 |
-
beta_cal = 0.5 * (
|
| 727 |
-
beta_cal_loss(
|
| 728 |
-
part_image_logits,
|
| 729 |
-
targets=part_targets,
|
| 730 |
-
group_ids=part_group_ids,
|
| 731 |
-
all_group_ids=all_part_group_ids,
|
| 732 |
-
beta=beta_cal_beta,
|
| 733 |
-
variant=beta_cal_variant,
|
| 734 |
-
weights=part_weights,
|
| 735 |
-
)
|
| 736 |
-
+ beta_cal_loss(
|
| 737 |
-
part_text_logits,
|
| 738 |
-
targets=part_targets,
|
| 739 |
-
group_ids=part_group_ids,
|
| 740 |
-
all_group_ids=all_part_group_ids,
|
| 741 |
-
beta=beta_cal_beta,
|
| 742 |
-
variant=beta_cal_variant,
|
| 743 |
-
weights=part_weights,
|
| 744 |
-
)
|
| 745 |
-
)
|
| 746 |
-
|
| 747 |
-
contrastive = (
|
| 748 |
-
contrastive_global_weight * global_contrastive
|
| 749 |
-
+ contrastive_local_weight * local_contrastive
|
| 750 |
-
+ contrastive_global_local_weight * global_local_contrastive
|
| 751 |
-
+ global_local_angle_aux_weight * global_local_angle_aux
|
| 752 |
-
+ beta_cal_weight * beta_cal
|
| 753 |
-
)
|
| 754 |
-
return {
|
| 755 |
-
"contrastive_loss": contrastive,
|
| 756 |
-
"global_contrastive_loss": global_contrastive,
|
| 757 |
-
"local_contrastive_loss": local_contrastive,
|
| 758 |
-
"global_local_contrastive_loss": global_local_contrastive,
|
| 759 |
-
"global_local_angle_aux_loss": global_local_angle_aux,
|
| 760 |
-
"beta_cal_loss": beta_cal,
|
| 761 |
-
}
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
def _contrastive_loss(
|
| 765 |
-
logits: Tensor,
|
| 766 |
-
targets: Tensor,
|
| 767 |
-
weights: Tensor | None,
|
| 768 |
-
loss_type: str,
|
| 769 |
-
sigmoid_negative_weight: float,
|
| 770 |
-
) -> Tensor:
|
| 771 |
-
if loss_type == "ce":
|
| 772 |
-
return contrastive_ce(logits, targets, weights)
|
| 773 |
-
if loss_type == "sigmoid":
|
| 774 |
-
return contrastive_sigmoid(logits, targets, weights, negative_weight=sigmoid_negative_weight)
|
| 775 |
-
if loss_type in {"siglip", "siglip_metric"}:
|
| 776 |
-
return contrastive_siglip(logits, targets, weights, negative_weight=sigmoid_negative_weight)
|
| 777 |
-
raise ValueError(f"Unsupported contrastive loss {loss_type!r}")
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
def uncha_entailment_losses(
|
| 781 |
-
image_feats: Tensor,
|
| 782 |
-
text_feats: Tensor,
|
| 783 |
-
part_image_flat: Tensor,
|
| 784 |
-
part_text_flat: Tensor,
|
| 785 |
-
image_for_parts: Tensor,
|
| 786 |
-
text_for_parts: Tensor,
|
| 787 |
-
kappa: Tensor,
|
| 788 |
-
inter_aperture_scale: float,
|
| 789 |
-
intra_aperture_scale: float,
|
| 790 |
-
piecewise_factor: float = 0.1,
|
| 791 |
-
calibration_alpha: float = 10.0,
|
| 792 |
-
stop_grad_calibration: bool = True,
|
| 793 |
-
geometry: str = "lorentz",
|
| 794 |
-
part_weights: Tensor | None = None,
|
| 795 |
-
) -> dict[str, Tensor]:
|
| 796 |
-
text_image = piecewise_entailment_residual(
|
| 797 |
-
specific=image_feats,
|
| 798 |
-
general=text_feats,
|
| 799 |
-
kappa=kappa,
|
| 800 |
-
aperture_scale=inter_aperture_scale,
|
| 801 |
-
factor=piecewise_factor,
|
| 802 |
-
geometry=geometry,
|
| 803 |
-
)
|
| 804 |
-
text_image_entailment = 0.5 * text_image.mean()
|
| 805 |
-
|
| 806 |
-
if part_image_flat.numel() == 0:
|
| 807 |
-
zero = image_feats.new_zeros(())
|
| 808 |
-
return {
|
| 809 |
-
"entailment_loss": text_image_entailment,
|
| 810 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 811 |
-
"part_text_image_entailment_loss": zero,
|
| 812 |
-
"cross_image_entailment_loss": zero,
|
| 813 |
-
"cross_text_entailment_loss": zero,
|
| 814 |
-
"cross_image_calibration_loss": zero,
|
| 815 |
-
"cross_text_calibration_loss": zero,
|
| 816 |
-
}
|
| 817 |
-
|
| 818 |
-
part_text_image = piecewise_entailment_residual(
|
| 819 |
-
specific=part_image_flat,
|
| 820 |
-
general=part_text_flat,
|
| 821 |
-
kappa=kappa,
|
| 822 |
-
aperture_scale=inter_aperture_scale,
|
| 823 |
-
factor=piecewise_factor,
|
| 824 |
-
geometry=geometry,
|
| 825 |
-
)
|
| 826 |
-
cross_image = piecewise_entailment_residual(
|
| 827 |
-
specific=image_for_parts,
|
| 828 |
-
general=part_image_flat,
|
| 829 |
-
kappa=kappa,
|
| 830 |
-
aperture_scale=intra_aperture_scale,
|
| 831 |
-
factor=piecewise_factor,
|
| 832 |
-
geometry=geometry,
|
| 833 |
-
)
|
| 834 |
-
cross_text = piecewise_entailment_residual(
|
| 835 |
-
specific=text_for_parts,
|
| 836 |
-
general=part_text_flat,
|
| 837 |
-
kappa=kappa,
|
| 838 |
-
aperture_scale=intra_aperture_scale,
|
| 839 |
-
factor=piecewise_factor,
|
| 840 |
-
geometry=geometry,
|
| 841 |
-
)
|
| 842 |
-
|
| 843 |
-
part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
|
| 844 |
-
cross_image_entailment, cross_image_calibration = uncertainty_calibrated_entailment_loss(
|
| 845 |
-
cross_image,
|
| 846 |
-
embedding_uncertainty(part_image_flat),
|
| 847 |
-
alpha=calibration_alpha,
|
| 848 |
-
stop_grad=stop_grad_calibration,
|
| 849 |
-
weights=part_weights,
|
| 850 |
-
)
|
| 851 |
-
cross_text_entailment, cross_text_calibration = uncertainty_calibrated_entailment_loss(
|
| 852 |
-
cross_text,
|
| 853 |
-
embedding_uncertainty(part_text_flat),
|
| 854 |
-
alpha=calibration_alpha,
|
| 855 |
-
stop_grad=stop_grad_calibration,
|
| 856 |
-
weights=part_weights,
|
| 857 |
-
)
|
| 858 |
-
|
| 859 |
-
entailment = (
|
| 860 |
-
text_image_entailment
|
| 861 |
-
+ part_text_image_entailment
|
| 862 |
-
+ 0.5 * (cross_image_entailment + cross_text_entailment)
|
| 863 |
-
+ cross_image_calibration
|
| 864 |
-
+ cross_text_calibration
|
| 865 |
-
)
|
| 866 |
-
return {
|
| 867 |
-
"entailment_loss": entailment,
|
| 868 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 869 |
-
"part_text_image_entailment_loss": part_text_image_entailment,
|
| 870 |
-
"cross_image_entailment_loss": cross_image_entailment,
|
| 871 |
-
"cross_text_entailment_loss": cross_text_entailment,
|
| 872 |
-
"cross_image_calibration_loss": cross_image_calibration,
|
| 873 |
-
"cross_text_calibration_loss": cross_text_calibration,
|
| 874 |
-
}
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
def uncha_argent_entailment_losses(
|
| 878 |
-
image_feats: Tensor,
|
| 879 |
-
text_feats: Tensor,
|
| 880 |
-
part_image_flat: Tensor,
|
| 881 |
-
part_text_flat: Tensor,
|
| 882 |
-
image_for_parts: Tensor,
|
| 883 |
-
text_for_parts: Tensor,
|
| 884 |
-
kappa: Tensor,
|
| 885 |
-
beta: float = 1.0,
|
| 886 |
-
part_weights: Tensor | None = None,
|
| 887 |
-
product_metric: str = "l1",
|
| 888 |
-
aggregation: str = "uncha",
|
| 889 |
-
) -> dict[str, Tensor]:
|
| 890 |
-
if aggregation not in {"uncha", "equal"}:
|
| 891 |
-
raise ValueError("aggregation must be 'uncha' or 'equal'")
|
| 892 |
-
text_image = argent_adaptive_entailment_residual(
|
| 893 |
-
specific=image_feats,
|
| 894 |
-
general=text_feats,
|
| 895 |
-
kappa=kappa,
|
| 896 |
-
adaptive_weight=False,
|
| 897 |
-
beta=beta,
|
| 898 |
-
product_metric=product_metric,
|
| 899 |
-
)
|
| 900 |
-
text_image_entailment = 0.5 * text_image.mean()
|
| 901 |
-
|
| 902 |
-
if part_image_flat.numel() == 0:
|
| 903 |
-
zero = image_feats.new_zeros(())
|
| 904 |
-
norm_regularization = argent_norm_regularization_loss(image_feats, text_feats)
|
| 905 |
-
return {
|
| 906 |
-
"entailment_loss": text_image_entailment,
|
| 907 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 908 |
-
"part_text_image_entailment_loss": zero,
|
| 909 |
-
"cross_image_entailment_loss": zero,
|
| 910 |
-
"cross_text_entailment_loss": zero,
|
| 911 |
-
"cross_image_calibration_loss": zero,
|
| 912 |
-
"cross_text_calibration_loss": zero,
|
| 913 |
-
"norm_regularization_loss": norm_regularization,
|
| 914 |
-
}
|
| 915 |
-
|
| 916 |
-
part_text_image = argent_adaptive_entailment_residual(
|
| 917 |
-
specific=part_image_flat,
|
| 918 |
-
general=part_text_flat,
|
| 919 |
-
kappa=kappa,
|
| 920 |
-
adaptive_weight=False,
|
| 921 |
-
beta=beta,
|
| 922 |
-
product_metric=product_metric,
|
| 923 |
-
)
|
| 924 |
-
cross_image = argent_adaptive_entailment_residual(
|
| 925 |
-
specific=image_for_parts,
|
| 926 |
-
general=part_image_flat,
|
| 927 |
-
kappa=kappa,
|
| 928 |
-
adaptive_weight=True,
|
| 929 |
-
beta=beta,
|
| 930 |
-
product_metric=product_metric,
|
| 931 |
-
)
|
| 932 |
-
cross_text = argent_adaptive_entailment_residual(
|
| 933 |
-
specific=text_for_parts,
|
| 934 |
-
general=part_text_flat,
|
| 935 |
-
kappa=kappa,
|
| 936 |
-
adaptive_weight=True,
|
| 937 |
-
beta=beta,
|
| 938 |
-
product_metric=product_metric,
|
| 939 |
-
)
|
| 940 |
-
|
| 941 |
-
part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
|
| 942 |
-
cross_image_entailment = 0.5 * weighted_mean(cross_image, part_weights)
|
| 943 |
-
cross_text_entailment = 0.5 * weighted_mean(cross_text, part_weights)
|
| 944 |
-
norm_regularization = argent_norm_regularization_loss(image_feats, text_feats, part_image_flat, part_text_flat)
|
| 945 |
-
if aggregation == "equal":
|
| 946 |
-
entailment = text_image_entailment + part_text_image_entailment + cross_image_entailment + cross_text_entailment
|
| 947 |
-
else:
|
| 948 |
-
entailment = text_image_entailment + part_text_image_entailment + 0.5 * (
|
| 949 |
-
cross_image_entailment + cross_text_entailment
|
| 950 |
-
)
|
| 951 |
-
diagnostics = argent_entailment_diagnostics(
|
| 952 |
-
image_feats=image_feats,
|
| 953 |
-
text_feats=text_feats,
|
| 954 |
-
part_image_flat=part_image_flat,
|
| 955 |
-
part_text_flat=part_text_flat,
|
| 956 |
-
image_for_parts=image_for_parts,
|
| 957 |
-
text_for_parts=text_for_parts,
|
| 958 |
-
kappa=kappa,
|
| 959 |
-
product_metric=product_metric,
|
| 960 |
-
)
|
| 961 |
-
|
| 962 |
-
return {
|
| 963 |
-
"entailment_loss": entailment,
|
| 964 |
-
"text_image_entailment_loss": text_image_entailment,
|
| 965 |
-
"part_text_image_entailment_loss": part_text_image_entailment,
|
| 966 |
-
"cross_image_entailment_loss": cross_image_entailment,
|
| 967 |
-
"cross_text_entailment_loss": cross_text_entailment,
|
| 968 |
-
"cross_image_calibration_loss": image_feats.new_zeros(()),
|
| 969 |
-
"cross_text_calibration_loss": image_feats.new_zeros(()),
|
| 970 |
-
"norm_regularization_loss": norm_regularization,
|
| 971 |
-
**diagnostics,
|
| 972 |
-
}
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
def hierarchical_beta_argent_entailment_losses(
|
| 976 |
-
image_feats: Tensor,
|
| 977 |
-
text_feats: Tensor,
|
| 978 |
-
part_image_flat: Tensor,
|
| 979 |
-
part_text_flat: Tensor,
|
| 980 |
-
image_for_parts: Tensor,
|
| 981 |
-
text_for_parts: Tensor,
|
| 982 |
-
beta_query_image_feats: Tensor,
|
| 983 |
-
beta_query_text_feats: Tensor,
|
| 984 |
-
beta_query_owner: Tensor,
|
| 985 |
-
beta_query_parent: Tensor,
|
| 986 |
-
beta_query_weight: Tensor,
|
| 987 |
-
kappa: Tensor,
|
| 988 |
-
beta_query_source_part: Tensor | None = None,
|
| 989 |
-
beta: float = 1.0,
|
| 990 |
-
part_weights: Tensor | None = None,
|
| 991 |
-
product_metric: str = "l1",
|
| 992 |
-
aggregation: str = "uncha",
|
| 993 |
-
) -> dict[str, Tensor]:
|
| 994 |
-
base = uncha_argent_entailment_losses(
|
| 995 |
-
image_feats=image_feats,
|
| 996 |
-
text_feats=text_feats,
|
| 997 |
-
part_image_flat=part_image_flat,
|
| 998 |
-
part_text_flat=part_text_flat,
|
| 999 |
-
image_for_parts=image_for_parts,
|
| 1000 |
-
text_for_parts=text_for_parts,
|
| 1001 |
-
kappa=kappa,
|
| 1002 |
-
beta=beta,
|
| 1003 |
-
part_weights=part_weights,
|
| 1004 |
-
product_metric=product_metric,
|
| 1005 |
-
aggregation=aggregation,
|
| 1006 |
-
)
|
| 1007 |
-
if beta_query_image_feats.numel() == 0:
|
| 1008 |
-
return {
|
| 1009 |
-
**base,
|
| 1010 |
-
"hier_beta_query_text_entailment_loss": image_feats.new_zeros(()),
|
| 1011 |
-
"hier_beta_visual_entailment_loss": image_feats.new_zeros(()),
|
| 1012 |
-
"hier_beta_text_entailment_loss": image_feats.new_zeros(()),
|
| 1013 |
-
"hier_beta_sourcepart_visual_entailment_loss": image_feats.new_zeros(()),
|
| 1014 |
-
"hier_beta_sourcepart_text_entailment_loss": image_feats.new_zeros(()),
|
| 1015 |
-
"hier_beta_query_count": beta_query_owner.new_tensor(0),
|
| 1016 |
-
"hier_beta_sourcepart_query_count": beta_query_owner.new_tensor(0),
|
| 1017 |
-
}
|
| 1018 |
-
|
| 1019 |
-
query_owner = beta_query_owner.to(device=image_feats.device, dtype=torch.long)
|
| 1020 |
-
query_weights = beta_query_weight.to(device=image_feats.device, dtype=torch.float32).clamp_min(0.0)
|
| 1021 |
-
if query_weights.numel() != beta_query_image_feats.size(0):
|
| 1022 |
-
raise ValueError("beta_query_weight must have one value per beta query")
|
| 1023 |
-
query_weights = query_weights / query_weights.mean().clamp_min(torch.finfo(query_weights.dtype).eps)
|
| 1024 |
-
|
| 1025 |
-
query_text = argent_adaptive_entailment_residual(
|
| 1026 |
-
specific=beta_query_image_feats,
|
| 1027 |
-
general=beta_query_text_feats,
|
| 1028 |
-
kappa=kappa,
|
| 1029 |
-
adaptive_weight=False,
|
| 1030 |
-
beta=beta,
|
| 1031 |
-
product_metric=product_metric,
|
| 1032 |
-
)
|
| 1033 |
-
visual_hierarchy = argent_adaptive_entailment_residual(
|
| 1034 |
-
specific=image_feats.index_select(0, query_owner),
|
| 1035 |
-
general=beta_query_image_feats,
|
| 1036 |
-
kappa=kappa,
|
| 1037 |
-
adaptive_weight=True,
|
| 1038 |
-
beta=beta,
|
| 1039 |
-
product_metric=product_metric,
|
| 1040 |
-
)
|
| 1041 |
-
query_text_entailment = 0.5 * weighted_mean(query_text, query_weights)
|
| 1042 |
-
visual_entailment = 0.5 * weighted_mean(visual_hierarchy, query_weights)
|
| 1043 |
-
|
| 1044 |
-
parent = beta_query_parent.to(device=image_feats.device, dtype=torch.long)
|
| 1045 |
-
parent_mask = (parent >= 0) & (parent < beta_query_text_feats.size(0)) & (query_weights > 0.0)
|
| 1046 |
-
if bool(parent_mask.any()):
|
| 1047 |
-
child_text = beta_query_text_feats[parent_mask]
|
| 1048 |
-
parent_text = beta_query_text_feats[parent[parent_mask]]
|
| 1049 |
-
text_hierarchy = argent_adaptive_entailment_residual(
|
| 1050 |
-
specific=parent_text,
|
| 1051 |
-
general=child_text,
|
| 1052 |
-
kappa=kappa,
|
| 1053 |
-
adaptive_weight=True,
|
| 1054 |
-
beta=beta,
|
| 1055 |
-
product_metric=product_metric,
|
| 1056 |
-
)
|
| 1057 |
-
text_entailment = 0.5 * weighted_mean(text_hierarchy, query_weights[parent_mask])
|
| 1058 |
-
else:
|
| 1059 |
-
text_entailment = image_feats.new_zeros(())
|
| 1060 |
-
|
| 1061 |
-
sourcepart_visual_entailment = image_feats.new_zeros(())
|
| 1062 |
-
sourcepart_text_entailment = image_feats.new_zeros(())
|
| 1063 |
-
sourcepart_query_count = beta_query_owner.new_tensor(0)
|
| 1064 |
-
if beta_query_source_part is not None and part_image_flat.numel() > 0:
|
| 1065 |
-
source_part = beta_query_source_part.to(device=image_feats.device, dtype=torch.long)
|
| 1066 |
-
if source_part.numel() != beta_query_image_feats.size(0):
|
| 1067 |
-
raise ValueError("beta_query_source_part must have one value per beta query")
|
| 1068 |
-
source_mask = (
|
| 1069 |
-
(source_part >= 0)
|
| 1070 |
-
& (source_part < part_image_flat.size(0))
|
| 1071 |
-
& (query_weights > 0.0)
|
| 1072 |
-
)
|
| 1073 |
-
if bool(source_mask.any()):
|
| 1074 |
-
source_indices = source_part[source_mask]
|
| 1075 |
-
sourcepart_visual = argent_adaptive_entailment_residual(
|
| 1076 |
-
specific=part_image_flat.index_select(0, source_indices),
|
| 1077 |
-
general=beta_query_image_feats[source_mask],
|
| 1078 |
-
kappa=kappa,
|
| 1079 |
-
adaptive_weight=True,
|
| 1080 |
-
beta=beta,
|
| 1081 |
-
product_metric=product_metric,
|
| 1082 |
-
)
|
| 1083 |
-
sourcepart_text = argent_adaptive_entailment_residual(
|
| 1084 |
-
specific=part_text_flat.index_select(0, source_indices),
|
| 1085 |
-
general=beta_query_text_feats[source_mask],
|
| 1086 |
-
kappa=kappa,
|
| 1087 |
-
adaptive_weight=True,
|
| 1088 |
-
beta=beta,
|
| 1089 |
-
product_metric=product_metric,
|
| 1090 |
-
)
|
| 1091 |
-
source_weights = query_weights[source_mask]
|
| 1092 |
-
sourcepart_visual_entailment = 0.5 * weighted_mean(sourcepart_visual, source_weights)
|
| 1093 |
-
sourcepart_text_entailment = 0.5 * weighted_mean(sourcepart_text, source_weights)
|
| 1094 |
-
sourcepart_query_count = beta_query_owner.new_tensor(int(source_mask.sum().item()))
|
| 1095 |
-
|
| 1096 |
-
norm_regularization = argent_norm_regularization_loss(
|
| 1097 |
-
image_feats,
|
| 1098 |
-
text_feats,
|
| 1099 |
-
part_image_flat,
|
| 1100 |
-
part_text_flat,
|
| 1101 |
-
beta_query_image_feats,
|
| 1102 |
-
beta_query_text_feats,
|
| 1103 |
-
)
|
| 1104 |
-
sourcepart_entailment = 0.5 * (sourcepart_visual_entailment + sourcepart_text_entailment)
|
| 1105 |
-
query_entailment = query_text_entailment + 0.5 * (visual_entailment + text_entailment) + sourcepart_entailment
|
| 1106 |
-
return {
|
| 1107 |
-
**base,
|
| 1108 |
-
"entailment_loss": base["entailment_loss"] + query_entailment,
|
| 1109 |
-
"norm_regularization_loss": norm_regularization,
|
| 1110 |
-
"hier_beta_query_text_entailment_loss": query_text_entailment,
|
| 1111 |
-
"hier_beta_visual_entailment_loss": visual_entailment,
|
| 1112 |
-
"hier_beta_text_entailment_loss": text_entailment,
|
| 1113 |
-
"hier_beta_sourcepart_visual_entailment_loss": sourcepart_visual_entailment,
|
| 1114 |
-
"hier_beta_sourcepart_text_entailment_loss": sourcepart_text_entailment,
|
| 1115 |
-
"hier_beta_query_count": beta_query_owner.new_tensor(beta_query_owner.numel()),
|
| 1116 |
-
"hier_beta_sourcepart_query_count": sourcepart_query_count,
|
| 1117 |
-
}
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
def argent_entailment_diagnostics(
|
| 1121 |
-
image_feats: Tensor,
|
| 1122 |
-
text_feats: Tensor,
|
| 1123 |
-
part_image_flat: Tensor,
|
| 1124 |
-
part_text_flat: Tensor,
|
| 1125 |
-
image_for_parts: Tensor,
|
| 1126 |
-
text_for_parts: Tensor,
|
| 1127 |
-
kappa: Tensor,
|
| 1128 |
-
product_metric: str = "l1",
|
| 1129 |
-
) -> dict[str, Tensor]:
|
| 1130 |
-
zero = image_feats.new_zeros(())
|
| 1131 |
-
|
| 1132 |
-
def angle_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1133 |
-
if specific.numel() == 0:
|
| 1134 |
-
return zero
|
| 1135 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1136 |
-
if angles.dim() == 2:
|
| 1137 |
-
angles = angles.mean(dim=-1)
|
| 1138 |
-
return angles.detach().mean()
|
| 1139 |
-
|
| 1140 |
-
def pent_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1141 |
-
if specific.numel() == 0:
|
| 1142 |
-
return zero
|
| 1143 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1144 |
-
if angles.dim() == 2:
|
| 1145 |
-
angles = angles.mean(dim=-1)
|
| 1146 |
-
scores = torch.clamp(1.0 - (2.0 * angles / math.pi), min=0.0, max=1.0)
|
| 1147 |
-
return scores.detach().mean()
|
| 1148 |
-
|
| 1149 |
-
def distance_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1150 |
-
if specific.numel() == 0:
|
| 1151 |
-
return zero
|
| 1152 |
-
return lorentz_dist(specific, general, kappa, product_metric=product_metric).detach().mean()
|
| 1153 |
-
|
| 1154 |
-
def adaptive_weight_mean(specific: Tensor, general: Tensor) -> Tensor:
|
| 1155 |
-
if specific.numel() == 0:
|
| 1156 |
-
return zero
|
| 1157 |
-
weights = 1.0 - torch.exp(-lorentz_dist(specific, general, kappa, product_metric=product_metric))
|
| 1158 |
-
return weights.detach().mean()
|
| 1159 |
-
|
| 1160 |
-
def space_norm_mean(embedding: Tensor) -> Tensor:
|
| 1161 |
-
if embedding.numel() == 0:
|
| 1162 |
-
return zero
|
| 1163 |
-
return torch.linalg.norm(_space_components(embedding).float(), dim=-1).detach().mean()
|
| 1164 |
-
|
| 1165 |
-
return {
|
| 1166 |
-
"argent_text_image_angle_mean": angle_mean(image_feats, text_feats),
|
| 1167 |
-
"argent_text_image_pent_mean": pent_mean(image_feats, text_feats),
|
| 1168 |
-
"argent_part_text_image_angle_mean": angle_mean(part_image_flat, part_text_flat),
|
| 1169 |
-
"argent_part_text_image_pent_mean": pent_mean(part_image_flat, part_text_flat),
|
| 1170 |
-
"argent_cross_image_angle_mean": angle_mean(image_for_parts, part_image_flat),
|
| 1171 |
-
"argent_cross_image_pent_mean": pent_mean(image_for_parts, part_image_flat),
|
| 1172 |
-
"argent_cross_image_distance_mean": distance_mean(image_for_parts, part_image_flat),
|
| 1173 |
-
"argent_cross_image_adaptive_weight_mean": adaptive_weight_mean(image_for_parts, part_image_flat),
|
| 1174 |
-
"argent_cross_text_angle_mean": angle_mean(text_for_parts, part_text_flat),
|
| 1175 |
-
"argent_cross_text_pent_mean": pent_mean(text_for_parts, part_text_flat),
|
| 1176 |
-
"argent_cross_text_distance_mean": distance_mean(text_for_parts, part_text_flat),
|
| 1177 |
-
"argent_cross_text_adaptive_weight_mean": adaptive_weight_mean(text_for_parts, part_text_flat),
|
| 1178 |
-
"argent_image_space_norm_mean": space_norm_mean(image_feats),
|
| 1179 |
-
"argent_text_space_norm_mean": space_norm_mean(text_feats),
|
| 1180 |
-
"argent_part_image_space_norm_mean": space_norm_mean(part_image_flat),
|
| 1181 |
-
"argent_part_text_space_norm_mean": space_norm_mean(part_text_flat),
|
| 1182 |
-
}
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
def part_quality_weights(
|
| 1186 |
-
image_for_parts: Tensor,
|
| 1187 |
-
text_for_parts: Tensor,
|
| 1188 |
-
part_image_flat: Tensor,
|
| 1189 |
-
part_text_flat: Tensor,
|
| 1190 |
-
part_owner: Tensor,
|
| 1191 |
-
batch_size: int,
|
| 1192 |
-
kappa: Tensor,
|
| 1193 |
-
mode: str,
|
| 1194 |
-
topk: int = 5,
|
| 1195 |
-
temperature: float = 4.0,
|
| 1196 |
-
product_metric: str = "l1",
|
| 1197 |
-
) -> tuple[Tensor | None, Tensor, Tensor]:
|
| 1198 |
-
if mode not in {"none", "soft", "topk"}:
|
| 1199 |
-
raise ValueError(f"Unsupported part quality mode {mode!r}; expected 'none', 'soft', or 'topk'")
|
| 1200 |
-
if mode == "none" or part_image_flat.numel() == 0:
|
| 1201 |
-
empty = part_image_flat.new_zeros((part_image_flat.size(0),))
|
| 1202 |
-
return None, empty, empty
|
| 1203 |
-
|
| 1204 |
-
with torch.no_grad():
|
| 1205 |
-
image_parent = torch.exp(-lorentz_dist(part_image_flat, image_for_parts, kappa, product_metric=product_metric))
|
| 1206 |
-
text_parent = torch.exp(-lorentz_dist(part_text_flat, text_for_parts, kappa, product_metric=product_metric))
|
| 1207 |
-
image_text = torch.exp(-lorentz_dist(part_image_flat, part_text_flat, kappa, product_metric=product_metric))
|
| 1208 |
-
scores = torch.stack([image_parent, text_parent, image_text]).mean(dim=0).clamp_min(0.0)
|
| 1209 |
-
|
| 1210 |
-
if mode == "soft":
|
| 1211 |
-
weights = _owner_softmax_weights(scores, part_owner, batch_size, temperature)
|
| 1212 |
-
else:
|
| 1213 |
-
weights = _owner_topk_weights(scores, part_owner, batch_size, topk)
|
| 1214 |
-
weights = weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
|
| 1215 |
-
return weights, scores, (weights > 0.0).to(dtype=scores.dtype)
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
def _owner_softmax_weights(scores: Tensor, part_owner: Tensor, batch_size: int, temperature: float) -> Tensor:
|
| 1219 |
-
weights = torch.zeros_like(scores)
|
| 1220 |
-
for owner in range(batch_size):
|
| 1221 |
-
mask = part_owner == owner
|
| 1222 |
-
if not bool(mask.any()):
|
| 1223 |
-
continue
|
| 1224 |
-
owner_scores = scores[mask]
|
| 1225 |
-
owner_weights = torch.softmax(owner_scores * temperature, dim=0) * owner_scores.numel()
|
| 1226 |
-
weights[mask] = owner_weights
|
| 1227 |
-
return weights
|
| 1228 |
-
|
| 1229 |
-
|
| 1230 |
-
def _owner_topk_weights(scores: Tensor, part_owner: Tensor, batch_size: int, topk: int) -> Tensor:
|
| 1231 |
-
if topk <= 0:
|
| 1232 |
-
raise ValueError("topk must be positive for top-k part quality weighting")
|
| 1233 |
-
weights = torch.zeros_like(scores)
|
| 1234 |
-
for owner in range(batch_size):
|
| 1235 |
-
indices = torch.nonzero(part_owner == owner, as_tuple=False).flatten()
|
| 1236 |
-
if indices.numel() == 0:
|
| 1237 |
-
continue
|
| 1238 |
-
keep = min(topk, indices.numel())
|
| 1239 |
-
selected = indices[scores[indices].topk(k=keep).indices]
|
| 1240 |
-
weights[selected] = 1.0
|
| 1241 |
-
return weights
|
| 1242 |
-
|
| 1243 |
-
|
| 1244 |
-
def argent_adaptive_entailment_residual(
|
| 1245 |
-
specific: Tensor,
|
| 1246 |
-
general: Tensor,
|
| 1247 |
-
kappa: Tensor,
|
| 1248 |
-
adaptive_weight: bool,
|
| 1249 |
-
beta: float = 1.0,
|
| 1250 |
-
product_metric: str = "l1",
|
| 1251 |
-
) -> Tensor:
|
| 1252 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1253 |
-
if angles.dim() == 2:
|
| 1254 |
-
angles = angles.mean(dim=-1)
|
| 1255 |
-
if adaptive_weight:
|
| 1256 |
-
weights = 1.0 - torch.exp(
|
| 1257 |
-
-lorentz_dist(specific=specific, general=general, kappa=kappa, product_metric=product_metric)
|
| 1258 |
-
)
|
| 1259 |
-
angles = angles * weights
|
| 1260 |
-
return F.huber_loss(angles, torch.zeros_like(angles), delta=beta, reduction="none")
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
def lorentz_dist(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
|
| 1264 |
-
return paired_dist(specific, general, kappa, product_metric=product_metric)
|
| 1265 |
-
|
| 1266 |
-
|
| 1267 |
-
def argent_norm_regularization_loss(*embeddings: Tensor, eps: float = 1e-6) -> Tensor:
|
| 1268 |
-
losses = []
|
| 1269 |
-
for embedding in embeddings:
|
| 1270 |
-
if embedding.numel() == 0:
|
| 1271 |
-
continue
|
| 1272 |
-
space = _space_components(embedding)
|
| 1273 |
-
space_norm = torch.linalg.norm(space.float(), dim=-1).clamp_min(eps)
|
| 1274 |
-
losses.append((space_norm.square() - torch.log(space_norm)).mean())
|
| 1275 |
-
if not losses:
|
| 1276 |
-
raise ValueError("argent_norm_regularization_loss requires at least one non-empty embedding tensor")
|
| 1277 |
-
return torch.stack(losses).mean()
|
| 1278 |
-
|
| 1279 |
-
|
| 1280 |
-
def piecewise_entailment_residual(
|
| 1281 |
-
specific: Tensor,
|
| 1282 |
-
general: Tensor,
|
| 1283 |
-
kappa: Tensor,
|
| 1284 |
-
aperture_scale: float,
|
| 1285 |
-
factor: float = 0.1,
|
| 1286 |
-
geometry: str = "lorentz",
|
| 1287 |
-
) -> Tensor:
|
| 1288 |
-
if geometry == "lorentz":
|
| 1289 |
-
angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
|
| 1290 |
-
apertures = factor_half_aperture(general=general, kappa=kappa)
|
| 1291 |
-
elif geometry == "euclidean":
|
| 1292 |
-
angles = euclidean_angle(specific=specific, general=general)
|
| 1293 |
-
apertures = euclidean_half_aperture(general=general, aperture_scale=aperture_scale)
|
| 1294 |
-
aperture_scale = 1.0
|
| 1295 |
-
else:
|
| 1296 |
-
raise ValueError(f"Unsupported entailment geometry {geometry!r}; expected 'lorentz' or 'euclidean'")
|
| 1297 |
-
residual = angles - aperture_scale * apertures
|
| 1298 |
-
loss = torch.where(residual > 0.0, residual + factor * angles, factor * angles)
|
| 1299 |
-
return loss.mean(dim=-1) if loss.dim() == 2 else loss
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
def euclidean_angle(specific: Tensor, general: Tensor, eps: float = 1e-6) -> Tensor:
|
| 1303 |
-
specific_space = _space_components(specific).float()
|
| 1304 |
-
general_space = _space_components(general).float()
|
| 1305 |
-
numerator = (specific_space * general_space).sum(dim=-1)
|
| 1306 |
-
denominator = torch.linalg.norm(specific_space, dim=-1) * torch.linalg.norm(general_space, dim=-1)
|
| 1307 |
-
dtype_eps = torch.finfo(specific_space.dtype).eps
|
| 1308 |
-
angle_eps = max(eps, 16.0 * dtype_eps)
|
| 1309 |
-
cosine = (numerator / denominator.clamp_min(angle_eps)).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
|
| 1310 |
-
return torch.acos(cosine)
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
def euclidean_half_aperture(general: Tensor, aperture_scale: float, eps: float = 1e-8) -> Tensor:
|
| 1314 |
-
general_norm = torch.linalg.norm(_space_components(general).float(), dim=-1).clamp_min(eps)
|
| 1315 |
-
return torch.atan(torch.as_tensor(aperture_scale, device=general.device, dtype=general.dtype) / general_norm)
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
def aggregate_part_consistency_loss(
|
| 1319 |
-
image_feats: Tensor,
|
| 1320 |
-
text_feats: Tensor,
|
| 1321 |
-
part_image_flat: Tensor,
|
| 1322 |
-
part_text_flat: Tensor,
|
| 1323 |
-
part_owner: Tensor,
|
| 1324 |
-
part_weights: Tensor | None = None,
|
| 1325 |
-
) -> Tensor:
|
| 1326 |
-
if part_image_flat.numel() == 0:
|
| 1327 |
-
return image_feats.new_zeros(())
|
| 1328 |
-
|
| 1329 |
-
batch_size = image_feats.size(0)
|
| 1330 |
-
image_space = _space_components(image_feats).reshape(batch_size, -1).float()
|
| 1331 |
-
text_space = _space_components(text_feats).reshape(batch_size, -1).float()
|
| 1332 |
-
part_image_space = _space_components(part_image_flat).reshape(part_image_flat.size(0), -1).float()
|
| 1333 |
-
part_text_space = _space_components(part_text_flat).reshape(part_text_flat.size(0), -1).float()
|
| 1334 |
-
if part_weights is None:
|
| 1335 |
-
counts = torch.bincount(part_owner, minlength=batch_size).to(device=image_feats.device, dtype=image_space.dtype)
|
| 1336 |
-
denom = counts
|
| 1337 |
-
valid = counts > 0
|
| 1338 |
-
weights = part_image_space.new_ones((part_image_space.size(0),))
|
| 1339 |
-
else:
|
| 1340 |
-
weights = part_weights.to(device=image_feats.device, dtype=image_space.dtype).flatten()
|
| 1341 |
-
if weights.numel() != part_owner.numel():
|
| 1342 |
-
raise ValueError("part_weights must have the same number of elements as part_owner when provided")
|
| 1343 |
-
denom = torch.zeros(batch_size, device=image_feats.device, dtype=image_space.dtype)
|
| 1344 |
-
denom.index_add_(0, part_owner, weights)
|
| 1345 |
-
valid = denom > 0
|
| 1346 |
-
|
| 1347 |
-
image_agg = image_space.new_zeros(image_space.shape)
|
| 1348 |
-
text_agg = text_space.new_zeros(text_space.shape)
|
| 1349 |
-
image_agg.index_add_(0, part_owner, part_image_space * weights[:, None])
|
| 1350 |
-
text_agg.index_add_(0, part_owner, part_text_space * weights[:, None])
|
| 1351 |
-
image_agg = image_agg[valid] / denom[valid, None].clamp_min(1.0)
|
| 1352 |
-
text_agg = text_agg[valid] / denom[valid, None].clamp_min(1.0)
|
| 1353 |
-
|
| 1354 |
-
image_space = image_space[valid]
|
| 1355 |
-
text_space = text_space[valid]
|
| 1356 |
-
return 0.25 * (
|
| 1357 |
-
cosine_residual(image_agg, image_space)
|
| 1358 |
-
+ cosine_residual(text_agg, text_space)
|
| 1359 |
-
+ cosine_residual(image_agg, text_space)
|
| 1360 |
-
+ cosine_residual(text_agg, image_space)
|
| 1361 |
-
)
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
def cosine_residual(x: Tensor, y: Tensor) -> Tensor:
|
| 1365 |
-
return (1.0 - F.cosine_similarity(x, y, dim=-1)).mean()
|
| 1366 |
-
|
| 1367 |
-
|
| 1368 |
-
def uncertainty_calibrated_entailment_loss(
|
| 1369 |
-
entail_residual: Tensor,
|
| 1370 |
-
log_uncertainty: Tensor,
|
| 1371 |
-
alpha: float = 10.0,
|
| 1372 |
-
stop_grad: bool = True,
|
| 1373 |
-
weights: Tensor | None = None,
|
| 1374 |
-
) -> tuple[Tensor, Tensor]:
|
| 1375 |
-
mean_loss = 0.5 * entail_residual
|
| 1376 |
-
uncertainty = torch.exp(log_uncertainty).clamp(min=1e-6, max=1e6)
|
| 1377 |
-
residual = entail_residual.detach() if stop_grad else entail_residual
|
| 1378 |
-
scaled_entail = residual / (uncertainty + 1e-6)
|
| 1379 |
-
calibration_term = 0.5 * scaled_entail + 0.5 * log_uncertainty
|
| 1380 |
-
prob = torch.softmax(log_uncertainty.flatten(), dim=0)
|
| 1381 |
-
entropy = -(prob * torch.log(prob + 1e-8)).sum()
|
| 1382 |
-
calibration_loss = alpha * (calibration_term + entropy)
|
| 1383 |
-
return weighted_mean(mean_loss, weights), weighted_mean(calibration_loss, weights)
|
| 1384 |
-
|
| 1385 |
-
|
| 1386 |
-
def embedding_uncertainty(x: Tensor) -> Tensor:
|
| 1387 |
-
space = _space_components(x)
|
| 1388 |
-
norm = torch.linalg.norm(space.float(), dim=-1)
|
| 1389 |
-
if norm.dim() > 1:
|
| 1390 |
-
norm = norm.mean(dim=-1)
|
| 1391 |
-
return F.softplus(-norm)
|
| 1392 |
-
|
| 1393 |
-
|
| 1394 |
-
def _space_components(x: Tensor) -> Tensor:
|
| 1395 |
-
return x[..., 1:] if x.shape[-1] > 1 else x
|
| 1396 |
-
|
| 1397 |
-
|
| 1398 |
-
def _flatten_valid_parts(part_image_feats: Tensor, part_text_feats: Tensor, part_mask: Tensor, targets: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
| 1399 |
-
part_targets = targets[:, None].expand_as(part_mask)[part_mask]
|
| 1400 |
-
return part_image_feats[part_mask], part_text_feats[part_mask], part_targets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/objectives.py
DELETED
|
@@ -1,580 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from collections.abc import Mapping
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from torch import Tensor, nn
|
| 7 |
-
|
| 8 |
-
from hyper3_clip.models.lorentz import log_map0, metric_pairwise_dist
|
| 9 |
-
from hyper3_clip.models.losses import (
|
| 10 |
-
aggregate_part_consistency_loss,
|
| 11 |
-
contrastive_ce,
|
| 12 |
-
gramian_volume_loss,
|
| 13 |
-
hierarchical_beta_argent_entailment_losses,
|
| 14 |
-
packed_part_contrastive_loss,
|
| 15 |
-
packed_part_entailment_loss,
|
| 16 |
-
part_quality_weights,
|
| 17 |
-
radius_order_hinge,
|
| 18 |
-
uncha_argent_entailment_losses,
|
| 19 |
-
uncha_contrastive_losses,
|
| 20 |
-
uncha_entailment_losses,
|
| 21 |
-
)
|
| 22 |
-
from hyper3_clip.training.distributed import gather_variable_many_with_grad, gather_variable_no_grad, get_rank
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class HyCoCLIPObjective(nn.Module):
|
| 26 |
-
def __init__(
|
| 27 |
-
self,
|
| 28 |
-
entail_weight: float,
|
| 29 |
-
inter_aperture_scale: float,
|
| 30 |
-
intra_aperture_scale: float,
|
| 31 |
-
product_metric: str = "l1",
|
| 32 |
-
) -> None:
|
| 33 |
-
super().__init__()
|
| 34 |
-
self.entail_weight = entail_weight
|
| 35 |
-
self.inter_aperture_scale = inter_aperture_scale
|
| 36 |
-
self.intra_aperture_scale = intra_aperture_scale
|
| 37 |
-
self.product_metric = product_metric
|
| 38 |
-
|
| 39 |
-
def forward(self, embeddings: Mapping[str, Tensor], logit_scale: Tensor) -> dict[str, Tensor]:
|
| 40 |
-
part_owner = embeddings["part_owner"].long()
|
| 41 |
-
part_count = part_owner.new_tensor(part_owner.numel())
|
| 42 |
-
contrastive = packed_part_contrastive_loss(
|
| 43 |
-
image_feats=embeddings["image_feats"],
|
| 44 |
-
text_feats=embeddings["text_feats"],
|
| 45 |
-
part_image_feats=embeddings["part_image_feats"],
|
| 46 |
-
part_text_feats=embeddings["part_text_feats"],
|
| 47 |
-
part_owner=part_owner,
|
| 48 |
-
kappa=embeddings["kappa"],
|
| 49 |
-
logit_scale=logit_scale,
|
| 50 |
-
all_image_feats=embeddings.get("all_image_feats"),
|
| 51 |
-
all_text_feats=embeddings.get("all_text_feats"),
|
| 52 |
-
targets=embeddings.get("targets"),
|
| 53 |
-
)
|
| 54 |
-
entailment = packed_part_entailment_loss(
|
| 55 |
-
image_feats=embeddings["image_feats"],
|
| 56 |
-
text_feats=embeddings["text_feats"],
|
| 57 |
-
part_image_feats=embeddings["part_image_feats"],
|
| 58 |
-
part_text_feats=embeddings["part_text_feats"],
|
| 59 |
-
part_owner=part_owner,
|
| 60 |
-
kappa=embeddings["kappa"],
|
| 61 |
-
inter_aperture_scale=self.inter_aperture_scale,
|
| 62 |
-
intra_aperture_scale=self.intra_aperture_scale,
|
| 63 |
-
)
|
| 64 |
-
total = contrastive + self.entail_weight * entailment
|
| 65 |
-
return {
|
| 66 |
-
"loss": total,
|
| 67 |
-
"contrastive_loss": contrastive,
|
| 68 |
-
"entailment_loss": entailment,
|
| 69 |
-
"part_count": part_count,
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class UNCHAObjective(nn.Module):
|
| 74 |
-
def __init__(
|
| 75 |
-
self,
|
| 76 |
-
entail_weight: float,
|
| 77 |
-
inter_aperture_scale: float,
|
| 78 |
-
intra_aperture_scale: float,
|
| 79 |
-
piecewise_factor: float = 0.1,
|
| 80 |
-
calibration_alpha: float = 10.0,
|
| 81 |
-
stop_grad_calibration: bool = True,
|
| 82 |
-
entailment_geometry: str = "lorentz",
|
| 83 |
-
aggregate_weight: float = 0.0,
|
| 84 |
-
entailment_loss: str = "piecewise",
|
| 85 |
-
argent_beta: float = 1.0,
|
| 86 |
-
argent_norm_weight: float = 0.0,
|
| 87 |
-
argent_aux_weight: float = 0.5,
|
| 88 |
-
argent_aggregation: str = "uncha",
|
| 89 |
-
part_weight_power: float = 0.0,
|
| 90 |
-
product_metric: str = "l1",
|
| 91 |
-
contrastive_loss: str = "ce",
|
| 92 |
-
sigmoid_negative_weight: float = 1.0,
|
| 93 |
-
part_quality_mode: str = "none",
|
| 94 |
-
part_quality_topk: int = 5,
|
| 95 |
-
part_quality_temperature: float = 4.0,
|
| 96 |
-
contrastive_global_weight: float = 1.0,
|
| 97 |
-
contrastive_local_weight: float = 1.0,
|
| 98 |
-
contrastive_global_local_weight: float = 1.0,
|
| 99 |
-
beta_cal_beta: float = 0.0,
|
| 100 |
-
beta_cal_variant: str = "ce",
|
| 101 |
-
beta_cal_weight: float = 0.0,
|
| 102 |
-
himo_component_weight: float = 0.0,
|
| 103 |
-
global_local_mode: str = "repeat",
|
| 104 |
-
global_local_metric: str = "distance",
|
| 105 |
-
global_local_angle_aux_weight: float = 0.0,
|
| 106 |
-
global_local_angle_aux_mode: str = "contrastive",
|
| 107 |
-
global_local_angle_aux_scale: float = 5.5,
|
| 108 |
-
global_local_angle_aux_aperture_scale: float = 1.0,
|
| 109 |
-
radius_order_weight: float = 0.0,
|
| 110 |
-
radius_order_margin: float = 0.0,
|
| 111 |
-
gramian_align_weight: float = 0.0,
|
| 112 |
-
) -> None:
|
| 113 |
-
super().__init__()
|
| 114 |
-
if entailment_loss not in {
|
| 115 |
-
"piecewise",
|
| 116 |
-
"argent",
|
| 117 |
-
"piecewise_argent",
|
| 118 |
-
"hier_beta_argent",
|
| 119 |
-
"hier_beta_sourcepart_argent",
|
| 120 |
-
}:
|
| 121 |
-
raise ValueError(
|
| 122 |
-
f"Unsupported UNCHA entailment loss {entailment_loss!r}; "
|
| 123 |
-
"expected 'piecewise', 'argent', 'piecewise_argent', 'hier_beta_argent', "
|
| 124 |
-
"or 'hier_beta_sourcepart_argent'"
|
| 125 |
-
)
|
| 126 |
-
if contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
|
| 127 |
-
raise ValueError("contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
|
| 128 |
-
if beta_cal_variant not in {"ce", "bce"}:
|
| 129 |
-
raise ValueError("beta_cal_variant must be 'ce' or 'bce'")
|
| 130 |
-
if argent_aggregation not in {"uncha", "equal"}:
|
| 131 |
-
raise ValueError("argent_aggregation must be 'uncha' or 'equal'")
|
| 132 |
-
if part_quality_mode not in {"none", "soft", "topk"}:
|
| 133 |
-
raise ValueError("part_quality_mode must be 'none', 'soft', or 'topk'")
|
| 134 |
-
if global_local_mode not in {"repeat", "inbatch"}:
|
| 135 |
-
raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
|
| 136 |
-
if global_local_metric not in {"distance", "angle"}:
|
| 137 |
-
raise ValueError("global_local_metric must be 'distance' or 'angle'")
|
| 138 |
-
if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
|
| 139 |
-
raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
|
| 140 |
-
if global_local_angle_aux_weight < 0.0:
|
| 141 |
-
raise ValueError("global_local_angle_aux_weight must be non-negative")
|
| 142 |
-
if global_local_angle_aux_scale <= 0.0:
|
| 143 |
-
raise ValueError("global_local_angle_aux_scale must be positive")
|
| 144 |
-
if global_local_angle_aux_aperture_scale <= 0.0:
|
| 145 |
-
raise ValueError("global_local_angle_aux_aperture_scale must be positive")
|
| 146 |
-
if part_quality_topk <= 0:
|
| 147 |
-
raise ValueError("part_quality_topk must be positive")
|
| 148 |
-
self.entail_weight = entail_weight
|
| 149 |
-
self.inter_aperture_scale = inter_aperture_scale
|
| 150 |
-
self.intra_aperture_scale = intra_aperture_scale
|
| 151 |
-
self.piecewise_factor = piecewise_factor
|
| 152 |
-
self.calibration_alpha = calibration_alpha
|
| 153 |
-
self.stop_grad_calibration = stop_grad_calibration
|
| 154 |
-
self.entailment_geometry = entailment_geometry
|
| 155 |
-
self.aggregate_weight = aggregate_weight
|
| 156 |
-
self.entailment_loss = entailment_loss
|
| 157 |
-
self.argent_beta = argent_beta
|
| 158 |
-
self.argent_norm_weight = argent_norm_weight
|
| 159 |
-
self.argent_aux_weight = argent_aux_weight
|
| 160 |
-
self.argent_aggregation = argent_aggregation
|
| 161 |
-
self.part_weight_power = part_weight_power
|
| 162 |
-
self.product_metric = product_metric
|
| 163 |
-
self.contrastive_loss = contrastive_loss
|
| 164 |
-
self.sigmoid_negative_weight = sigmoid_negative_weight
|
| 165 |
-
self.part_quality_mode = part_quality_mode
|
| 166 |
-
self.part_quality_topk = part_quality_topk
|
| 167 |
-
self.part_quality_temperature = part_quality_temperature
|
| 168 |
-
self.contrastive_global_weight = float(contrastive_global_weight)
|
| 169 |
-
self.contrastive_local_weight = float(contrastive_local_weight)
|
| 170 |
-
self.contrastive_global_local_weight = float(contrastive_global_local_weight)
|
| 171 |
-
self.beta_cal_beta = float(beta_cal_beta)
|
| 172 |
-
self.beta_cal_variant = beta_cal_variant
|
| 173 |
-
self.beta_cal_weight = float(beta_cal_weight)
|
| 174 |
-
self.himo_component_weight = float(himo_component_weight)
|
| 175 |
-
self.global_local_mode = global_local_mode
|
| 176 |
-
self.global_local_metric = global_local_metric
|
| 177 |
-
self.global_local_angle_aux_weight = float(global_local_angle_aux_weight)
|
| 178 |
-
self.global_local_angle_aux_mode = global_local_angle_aux_mode
|
| 179 |
-
self.global_local_angle_aux_scale = float(global_local_angle_aux_scale)
|
| 180 |
-
self.global_local_angle_aux_aperture_scale = float(global_local_angle_aux_aperture_scale)
|
| 181 |
-
self.radius_order_weight = float(radius_order_weight)
|
| 182 |
-
self.radius_order_margin = float(radius_order_margin)
|
| 183 |
-
self.gramian_align_weight = float(gramian_align_weight)
|
| 184 |
-
|
| 185 |
-
def forward(self, embeddings: Mapping[str, Tensor], logit_scales: Mapping[str, Tensor]) -> dict[str, Tensor]:
|
| 186 |
-
part_owner = embeddings["part_owner"].long()
|
| 187 |
-
part_count = part_owner.new_tensor(part_owner.numel())
|
| 188 |
-
part_image_flat = embeddings["part_image_feats"]
|
| 189 |
-
part_text_flat = embeddings["part_text_feats"]
|
| 190 |
-
image_feats = embeddings["image_feats"]
|
| 191 |
-
text_feats = embeddings["text_feats"]
|
| 192 |
-
|
| 193 |
-
if part_owner.numel() == 0:
|
| 194 |
-
image_for_parts = image_feats.new_zeros((0, image_feats.size(-1)))
|
| 195 |
-
text_for_parts = text_feats.new_zeros((0, text_feats.size(-1)))
|
| 196 |
-
else:
|
| 197 |
-
image_for_parts = image_feats[part_owner]
|
| 198 |
-
text_for_parts = text_feats[part_owner]
|
| 199 |
-
count_part_weights = _part_weights(part_owner, image_feats.size(0), self.part_weight_power)
|
| 200 |
-
quality_part_weights, quality_scores, quality_keep = part_quality_weights(
|
| 201 |
-
image_for_parts=image_for_parts,
|
| 202 |
-
text_for_parts=text_for_parts,
|
| 203 |
-
part_image_flat=part_image_flat,
|
| 204 |
-
part_text_flat=part_text_flat,
|
| 205 |
-
part_owner=part_owner,
|
| 206 |
-
batch_size=image_feats.size(0),
|
| 207 |
-
kappa=embeddings["kappa"],
|
| 208 |
-
mode=self.part_quality_mode,
|
| 209 |
-
topk=self.part_quality_topk,
|
| 210 |
-
temperature=self.part_quality_temperature,
|
| 211 |
-
product_metric=self.product_metric,
|
| 212 |
-
)
|
| 213 |
-
part_weights = _combine_part_weights(count_part_weights, quality_part_weights)
|
| 214 |
-
|
| 215 |
-
needs_repeated_global_local = self.global_local_mode == "repeat" and self.contrastive_global_local_weight != 0.0
|
| 216 |
-
part_feature_tensors = [part_image_flat, part_text_flat]
|
| 217 |
-
if needs_repeated_global_local:
|
| 218 |
-
part_feature_tensors.extend([image_for_parts, text_for_parts])
|
| 219 |
-
gathered_part_features, part_counts = gather_variable_many_with_grad(part_feature_tensors)
|
| 220 |
-
all_part_image_feats = gathered_part_features[0]
|
| 221 |
-
all_part_text_feats = gathered_part_features[1]
|
| 222 |
-
all_image_for_parts = gathered_part_features[2] if needs_repeated_global_local else None
|
| 223 |
-
all_text_for_parts = gathered_part_features[3] if needs_repeated_global_local else None
|
| 224 |
-
image_euc_feats = embeddings.get("image_euc_feats")
|
| 225 |
-
text_euc_feats = embeddings.get("text_euc_feats")
|
| 226 |
-
part_image_euc_flat = embeddings.get("part_image_euc_feats")
|
| 227 |
-
part_text_euc_flat = embeddings.get("part_text_euc_feats")
|
| 228 |
-
image_for_parts_euc = None
|
| 229 |
-
text_for_parts_euc = None
|
| 230 |
-
all_part_image_euc_feats = None
|
| 231 |
-
all_part_text_euc_feats = None
|
| 232 |
-
all_image_for_parts_euc = None
|
| 233 |
-
all_text_for_parts_euc = None
|
| 234 |
-
if (
|
| 235 |
-
image_euc_feats is not None
|
| 236 |
-
and text_euc_feats is not None
|
| 237 |
-
and part_owner.numel() > 0
|
| 238 |
-
and needs_repeated_global_local
|
| 239 |
-
):
|
| 240 |
-
image_for_parts_euc = image_euc_feats[part_owner]
|
| 241 |
-
text_for_parts_euc = text_euc_feats[part_owner]
|
| 242 |
-
if part_image_euc_flat is not None and part_text_euc_flat is not None:
|
| 243 |
-
euc_feature_tensors = [part_image_euc_flat, part_text_euc_flat]
|
| 244 |
-
if image_for_parts_euc is not None and text_for_parts_euc is not None:
|
| 245 |
-
euc_feature_tensors.extend([image_for_parts_euc, text_for_parts_euc])
|
| 246 |
-
gathered_euc_features, _ = gather_variable_many_with_grad(euc_feature_tensors)
|
| 247 |
-
all_part_image_euc_feats = gathered_euc_features[0]
|
| 248 |
-
all_part_text_euc_feats = gathered_euc_features[1]
|
| 249 |
-
if image_for_parts_euc is not None and text_for_parts_euc is not None:
|
| 250 |
-
all_image_for_parts_euc = gathered_euc_features[2]
|
| 251 |
-
all_text_for_parts_euc = gathered_euc_features[3]
|
| 252 |
-
if "targets" not in embeddings:
|
| 253 |
-
raise ValueError("UNCHAObjective requires 'targets' to compute group-aware losses")
|
| 254 |
-
global_targets = embeddings["targets"]
|
| 255 |
-
part_group_ids = global_targets[part_owner] if part_owner.numel() > 0 else part_owner.new_zeros((0,))
|
| 256 |
-
all_part_group_ids = None
|
| 257 |
-
if self.beta_cal_weight > 0.0 and self.beta_cal_beta > 0.0:
|
| 258 |
-
all_part_group_ids, _ = gather_variable_no_grad(part_group_ids)
|
| 259 |
-
part_offset = part_counts[: get_rank()].sum() if part_counts.numel() > 1 else part_counts.new_zeros(())
|
| 260 |
-
part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device) + part_offset
|
| 261 |
-
|
| 262 |
-
contrastive = uncha_contrastive_losses(
|
| 263 |
-
image_feats=image_feats,
|
| 264 |
-
text_feats=text_feats,
|
| 265 |
-
part_image_flat=part_image_flat,
|
| 266 |
-
part_text_flat=part_text_flat,
|
| 267 |
-
image_for_parts=image_for_parts,
|
| 268 |
-
text_for_parts=text_for_parts,
|
| 269 |
-
image_euc_feats=image_euc_feats,
|
| 270 |
-
text_euc_feats=text_euc_feats,
|
| 271 |
-
part_image_euc_flat=part_image_euc_flat,
|
| 272 |
-
part_text_euc_flat=part_text_euc_flat,
|
| 273 |
-
image_for_parts_euc=image_for_parts_euc,
|
| 274 |
-
text_for_parts_euc=text_for_parts_euc,
|
| 275 |
-
kappa=embeddings["kappa"],
|
| 276 |
-
global_logit_scale=logit_scales["global"],
|
| 277 |
-
local_logit_scale=logit_scales["local"],
|
| 278 |
-
global_local_logit_scale=logit_scales["global_local"],
|
| 279 |
-
all_image_feats=embeddings.get("all_image_feats"),
|
| 280 |
-
all_text_feats=embeddings.get("all_text_feats"),
|
| 281 |
-
all_part_image_feats=all_part_image_feats,
|
| 282 |
-
all_part_text_feats=all_part_text_feats,
|
| 283 |
-
all_image_for_parts=all_image_for_parts,
|
| 284 |
-
all_text_for_parts=all_text_for_parts,
|
| 285 |
-
all_image_euc_feats=embeddings.get("all_image_euc_feats"),
|
| 286 |
-
all_text_euc_feats=embeddings.get("all_text_euc_feats"),
|
| 287 |
-
all_part_image_euc_feats=all_part_image_euc_feats,
|
| 288 |
-
all_part_text_euc_feats=all_part_text_euc_feats,
|
| 289 |
-
all_image_for_parts_euc=all_image_for_parts_euc,
|
| 290 |
-
all_text_for_parts_euc=all_text_for_parts_euc,
|
| 291 |
-
global_targets=global_targets,
|
| 292 |
-
part_targets=part_targets,
|
| 293 |
-
part_weights=part_weights,
|
| 294 |
-
product_metric=self.product_metric,
|
| 295 |
-
loss_type=self.contrastive_loss,
|
| 296 |
-
contrastive_global_weight=self.contrastive_global_weight,
|
| 297 |
-
contrastive_local_weight=self.contrastive_local_weight,
|
| 298 |
-
contrastive_global_local_weight=self.contrastive_global_local_weight,
|
| 299 |
-
beta_cal_beta=self.beta_cal_beta,
|
| 300 |
-
beta_cal_variant=self.beta_cal_variant,
|
| 301 |
-
beta_cal_weight=self.beta_cal_weight,
|
| 302 |
-
part_group_ids=part_group_ids,
|
| 303 |
-
all_part_group_ids=all_part_group_ids,
|
| 304 |
-
global_logit_bias=logit_scales.get("global_bias"),
|
| 305 |
-
local_logit_bias=logit_scales.get("local_bias"),
|
| 306 |
-
global_local_logit_bias=logit_scales.get("global_local_bias"),
|
| 307 |
-
sigmoid_negative_weight=self.sigmoid_negative_weight,
|
| 308 |
-
global_local_mode=self.global_local_mode,
|
| 309 |
-
global_local_metric=self.global_local_metric,
|
| 310 |
-
global_local_angle_aux_weight=self.global_local_angle_aux_weight,
|
| 311 |
-
global_local_angle_aux_mode=self.global_local_angle_aux_mode,
|
| 312 |
-
global_local_angle_aux_scale=self.global_local_angle_aux_scale,
|
| 313 |
-
global_local_angle_aux_aperture_scale=self.global_local_angle_aux_aperture_scale,
|
| 314 |
-
)
|
| 315 |
-
himo_component_loss = image_feats.new_zeros(())
|
| 316 |
-
if self.himo_component_weight > 0.0 and embeddings.get("himo_text_feats") is not None:
|
| 317 |
-
himo_text_feats = embeddings["himo_text_feats"]
|
| 318 |
-
all_himo_text_feats = embeddings.get("all_himo_text_feats")
|
| 319 |
-
if all_himo_text_feats is None:
|
| 320 |
-
raise ValueError("himo_text_feats requires all_himo_text_feats for distributed contrastive loss")
|
| 321 |
-
scale = logit_scales["global"].exp().clamp(max=100.0)
|
| 322 |
-
logits_i_t = -metric_pairwise_dist(image_feats, all_himo_text_feats, embeddings["kappa"], product_metric=self.product_metric) * scale
|
| 323 |
-
logits_t_i = -metric_pairwise_dist(himo_text_feats, embeddings["all_image_feats"], embeddings["kappa"], product_metric=self.product_metric) * scale
|
| 324 |
-
himo_component_loss = 0.5 * (contrastive_ce(logits_i_t, global_targets) + contrastive_ce(logits_t_i, global_targets))
|
| 325 |
-
if self.entailment_loss == "argent":
|
| 326 |
-
entailment = uncha_argent_entailment_losses(
|
| 327 |
-
image_feats=image_feats,
|
| 328 |
-
text_feats=text_feats,
|
| 329 |
-
part_image_flat=part_image_flat,
|
| 330 |
-
part_text_flat=part_text_flat,
|
| 331 |
-
image_for_parts=image_for_parts,
|
| 332 |
-
text_for_parts=text_for_parts,
|
| 333 |
-
kappa=embeddings["kappa"],
|
| 334 |
-
beta=self.argent_beta,
|
| 335 |
-
part_weights=part_weights,
|
| 336 |
-
product_metric=self.product_metric,
|
| 337 |
-
aggregation=self.argent_aggregation,
|
| 338 |
-
)
|
| 339 |
-
elif self.entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}:
|
| 340 |
-
required = (
|
| 341 |
-
"beta_query_image_feats",
|
| 342 |
-
"beta_query_text_feats",
|
| 343 |
-
"beta_query_owner",
|
| 344 |
-
"beta_query_parent",
|
| 345 |
-
"beta_query_weight",
|
| 346 |
-
)
|
| 347 |
-
if self.entailment_loss == "hier_beta_sourcepart_argent":
|
| 348 |
-
required = (*required, "beta_query_source_part")
|
| 349 |
-
missing = [key for key in required if embeddings.get(key) is None]
|
| 350 |
-
if missing:
|
| 351 |
-
raise ValueError(f"{self.entailment_loss} requires beta query embeddings: missing {missing}")
|
| 352 |
-
entailment = hierarchical_beta_argent_entailment_losses(
|
| 353 |
-
image_feats=image_feats,
|
| 354 |
-
text_feats=text_feats,
|
| 355 |
-
part_image_flat=part_image_flat,
|
| 356 |
-
part_text_flat=part_text_flat,
|
| 357 |
-
image_for_parts=image_for_parts,
|
| 358 |
-
text_for_parts=text_for_parts,
|
| 359 |
-
beta_query_image_feats=embeddings["beta_query_image_feats"],
|
| 360 |
-
beta_query_text_feats=embeddings["beta_query_text_feats"],
|
| 361 |
-
beta_query_owner=embeddings["beta_query_owner"],
|
| 362 |
-
beta_query_parent=embeddings["beta_query_parent"],
|
| 363 |
-
beta_query_weight=embeddings["beta_query_weight"],
|
| 364 |
-
beta_query_source_part=embeddings.get("beta_query_source_part")
|
| 365 |
-
if self.entailment_loss == "hier_beta_sourcepart_argent"
|
| 366 |
-
else None,
|
| 367 |
-
kappa=embeddings["kappa"],
|
| 368 |
-
beta=self.argent_beta,
|
| 369 |
-
part_weights=part_weights,
|
| 370 |
-
product_metric=self.product_metric,
|
| 371 |
-
aggregation=self.argent_aggregation,
|
| 372 |
-
)
|
| 373 |
-
else:
|
| 374 |
-
piecewise_entailment = uncha_entailment_losses(
|
| 375 |
-
image_feats=image_feats,
|
| 376 |
-
text_feats=text_feats,
|
| 377 |
-
part_image_flat=part_image_flat,
|
| 378 |
-
part_text_flat=part_text_flat,
|
| 379 |
-
image_for_parts=image_for_parts,
|
| 380 |
-
text_for_parts=text_for_parts,
|
| 381 |
-
kappa=embeddings["kappa"],
|
| 382 |
-
inter_aperture_scale=self.inter_aperture_scale,
|
| 383 |
-
intra_aperture_scale=self.intra_aperture_scale,
|
| 384 |
-
piecewise_factor=self.piecewise_factor,
|
| 385 |
-
calibration_alpha=self.calibration_alpha,
|
| 386 |
-
stop_grad_calibration=self.stop_grad_calibration,
|
| 387 |
-
geometry=self.entailment_geometry,
|
| 388 |
-
part_weights=part_weights,
|
| 389 |
-
)
|
| 390 |
-
if self.entailment_loss == "piecewise_argent":
|
| 391 |
-
argent_entailment = uncha_argent_entailment_losses(
|
| 392 |
-
image_feats=image_feats,
|
| 393 |
-
text_feats=text_feats,
|
| 394 |
-
part_image_flat=part_image_flat,
|
| 395 |
-
part_text_flat=part_text_flat,
|
| 396 |
-
image_for_parts=image_for_parts,
|
| 397 |
-
text_for_parts=text_for_parts,
|
| 398 |
-
kappa=embeddings["kappa"],
|
| 399 |
-
beta=self.argent_beta,
|
| 400 |
-
part_weights=part_weights,
|
| 401 |
-
product_metric=self.product_metric,
|
| 402 |
-
aggregation=self.argent_aggregation,
|
| 403 |
-
)
|
| 404 |
-
entailment = {
|
| 405 |
-
**piecewise_entailment,
|
| 406 |
-
"entailment_loss": piecewise_entailment["entailment_loss"]
|
| 407 |
-
+ self.argent_aux_weight * argent_entailment["entailment_loss"],
|
| 408 |
-
"piecewise_entailment_loss": piecewise_entailment["entailment_loss"],
|
| 409 |
-
"argent_entailment_loss": argent_entailment["entailment_loss"],
|
| 410 |
-
"norm_regularization_loss": argent_entailment["norm_regularization_loss"],
|
| 411 |
-
}
|
| 412 |
-
else:
|
| 413 |
-
entailment = piecewise_entailment
|
| 414 |
-
aggregate = aggregate_part_consistency_loss(
|
| 415 |
-
image_feats=image_feats,
|
| 416 |
-
text_feats=text_feats,
|
| 417 |
-
part_image_flat=part_image_flat,
|
| 418 |
-
part_text_flat=part_text_flat,
|
| 419 |
-
part_owner=part_owner,
|
| 420 |
-
part_weights=part_weights,
|
| 421 |
-
)
|
| 422 |
-
radius_order = image_feats.new_zeros(())
|
| 423 |
-
if self.radius_order_weight > 0.0:
|
| 424 |
-
radius_order = (
|
| 425 |
-
radius_order_hinge(image_feats, text_feats, embeddings["kappa"], self.radius_order_margin)
|
| 426 |
-
+ radius_order_hinge(part_image_flat, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
|
| 427 |
-
+ radius_order_hinge(image_for_parts, part_image_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
|
| 428 |
-
+ radius_order_hinge(text_for_parts, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
|
| 429 |
-
)
|
| 430 |
-
gramian_align = image_feats.new_zeros(())
|
| 431 |
-
if self.gramian_align_weight > 0.0 and part_owner.numel() > 0:
|
| 432 |
-
def _tangent_flat(x: Tensor) -> Tensor:
|
| 433 |
-
tangent = log_map0(x, embeddings["kappa"])
|
| 434 |
-
return tangent.reshape(tangent.size(0), -1) if tangent.dim() == 3 else tangent
|
| 435 |
-
|
| 436 |
-
gramian_vectors = torch.stack(
|
| 437 |
-
[
|
| 438 |
-
_tangent_flat(image_for_parts),
|
| 439 |
-
_tangent_flat(text_for_parts),
|
| 440 |
-
_tangent_flat(part_image_flat),
|
| 441 |
-
_tangent_flat(part_text_flat),
|
| 442 |
-
],
|
| 443 |
-
dim=1,
|
| 444 |
-
)
|
| 445 |
-
gramian_align = gramian_volume_loss(gramian_vectors, part_weights)
|
| 446 |
-
entail_weight_scale = embeddings.get("entail_weight_scale", image_feats.new_ones(()))
|
| 447 |
-
total = (
|
| 448 |
-
contrastive["contrastive_loss"]
|
| 449 |
-
+ self.himo_component_weight * himo_component_loss
|
| 450 |
-
+ self.entail_weight * entail_weight_scale * entailment["entailment_loss"]
|
| 451 |
-
+ self.aggregate_weight * aggregate
|
| 452 |
-
+ self.radius_order_weight * radius_order
|
| 453 |
-
+ self.gramian_align_weight * gramian_align
|
| 454 |
-
+ self.argent_norm_weight * entailment.get(
|
| 455 |
-
"norm_regularization_loss",
|
| 456 |
-
image_feats.new_zeros(()),
|
| 457 |
-
)
|
| 458 |
-
)
|
| 459 |
-
return {
|
| 460 |
-
"loss": total,
|
| 461 |
-
**contrastive,
|
| 462 |
-
"himo_component_contrastive_loss": himo_component_loss,
|
| 463 |
-
**entailment,
|
| 464 |
-
"aggregate_consistency_loss": aggregate,
|
| 465 |
-
"radius_order_loss": radius_order,
|
| 466 |
-
"gramian_align_loss": gramian_align,
|
| 467 |
-
"part_count": part_count,
|
| 468 |
-
"entail_weight_scale": entail_weight_scale.detach(),
|
| 469 |
-
"part_quality_mean": (
|
| 470 |
-
image_feats.new_zeros(()) if quality_scores.numel() == 0 else quality_scores.mean().detach()
|
| 471 |
-
),
|
| 472 |
-
"part_quality_keep_fraction": (
|
| 473 |
-
image_feats.new_zeros(()) if quality_keep.numel() == 0 else quality_keep.mean().detach()
|
| 474 |
-
),
|
| 475 |
-
}
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
def build_objective(
|
| 479 |
-
objective: str,
|
| 480 |
-
entail_weight: float,
|
| 481 |
-
inter_aperture_scale: float,
|
| 482 |
-
intra_aperture_scale: float,
|
| 483 |
-
uncha_piecewise_factor: float = 0.1,
|
| 484 |
-
uncha_calibration_alpha: float = 10.0,
|
| 485 |
-
uncha_stop_grad_calibration: bool = True,
|
| 486 |
-
uncha_entailment_geometry: str = "lorentz",
|
| 487 |
-
uncha_aggregate_weight: float = 0.0,
|
| 488 |
-
uncha_entailment_loss: str = "piecewise",
|
| 489 |
-
uncha_argent_beta: float = 1.0,
|
| 490 |
-
uncha_argent_norm_weight: float = 0.0,
|
| 491 |
-
uncha_argent_aux_weight: float = 0.5,
|
| 492 |
-
uncha_argent_aggregation: str = "uncha",
|
| 493 |
-
uncha_part_weight_power: float = 0.0,
|
| 494 |
-
uncha_contrastive_loss: str = "ce",
|
| 495 |
-
uncha_sigmoid_negative_weight: float = 1.0,
|
| 496 |
-
uncha_part_quality_mode: str = "none",
|
| 497 |
-
uncha_part_quality_topk: int = 5,
|
| 498 |
-
uncha_part_quality_temperature: float = 4.0,
|
| 499 |
-
uncha_contrastive_global_weight: float = 1.0,
|
| 500 |
-
uncha_contrastive_local_weight: float = 1.0,
|
| 501 |
-
uncha_contrastive_global_local_weight: float = 1.0,
|
| 502 |
-
uncha_beta_cal_beta: float = 0.0,
|
| 503 |
-
uncha_beta_cal_variant: str = "ce",
|
| 504 |
-
uncha_beta_cal_weight: float = 0.0,
|
| 505 |
-
uncha_himo_component_weight: float = 0.0,
|
| 506 |
-
uncha_global_local_mode: str = "repeat",
|
| 507 |
-
uncha_global_local_metric: str = "distance",
|
| 508 |
-
uncha_global_local_angle_aux_weight: float = 0.0,
|
| 509 |
-
uncha_global_local_angle_aux_mode: str = "contrastive",
|
| 510 |
-
uncha_global_local_angle_aux_scale: float = 5.5,
|
| 511 |
-
uncha_global_local_angle_aux_aperture_scale: float = 1.0,
|
| 512 |
-
uncha_radius_order_weight: float = 0.0,
|
| 513 |
-
uncha_radius_order_margin: float = 0.0,
|
| 514 |
-
uncha_gramian_align_weight: float = 0.0,
|
| 515 |
-
product_metric: str = "l1",
|
| 516 |
-
) -> nn.Module:
|
| 517 |
-
if objective == "hycoclip":
|
| 518 |
-
return HyCoCLIPObjective(
|
| 519 |
-
entail_weight=entail_weight,
|
| 520 |
-
inter_aperture_scale=inter_aperture_scale,
|
| 521 |
-
intra_aperture_scale=intra_aperture_scale,
|
| 522 |
-
product_metric=product_metric,
|
| 523 |
-
)
|
| 524 |
-
if objective == "uncha":
|
| 525 |
-
return UNCHAObjective(
|
| 526 |
-
entail_weight=entail_weight,
|
| 527 |
-
inter_aperture_scale=inter_aperture_scale,
|
| 528 |
-
intra_aperture_scale=intra_aperture_scale,
|
| 529 |
-
piecewise_factor=uncha_piecewise_factor,
|
| 530 |
-
calibration_alpha=uncha_calibration_alpha,
|
| 531 |
-
stop_grad_calibration=uncha_stop_grad_calibration,
|
| 532 |
-
entailment_geometry=uncha_entailment_geometry,
|
| 533 |
-
aggregate_weight=uncha_aggregate_weight,
|
| 534 |
-
entailment_loss=uncha_entailment_loss,
|
| 535 |
-
argent_beta=uncha_argent_beta,
|
| 536 |
-
argent_norm_weight=uncha_argent_norm_weight,
|
| 537 |
-
argent_aux_weight=uncha_argent_aux_weight,
|
| 538 |
-
argent_aggregation=uncha_argent_aggregation,
|
| 539 |
-
part_weight_power=uncha_part_weight_power,
|
| 540 |
-
product_metric=product_metric,
|
| 541 |
-
contrastive_loss=uncha_contrastive_loss,
|
| 542 |
-
sigmoid_negative_weight=uncha_sigmoid_negative_weight,
|
| 543 |
-
part_quality_mode=uncha_part_quality_mode,
|
| 544 |
-
part_quality_topk=uncha_part_quality_topk,
|
| 545 |
-
part_quality_temperature=uncha_part_quality_temperature,
|
| 546 |
-
contrastive_global_weight=uncha_contrastive_global_weight,
|
| 547 |
-
contrastive_local_weight=uncha_contrastive_local_weight,
|
| 548 |
-
contrastive_global_local_weight=uncha_contrastive_global_local_weight,
|
| 549 |
-
beta_cal_beta=uncha_beta_cal_beta,
|
| 550 |
-
beta_cal_variant=uncha_beta_cal_variant,
|
| 551 |
-
beta_cal_weight=uncha_beta_cal_weight,
|
| 552 |
-
himo_component_weight=uncha_himo_component_weight,
|
| 553 |
-
global_local_mode=uncha_global_local_mode,
|
| 554 |
-
global_local_metric=uncha_global_local_metric,
|
| 555 |
-
global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
|
| 556 |
-
global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
|
| 557 |
-
global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
|
| 558 |
-
global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
|
| 559 |
-
radius_order_weight=uncha_radius_order_weight,
|
| 560 |
-
radius_order_margin=uncha_radius_order_margin,
|
| 561 |
-
gramian_align_weight=uncha_gramian_align_weight,
|
| 562 |
-
)
|
| 563 |
-
raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip' or 'uncha'")
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
def _part_weights(part_owner: Tensor, batch_size: int, power: float) -> Tensor | None:
|
| 567 |
-
if power <= 0.0 or part_owner.numel() == 0:
|
| 568 |
-
return None
|
| 569 |
-
counts = torch.bincount(part_owner, minlength=batch_size).to(dtype=torch.float32, device=part_owner.device)
|
| 570 |
-
weights = counts[part_owner].clamp_min(1.0).pow(-power)
|
| 571 |
-
return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
def _combine_part_weights(count_weights: Tensor | None, quality_weights: Tensor | None) -> Tensor | None:
|
| 575 |
-
if count_weights is None:
|
| 576 |
-
return quality_weights
|
| 577 |
-
if quality_weights is None:
|
| 578 |
-
return count_weights
|
| 579 |
-
weights = count_weights * quality_weights
|
| 580 |
-
return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/models/tren.py
DELETED
|
@@ -1,255 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from torch import Tensor, nn
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class FourierPositionEncoding2D(nn.Module):
|
| 11 |
-
def __init__(self, dim: int, scale: float = 1.0) -> None:
|
| 12 |
-
super().__init__()
|
| 13 |
-
if dim <= 0 or dim % 2 != 0:
|
| 14 |
-
raise ValueError("FourierPositionEncoding2D dim must be a positive even integer")
|
| 15 |
-
if scale <= 0.0:
|
| 16 |
-
raise ValueError("FourierPositionEncoding2D scale must be positive")
|
| 17 |
-
generator = torch.Generator()
|
| 18 |
-
generator.manual_seed(42)
|
| 19 |
-
self.register_buffer("gaussian_matrix", scale * torch.randn((2, dim // 2), generator=generator))
|
| 20 |
-
|
| 21 |
-
def forward(self, coords: Tensor) -> Tensor:
|
| 22 |
-
projected = (2.0 * coords.float() - 1.0) @ self.gaussian_matrix
|
| 23 |
-
projected = 2.0 * math.pi * projected
|
| 24 |
-
return torch.cat([torch.sin(projected), torch.cos(projected)], dim=-1)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class _MLPBlock(nn.Module):
|
| 28 |
-
def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
|
| 29 |
-
super().__init__()
|
| 30 |
-
self.net = nn.Sequential(
|
| 31 |
-
nn.Linear(dim, hidden_dim),
|
| 32 |
-
nn.GELU(),
|
| 33 |
-
nn.Dropout(dropout),
|
| 34 |
-
nn.Linear(hidden_dim, dim),
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def forward(self, x: Tensor) -> Tensor:
|
| 38 |
-
return self.net(x)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class _AttentionLayer(nn.Module):
|
| 42 |
-
def __init__(
|
| 43 |
-
self,
|
| 44 |
-
q_dim: int,
|
| 45 |
-
kv_dim: int,
|
| 46 |
-
hidden_dim: int,
|
| 47 |
-
*,
|
| 48 |
-
num_heads: int,
|
| 49 |
-
dropout: float,
|
| 50 |
-
use_bias: bool = False,
|
| 51 |
-
use_v_proj: bool = True,
|
| 52 |
-
use_out_proj: bool = True,
|
| 53 |
-
) -> None:
|
| 54 |
-
super().__init__()
|
| 55 |
-
if hidden_dim % num_heads != 0:
|
| 56 |
-
raise ValueError("hidden_dim must be divisible by num_heads")
|
| 57 |
-
if not use_v_proj and kv_dim != hidden_dim:
|
| 58 |
-
raise ValueError("kv_dim must equal hidden_dim when value projection is disabled")
|
| 59 |
-
self.hidden_dim = hidden_dim
|
| 60 |
-
self.num_heads = num_heads
|
| 61 |
-
self.head_dim = hidden_dim // num_heads
|
| 62 |
-
self.q_proj = nn.Linear(q_dim, hidden_dim, bias=use_bias)
|
| 63 |
-
self.k_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias)
|
| 64 |
-
self.v_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias) if use_v_proj else nn.Identity()
|
| 65 |
-
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias) if use_out_proj else nn.Identity()
|
| 66 |
-
self.q_norm = nn.LayerNorm(self.head_dim)
|
| 67 |
-
self.k_norm = nn.LayerNorm(self.head_dim)
|
| 68 |
-
self.dropout = nn.Dropout(dropout)
|
| 69 |
-
self.scale = self.head_dim**-0.5
|
| 70 |
-
|
| 71 |
-
nn.init.kaiming_normal_(self.q_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 72 |
-
nn.init.kaiming_normal_(self.k_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 73 |
-
if isinstance(self.v_proj, nn.Linear):
|
| 74 |
-
nn.init.kaiming_normal_(self.v_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 75 |
-
if isinstance(self.out_proj, nn.Linear):
|
| 76 |
-
nn.init.kaiming_normal_(self.out_proj.weight, mode="fan_in", nonlinearity="linear")
|
| 77 |
-
|
| 78 |
-
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 79 |
-
batch_size, q_len, _ = q.shape
|
| 80 |
-
_, kv_len, _ = k.shape
|
| 81 |
-
query = self.q_proj(q).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 82 |
-
key = self.k_proj(k).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
-
value = self.v_proj(v).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 84 |
-
|
| 85 |
-
query = self.q_norm(query)
|
| 86 |
-
key = self.k_norm(key)
|
| 87 |
-
attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
|
| 88 |
-
attn_weights = self.dropout(F.softmax(attn_scores, dim=-1))
|
| 89 |
-
out = torch.matmul(attn_weights, value)
|
| 90 |
-
out = out.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_dim)
|
| 91 |
-
return self.out_proj(out), attn_weights
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class _CrossAttentionBlock(nn.Module):
|
| 95 |
-
def __init__(self, dim: int, *, num_heads: int, dropout: float) -> None:
|
| 96 |
-
super().__init__()
|
| 97 |
-
self.query_norm = nn.LayerNorm(dim)
|
| 98 |
-
self.cross_attn = _AttentionLayer(dim, dim, dim, num_heads=num_heads, dropout=dropout)
|
| 99 |
-
self.dropout = nn.Dropout(dropout)
|
| 100 |
-
self.mlp_norm = nn.LayerNorm(dim)
|
| 101 |
-
self.mlp = _MLPBlock(dim, 2 * dim, dropout)
|
| 102 |
-
self.out_norm = nn.LayerNorm(dim)
|
| 103 |
-
|
| 104 |
-
def forward(self, query: Tensor, context: Tensor) -> Tensor:
|
| 105 |
-
x, _ = self.cross_attn(self.query_norm(query), context, context)
|
| 106 |
-
x = query + self.dropout(x)
|
| 107 |
-
return self.out_norm(x + self.mlp(self.mlp_norm(x)))
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
class TRENRegionEncoder(nn.Module):
|
| 111 |
-
"""T-REN-style point-prompted region token encoder.
|
| 112 |
-
|
| 113 |
-
The module follows the public T-REN architecture: learned k-per-prompt
|
| 114 |
-
query tokens, Fourier 2D prompt/patch position encodings, alternating
|
| 115 |
-
cross-attention and per-prompt self-attention, then final single-head
|
| 116 |
-
attention that pools unprojected patch tokens into region tokens.
|
| 117 |
-
"""
|
| 118 |
-
|
| 119 |
-
def __init__(
|
| 120 |
-
self,
|
| 121 |
-
vision_dim: int,
|
| 122 |
-
text_dim: int,
|
| 123 |
-
*,
|
| 124 |
-
hidden_dim: int | None = None,
|
| 125 |
-
num_region_tokens: int = 3,
|
| 126 |
-
num_decoder_layers: int = 2,
|
| 127 |
-
num_attention_heads: int = 8,
|
| 128 |
-
prompt_grid_size: int = 7,
|
| 129 |
-
dropout: float = 0.1,
|
| 130 |
-
) -> None:
|
| 131 |
-
super().__init__()
|
| 132 |
-
if num_region_tokens <= 0:
|
| 133 |
-
raise ValueError("num_region_tokens must be positive")
|
| 134 |
-
if num_decoder_layers <= 0:
|
| 135 |
-
raise ValueError("num_decoder_layers must be positive")
|
| 136 |
-
if prompt_grid_size <= 0:
|
| 137 |
-
raise ValueError("prompt_grid_size must be positive")
|
| 138 |
-
hidden_dim = int(hidden_dim or vision_dim)
|
| 139 |
-
if hidden_dim != vision_dim:
|
| 140 |
-
raise ValueError("TRENRegionEncoder currently requires hidden_dim == vision_dim")
|
| 141 |
-
if hidden_dim % 2 != 0:
|
| 142 |
-
raise ValueError("TRENRegionEncoder hidden_dim must be even for Fourier features")
|
| 143 |
-
if hidden_dim % num_attention_heads != 0:
|
| 144 |
-
raise ValueError("TRENRegionEncoder hidden_dim must be divisible by num_attention_heads")
|
| 145 |
-
|
| 146 |
-
self.vision_dim = vision_dim
|
| 147 |
-
self.text_dim = text_dim
|
| 148 |
-
self.hidden_dim = hidden_dim
|
| 149 |
-
self.num_region_tokens = num_region_tokens
|
| 150 |
-
self.prompt_grid_size = prompt_grid_size
|
| 151 |
-
self.position_encoder = FourierPositionEncoding2D(hidden_dim)
|
| 152 |
-
self.region_token_embeddings = nn.Embedding(num_region_tokens, hidden_dim)
|
| 153 |
-
nn.init.normal_(self.region_token_embeddings.weight, std=0.02)
|
| 154 |
-
self.region_attention_layers = nn.ModuleList(
|
| 155 |
-
[_CrossAttentionBlock(hidden_dim, num_heads=num_attention_heads, dropout=dropout) for _ in range(num_decoder_layers)]
|
| 156 |
-
)
|
| 157 |
-
self.region_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
|
| 158 |
-
self.prompt_attention_layers = nn.ModuleList(
|
| 159 |
-
[
|
| 160 |
-
_AttentionLayer(
|
| 161 |
-
hidden_dim,
|
| 162 |
-
hidden_dim,
|
| 163 |
-
hidden_dim,
|
| 164 |
-
num_heads=num_attention_heads,
|
| 165 |
-
dropout=dropout,
|
| 166 |
-
)
|
| 167 |
-
for _ in range(num_decoder_layers)
|
| 168 |
-
]
|
| 169 |
-
)
|
| 170 |
-
self.prompt_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
|
| 171 |
-
self.token_prediction_head = _AttentionLayer(
|
| 172 |
-
hidden_dim,
|
| 173 |
-
hidden_dim,
|
| 174 |
-
hidden_dim,
|
| 175 |
-
num_heads=1,
|
| 176 |
-
dropout=0.0,
|
| 177 |
-
use_v_proj=False,
|
| 178 |
-
use_out_proj=False,
|
| 179 |
-
)
|
| 180 |
-
self.text_alignment_block = nn.Sequential(
|
| 181 |
-
nn.Linear(hidden_dim, 2 * hidden_dim),
|
| 182 |
-
nn.GELU(),
|
| 183 |
-
nn.Dropout(dropout),
|
| 184 |
-
nn.Linear(2 * hidden_dim, text_dim),
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
def forward(self, image_tokens: Tensor) -> dict[str, Tensor]:
|
| 188 |
-
patch_tokens, patch_grid = _patch_tokens_and_grid(image_tokens)
|
| 189 |
-
batch_size, patch_count, _ = patch_tokens.shape
|
| 190 |
-
patch_coords = _grid_coords(patch_grid, patch_grid, patch_tokens.device)
|
| 191 |
-
prompt_coords = _grid_coords(self.prompt_grid_size, self.prompt_grid_size, patch_tokens.device)
|
| 192 |
-
prompt_count = prompt_coords.size(0)
|
| 193 |
-
|
| 194 |
-
feature_pos = self.position_encoder(patch_coords).to(dtype=patch_tokens.dtype)
|
| 195 |
-
prompt_pos = self.position_encoder(prompt_coords).to(dtype=patch_tokens.dtype)
|
| 196 |
-
kv = patch_tokens + feature_pos.unsqueeze(0)
|
| 197 |
-
prompt_pos = prompt_pos.view(1, prompt_count, 1, self.hidden_dim)
|
| 198 |
-
|
| 199 |
-
q = self.region_token_embeddings.weight.to(dtype=patch_tokens.dtype)
|
| 200 |
-
q = q.view(1, 1, self.num_region_tokens, self.hidden_dim).expand(
|
| 201 |
-
batch_size,
|
| 202 |
-
prompt_count,
|
| 203 |
-
self.num_region_tokens,
|
| 204 |
-
self.hidden_dim,
|
| 205 |
-
)
|
| 206 |
-
for region_layer, region_norm, prompt_layer, prompt_norm in zip(
|
| 207 |
-
self.region_attention_layers,
|
| 208 |
-
self.region_attention_norms,
|
| 209 |
-
self.prompt_attention_layers,
|
| 210 |
-
self.prompt_attention_norms,
|
| 211 |
-
strict=True,
|
| 212 |
-
):
|
| 213 |
-
q = q + prompt_pos
|
| 214 |
-
q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
|
| 215 |
-
q = region_layer(q, kv)
|
| 216 |
-
q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 217 |
-
q = region_norm(q)
|
| 218 |
-
q = q.reshape(batch_size * prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 219 |
-
q, _ = prompt_layer(q, q, q)
|
| 220 |
-
q = prompt_norm(q)
|
| 221 |
-
q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 222 |
-
|
| 223 |
-
flat_q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
|
| 224 |
-
visual_tokens, attn_weights = self.token_prediction_head(flat_q, kv, patch_tokens)
|
| 225 |
-
visual_tokens = visual_tokens.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
|
| 226 |
-
attn_weights = attn_weights.squeeze(1).reshape(batch_size, prompt_count, self.num_region_tokens, patch_count)
|
| 227 |
-
region_masks = attn_weights / attn_weights.amax(dim=-1, keepdim=True).clamp_min(torch.finfo(attn_weights.dtype).eps)
|
| 228 |
-
region_masks = region_masks.reshape(batch_size, prompt_count, self.num_region_tokens, patch_grid, patch_grid)
|
| 229 |
-
text_aligned_tokens = self.text_alignment_block(visual_tokens)
|
| 230 |
-
return {
|
| 231 |
-
"visual_tokens": visual_tokens,
|
| 232 |
-
"text_aligned_tokens": text_aligned_tokens,
|
| 233 |
-
"region_masks": region_masks,
|
| 234 |
-
"prompt_coords": prompt_coords,
|
| 235 |
-
}
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
def _patch_tokens_and_grid(tokens: Tensor) -> tuple[Tensor, int]:
|
| 239 |
-
if tokens.ndim != 3:
|
| 240 |
-
raise ValueError("TRENRegionEncoder expects image tokens with shape [batch, tokens, dim]")
|
| 241 |
-
token_count = tokens.size(1)
|
| 242 |
-
grid = int(math.isqrt(token_count))
|
| 243 |
-
if grid * grid == token_count:
|
| 244 |
-
return tokens, grid
|
| 245 |
-
grid = int(math.isqrt(token_count - 1))
|
| 246 |
-
if grid * grid == token_count - 1:
|
| 247 |
-
return tokens[:, 1:, :], grid
|
| 248 |
-
raise ValueError(f"Cannot infer a square patch grid from {token_count} image tokens")
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
def _grid_coords(height: int, width: int, device: torch.device) -> Tensor:
|
| 252 |
-
y = torch.linspace(0.5 / height, 1.0 - 0.5 / height, height, device=device)
|
| 253 |
-
x = torch.linspace(0.5 / width, 1.0 - 0.5 / width, width, device=device)
|
| 254 |
-
yy, xx = torch.meshgrid(y, x, indexing="ij")
|
| 255 |
-
return torch.stack([xx, yy], dim=-1).reshape(-1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip/training/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
__all__: list[str] = []
|
|
|
|
|
|
hyper3_clip/training/distributed.py
DELETED
|
@@ -1,149 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from collections.abc import Sequence
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.distributed as dist
|
| 8 |
-
from torch.distributed.nn import all_gather as differentiable_all_gather
|
| 9 |
-
from torch import Tensor
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def init_distributed() -> None:
|
| 13 |
-
if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not dist.is_initialized():
|
| 14 |
-
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
| 15 |
-
if torch.cuda.is_available():
|
| 16 |
-
torch.cuda.set_device(get_local_rank())
|
| 17 |
-
dist.init_process_group(backend=backend)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def is_distributed() -> bool:
|
| 21 |
-
return dist.is_available() and dist.is_initialized()
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def barrier() -> None:
|
| 25 |
-
if is_distributed():
|
| 26 |
-
dist.barrier()
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def destroy_distributed() -> None:
|
| 30 |
-
if is_distributed():
|
| 31 |
-
dist.destroy_process_group()
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_rank() -> int:
|
| 35 |
-
return dist.get_rank() if is_distributed() else 0
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def get_world_size() -> int:
|
| 39 |
-
return dist.get_world_size() if is_distributed() else 1
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_local_rank() -> int:
|
| 43 |
-
return int(os.environ.get("LOCAL_RANK", "0"))
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def is_main_process() -> bool:
|
| 47 |
-
return get_rank() == 0
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def gather_with_grad(tensor: Tensor) -> Tensor:
|
| 51 |
-
world_size = get_world_size()
|
| 52 |
-
if world_size == 1:
|
| 53 |
-
return tensor
|
| 54 |
-
return torch.cat(list(differentiable_all_gather(tensor.contiguous())), dim=0)
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def gather_variable_with_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
|
| 58 |
-
"""Gather tensors with variable first-dimension lengths across ranks."""
|
| 59 |
-
count_tensor, max_count, keep = _variable_gather_metadata(tensor)
|
| 60 |
-
if get_world_size() == 1:
|
| 61 |
-
return tensor, count_tensor
|
| 62 |
-
return _gather_variable_from_metadata(tensor, max_count, keep), count_tensor
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def gather_variable_many_with_grad(tensors: Sequence[Tensor]) -> tuple[list[Tensor], Tensor]:
|
| 66 |
-
"""Gather same-length variable tensors while sharing count metadata.
|
| 67 |
-
|
| 68 |
-
Tensors with matching dtype/rank/trailing shape are packed along the last
|
| 69 |
-
dimension so a single differentiable all-gather can serve several feature
|
| 70 |
-
tensors with the same variable first dimension.
|
| 71 |
-
"""
|
| 72 |
-
if not tensors:
|
| 73 |
-
raise ValueError("gather_variable_many_with_grad requires at least one tensor")
|
| 74 |
-
first = tensors[0]
|
| 75 |
-
for tensor in tensors:
|
| 76 |
-
if tensor.device != first.device:
|
| 77 |
-
raise ValueError("all tensors must be on the same device")
|
| 78 |
-
if tensor.shape[0] != first.shape[0]:
|
| 79 |
-
raise ValueError("all tensors must have the same first dimension")
|
| 80 |
-
count_tensor, max_count, keep = _variable_gather_metadata(first)
|
| 81 |
-
if get_world_size() == 1:
|
| 82 |
-
return list(tensors), count_tensor
|
| 83 |
-
|
| 84 |
-
gathered: list[Tensor | None] = [None] * len(tensors)
|
| 85 |
-
groups: dict[tuple[torch.dtype, torch.Size, int], list[int]] = {}
|
| 86 |
-
for index, tensor in enumerate(tensors):
|
| 87 |
-
if tensor.dim() == 0:
|
| 88 |
-
raise ValueError("variable gather tensors must have at least one dimension")
|
| 89 |
-
key = (tensor.dtype, tensor.shape[1:-1], tensor.dim()) if tensor.dim() > 1 else (tensor.dtype, torch.Size(), 1)
|
| 90 |
-
groups.setdefault(key, []).append(index)
|
| 91 |
-
|
| 92 |
-
for indices in groups.values():
|
| 93 |
-
group_tensors = [tensors[index] for index in indices]
|
| 94 |
-
if len(group_tensors) == 1 or group_tensors[0].dim() == 1:
|
| 95 |
-
for index, tensor in zip(indices, group_tensors, strict=True):
|
| 96 |
-
gathered[index] = _gather_variable_from_metadata(tensor, max_count, keep)
|
| 97 |
-
continue
|
| 98 |
-
widths = [tensor.shape[-1] for tensor in group_tensors]
|
| 99 |
-
packed = torch.cat(group_tensors, dim=-1)
|
| 100 |
-
gathered_packed = _gather_variable_from_metadata(packed, max_count, keep)
|
| 101 |
-
for index, chunk in zip(indices, gathered_packed.split(widths, dim=-1), strict=True):
|
| 102 |
-
gathered[index] = chunk
|
| 103 |
-
|
| 104 |
-
if any(tensor is None for tensor in gathered):
|
| 105 |
-
raise RuntimeError("internal error while gathering variable tensors")
|
| 106 |
-
return [tensor for tensor in gathered if tensor is not None], count_tensor
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def gather_variable_no_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
|
| 110 |
-
"""Gather variable-length tensors that do not require autograd."""
|
| 111 |
-
count_tensor, max_count, keep = _variable_gather_metadata(tensor)
|
| 112 |
-
if get_world_size() == 1:
|
| 113 |
-
return tensor, count_tensor
|
| 114 |
-
padded = tensor.new_zeros((max_count, *tensor.shape[1:]))
|
| 115 |
-
padded[: tensor.shape[0]] = tensor
|
| 116 |
-
gathered = [torch.zeros_like(padded) for _ in range(get_world_size())]
|
| 117 |
-
dist.all_gather(gathered, padded.contiguous())
|
| 118 |
-
return torch.cat(gathered, dim=0)[keep], count_tensor
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def _variable_gather_metadata(tensor: Tensor) -> tuple[Tensor, int, Tensor]:
|
| 122 |
-
world_size = get_world_size()
|
| 123 |
-
local_count = torch.tensor([tensor.shape[0]], device=tensor.device, dtype=torch.long)
|
| 124 |
-
if world_size == 1:
|
| 125 |
-
keep = torch.ones(tensor.shape[0], device=tensor.device, dtype=torch.bool)
|
| 126 |
-
return local_count, tensor.shape[0], keep
|
| 127 |
-
|
| 128 |
-
counts = [torch.zeros_like(local_count) for _ in range(world_size)]
|
| 129 |
-
dist.all_gather(counts, local_count)
|
| 130 |
-
count_tensor = torch.cat(counts)
|
| 131 |
-
max_count = int(count_tensor.max().item())
|
| 132 |
-
keep = torch.zeros(world_size * max_count, device=tensor.device, dtype=torch.bool)
|
| 133 |
-
for rank, count in enumerate(count_tensor.tolist()):
|
| 134 |
-
start = rank * max_count
|
| 135 |
-
keep[start : start + count] = True
|
| 136 |
-
return count_tensor, max_count, keep
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def _gather_variable_from_metadata(tensor: Tensor, max_count: int, keep: Tensor) -> Tensor:
|
| 140 |
-
padded_shape = (max_count, *tensor.shape[1:])
|
| 141 |
-
padded = tensor.new_zeros(padded_shape)
|
| 142 |
-
padded[: tensor.shape[0]] = tensor
|
| 143 |
-
|
| 144 |
-
gathered = torch.cat(list(differentiable_all_gather(padded.contiguous())), dim=0)
|
| 145 |
-
return gathered[keep]
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def local_target_indices(batch_size: int, device: torch.device) -> Tensor:
|
| 149 |
-
return torch.arange(batch_size, device=device) + batch_size * get_rank()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyper3_clip_provider.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
"""HyperView embedding provider for the Hyper3-CLIP v0.5 HF checkpoint."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch
|
| 11 |
-
import yaml
|
| 12 |
-
from huggingface_hub import snapshot_download
|
| 13 |
-
from lancedb.embeddings import EmbeddingFunction
|
| 14 |
-
from pydantic import PrivateAttr
|
| 15 |
-
from safetensors.torch import load_file
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class Hyper3ClipEmbeddings(EmbeddingFunction):
|
| 19 |
-
"""Image embeddings from Hyper3-CLIP v0.5 in Lorentz/hyperboloid space."""
|
| 20 |
-
|
| 21 |
-
name: str = "hyper3labs/hyper3-clip-v0.5"
|
| 22 |
-
batch_size: int = 8
|
| 23 |
-
device: str = "cpu"
|
| 24 |
-
|
| 25 |
-
_model: Any = PrivateAttr(default=None)
|
| 26 |
-
_transform: Any = PrivateAttr(default=None)
|
| 27 |
-
|
| 28 |
-
@property
|
| 29 |
-
def geometry(self) -> str:
|
| 30 |
-
return "hyperboloid"
|
| 31 |
-
|
| 32 |
-
@property
|
| 33 |
-
def curvature(self) -> float:
|
| 34 |
-
self._ensure_model()
|
| 35 |
-
return float(self._model._kappa().detach().cpu().reshape(-1)[0].item())
|
| 36 |
-
|
| 37 |
-
def ndims(self) -> int:
|
| 38 |
-
return 513
|
| 39 |
-
|
| 40 |
-
def _ensure_model(self) -> None:
|
| 41 |
-
if self._model is not None:
|
| 42 |
-
return
|
| 43 |
-
|
| 44 |
-
from hyper3_clip import Hyper3CLIP
|
| 45 |
-
from torchvision import transforms
|
| 46 |
-
|
| 47 |
-
token = os.environ.get("HF_TOKEN")
|
| 48 |
-
local_dir = snapshot_download(
|
| 49 |
-
self.name,
|
| 50 |
-
allow_patterns=["config.yaml", "model.safetensors"],
|
| 51 |
-
token=token,
|
| 52 |
-
)
|
| 53 |
-
root = Path(local_dir)
|
| 54 |
-
config = yaml.safe_load((root / "config.yaml").read_text(encoding="utf-8"))
|
| 55 |
-
|
| 56 |
-
model = Hyper3CLIP(**config["model"])
|
| 57 |
-
state = load_file(root / "model.safetensors", device="cpu")
|
| 58 |
-
model.load_state_dict(state)
|
| 59 |
-
model.to(torch.device(self.device))
|
| 60 |
-
model.eval()
|
| 61 |
-
|
| 62 |
-
self._model = model
|
| 63 |
-
image_size = int(config.get("data", {}).get("image_size", 224))
|
| 64 |
-
self._transform = transforms.Compose(
|
| 65 |
-
[
|
| 66 |
-
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 67 |
-
transforms.CenterCrop(image_size),
|
| 68 |
-
transforms.ToTensor(),
|
| 69 |
-
transforms.Normalize(
|
| 70 |
-
mean=(0.485, 0.456, 0.406),
|
| 71 |
-
std=(0.229, 0.224, 0.225),
|
| 72 |
-
),
|
| 73 |
-
]
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
def compute_source_embeddings(
|
| 77 |
-
self,
|
| 78 |
-
inputs: Any,
|
| 79 |
-
*args: Any,
|
| 80 |
-
**kwargs: Any,
|
| 81 |
-
) -> list[np.ndarray | None]:
|
| 82 |
-
from PIL import Image
|
| 83 |
-
from hyperview.core.sample import Sample
|
| 84 |
-
|
| 85 |
-
self._ensure_model()
|
| 86 |
-
device = torch.device(self.device)
|
| 87 |
-
images = []
|
| 88 |
-
for item in self.sanitize_input(inputs):
|
| 89 |
-
if isinstance(item, Sample):
|
| 90 |
-
with item.load_image() as img:
|
| 91 |
-
images.append(img.convert("RGB"))
|
| 92 |
-
elif isinstance(item, str):
|
| 93 |
-
with Image.open(item) as img:
|
| 94 |
-
images.append(img.convert("RGB"))
|
| 95 |
-
elif isinstance(item, Image.Image):
|
| 96 |
-
images.append(item.convert("RGB"))
|
| 97 |
-
else:
|
| 98 |
-
raise TypeError(f"Unsupported input type: {type(item)}")
|
| 99 |
-
|
| 100 |
-
outputs: list[np.ndarray | None] = []
|
| 101 |
-
with torch.inference_mode():
|
| 102 |
-
for start in range(0, len(images), self.batch_size):
|
| 103 |
-
batch = images[start:start + self.batch_size]
|
| 104 |
-
tensor = torch.stack([self._transform(image) for image in batch]).to(device)
|
| 105 |
-
encoded = self._model.encode_image(tensor).detach().cpu().numpy().astype(np.float32)
|
| 106 |
-
outputs.extend(encoded)
|
| 107 |
-
return outputs
|
| 108 |
-
|
| 109 |
-
def compute_query_embeddings(
|
| 110 |
-
self,
|
| 111 |
-
query: Any,
|
| 112 |
-
*args: Any,
|
| 113 |
-
**kwargs: Any,
|
| 114 |
-
) -> list[np.ndarray | None]:
|
| 115 |
-
return self.compute_source_embeddings([query], *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|