github-actions[bot] commited on
Commit
ccf4b11
·
1 Parent(s): 1779c3f

Deploy hyper3labs/HyperView-ABO-Catalog from Hyper3Labs/hyperview-spaces@fd3578c

Browse files
Dockerfile CHANGED
@@ -20,7 +20,8 @@ WORKDIR $HOME/app
20
 
21
  RUN pip install --upgrade pip
22
 
23
- ARG HYPERVIEW_VERSION=0.6.0
 
24
 
25
  # Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
26
  RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
@@ -33,14 +34,9 @@ import hyperview as hv
33
  print("hyperview", hv.__version__, inspect.signature(hv.launch))
34
  PY
35
  RUN pip install \
 
36
  "datasets>=4.5.0" \
37
- "Pillow>=12.0.0" \
38
- "timm>=1.0.0" \
39
- "transformers==4.49.0" \
40
- "safetensors>=0.4.0" \
41
- "pyyaml>=6.0.0" \
42
- "sentencepiece>=0.2.0" \
43
- "protobuf>=4.25.0"
44
 
45
  COPY --chown=user . .
46
 
 
20
 
21
  RUN pip install --upgrade pip
22
 
23
+ ARG HYPERVIEW_VERSION=0.6.1
24
+ ARG HYPER_MODELS_VERSION=0.3.0
25
 
26
  # Install CPU-only PyTorch first so the Space does not pull the default CUDA bundle.
27
  RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
 
34
  print("hyperview", hv.__version__, inspect.signature(hv.launch))
35
  PY
36
  RUN pip install \
37
+ "hyper-models[ml]==${HYPER_MODELS_VERSION}" \
38
  "datasets>=4.5.0" \
39
+ "Pillow>=12.0.0"
 
 
 
 
 
 
40
 
41
  COPY --chown=user . .
42
 
README.md CHANGED
@@ -14,7 +14,7 @@ This demo builds a small Amazon Berkeley Objects product-catalog subset and open
14
  HyperView with two pinned scatter panels plus a comparison readout:
15
 
16
  - CLIP ViT-B/32 in a Euclidean 2D layout
17
- - Hyper3-CLIP `hyper3labs/hyper3-clip-v0.5` in a Poincare 2D layout
18
 
19
  The right-side panel uses fixed product examples to compare nearest-neighbor
20
  behavior for the same query under each model.
@@ -45,7 +45,7 @@ variables or edit the second entry in `MODEL_SPECS`:
45
 
46
  ```bash
47
  ABO_CANDIDATE_DISPLAY_NAME="New Model" \
48
- ABO_CANDIDATE_PROVIDER="hyper3-clip" \
49
  ABO_CANDIDATE_MODEL="new-model-id" \
50
  ABO_CANDIDATE_LAYOUT="poincare:2d" \
51
  ABO_CANDIDATE_GEOMETRY="poincare" \
@@ -61,10 +61,11 @@ JavaScript.
61
  This folder is intended to deploy to `hyper3labs/HyperView-ABO-Catalog` from
62
  the `hyperview-spaces` deployment repository.
63
 
64
- The Dockerfile installs `hyperview==0.6.0` from PyPI. The released HyperView
65
- wheel includes the built frontend assets, so this Space does not carry a local
66
- `static/` bundle or copy frontend files into the installed package.
 
67
 
68
- Hyper3-CLIP weights are loaded from the gated
69
- `hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs an
70
- `HF_TOKEN` secret with access to that model.
 
14
  HyperView with two pinned scatter panels plus a comparison readout:
15
 
16
  - CLIP ViT-B/32 in a Euclidean 2D layout
17
+ - Hyper3-CLIP `hyper3-clip-v0.5` from `hyper-models` in a Poincare 2D layout
18
 
19
  The right-side panel uses fixed product examples to compare nearest-neighbor
20
  behavior for the same query under each model.
 
45
 
46
  ```bash
47
  ABO_CANDIDATE_DISPLAY_NAME="New Model" \
48
+ ABO_CANDIDATE_PROVIDER="hyper-models" \
49
  ABO_CANDIDATE_MODEL="new-model-id" \
50
  ABO_CANDIDATE_LAYOUT="poincare:2d" \
51
  ABO_CANDIDATE_GEOMETRY="poincare" \
 
61
  This folder is intended to deploy to `hyper3labs/HyperView-ABO-Catalog` from
62
  the `hyperview-spaces` deployment repository.
63
 
64
+ The Dockerfile installs `hyperview==0.6.1` and `hyper-models[ml]==0.3.0` from
65
+ PyPI. The released HyperView wheel includes the built frontend assets, so this
66
+ Space does not carry a local `static/` bundle or copy frontend files into the
67
+ installed package.
68
 
69
+ Hyper3-CLIP weights are loaded through the `hyper-models` catalog entry for the
70
+ gated `hyper3labs/hyper3-clip-v0.5` model repository at runtime. The Space needs
71
+ an `HF_TOKEN` secret with access to that model.
demo.py CHANGED
@@ -64,8 +64,8 @@ MODEL_SPECS = [
64
  "key": "candidate",
65
  "display_name": os.environ.get("ABO_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
66
  "button_label": os.environ.get("ABO_CANDIDATE_BUTTON_LABEL", "Hyper3-CLIP query"),
67
- "provider": os.environ.get("ABO_CANDIDATE_PROVIDER", "hyper3-clip"),
68
- "model": os.environ.get("ABO_CANDIDATE_MODEL", "hyper3labs/hyper3-clip-v0.5"),
69
  "layout": os.environ.get("ABO_CANDIDATE_LAYOUT", "poincare:2d"),
70
  "geometry": os.environ.get("ABO_CANDIDATE_GEOMETRY", "poincare"),
71
  "layout_dimension": int(os.environ.get("ABO_CANDIDATE_LAYOUT_DIMENSION", "2")),
@@ -341,17 +341,6 @@ def supported_kwargs(func: Any, kwargs: dict[str, Any]) -> dict[str, Any]:
341
  return {key: value for key, value in kwargs.items() if key in params}
342
 
343
 
344
- def register_hyper3_clip_provider() -> None:
345
- from hyperview.runtime import ProviderRegistry
346
-
347
- ProviderRegistry().register_python(
348
- "hyper3-clip",
349
- "hyper3_clip_provider:Hyper3ClipEmbeddings",
350
- description="Hyper3-CLIP v0.5 image embeddings from hyper3labs/hyper3-clip-v0.5",
351
- overwrite=True,
352
- )
353
-
354
-
355
  def api_base_url() -> str:
356
  host = "127.0.0.1" if SPACE_HOST == "0.0.0.0" else SPACE_HOST
357
  return f"http://{host}:{SPACE_PORT}"
@@ -500,7 +489,6 @@ def launch_demo(dataset: hv.Dataset, layouts: dict[str, str]) -> hv.Session:
500
 
501
 
502
  def main() -> None:
503
- register_hyper3_clip_provider()
504
  dataset, layouts = build_dataset()
505
  print("Layouts:", flush=True)
506
  for spec in MODEL_SPECS:
 
64
  "key": "candidate",
65
  "display_name": os.environ.get("ABO_CANDIDATE_DISPLAY_NAME", "Hyper3-CLIP"),
66
  "button_label": os.environ.get("ABO_CANDIDATE_BUTTON_LABEL", "Hyper3-CLIP query"),
67
+ "provider": os.environ.get("ABO_CANDIDATE_PROVIDER", "hyper-models"),
68
+ "model": os.environ.get("ABO_CANDIDATE_MODEL", "hyper3-clip-v0.5"),
69
  "layout": os.environ.get("ABO_CANDIDATE_LAYOUT", "poincare:2d"),
70
  "geometry": os.environ.get("ABO_CANDIDATE_GEOMETRY", "poincare"),
71
  "layout_dimension": int(os.environ.get("ABO_CANDIDATE_LAYOUT_DIMENSION", "2")),
 
341
  return {key: value for key, value in kwargs.items() if key in params}
342
 
343
 
 
 
 
 
 
 
 
 
 
 
 
344
  def api_base_url() -> str:
345
  host = "127.0.0.1" if SPACE_HOST == "0.0.0.0" else SPACE_HOST
346
  return f"http://{host}:{SPACE_PORT}"
 
489
 
490
 
491
  def main() -> None:
 
492
  dataset, layouts = build_dataset()
493
  print("Layouts:", flush=True)
494
  for spec in MODEL_SPECS:
hyper3_clip/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
2
-
3
- __all__ = ["Hyper3CLIP"]
 
 
 
 
hyper3_clip/models/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from hyper3_clip.models.hyper3_clip import Hyper3CLIP
2
-
3
- __all__ = ["Hyper3CLIP"]
 
 
 
 
hyper3_clip/models/encoders.py DELETED
@@ -1,173 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import timm
4
- import torch
5
- from torch import nn
6
- from transformers import (
7
- AutoConfig,
8
- AutoModel,
9
- AutoTokenizer,
10
- CLIPTextConfig,
11
- CLIPTextModel,
12
- CLIPTextModelWithProjection,
13
- CLIPVisionConfig,
14
- CLIPVisionModel,
15
- CLIPVisionModelWithProjection,
16
- SiglipTextConfig,
17
- SiglipTextModel,
18
- SiglipVisionConfig,
19
- SiglipVisionModel,
20
- )
21
-
22
-
23
- class VisionEncoder(nn.Module):
24
- def __init__(self, backbone_name: str, pretrained: bool = True) -> None:
25
- super().__init__()
26
- self.kind = "timm"
27
- if backbone_name.startswith("hf_clip_projected:"):
28
- self.kind = "hf_clip_projected"
29
- model_name = backbone_name.removeprefix("hf_clip_projected:")
30
- self.backbone = (
31
- CLIPVisionModelWithProjection.from_pretrained(model_name)
32
- if pretrained
33
- else CLIPVisionModelWithProjection(CLIPVisionConfig.from_pretrained(model_name))
34
- )
35
- self.output_dim = self.backbone.config.projection_dim
36
- elif backbone_name.startswith("hf_clip:"):
37
- self.kind = "hf_vision"
38
- model_name = backbone_name.removeprefix("hf_clip:")
39
- self.backbone = (
40
- CLIPVisionModel.from_pretrained(model_name)
41
- if pretrained
42
- else CLIPVisionModel(CLIPVisionConfig.from_pretrained(model_name))
43
- )
44
- self.output_dim = self.backbone.config.hidden_size
45
- elif backbone_name.startswith("hf_siglip:"):
46
- self.kind = "hf_vision"
47
- model_name = backbone_name.removeprefix("hf_siglip:")
48
- self.backbone = (
49
- SiglipVisionModel.from_pretrained(model_name)
50
- if pretrained
51
- else SiglipVisionModel(SiglipVisionConfig.from_pretrained(model_name))
52
- )
53
- self.output_dim = self.backbone.config.hidden_size
54
- else:
55
- self.backbone = timm.create_model(
56
- backbone_name,
57
- pretrained=pretrained,
58
- num_classes=0,
59
- global_pool="avg",
60
- )
61
- self.output_dim = self.backbone.num_features
62
-
63
- def forward(self, image: torch.Tensor) -> torch.Tensor:
64
- if self.kind == "hf_clip_projected":
65
- return self.backbone(pixel_values=image).image_embeds
66
- if self.kind == "hf_vision":
67
- out = self.backbone(pixel_values=image)
68
- if hasattr(out, "pooler_output") and out.pooler_output is not None:
69
- return out.pooler_output
70
- return out.last_hidden_state[:, 0]
71
- return self.backbone(image)
72
-
73
- def forward_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
74
- if self.kind == "hf_clip_projected":
75
- out = self.backbone(pixel_values=image)
76
- tokens = getattr(out, "last_hidden_state", None)
77
- if tokens is None and hasattr(out, "vision_model_output"):
78
- tokens = out.vision_model_output.last_hidden_state
79
- if tokens is None:
80
- raise RuntimeError("Projected CLIP vision output did not include patch tokens")
81
- return out.image_embeds, tokens
82
- if self.kind == "hf_vision":
83
- out = self.backbone(pixel_values=image)
84
- if hasattr(out, "pooler_output") and out.pooler_output is not None:
85
- pooled = out.pooler_output
86
- else:
87
- pooled = out.last_hidden_state[:, 0]
88
- return pooled, out.last_hidden_state
89
-
90
- if not hasattr(self.backbone, "forward_features"):
91
- pooled = self.backbone(image)
92
- return pooled, pooled[:, None, :]
93
- features = self.backbone.forward_features(image)
94
- if hasattr(self.backbone, "forward_head"):
95
- pooled = self.backbone.forward_head(features, pre_logits=False)
96
- else:
97
- pooled = self.backbone(image)
98
- return pooled, _tokens_from_features(features)
99
-
100
-
101
- class TextEncoder(nn.Module):
102
- def __init__(self, model_name: str, pretrained: bool = True, pooling: str = "auto") -> None:
103
- super().__init__()
104
- if pooling not in {"auto", "pooler", "cls", "mean"}:
105
- raise ValueError(f"Unsupported text pooling {pooling!r}; expected auto, pooler, cls, or mean")
106
- self.kind = "hf_text"
107
- self.pooling = pooling
108
- tokenizer_name = model_name.removeprefix("hf_clip_projected:")
109
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
110
- model_name_lower = model_name.lower()
111
- if model_name.startswith("hf_clip_projected:"):
112
- self.kind = "hf_clip_projected"
113
- projected_model_name = model_name.removeprefix("hf_clip_projected:")
114
- if pretrained:
115
- self.backbone = CLIPTextModelWithProjection.from_pretrained(projected_model_name)
116
- else:
117
- self.backbone = CLIPTextModelWithProjection(CLIPTextConfig.from_pretrained(projected_model_name))
118
- self.output_dim = self.backbone.config.projection_dim
119
- elif "siglip" in model_name_lower:
120
- if pretrained:
121
- self.backbone = SiglipTextModel.from_pretrained(model_name)
122
- else:
123
- self.backbone = SiglipTextModel(SiglipTextConfig.from_pretrained(model_name))
124
- self.output_dim = self.backbone.config.hidden_size
125
- elif "clip" in model_name_lower:
126
- if pretrained:
127
- self.backbone = CLIPTextModel.from_pretrained(model_name)
128
- else:
129
- self.backbone = CLIPTextModel(CLIPTextConfig.from_pretrained(model_name))
130
- self.output_dim = self.backbone.config.hidden_size
131
- else:
132
- if pretrained:
133
- self.backbone = AutoModel.from_pretrained(model_name)
134
- else:
135
- self.backbone = AutoModel.from_config(AutoConfig.from_pretrained(model_name))
136
- hidden_size = getattr(self.backbone.config, "hidden_size", None)
137
- if hidden_size is None:
138
- raise ValueError(f"Unsupported text model config for {model_name}")
139
- self.output_dim = hidden_size
140
-
141
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
142
- out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
143
- if self.kind == "hf_clip_projected":
144
- return out.text_embeds
145
- if self.pooling == "mean":
146
- mask = attention_mask.to(dtype=out.last_hidden_state.dtype).unsqueeze(-1)
147
- summed = (out.last_hidden_state * mask).sum(dim=1)
148
- denom = mask.sum(dim=1).clamp_min(1.0)
149
- return summed / denom
150
- if self.pooling in {"auto", "pooler"} and hasattr(out, "pooler_output") and out.pooler_output is not None:
151
- return out.pooler_output
152
- return out.last_hidden_state[:, 0]
153
-
154
-
155
- def _tokens_from_features(features: torch.Tensor | dict | tuple | list) -> torch.Tensor:
156
- if isinstance(features, dict):
157
- for key in ("x", "last_hidden_state", "features"):
158
- if key in features:
159
- features = features[key]
160
- break
161
- else:
162
- features = next(iter(features.values()))
163
- if isinstance(features, tuple | list):
164
- features = features[0]
165
- if not torch.is_tensor(features):
166
- raise TypeError(f"Expected tensor features, got {type(features)!r}")
167
- if features.ndim == 4:
168
- return features.flatten(2).transpose(1, 2)
169
- if features.ndim == 3:
170
- return features
171
- if features.ndim == 2:
172
- return features[:, None, :]
173
- raise ValueError(f"Unsupported feature tensor shape {tuple(features.shape)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/experimental.py DELETED
@@ -1,587 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections.abc import Callable
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import Tensor, nn
8
-
9
- from hyper3_clip.models.lorentz import exp_map0, metric_pairwise_dist
10
- from hyper3_clip.models.losses import beta_cal_loss
11
- from hyper3_clip.models.tren import TRENRegionEncoder
12
- from hyper3_clip.training.distributed import gather_variable_with_grad, gather_with_grad, get_rank
13
-
14
-
15
- ProjectionHeadFactory = Callable[[int, int, int | None], nn.Module]
16
-
17
-
18
- class ExperimentalObjectiveMixin:
19
- @staticmethod
20
- def _validate_experimental_options(
21
- *,
22
- proclip_geometry: str,
23
- proclip_projection_hidden_dim: int | None,
24
- proclip_component_dim: int | None,
25
- beta_clip_weight: float,
26
- beta_clip_global_weight: float,
27
- beta_clip_beta: float,
28
- beta_clip_variant: str,
29
- beta_clip_similarity: str,
30
- beta_clip_num_heads: int,
31
- beta_clip_mlp_ratio: float,
32
- tren_weight: float,
33
- tren_visual_distill_weight: float,
34
- tren_text_distill_weight: float,
35
- tren_region_text_weight: float,
36
- tren_num_region_tokens: int,
37
- tren_num_decoder_layers: int,
38
- tren_num_attention_heads: int,
39
- tren_prompt_grid_size: int,
40
- tren_dropout: float,
41
- ) -> None:
42
- if proclip_geometry not in {"product", "hyperbolic", "euclidean", "spherical", "clip"}:
43
- raise ValueError("proclip_geometry must be 'product', 'hyperbolic', 'euclidean', 'spherical', or 'clip'")
44
- if proclip_projection_hidden_dim is not None and proclip_projection_hidden_dim <= 0:
45
- raise ValueError("proclip_projection_hidden_dim must be positive when set")
46
- if proclip_component_dim is not None and proclip_component_dim <= 0:
47
- raise ValueError("proclip_component_dim must be positive when set")
48
- if beta_clip_variant not in {"ce", "bce"}:
49
- raise ValueError("beta_clip_variant must be 'ce' or 'bce'")
50
- if beta_clip_similarity not in {"metric", "dot"}:
51
- raise ValueError("beta_clip_similarity must be 'metric' or 'dot'")
52
- if beta_clip_weight < 0.0:
53
- raise ValueError("beta_clip_weight must be non-negative")
54
- if beta_clip_global_weight < 0.0:
55
- raise ValueError("beta_clip_global_weight must be non-negative")
56
- if beta_clip_beta < 0.0:
57
- raise ValueError("beta_clip_beta must be non-negative")
58
- if beta_clip_num_heads <= 0:
59
- raise ValueError("beta_clip_num_heads must be positive")
60
- if beta_clip_mlp_ratio <= 0.0:
61
- raise ValueError("beta_clip_mlp_ratio must be positive")
62
- if tren_weight < 0.0:
63
- raise ValueError("tren_weight must be non-negative")
64
- if tren_visual_distill_weight < 0.0 or tren_text_distill_weight < 0.0 or tren_region_text_weight < 0.0:
65
- raise ValueError("T-REN loss weights must be non-negative")
66
- if tren_num_region_tokens <= 0:
67
- raise ValueError("tren_num_region_tokens must be positive")
68
- if tren_num_decoder_layers <= 0:
69
- raise ValueError("tren_num_decoder_layers must be positive")
70
- if tren_num_attention_heads <= 0:
71
- raise ValueError("tren_num_attention_heads must be positive")
72
- if tren_prompt_grid_size <= 0:
73
- raise ValueError("tren_prompt_grid_size must be positive")
74
- if tren_dropout < 0.0:
75
- raise ValueError("tren_dropout must be non-negative")
76
-
77
- def _init_experimental_modules(
78
- self,
79
- *,
80
- beta_clip_num_heads: int,
81
- beta_clip_mlp_ratio: float,
82
- tren_num_region_tokens: int,
83
- tren_num_decoder_layers: int,
84
- tren_num_attention_heads: int,
85
- tren_prompt_grid_size: int,
86
- tren_dropout: float,
87
- projection_hidden_dim: int | None,
88
- proclip_projection_hidden_dim: int | None,
89
- projection_head: ProjectionHeadFactory,
90
- ) -> None:
91
- if self.beta_query_pooling_enabled:
92
- if self.vision_encoder.output_dim % beta_clip_num_heads != 0:
93
- raise ValueError("vision encoder output_dim must be divisible by beta_clip_num_heads")
94
- beta_clip_hidden_dim = max(1, int(round(self.vision_encoder.output_dim * beta_clip_mlp_ratio)))
95
- self.beta_clip_text_query_proj = nn.Linear(self.text_encoder.output_dim, self.vision_encoder.output_dim)
96
- self.beta_clip_cross_attention = nn.MultiheadAttention(
97
- self.vision_encoder.output_dim,
98
- beta_clip_num_heads,
99
- batch_first=True,
100
- )
101
- self.beta_clip_mlp_norm = nn.LayerNorm(self.vision_encoder.output_dim)
102
- self.beta_clip_pool_mlp = nn.Sequential(
103
- nn.Linear(self.vision_encoder.output_dim, beta_clip_hidden_dim),
104
- nn.GELU(),
105
- nn.Linear(beta_clip_hidden_dim, self.vision_encoder.output_dim),
106
- )
107
- if self.beta_clip_enabled:
108
- self.beta_clip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
109
- if self.tren_enabled:
110
- self.tren_region_encoder = TRENRegionEncoder(
111
- vision_dim=self.vision_encoder.output_dim,
112
- text_dim=self.text_encoder.output_dim,
113
- num_region_tokens=tren_num_region_tokens,
114
- num_decoder_layers=tren_num_decoder_layers,
115
- num_attention_heads=tren_num_attention_heads,
116
- prompt_grid_size=tren_prompt_grid_size,
117
- dropout=tren_dropout,
118
- )
119
- self.tren_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
120
- if self.proclip_enabled:
121
- component_dim = self._proclip_component_dim
122
- spherical_dim = self._proclip_spherical_ambient_dim
123
- proclip_hidden_dim = proclip_projection_hidden_dim
124
- if proclip_hidden_dim is None:
125
- proclip_hidden_dim = projection_hidden_dim
126
- if self.proclip_dedicated_hyperbolic:
127
- self.proclip_image_hyperbolic_proj = projection_head(
128
- self.vision_encoder.output_dim, self.embed_dim, proclip_hidden_dim
129
- )
130
- self.proclip_text_hyperbolic_proj = projection_head(
131
- self.text_encoder.output_dim, self.embed_dim, proclip_hidden_dim
132
- )
133
- self.proclip_image_euclidean_proj = projection_head(
134
- self.vision_encoder.output_dim, component_dim, proclip_hidden_dim
135
- )
136
- self.proclip_text_euclidean_proj = projection_head(
137
- self.text_encoder.output_dim, component_dim, proclip_hidden_dim
138
- )
139
- self.proclip_image_spherical_proj = projection_head(
140
- self.vision_encoder.output_dim, spherical_dim, proclip_hidden_dim
141
- )
142
- self.proclip_text_spherical_proj = projection_head(
143
- self.text_encoder.output_dim, spherical_dim, proclip_hidden_dim
144
- )
145
- self.proclip_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
146
- self.proclip_log_weights = nn.Parameter(torch.zeros(3))
147
-
148
- @property
149
- def proclip_enabled(self) -> bool:
150
- return (
151
- self.objective_name == "proclip"
152
- or self.proclip_component_dim is not None
153
- or self.proclip_weight > 0.0
154
- or self.proclip_retrieval
155
- )
156
-
157
- @property
158
- def beta_clip_enabled(self) -> bool:
159
- return self.beta_clip_weight > 0.0
160
-
161
- @property
162
- def beta_query_pooling_enabled(self) -> bool:
163
- return self.beta_clip_enabled or (
164
- self.objective_name == "uncha"
165
- and self.uncha_entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}
166
- )
167
-
168
- @property
169
- def tren_enabled(self) -> bool:
170
- return self.tren_weight > 0.0
171
-
172
- @property
173
- def _proclip_component_dim(self) -> int:
174
- return int(self.proclip_component_dim or self.embed_dim)
175
-
176
- @property
177
- def _proclip_spherical_ambient_dim(self) -> int:
178
- return self._proclip_component_dim + 1
179
-
180
- def _clamp_experimental_logit_scales(self) -> None:
181
- if self.proclip_enabled:
182
- self.proclip_logit_scale.clamp_(max=4.6052)
183
- if self.beta_clip_enabled:
184
- self.beta_clip_logit_scale.clamp_(max=4.6052)
185
- if self.tren_enabled:
186
- self.tren_logit_scale.clamp_(max=4.6052)
187
-
188
- def _detached_experimental_logit_scales(self) -> dict[str, torch.Tensor]:
189
- logs = {}
190
- if self.proclip_enabled:
191
- logs.update(self._detached_proclip_logs())
192
- if self.beta_clip_enabled:
193
- logs["beta_clip_logit_scale"] = self.beta_clip_logit_scale.exp().detach()
194
- if self.tren_enabled:
195
- logs["tren_logit_scale"] = self.tren_logit_scale.exp().detach()
196
- return logs
197
-
198
- def _beta_clip_global_contrastive_loss(
199
- self,
200
- *,
201
- image_euc: torch.Tensor,
202
- text_euc: torch.Tensor,
203
- targets: torch.Tensor,
204
- ) -> torch.Tensor:
205
- image_feats = F.normalize(image_euc.float(), dim=-1)
206
- text_feats = F.normalize(text_euc.float(), dim=-1)
207
- all_image_feats = gather_with_grad(image_feats)
208
- all_text_feats = gather_with_grad(text_feats)
209
- if self.objective_name == "hycoclip":
210
- scale = self.logit_scale.exp().clamp(max=100.0)
211
- elif self.objective_name == "proclip":
212
- scale = self.proclip_logit_scale.exp().clamp(max=100.0)
213
- else:
214
- scale = self.global_logit_scale.exp().clamp(max=100.0)
215
- logits_i_t = image_feats @ all_text_feats.T * scale
216
- logits_t_i = text_feats @ all_image_feats.T * scale
217
- return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
218
-
219
- def _beta_query_entailment_embeddings(
220
- self,
221
- *,
222
- image_tokens: torch.Tensor,
223
- beta_query_input_ids: torch.Tensor | None,
224
- beta_query_attention_mask: torch.Tensor | None,
225
- beta_query_owner: torch.Tensor | None,
226
- beta_query_parent: torch.Tensor | None,
227
- beta_query_weight: torch.Tensor | None,
228
- beta_query_source_part: torch.Tensor | None,
229
- kappa: torch.Tensor,
230
- query_base: torch.Tensor | None = None,
231
- ) -> dict[str, torch.Tensor]:
232
- if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
233
- raise ValueError(f"{self.uncha_entailment_loss} requires beta query tensors from the collator")
234
- if beta_query_parent is None or beta_query_weight is None:
235
- raise ValueError(f"{self.uncha_entailment_loss} requires beta query hierarchy metadata from the collator")
236
- if self.uncha_entailment_loss == "hier_beta_sourcepart_argent" and beta_query_source_part is None:
237
- raise ValueError("hier_beta_sourcepart_argent requires beta_query_source_part from the collator")
238
- if beta_query_input_ids.shape[0] == 0:
239
- source_part = (
240
- beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)
241
- if beta_query_source_part is not None
242
- else beta_query_owner.new_zeros((0,), device=image_tokens.device, dtype=torch.long)
243
- )
244
- return {
245
- "beta_query_image_feats": image_tokens.new_zeros((0, self.embed_dim)),
246
- "beta_query_text_feats": image_tokens.new_zeros((0, self.embed_dim)),
247
- "beta_query_owner": beta_query_owner.to(device=image_tokens.device, dtype=torch.long),
248
- "beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
249
- "beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
250
- "beta_query_source_part": source_part,
251
- }
252
-
253
- query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
254
- if query_base is None:
255
- query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
256
- conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, query_owner)
257
- query_image_euc = self.image_proj(conditioned_image_base)
258
- query_text_euc = self.text_proj(query_base)
259
- return {
260
- "beta_query_image_feats": self.project_image_features(query_image_euc),
261
- "beta_query_text_feats": self.project_text_features(query_text_euc),
262
- "beta_query_owner": query_owner,
263
- "beta_query_parent": beta_query_parent.to(device=image_tokens.device, dtype=torch.long),
264
- "beta_query_weight": beta_query_weight.to(device=image_tokens.device, dtype=torch.float32),
265
- **(
266
- {"beta_query_source_part": beta_query_source_part.to(device=image_tokens.device, dtype=torch.long)}
267
- if beta_query_source_part is not None
268
- else {}
269
- ),
270
- }
271
-
272
- def _beta_clip_auxiliary_loss(
273
- self,
274
- *,
275
- image_tokens: torch.Tensor,
276
- beta_query_input_ids: torch.Tensor | None,
277
- beta_query_attention_mask: torch.Tensor | None,
278
- beta_query_owner: torch.Tensor | None,
279
- global_targets: torch.Tensor,
280
- kappa: torch.Tensor,
281
- ) -> torch.Tensor:
282
- if beta_query_input_ids is None or beta_query_attention_mask is None or beta_query_owner is None:
283
- raise ValueError("beta-CLIP auxiliary requires beta query tensors from the collator")
284
- if beta_query_input_ids.shape[0] == 0:
285
- return image_tokens.new_zeros(())
286
-
287
- beta_query_owner = beta_query_owner.to(device=image_tokens.device, dtype=torch.long)
288
- query_base = self.encode_text_base(beta_query_input_ids, beta_query_attention_mask)
289
- conditioned_image_base = self._beta_clip_text_conditioned_pool(image_tokens, query_base, beta_query_owner)
290
- query_image_euc = self.image_proj(conditioned_image_base)
291
- query_text_euc = self.text_proj(query_base)
292
-
293
- if self.beta_clip_similarity == "dot":
294
- query_image_feats = F.normalize(query_image_euc.float(), dim=-1)
295
- query_text_feats = F.normalize(query_text_euc.float(), dim=-1)
296
- else:
297
- query_image_feats = self.project_image_features(query_image_euc)
298
- query_text_feats = self.project_text_features(query_text_euc)
299
-
300
- all_query_image_feats, query_counts = gather_variable_with_grad(query_image_feats)
301
- all_query_text_feats, _ = gather_variable_with_grad(query_text_feats)
302
- query_offset = query_counts[: get_rank()].sum() if query_counts.numel() > 1 else query_counts.new_zeros(())
303
- query_targets = torch.arange(query_image_feats.size(0), device=query_image_feats.device) + query_offset
304
- query_group_ids = global_targets.index_select(0, beta_query_owner)
305
- all_query_group_ids, _ = gather_variable_with_grad(query_group_ids)
306
-
307
- scale = self.beta_clip_logit_scale.exp().clamp(max=100.0)
308
- if self.beta_clip_similarity == "dot":
309
- logits_i_t = query_image_feats @ all_query_text_feats.T * scale
310
- logits_t_i = query_text_feats @ all_query_image_feats.T * scale
311
- else:
312
- logits_i_t = -metric_pairwise_dist(
313
- query_image_feats,
314
- all_query_text_feats,
315
- kappa,
316
- product_metric=self.phyclip_product_metric,
317
- ) * scale
318
- logits_t_i = -metric_pairwise_dist(
319
- query_text_feats,
320
- all_query_image_feats,
321
- kappa,
322
- product_metric=self.phyclip_product_metric,
323
- ) * scale
324
- return 0.5 * (
325
- beta_cal_loss(
326
- logits_i_t,
327
- targets=query_targets,
328
- group_ids=query_group_ids,
329
- all_group_ids=all_query_group_ids,
330
- beta=self.beta_clip_beta,
331
- variant=self.beta_clip_variant,
332
- )
333
- + beta_cal_loss(
334
- logits_t_i,
335
- targets=query_targets,
336
- group_ids=query_group_ids,
337
- all_group_ids=all_query_group_ids,
338
- beta=self.beta_clip_beta,
339
- variant=self.beta_clip_variant,
340
- )
341
- )
342
-
343
- def _beta_clip_text_conditioned_pool(
344
- self,
345
- image_tokens: torch.Tensor,
346
- query_base: torch.Tensor,
347
- query_owner: torch.Tensor,
348
- ) -> torch.Tensor:
349
- if image_tokens.ndim != 3:
350
- raise ValueError("beta-CLIP image tokens must have shape [batch, tokens, dim]")
351
- if getattr(self, "group_beta_query_pooling", False):
352
- return self._beta_clip_text_conditioned_pool_grouped(image_tokens, query_base, query_owner)
353
- if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1:
354
- image_tokens = image_tokens[:, 1:, :]
355
- selected_tokens = image_tokens.index_select(0, query_owner).to(dtype=query_base.dtype)
356
- query = self.beta_clip_text_query_proj(query_base).unsqueeze(1)
357
- attended, _ = self.beta_clip_cross_attention(query, selected_tokens, selected_tokens, need_weights=False)
358
- pooled = attended.squeeze(1)
359
- return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
360
-
361
- def _beta_clip_text_conditioned_pool_grouped(
362
- self,
363
- image_tokens: torch.Tensor,
364
- query_base: torch.Tensor,
365
- query_owner: torch.Tensor,
366
- ) -> torch.Tensor:
367
- if query_owner.numel() == 0:
368
- return query_base.new_zeros((0, self.vision_encoder.output_dim))
369
- if query_owner.min().item() < 0 or query_owner.max().item() >= image_tokens.size(0):
370
- raise IndexError("beta_query_owner contains an out-of-range image index")
371
-
372
- tokens = image_tokens[:, 1:, :] if self.beta_clip_drop_cls_token and image_tokens.size(1) > 1 else image_tokens
373
- tokens = tokens.to(dtype=query_base.dtype)
374
- query_projected = self.beta_clip_text_query_proj(query_base)
375
- counts = torch.bincount(query_owner, minlength=image_tokens.size(0))
376
- max_queries = int(counts.max().item())
377
-
378
- order = torch.argsort(query_owner)
379
- sorted_owner = query_owner.index_select(0, order)
380
- owner_offsets = torch.zeros_like(counts)
381
- owner_offsets[1:] = counts.cumsum(0)[:-1]
382
- sorted_positions = torch.arange(query_owner.numel(), device=query_owner.device) - owner_offsets.index_select(
383
- 0, sorted_owner
384
- )
385
- positions = torch.empty_like(sorted_positions)
386
- positions[order] = sorted_positions
387
-
388
- packed_query = query_projected.new_zeros((image_tokens.size(0), max_queries, query_projected.size(-1)))
389
- packed_query[query_owner, positions] = query_projected
390
- attended, _ = self.beta_clip_cross_attention(packed_query, tokens, tokens, need_weights=False)
391
- pooled = attended[query_owner, positions]
392
- return pooled + self.beta_clip_pool_mlp(self.beta_clip_mlp_norm(pooled))
393
-
394
- def _tren_auxiliary_losses(
395
- self,
396
- *,
397
- image_tokens: torch.Tensor,
398
- part_owner: torch.Tensor,
399
- part_image_base: torch.Tensor,
400
- part_text_base: torch.Tensor,
401
- ) -> dict[str, torch.Tensor]:
402
- zero = image_tokens.new_zeros(())
403
- if part_owner.numel() == 0:
404
- return {
405
- "tren_loss": zero,
406
- "tren_visual_distill_loss": zero,
407
- "tren_text_distill_loss": zero,
408
- "tren_region_text_contrastive_loss": zero,
409
- "tren_assignment_count": part_owner.new_tensor(0),
410
- }
411
-
412
- tren_outputs = self.tren_region_encoder(image_tokens)
413
- visual_tokens = tren_outputs["visual_tokens"].flatten(1, 2)
414
- text_tokens = tren_outputs["text_aligned_tokens"].flatten(1, 2)
415
-
416
- matched_visual: list[torch.Tensor] = []
417
- matched_text: list[torch.Tensor] = []
418
- target_visual: list[torch.Tensor] = []
419
- target_text: list[torch.Tensor] = []
420
- for owner in range(image_tokens.size(0)):
421
- region_mask = part_owner == owner
422
- if not bool(region_mask.any()):
423
- continue
424
- owner_target_visual = part_image_base[region_mask].detach()
425
- owner_target_text = part_text_base[region_mask].detach()
426
- owner_visual_tokens = visual_tokens[owner]
427
- owner_text_tokens = text_tokens[owner]
428
- pred_indices, target_indices = _greedy_region_assignment(owner_visual_tokens, owner_target_visual)
429
- if pred_indices.numel() == 0:
430
- continue
431
- matched_visual.append(owner_visual_tokens.index_select(0, pred_indices))
432
- matched_text.append(owner_text_tokens.index_select(0, pred_indices))
433
- target_visual.append(owner_target_visual.index_select(0, target_indices))
434
- target_text.append(owner_target_text.index_select(0, target_indices))
435
-
436
- if not matched_visual:
437
- return {
438
- "tren_loss": zero,
439
- "tren_visual_distill_loss": zero,
440
- "tren_text_distill_loss": zero,
441
- "tren_region_text_contrastive_loss": zero,
442
- "tren_assignment_count": part_owner.new_tensor(0),
443
- }
444
-
445
- matched_visual_tensor = torch.cat(matched_visual, dim=0)
446
- matched_text_tensor = torch.cat(matched_text, dim=0)
447
- target_visual_tensor = torch.cat(target_visual, dim=0)
448
- target_text_tensor = torch.cat(target_text, dim=0)
449
- visual_distill = 1.0 - F.cosine_similarity(matched_visual_tensor, target_visual_tensor, dim=-1).mean()
450
- text_distill = 1.0 - F.cosine_similarity(matched_text_tensor, target_text_tensor, dim=-1).mean()
451
- region_text = _symmetric_dot_contrastive(
452
- matched_text_tensor,
453
- target_text_tensor,
454
- scale=self.tren_logit_scale.exp().clamp(max=100.0),
455
- )
456
- total = (
457
- self.tren_visual_distill_weight * visual_distill
458
- + self.tren_text_distill_weight * text_distill
459
- + self.tren_region_text_weight * region_text
460
- )
461
- return {
462
- "tren_loss": total,
463
- "tren_visual_distill_loss": visual_distill,
464
- "tren_text_distill_loss": text_distill,
465
- "tren_region_text_contrastive_loss": region_text,
466
- "tren_assignment_count": part_owner.new_tensor(matched_visual_tensor.size(0)),
467
- }
468
-
469
- def _project_proclip_image_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
470
- if self.proclip_geometry == "clip":
471
- return F.normalize(base_feats.float(), dim=-1)
472
- if self.proclip_dedicated_hyperbolic:
473
- hyperbolic = exp_map0(self.proclip_image_hyperbolic_proj(base_feats.float()), self._kappa().float())
474
- return self._pack_proclip_features(
475
- hyperbolic=hyperbolic,
476
- euclidean=self.proclip_image_euclidean_proj(base_feats.float()),
477
- spherical=self.proclip_image_spherical_proj(base_feats.float()),
478
- )
479
-
480
- def _project_proclip_text_base(self, base_feats: torch.Tensor, hyperbolic: torch.Tensor) -> torch.Tensor:
481
- if self.proclip_geometry == "clip":
482
- return F.normalize(base_feats.float(), dim=-1)
483
- if self.proclip_dedicated_hyperbolic:
484
- hyperbolic = exp_map0(self.proclip_text_hyperbolic_proj(base_feats.float()), self._kappa().float())
485
- return self._pack_proclip_features(
486
- hyperbolic=hyperbolic,
487
- euclidean=self.proclip_text_euclidean_proj(base_feats.float()),
488
- spherical=self.proclip_text_spherical_proj(base_feats.float()),
489
- )
490
-
491
- def _pack_proclip_features(self, hyperbolic: torch.Tensor, euclidean: torch.Tensor, spherical: torch.Tensor) -> torch.Tensor:
492
- spherical = F.normalize(spherical.float(), dim=-1)
493
- if self.proclip_geometry == "hyperbolic":
494
- return hyperbolic.float()
495
- if self.proclip_geometry == "euclidean":
496
- return euclidean.float()
497
- if self.proclip_geometry == "spherical":
498
- return spherical
499
- return torch.cat([hyperbolic.float(), euclidean.float(), spherical], dim=-1)
500
-
501
- def _split_proclip_features(self, feats: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
502
- hyperbolic_dim = self.embed_dim + 1
503
- component_dim = self._proclip_component_dim
504
- spherical_dim = self._proclip_spherical_ambient_dim
505
- hyperbolic = feats[:, :hyperbolic_dim]
506
- euclidean = feats[:, hyperbolic_dim : hyperbolic_dim + component_dim]
507
- spherical = feats[:, hyperbolic_dim + component_dim : hyperbolic_dim + component_dim + spherical_dim]
508
- return hyperbolic, euclidean, spherical
509
-
510
- def _proclip_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
511
- if self.proclip_geometry == "clip":
512
- return image_feats.float() @ text_feats.float().T
513
- if self.proclip_geometry == "hyperbolic":
514
- return -metric_pairwise_dist(image_feats, text_feats, self._kappa()).square()
515
- if self.proclip_geometry == "euclidean":
516
- return -torch.cdist(image_feats.float(), text_feats.float(), p=2).square()
517
- if self.proclip_geometry == "spherical":
518
- dot = (image_feats.float() @ text_feats.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
519
- return -torch.acos(dot).square()
520
- image_hyp, image_euc, image_sph = self._split_proclip_features(image_feats)
521
- text_hyp, text_euc, text_sph = self._split_proclip_features(text_feats)
522
- weights = self.proclip_log_weights.exp().to(device=image_feats.device, dtype=torch.float32)
523
- hyperbolic_dist2 = metric_pairwise_dist(image_hyp, text_hyp, self._kappa()).square()
524
- euclidean_dist2 = torch.cdist(image_euc.float(), text_euc.float(), p=2).square()
525
- spherical_dot = (image_sph.float() @ text_sph.float().T).clamp(min=-1.0 + 1e-6, max=1.0 - 1e-6)
526
- spherical_dist2 = torch.acos(spherical_dot).square()
527
- return -(weights[0] * hyperbolic_dist2 + weights[1] * euclidean_dist2 + weights[2] * spherical_dist2)
528
-
529
- def _proclip_contrastive_loss(
530
- self,
531
- image_feats: torch.Tensor,
532
- text_feats: torch.Tensor,
533
- all_image_feats: torch.Tensor,
534
- all_text_feats: torch.Tensor,
535
- targets: torch.Tensor,
536
- ) -> torch.Tensor:
537
- scale = self.proclip_logit_scale.exp().clamp(max=100.0)
538
- logits_i_t = self._proclip_similarity_scores(image_feats, all_text_feats) * scale
539
- logits_t_i = self._proclip_similarity_scores(text_feats, all_image_feats) * scale
540
- return 0.5 * (F.cross_entropy(logits_i_t, targets) + F.cross_entropy(logits_t_i, targets))
541
-
542
- def _detached_proclip_logs(self) -> dict[str, torch.Tensor]:
543
- weights = self.proclip_log_weights.exp().detach()
544
- return {
545
- "proclip_logit_scale": self.proclip_logit_scale.exp().detach(),
546
- "proclip_hyperbolic_weight": weights[0],
547
- "proclip_euclidean_weight": weights[1],
548
- "proclip_spherical_weight": weights[2],
549
- }
550
-
551
-
552
- def _greedy_region_assignment(pred_tokens: torch.Tensor, target_tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
553
- if pred_tokens.numel() == 0 or target_tokens.numel() == 0:
554
- empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
555
- return empty, empty
556
- similarities = F.normalize(pred_tokens.float(), dim=-1) @ F.normalize(target_tokens.float(), dim=-1).T
557
- pair_scores = similarities.flatten()
558
- order = torch.argsort(pair_scores, descending=True)
559
- used_pred = torch.zeros(pred_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
560
- used_target = torch.zeros(target_tokens.size(0), dtype=torch.bool, device=pred_tokens.device)
561
- pred_indices: list[torch.Tensor] = []
562
- target_indices: list[torch.Tensor] = []
563
- for flat_index in order:
564
- pred_index = torch.div(flat_index, target_tokens.size(0), rounding_mode="floor")
565
- target_index = flat_index % target_tokens.size(0)
566
- if used_pred[pred_index] or used_target[target_index]:
567
- continue
568
- used_pred[pred_index] = True
569
- used_target[target_index] = True
570
- pred_indices.append(pred_index)
571
- target_indices.append(target_index)
572
- if len(target_indices) == target_tokens.size(0):
573
- break
574
- if not pred_indices:
575
- empty = torch.zeros((0,), dtype=torch.long, device=pred_tokens.device)
576
- return empty, empty
577
- return torch.stack(pred_indices), torch.stack(target_indices)
578
-
579
-
580
- def _symmetric_dot_contrastive(region_tokens: torch.Tensor, text_tokens: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
581
- if region_tokens.size(0) == 1:
582
- return region_tokens.new_zeros(())
583
- region_tokens = F.normalize(region_tokens.float(), dim=-1)
584
- text_tokens = F.normalize(text_tokens.float(), dim=-1)
585
- logits = region_tokens @ text_tokens.T * scale
586
- targets = torch.arange(logits.size(0), device=logits.device)
587
- return 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/himo.py DELETED
@@ -1,55 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- from torch import Tensor
5
-
6
-
7
- def hide_reconstruct_embeddings(
8
- embeddings: Tensor,
9
- *,
10
- variance_threshold: float = 0.9,
11
- detach_pca: bool = True,
12
- eps: float = 1e-8,
13
- ) -> Tensor:
14
- """HiMo-CLIP HiDe: PCA-reconstruct embeddings using top principal components.
15
-
16
- Given a batch of embeddings ``U ∈ R^{B×D}``, compute mean-centered embeddings,
17
- perform SVD/PCA, choose the smallest number of components whose cumulative
18
- explained variance exceeds ``variance_threshold``, and reconstruct each
19
- embedding from this principal subspace:
20
-
21
- u'_i = P^T (P (u_i - ū)) + ū
22
-
23
- where P stacks the selected principal components as rows.
24
- """
25
- if embeddings.ndim != 2:
26
- raise ValueError("hide_reconstruct_embeddings expects a [batch, dim] tensor")
27
- if not (0.0 < variance_threshold <= 1.0):
28
- raise ValueError("variance_threshold must be in (0, 1]")
29
- if embeddings.size(0) < 2:
30
- return embeddings
31
-
32
- u = embeddings.to(dtype=torch.float32)
33
- mean = u.mean(dim=0, keepdim=True)
34
- centered = u - mean
35
- if detach_pca:
36
- centered_for_pca = centered.detach()
37
- else:
38
- centered_for_pca = centered
39
-
40
- # SVD: centered = U S Vh, principal components are rows of Vh.
41
- _, s, vh = torch.linalg.svd(centered_for_pca, full_matrices=False)
42
- if s.numel() == 0 or float((s.square().sum()).item()) <= eps:
43
- return embeddings
44
-
45
- explained = s.square()
46
- cumulative = explained.cumsum(dim=0) / explained.sum().clamp_min(eps)
47
- m = int((cumulative >= variance_threshold).to(dtype=torch.int64).argmax().item()) + 1
48
- m = max(1, min(m, vh.size(0)))
49
- p = vh[:m]
50
- if detach_pca:
51
- p = p.detach()
52
-
53
- recon = (centered @ p.T) @ p + mean
54
- return recon.to(dtype=embeddings.dtype)
55
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/hyper3_clip.py DELETED
@@ -1,958 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
-
7
- from hyper3_clip.models.encoders import TextEncoder, VisionEncoder
8
- from hyper3_clip.models.experimental import ExperimentalObjectiveMixin
9
- from hyper3_clip.models.himo import hide_reconstruct_embeddings
10
- from hyper3_clip.models.lorentz import exp_map0, metric_similarity
11
- from hyper3_clip.models.objectives import build_objective
12
- from hyper3_clip.training.distributed import (
13
- gather_with_grad,
14
- get_rank,
15
- get_world_size,
16
- local_target_indices,
17
- )
18
-
19
-
20
- class Hyper3CLIP(ExperimentalObjectiveMixin, nn.Module):
21
- def __init__(
22
- self,
23
- vision_backbone: str,
24
- text_model_name: str,
25
- embed_dim: int,
26
- curv_init: float,
27
- learn_curv: bool,
28
- entail_weight: float,
29
- inter_aperture_scale: float,
30
- intra_aperture_scale: float,
31
- objective: str = "hycoclip",
32
- uncha_piecewise_factor: float = 0.1,
33
- uncha_calibration_alpha: float = 10.0,
34
- uncha_stop_grad_calibration: bool = True,
35
- vision_pretrained: bool = True,
36
- text_pretrained: bool = True,
37
- text_pooling: str = "auto",
38
- freeze_vision_encoder: bool = False,
39
- freeze_text_encoder: bool = False,
40
- normalize_encoder_features: bool = False,
41
- projection_hidden_dim: int | None = None,
42
- uncha_entailment_geometry: str = "lorentz",
43
- uncha_aggregate_weight: float = 0.0,
44
- uncha_entailment_loss: str = "piecewise",
45
- uncha_argent_beta: float = 1.0,
46
- uncha_argent_norm_weight: float = 0.0,
47
- uncha_argent_aux_weight: float = 0.5,
48
- uncha_argent_aggregation: str = "uncha",
49
- uncha_part_weight_power: float = 0.0,
50
- uncha_contrastive_loss: str = "ce",
51
- uncha_sigmoid_bias_init: float = -10.0,
52
- uncha_sigmoid_negative_weight: float = 1.0,
53
- uncha_part_quality_mode: str = "none",
54
- uncha_part_quality_topk: int = 5,
55
- uncha_part_quality_temperature: float = 4.0,
56
- uncha_entailment_warmup_steps: int = 0,
57
- uncha_contrastive_global_weight: float = 1.0,
58
- uncha_contrastive_local_weight: float = 1.0,
59
- uncha_contrastive_global_local_weight: float = 1.0,
60
- uncha_global_local_mode: str = "repeat",
61
- uncha_global_local_metric: str = "distance",
62
- uncha_global_local_angle_aux_weight: float = 0.0,
63
- uncha_global_local_angle_aux_mode: str = "contrastive",
64
- uncha_global_local_angle_aux_scale: float = 5.5,
65
- uncha_global_local_angle_aux_aperture_scale: float = 1.0,
66
- uncha_beta_cal_beta: float = 0.0,
67
- uncha_beta_cal_variant: str = "ce",
68
- uncha_beta_cal_weight: float = 0.0,
69
- uncha_himo_component_weight: float = 0.0,
70
- uncha_himo_variance_threshold: float = 0.9,
71
- uncha_himo_detach_pca: bool = True,
72
- uncha_radius_order_weight: float = 0.0,
73
- uncha_radius_order_margin: float = 0.0,
74
- uncha_gramian_align_weight: float = 0.0,
75
- phyclip_subspace_dim: int | None = None,
76
- phyclip_product_metric: str = "l1",
77
- proclip_weight: float = 0.0,
78
- proclip_component_dim: int | None = None,
79
- proclip_retrieval: bool = False,
80
- proclip_geometry: str = "product",
81
- proclip_dedicated_hyperbolic: bool = False,
82
- proclip_projection_hidden_dim: int | None = None,
83
- beta_clip_weight: float = 0.0,
84
- beta_clip_global_weight: float = 0.0,
85
- beta_clip_beta: float = 0.5,
86
- beta_clip_variant: str = "ce",
87
- beta_clip_similarity: str = "metric",
88
- beta_clip_num_heads: int = 8,
89
- beta_clip_mlp_ratio: float = 4.0,
90
- beta_clip_drop_cls_token: bool = True,
91
- tren_weight: float = 0.0,
92
- tren_visual_distill_weight: float = 1.0,
93
- tren_text_distill_weight: float = 1.0,
94
- tren_region_text_weight: float = 1.0,
95
- tren_num_region_tokens: int = 3,
96
- tren_num_decoder_layers: int = 2,
97
- tren_num_attention_heads: int = 8,
98
- tren_prompt_grid_size: int = 7,
99
- tren_dropout: float = 0.1,
100
- fuse_whole_part_encoder_forwards: bool = False,
101
- fuse_beta_query_encoder_forwards: bool = False,
102
- group_beta_query_pooling: bool = False,
103
- objective_autocast_dtype: str = "float32",
104
- ) -> None:
105
- super().__init__()
106
- if objective not in {"hycoclip", "uncha", "proclip"}:
107
- raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip', 'uncha', or 'proclip'")
108
- if phyclip_product_metric not in {"l1", "l2"}:
109
- raise ValueError("phyclip_product_metric must be 'l1' or 'l2'")
110
- self._validate_experimental_options(
111
- proclip_geometry=proclip_geometry,
112
- proclip_projection_hidden_dim=proclip_projection_hidden_dim,
113
- proclip_component_dim=proclip_component_dim,
114
- beta_clip_weight=beta_clip_weight,
115
- beta_clip_global_weight=beta_clip_global_weight,
116
- beta_clip_beta=beta_clip_beta,
117
- beta_clip_variant=beta_clip_variant,
118
- beta_clip_similarity=beta_clip_similarity,
119
- beta_clip_num_heads=beta_clip_num_heads,
120
- beta_clip_mlp_ratio=beta_clip_mlp_ratio,
121
- tren_weight=tren_weight,
122
- tren_visual_distill_weight=tren_visual_distill_weight,
123
- tren_text_distill_weight=tren_text_distill_weight,
124
- tren_region_text_weight=tren_region_text_weight,
125
- tren_num_region_tokens=tren_num_region_tokens,
126
- tren_num_decoder_layers=tren_num_decoder_layers,
127
- tren_num_attention_heads=tren_num_attention_heads,
128
- tren_prompt_grid_size=tren_prompt_grid_size,
129
- tren_dropout=tren_dropout,
130
- )
131
- if objective_autocast_dtype not in {"float32", "fp32", "float16", "fp16", "bfloat16", "bf16"}:
132
- raise ValueError("objective_autocast_dtype must be one of 'float32', 'float16', or 'bfloat16'")
133
- if uncha_contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
134
- raise ValueError("uncha_contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
135
- if uncha_global_local_metric not in {"distance", "angle"}:
136
- raise ValueError("uncha_global_local_metric must be 'distance' or 'angle'")
137
- if uncha_global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
138
- raise ValueError("uncha_global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
139
- if uncha_global_local_angle_aux_weight < 0.0:
140
- raise ValueError("uncha_global_local_angle_aux_weight must be non-negative")
141
- if uncha_global_local_angle_aux_scale <= 0.0:
142
- raise ValueError("uncha_global_local_angle_aux_scale must be positive")
143
- if uncha_global_local_angle_aux_aperture_scale <= 0.0:
144
- raise ValueError("uncha_global_local_angle_aux_aperture_scale must be positive")
145
- if uncha_entailment_warmup_steps < 0:
146
- raise ValueError("uncha_entailment_warmup_steps must be non-negative")
147
- self.objective_name = objective
148
- self.uncha_contrastive_loss = uncha_contrastive_loss
149
- self.uncha_entailment_loss = uncha_entailment_loss
150
- self.uncha_entailment_warmup_steps = uncha_entailment_warmup_steps
151
- self.uncha_himo_component_weight = float(uncha_himo_component_weight)
152
- self.uncha_himo_variance_threshold = float(uncha_himo_variance_threshold)
153
- self.uncha_himo_detach_pca = bool(uncha_himo_detach_pca)
154
- self.proclip_weight = float(proclip_weight)
155
- self.proclip_retrieval = bool(proclip_retrieval)
156
- self.proclip_geometry = proclip_geometry
157
- self.proclip_dedicated_hyperbolic = bool(proclip_dedicated_hyperbolic)
158
- self.beta_clip_weight = float(beta_clip_weight)
159
- self.beta_clip_global_weight = float(beta_clip_global_weight)
160
- self.beta_clip_beta = float(beta_clip_beta)
161
- self.beta_clip_variant = beta_clip_variant
162
- self.beta_clip_similarity = beta_clip_similarity
163
- self.beta_clip_drop_cls_token = bool(beta_clip_drop_cls_token)
164
- self.tren_weight = float(tren_weight)
165
- self.tren_visual_distill_weight = float(tren_visual_distill_weight)
166
- self.tren_text_distill_weight = float(tren_text_distill_weight)
167
- self.tren_region_text_weight = float(tren_region_text_weight)
168
- self.fuse_whole_part_encoder_forwards = bool(fuse_whole_part_encoder_forwards)
169
- self.fuse_beta_query_encoder_forwards = bool(fuse_beta_query_encoder_forwards)
170
- self.group_beta_query_pooling = bool(group_beta_query_pooling)
171
- self.objective_autocast_dtype = objective_autocast_dtype
172
- self.freeze_vision_encoder = bool(freeze_vision_encoder)
173
- self.freeze_text_encoder = bool(freeze_text_encoder)
174
- self.normalize_encoder_features = bool(normalize_encoder_features)
175
- self.phyclip_subspace_dim = phyclip_subspace_dim
176
- self.phyclip_product_metric = phyclip_product_metric
177
- self.proclip_component_dim = proclip_component_dim
178
- if projection_hidden_dim is not None and projection_hidden_dim <= 0:
179
- raise ValueError("projection_hidden_dim must be positive when set")
180
- if self.proclip_enabled and phyclip_subspace_dim is not None:
181
- raise ValueError("ProCLIP mixed-curvature proxy cannot be combined with PHyCLIP Lorentz factors")
182
- if phyclip_subspace_dim is not None:
183
- if phyclip_subspace_dim <= 0:
184
- raise ValueError("phyclip_subspace_dim must be positive when set")
185
- if embed_dim % phyclip_subspace_dim != 0:
186
- raise ValueError("embed_dim must be divisible by phyclip_subspace_dim")
187
- self.phyclip_num_factors = embed_dim // phyclip_subspace_dim
188
- else:
189
- self.phyclip_num_factors = 0
190
- self.vision_encoder = VisionEncoder(vision_backbone, pretrained=vision_pretrained)
191
- self.text_encoder = TextEncoder(text_model_name, pretrained=text_pretrained, pooling=text_pooling)
192
- self.tokenizer = self.text_encoder.tokenizer
193
- self.embed_dim = embed_dim
194
- if self.freeze_vision_encoder:
195
- self.vision_encoder.requires_grad_(False)
196
- self.vision_encoder.eval()
197
- if self.freeze_text_encoder:
198
- self.text_encoder.requires_grad_(False)
199
- self.text_encoder.eval()
200
-
201
- self.image_proj = _projection_head(self.vision_encoder.output_dim, embed_dim, projection_hidden_dim)
202
- self.text_proj = _projection_head(self.text_encoder.output_dim, embed_dim, projection_hidden_dim)
203
- self._init_experimental_modules(
204
- beta_clip_num_heads=beta_clip_num_heads,
205
- beta_clip_mlp_ratio=beta_clip_mlp_ratio,
206
- tren_num_region_tokens=tren_num_region_tokens,
207
- tren_num_decoder_layers=tren_num_decoder_layers,
208
- tren_num_attention_heads=tren_num_attention_heads,
209
- tren_prompt_grid_size=tren_prompt_grid_size,
210
- tren_dropout=tren_dropout,
211
- projection_hidden_dim=projection_hidden_dim,
212
- proclip_projection_hidden_dim=proclip_projection_hidden_dim,
213
- projection_head=_projection_head,
214
- )
215
-
216
- if objective == "hycoclip":
217
- self.logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
218
- elif objective == "uncha":
219
- self.global_logit_scale = nn.Parameter(torch.tensor(1 / 0.07).log())
220
- self.local_logit_scale = nn.Parameter(torch.tensor(1 / 0.05).log())
221
- self.global_local_logit_scale = nn.Parameter(torch.tensor(1 / 0.06).log())
222
- if uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
223
- self.global_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
224
- self.local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
225
- self.global_local_logit_bias = nn.Parameter(torch.tensor(float(uncha_sigmoid_bias_init)))
226
- alpha_dim = phyclip_subspace_dim or embed_dim
227
- alpha_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
228
- self.visual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
229
- self.textual_alpha = nn.Parameter(torch.full(alpha_shape, alpha_dim**-0.5).log())
230
-
231
- curv_shape = (self.phyclip_num_factors,) if self.phyclip_enabled else ()
232
- log_curv = torch.full(curv_shape, curv_init).log()
233
- self.log_curv = nn.Parameter(log_curv, requires_grad=learn_curv)
234
- self.curv_min = curv_init / 10.0
235
- self.curv_max = curv_init * 10.0
236
- self.objective = None
237
- if objective != "proclip":
238
- self.objective = build_objective(
239
- objective=objective,
240
- entail_weight=entail_weight,
241
- inter_aperture_scale=inter_aperture_scale,
242
- intra_aperture_scale=intra_aperture_scale,
243
- uncha_piecewise_factor=uncha_piecewise_factor,
244
- uncha_calibration_alpha=uncha_calibration_alpha,
245
- uncha_stop_grad_calibration=uncha_stop_grad_calibration,
246
- uncha_entailment_geometry=uncha_entailment_geometry,
247
- uncha_aggregate_weight=uncha_aggregate_weight,
248
- uncha_entailment_loss=uncha_entailment_loss,
249
- uncha_argent_beta=uncha_argent_beta,
250
- uncha_argent_norm_weight=uncha_argent_norm_weight,
251
- uncha_argent_aux_weight=uncha_argent_aux_weight,
252
- uncha_argent_aggregation=uncha_argent_aggregation,
253
- uncha_part_weight_power=uncha_part_weight_power,
254
- uncha_contrastive_loss=uncha_contrastive_loss,
255
- uncha_sigmoid_negative_weight=uncha_sigmoid_negative_weight,
256
- uncha_part_quality_mode=uncha_part_quality_mode,
257
- uncha_part_quality_topk=uncha_part_quality_topk,
258
- uncha_part_quality_temperature=uncha_part_quality_temperature,
259
- uncha_contrastive_global_weight=uncha_contrastive_global_weight,
260
- uncha_contrastive_local_weight=uncha_contrastive_local_weight,
261
- uncha_contrastive_global_local_weight=uncha_contrastive_global_local_weight,
262
- uncha_global_local_mode=uncha_global_local_mode,
263
- uncha_global_local_metric=uncha_global_local_metric,
264
- uncha_global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
265
- uncha_global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
266
- uncha_global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
267
- uncha_global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
268
- uncha_beta_cal_beta=uncha_beta_cal_beta,
269
- uncha_beta_cal_variant=uncha_beta_cal_variant,
270
- uncha_beta_cal_weight=uncha_beta_cal_weight,
271
- uncha_himo_component_weight=uncha_himo_component_weight,
272
- uncha_radius_order_weight=uncha_radius_order_weight,
273
- uncha_radius_order_margin=uncha_radius_order_margin,
274
- uncha_gramian_align_weight=uncha_gramian_align_weight,
275
- product_metric=phyclip_product_metric,
276
- )
277
-
278
- def train(self, mode: bool = True) -> Hyper3CLIP:
279
- super().train(mode)
280
- if self.freeze_vision_encoder:
281
- self.vision_encoder.eval()
282
- if self.freeze_text_encoder:
283
- self.text_encoder.eval()
284
- return self
285
-
286
- @property
287
- def phyclip_enabled(self) -> bool:
288
- return self.phyclip_subspace_dim is not None
289
-
290
- def _kappa(self) -> torch.Tensor:
291
- return self.log_curv.exp().clamp(min=self.curv_min, max=self.curv_max)
292
-
293
- def encode_image(self, image: torch.Tensor, project: bool = True) -> torch.Tensor:
294
- feats = self.image_proj(self.encode_image_base(image))
295
- if not project:
296
- return feats
297
- return self.project_image_features(feats)
298
-
299
- def encode_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, project: bool = True) -> torch.Tensor:
300
- feats = self.text_proj(self.encode_text_base(input_ids, attention_mask))
301
- if not project:
302
- return feats
303
- return self.project_text_features(feats)
304
-
305
- def encode_image_base(self, image: torch.Tensor) -> torch.Tensor:
306
- with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
307
- feats = self.vision_encoder(image)
308
- feats = feats.detach() if self.freeze_vision_encoder else feats
309
- return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
310
-
311
- def encode_image_base_with_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
312
- with torch.set_grad_enabled(self.training and not self.freeze_vision_encoder):
313
- feats, tokens = self.vision_encoder.forward_with_tokens(image)
314
- if self.freeze_vision_encoder:
315
- feats = feats.detach()
316
- tokens = tokens.detach()
317
- if self.normalize_encoder_features:
318
- feats = F.normalize(feats.float(), dim=-1)
319
- return feats, tokens
320
-
321
- def encode_text_base(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
322
- with torch.set_grad_enabled(self.training and not self.freeze_text_encoder):
323
- feats = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
324
- feats = feats.detach() if self.freeze_text_encoder else feats
325
- return F.normalize(feats.float(), dim=-1) if self.normalize_encoder_features else feats
326
-
327
- def project_image_features(self, feats: torch.Tensor) -> torch.Tensor:
328
- if self.phyclip_enabled:
329
- return self._project_product_features(feats, self.visual_alpha)
330
- return exp_map0(feats.float() * self.visual_alpha.exp().float(), self._kappa().float())
331
-
332
- def project_text_features(self, feats: torch.Tensor) -> torch.Tensor:
333
- if self.phyclip_enabled:
334
- return self._project_product_features(feats, self.textual_alpha)
335
- return exp_map0(feats.float() * self.textual_alpha.exp().float(), self._kappa().float())
336
-
337
- def similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
338
- return metric_similarity(image_feats, text_feats, self._kappa(), product_metric=self.phyclip_product_metric)
339
-
340
- def encode_retrieval_image(self, image: torch.Tensor) -> torch.Tensor:
341
- base = self.encode_image_base(image)
342
- tangent = self.image_proj(base)
343
- if self.proclip_retrieval:
344
- return self._project_proclip_image_base(base, self.project_image_features(tangent))
345
- return self.project_image_features(tangent)
346
-
347
- def encode_retrieval_text(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
348
- base = self.encode_text_base(input_ids, attention_mask)
349
- tangent = self.text_proj(base)
350
- if self.proclip_retrieval:
351
- return self._project_proclip_text_base(base, self.project_text_features(tangent))
352
- return self.project_text_features(tangent)
353
-
354
- def retrieval_similarity_scores(self, image_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
355
- if self.proclip_retrieval:
356
- return self._proclip_similarity_scores(image_feats, text_feats)
357
- return self.similarity_scores(image_feats, text_feats)
358
-
359
- @property
360
- def retrieval_requires_chunking(self) -> bool:
361
- return self.phyclip_enabled or self.proclip_retrieval
362
-
363
- def _objective_autocast(self, device_type: str):
364
- dtype = {
365
- "float32": torch.float32,
366
- "fp32": torch.float32,
367
- "float16": torch.float16,
368
- "fp16": torch.float16,
369
- "bfloat16": torch.bfloat16,
370
- "bf16": torch.bfloat16,
371
- }[self.objective_autocast_dtype]
372
- enabled = device_type != "cpu" and dtype is not torch.float32
373
- return torch.autocast(device_type=device_type, dtype=dtype, enabled=enabled)
374
-
375
- def forward(
376
- self,
377
- image: torch.Tensor,
378
- part_images: torch.Tensor,
379
- text_input_ids: torch.Tensor,
380
- text_attention_mask: torch.Tensor,
381
- part_text_input_ids: torch.Tensor,
382
- part_text_attention_mask: torch.Tensor,
383
- part_owner: torch.Tensor,
384
- step: int | None = None,
385
- beta_query_input_ids: torch.Tensor | None = None,
386
- beta_query_attention_mask: torch.Tensor | None = None,
387
- beta_query_owner: torch.Tensor | None = None,
388
- beta_query_type: torch.Tensor | None = None,
389
- beta_query_parent: torch.Tensor | None = None,
390
- beta_query_weight: torch.Tensor | None = None,
391
- beta_query_source_part: torch.Tensor | None = None,
392
- ) -> dict[str, torch.Tensor]:
393
- with torch.no_grad():
394
- self._clamp_logit_scales()
395
- self.visual_alpha.clamp_(max=0.0)
396
- self.textual_alpha.clamp_(max=0.0)
397
- kappa = self._kappa()
398
-
399
- feature_dim = self.embed_dim
400
- beta_image_tokens = None
401
- beta_query_base = None
402
- part_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
403
- part_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
404
- hier_beta_enabled = self.objective_name == "uncha" and self.uncha_entailment_loss in {
405
- "hier_beta_argent",
406
- "hier_beta_sourcepart_argent",
407
- }
408
- if (
409
- hier_beta_enabled
410
- and self.fuse_beta_query_encoder_forwards
411
- and not self.tren_enabled
412
- and beta_query_input_ids is not None
413
- and beta_query_attention_mask is not None
414
- and part_images.shape[0] > 0
415
- ):
416
- (
417
- image_base,
418
- text_base,
419
- image_euc,
420
- text_euc,
421
- image_feats,
422
- text_feats,
423
- part_image_feats,
424
- part_text_feats,
425
- part_image_euc,
426
- part_text_euc,
427
- part_image_base,
428
- part_text_base,
429
- beta_image_tokens,
430
- beta_query_base,
431
- ) = self._encode_hier_beta_whole_parts_and_queries(
432
- image=image,
433
- part_images=part_images,
434
- text_input_ids=text_input_ids,
435
- text_attention_mask=text_attention_mask,
436
- part_text_input_ids=part_text_input_ids,
437
- part_text_attention_mask=part_text_attention_mask,
438
- beta_query_input_ids=beta_query_input_ids,
439
- beta_query_attention_mask=beta_query_attention_mask,
440
- )
441
- elif self.beta_query_pooling_enabled or self.tren_enabled:
442
- image_base, beta_image_tokens = self.encode_image_base_with_tokens(image)
443
- text_base = self.encode_text_base(text_input_ids, text_attention_mask)
444
- image_euc = self.image_proj(image_base)
445
- text_euc = self.text_proj(text_base)
446
- image_feats = self.project_image_features(image_euc)
447
- text_feats = self.project_text_features(text_euc)
448
- (
449
- part_image_feats,
450
- part_text_feats,
451
- part_image_euc,
452
- part_text_euc,
453
- part_image_base,
454
- part_text_base,
455
- ) = self._encode_parts_with_base(
456
- part_images=part_images,
457
- part_text_input_ids=part_text_input_ids,
458
- part_text_attention_mask=part_text_attention_mask,
459
- feature_dim=feature_dim,
460
- )
461
- elif self.fuse_whole_part_encoder_forwards and self.objective_name != "proclip" and part_images.shape[0] > 0:
462
- (
463
- image_base,
464
- text_base,
465
- image_euc,
466
- text_euc,
467
- image_feats,
468
- text_feats,
469
- part_image_feats,
470
- part_text_feats,
471
- part_image_euc,
472
- part_text_euc,
473
- part_image_base,
474
- part_text_base,
475
- ) = self._encode_whole_and_parts(
476
- image=image,
477
- part_images=part_images,
478
- text_input_ids=text_input_ids,
479
- text_attention_mask=text_attention_mask,
480
- part_text_input_ids=part_text_input_ids,
481
- part_text_attention_mask=part_text_attention_mask,
482
- )
483
- else:
484
- image_base = self.encode_image_base(image)
485
- text_base = self.encode_text_base(text_input_ids, text_attention_mask)
486
- image_euc = self.image_proj(image_base)
487
- text_euc = self.text_proj(text_base)
488
- image_feats = self.project_image_features(image_euc)
489
- text_feats = self.project_text_features(text_euc)
490
- (
491
- part_image_feats,
492
- part_text_feats,
493
- part_image_euc,
494
- part_text_euc,
495
- part_image_base,
496
- part_text_base,
497
- ) = self._encode_parts_with_base(
498
- part_images=part_images,
499
- part_text_input_ids=part_text_input_ids,
500
- part_text_attention_mask=part_text_attention_mask,
501
- feature_dim=feature_dim,
502
- )
503
- targets = local_target_indices(image_feats.size(0), image_feats.device)
504
-
505
- if self.objective_name == "proclip":
506
- proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
507
- proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
508
- proclip_loss = self._proclip_contrastive_loss(
509
- image_feats=proclip_image_feats,
510
- text_feats=proclip_text_feats,
511
- all_image_feats=gather_with_grad(proclip_image_feats),
512
- all_text_feats=gather_with_grad(proclip_text_feats),
513
- targets=targets,
514
- )
515
- zero = proclip_loss.new_zeros(())
516
- return {
517
- "loss": proclip_loss,
518
- "contrastive_loss": proclip_loss,
519
- "entailment_loss": zero,
520
- "part_count": part_owner.new_tensor(0),
521
- "proclip_contrastive_loss": proclip_loss,
522
- **self._detached_kappa_logs(kappa),
523
- **self._detached_logit_scales(),
524
- }
525
-
526
- himo_text_feats = None
527
- all_himo_text_feats = None
528
- if self.objective_name == "uncha" and self.uncha_himo_component_weight > 0.0:
529
- all_text_euc = gather_with_grad(text_euc)
530
- all_component_euc = hide_reconstruct_embeddings(
531
- all_text_euc,
532
- variance_threshold=self.uncha_himo_variance_threshold,
533
- detach_pca=self.uncha_himo_detach_pca,
534
- )
535
- if get_world_size() > 1:
536
- start = text_euc.size(0) * get_rank()
537
- end = start + text_euc.size(0)
538
- component_euc = all_component_euc[start:end]
539
- else:
540
- component_euc = all_component_euc
541
- himo_text_feats = self.project_text_features(component_euc)
542
- all_himo_text_feats = gather_with_grad(himo_text_feats)
543
- all_image_feats = gather_with_grad(image_feats)
544
- all_text_feats = gather_with_grad(text_feats)
545
- all_image_euc = None
546
- all_text_euc = None
547
- if self.objective_name == "uncha" and self.uncha_contrastive_loss == "siglip":
548
- all_image_euc = gather_with_grad(image_euc)
549
- all_text_euc = gather_with_grad(text_euc)
550
- part_owner = part_owner.to(device=image_feats.device, dtype=torch.long)
551
- beta_query_embeddings = {}
552
- if self.objective_name == "uncha" and self.uncha_entailment_loss in {
553
- "hier_beta_argent",
554
- "hier_beta_sourcepart_argent",
555
- }:
556
- if beta_image_tokens is None:
557
- raise RuntimeError(f"{self.uncha_entailment_loss} requires image patch tokens")
558
- with torch.autocast(device_type=image.device.type, enabled=False):
559
- beta_query_embeddings = self._beta_query_entailment_embeddings(
560
- image_tokens=beta_image_tokens.float(),
561
- beta_query_input_ids=beta_query_input_ids,
562
- beta_query_attention_mask=beta_query_attention_mask,
563
- beta_query_owner=beta_query_owner,
564
- beta_query_parent=beta_query_parent,
565
- beta_query_weight=beta_query_weight,
566
- beta_query_source_part=beta_query_source_part,
567
- kappa=kappa.float(),
568
- query_base=beta_query_base,
569
- )
570
-
571
- with self._objective_autocast(image.device.type):
572
- if self.objective is None:
573
- raise RuntimeError("Non-ProCLIP forward requires an objective module")
574
- losses = self.objective(
575
- {
576
- "image_feats": image_feats,
577
- "text_feats": text_feats,
578
- "part_image_feats": part_image_feats,
579
- "part_text_feats": part_text_feats,
580
- "part_owner": part_owner,
581
- "all_image_feats": all_image_feats,
582
- "all_text_feats": all_text_feats,
583
- **(
584
- {
585
- "image_euc_feats": image_euc,
586
- "text_euc_feats": text_euc,
587
- "part_image_euc_feats": part_image_euc,
588
- "part_text_euc_feats": part_text_euc,
589
- "all_image_euc_feats": all_image_euc,
590
- "all_text_euc_feats": all_text_euc,
591
- }
592
- if all_image_euc is not None and all_text_euc is not None
593
- else {}
594
- ),
595
- "targets": targets,
596
- "kappa": kappa,
597
- "entail_weight_scale": self._entail_weight_scale(step, image_feats.device),
598
- **beta_query_embeddings,
599
- **(
600
- {
601
- "himo_text_feats": himo_text_feats,
602
- "all_himo_text_feats": all_himo_text_feats,
603
- }
604
- if himo_text_feats is not None
605
- else {}
606
- ),
607
- },
608
- self._objective_logit_scales(),
609
- )
610
-
611
- if self.beta_clip_global_weight > 0.0:
612
- with torch.autocast(device_type=image.device.type, enabled=False):
613
- beta_clip_global_loss = self._beta_clip_global_contrastive_loss(
614
- image_euc=image_euc,
615
- text_euc=text_euc,
616
- targets=targets,
617
- )
618
- losses = {
619
- **losses,
620
- "loss": losses["loss"] + self.beta_clip_global_weight * beta_clip_global_loss,
621
- "beta_clip_global_loss": beta_clip_global_loss,
622
- }
623
-
624
- if self.beta_clip_enabled:
625
- if beta_image_tokens is None:
626
- raise RuntimeError("beta-CLIP auxiliary requires image patch tokens")
627
- with torch.autocast(device_type=image.device.type, enabled=False):
628
- beta_clip_loss = self._beta_clip_auxiliary_loss(
629
- image_tokens=beta_image_tokens.float(),
630
- beta_query_input_ids=beta_query_input_ids,
631
- beta_query_attention_mask=beta_query_attention_mask,
632
- beta_query_owner=beta_query_owner,
633
- global_targets=targets,
634
- kappa=kappa.float(),
635
- )
636
- losses = {
637
- **losses,
638
- "loss": losses["loss"] + self.beta_clip_weight * beta_clip_loss,
639
- "beta_clip_loss": beta_clip_loss,
640
- }
641
-
642
- if self.tren_enabled:
643
- if beta_image_tokens is None:
644
- raise RuntimeError("T-REN auxiliary requires image patch tokens")
645
- with torch.autocast(device_type=image.device.type, enabled=False):
646
- tren_losses = self._tren_auxiliary_losses(
647
- image_tokens=beta_image_tokens.float(),
648
- part_owner=part_owner,
649
- part_image_base=part_image_base.float(),
650
- part_text_base=part_text_base.float(),
651
- )
652
- losses = {
653
- **losses,
654
- "loss": losses["loss"] + self.tren_weight * tren_losses["tren_loss"],
655
- **tren_losses,
656
- }
657
-
658
- if self.proclip_enabled and self.proclip_weight > 0.0:
659
- proclip_image_feats = self._project_proclip_image_base(image_base, image_feats)
660
- proclip_text_feats = self._project_proclip_text_base(text_base, text_feats)
661
- proclip_loss = self._proclip_contrastive_loss(
662
- image_feats=proclip_image_feats,
663
- text_feats=proclip_text_feats,
664
- all_image_feats=gather_with_grad(proclip_image_feats),
665
- all_text_feats=gather_with_grad(proclip_text_feats),
666
- targets=targets,
667
- )
668
- losses = {
669
- **losses,
670
- "loss": losses["loss"] + self.proclip_weight * proclip_loss,
671
- "proclip_contrastive_loss": proclip_loss,
672
- }
673
-
674
- return {**losses, **self._detached_kappa_logs(kappa), **self._detached_logit_scales()}
675
-
676
- def _encode_parts(
677
- self,
678
- part_images: torch.Tensor,
679
- part_text_input_ids: torch.Tensor,
680
- part_text_attention_mask: torch.Tensor,
681
- feature_dim: int,
682
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
683
- if part_images.shape[0] == 0:
684
- empty = part_images.new_zeros((0, feature_dim))
685
- return empty, empty, empty, empty
686
-
687
- part_image_euc = self.image_proj(self.encode_image_base(part_images))
688
- part_text_euc = self.text_proj(self.encode_text_base(part_text_input_ids, part_text_attention_mask))
689
- part_image_feats = self.project_image_features(part_image_euc)
690
- part_text_feats = self.project_text_features(part_text_euc)
691
- return part_image_feats, part_text_feats, part_image_euc, part_text_euc
692
-
693
- def _encode_parts_with_base(
694
- self,
695
- part_images: torch.Tensor,
696
- part_text_input_ids: torch.Tensor,
697
- part_text_attention_mask: torch.Tensor,
698
- feature_dim: int,
699
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
700
- if part_images.shape[0] == 0:
701
- empty = part_images.new_zeros((0, feature_dim))
702
- empty_image_base = part_images.new_zeros((0, self.vision_encoder.output_dim))
703
- empty_text_base = part_images.new_zeros((0, self.text_encoder.output_dim))
704
- return empty, empty, empty, empty, empty_image_base, empty_text_base
705
-
706
- part_image_base = self.encode_image_base(part_images)
707
- part_text_base = self.encode_text_base(part_text_input_ids, part_text_attention_mask)
708
- part_image_euc = self.image_proj(part_image_base)
709
- part_text_euc = self.text_proj(part_text_base)
710
- part_image_feats = self.project_image_features(part_image_euc)
711
- part_text_feats = self.project_text_features(part_text_euc)
712
- return part_image_feats, part_text_feats, part_image_euc, part_text_euc, part_image_base, part_text_base
713
-
714
- def _encode_whole_and_parts(
715
- self,
716
- image: torch.Tensor,
717
- part_images: torch.Tensor,
718
- text_input_ids: torch.Tensor,
719
- text_attention_mask: torch.Tensor,
720
- part_text_input_ids: torch.Tensor,
721
- part_text_attention_mask: torch.Tensor,
722
- ) -> tuple[
723
- torch.Tensor,
724
- torch.Tensor,
725
- torch.Tensor,
726
- torch.Tensor,
727
- torch.Tensor,
728
- torch.Tensor,
729
- torch.Tensor,
730
- torch.Tensor,
731
- torch.Tensor,
732
- torch.Tensor,
733
- torch.Tensor,
734
- torch.Tensor,
735
- ]:
736
- batch_size = image.shape[0]
737
- part_count = part_images.shape[0]
738
- image_base_all = self.encode_image_base(torch.cat([image, part_images], dim=0))
739
- image_euc_all = self.image_proj(image_base_all)
740
- image_feats_all = self.project_image_features(image_euc_all)
741
-
742
- text_ids, text_mask = self._concat_text_batches(
743
- text_input_ids,
744
- text_attention_mask,
745
- part_text_input_ids,
746
- part_text_attention_mask,
747
- )
748
- text_base_all = self.encode_text_base(text_ids, text_mask)
749
- text_euc_all = self.text_proj(text_base_all)
750
- text_feats_all = self.project_text_features(text_euc_all)
751
-
752
- image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
753
- text_base, part_text_base = text_base_all.split([batch_size, part_count], dim=0)
754
- image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
755
- text_euc, part_text_euc = text_euc_all.split([batch_size, part_count], dim=0)
756
- image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
757
- text_feats, part_text_feats = text_feats_all.split([batch_size, part_count], dim=0)
758
- return (
759
- image_base,
760
- text_base,
761
- image_euc,
762
- text_euc,
763
- image_feats,
764
- text_feats,
765
- part_image_feats,
766
- part_text_feats,
767
- part_image_euc,
768
- part_text_euc,
769
- part_image_base,
770
- part_text_base,
771
- )
772
-
773
- def _encode_hier_beta_whole_parts_and_queries(
774
- self,
775
- image: torch.Tensor,
776
- part_images: torch.Tensor,
777
- text_input_ids: torch.Tensor,
778
- text_attention_mask: torch.Tensor,
779
- part_text_input_ids: torch.Tensor,
780
- part_text_attention_mask: torch.Tensor,
781
- beta_query_input_ids: torch.Tensor,
782
- beta_query_attention_mask: torch.Tensor,
783
- ) -> tuple[
784
- torch.Tensor,
785
- torch.Tensor,
786
- torch.Tensor,
787
- torch.Tensor,
788
- torch.Tensor,
789
- torch.Tensor,
790
- torch.Tensor,
791
- torch.Tensor,
792
- torch.Tensor,
793
- torch.Tensor,
794
- torch.Tensor,
795
- torch.Tensor,
796
- torch.Tensor,
797
- torch.Tensor,
798
- ]:
799
- batch_size = image.shape[0]
800
- part_count = part_images.shape[0]
801
- query_count = beta_query_input_ids.shape[0]
802
-
803
- image_base_all, image_tokens_all = self.encode_image_base_with_tokens(torch.cat([image, part_images], dim=0))
804
- image_euc_all = self.image_proj(image_base_all)
805
- image_feats_all = self.project_image_features(image_euc_all)
806
- image_base, part_image_base = image_base_all.split([batch_size, part_count], dim=0)
807
- image_euc, part_image_euc = image_euc_all.split([batch_size, part_count], dim=0)
808
- image_feats, part_image_feats = image_feats_all.split([batch_size, part_count], dim=0)
809
- beta_image_tokens = image_tokens_all[:batch_size]
810
-
811
- text_ids, text_mask = self._concat_text_batch_list(
812
- (text_input_ids, text_attention_mask),
813
- (part_text_input_ids, part_text_attention_mask),
814
- (beta_query_input_ids, beta_query_attention_mask),
815
- )
816
- text_base_all = self.encode_text_base(text_ids, text_mask)
817
- text_euc_all = self.text_proj(text_base_all)
818
- text_feats_all = self.project_text_features(text_euc_all)
819
- text_base, part_text_base, beta_query_base = text_base_all.split([batch_size, part_count, query_count], dim=0)
820
- text_euc, part_text_euc, _ = text_euc_all.split([batch_size, part_count, query_count], dim=0)
821
- text_feats, part_text_feats, _ = text_feats_all.split([batch_size, part_count, query_count], dim=0)
822
-
823
- return (
824
- image_base,
825
- text_base,
826
- image_euc,
827
- text_euc,
828
- image_feats,
829
- text_feats,
830
- part_image_feats,
831
- part_text_feats,
832
- part_image_euc,
833
- part_text_euc,
834
- part_image_base,
835
- part_text_base,
836
- beta_image_tokens,
837
- beta_query_base,
838
- )
839
-
840
- def _concat_text_batches(
841
- self,
842
- text_input_ids: torch.Tensor,
843
- text_attention_mask: torch.Tensor,
844
- part_text_input_ids: torch.Tensor,
845
- part_text_attention_mask: torch.Tensor,
846
- ) -> tuple[torch.Tensor, torch.Tensor]:
847
- return self._concat_text_batch_list(
848
- (text_input_ids, text_attention_mask),
849
- (part_text_input_ids, part_text_attention_mask),
850
- )
851
-
852
- def _concat_text_batch_list(
853
- self,
854
- *batches: tuple[torch.Tensor, torch.Tensor],
855
- ) -> tuple[torch.Tensor, torch.Tensor]:
856
- target_length = max(input_ids.shape[1] for input_ids, _ in batches)
857
- pad_token_id = self.text_encoder.tokenizer.pad_token_id
858
- if pad_token_id is None:
859
- pad_token_id = 0
860
- return (
861
- torch.cat([_pad_sequence_dim(input_ids, target_length, pad_token_id) for input_ids, _ in batches], dim=0),
862
- torch.cat([_pad_sequence_dim(attention_mask, target_length, 0) for _, attention_mask in batches], dim=0),
863
- )
864
-
865
- def _clamp_logit_scales(self) -> None:
866
- if self.objective_name == "proclip":
867
- self.proclip_logit_scale.clamp_(max=4.6052)
868
- self._clamp_experimental_logit_scales()
869
- return
870
- if self.objective_name == "hycoclip":
871
- self.logit_scale.clamp_(max=4.6052)
872
- self._clamp_experimental_logit_scales()
873
- return
874
- self.global_logit_scale.clamp_(max=4.6052)
875
- self.local_logit_scale.clamp_(max=4.6052)
876
- self.global_local_logit_scale.clamp_(max=4.6052)
877
- self._clamp_experimental_logit_scales()
878
-
879
- def _objective_logit_scales(self) -> torch.Tensor | dict[str, torch.Tensor]:
880
- if self.objective_name == "hycoclip":
881
- return self.logit_scale
882
- if self.objective_name == "proclip":
883
- return self.proclip_logit_scale
884
- return {
885
- "global": self.global_logit_scale,
886
- "local": self.local_logit_scale,
887
- "global_local": self.global_local_logit_scale,
888
- **(
889
- {
890
- "global_bias": self.global_logit_bias,
891
- "local_bias": self.local_logit_bias,
892
- "global_local_bias": self.global_local_logit_bias,
893
- }
894
- if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}
895
- else {}
896
- ),
897
- }
898
-
899
- def _detached_logit_scales(self) -> dict[str, torch.Tensor]:
900
- if self.objective_name == "proclip":
901
- return self._detached_experimental_logit_scales()
902
- if self.objective_name == "hycoclip":
903
- logs = {"logit_scale": self.logit_scale.exp().detach()}
904
- logs.update(self._detached_experimental_logit_scales())
905
- return logs
906
- logs = {
907
- "global_logit_scale": self.global_logit_scale.exp().detach(),
908
- "local_logit_scale": self.local_logit_scale.exp().detach(),
909
- "global_local_logit_scale": self.global_local_logit_scale.exp().detach(),
910
- }
911
- if self.uncha_contrastive_loss in {"sigmoid", "siglip", "siglip_metric"}:
912
- logs.update(
913
- {
914
- "global_logit_bias": self.global_logit_bias.detach(),
915
- "local_logit_bias": self.local_logit_bias.detach(),
916
- "global_local_logit_bias": self.global_local_logit_bias.detach(),
917
- }
918
- )
919
- logs.update(self._detached_experimental_logit_scales())
920
- return logs
921
-
922
- def _project_product_features(self, feats: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
923
- product_feats = feats.float().reshape(feats.size(0), self.phyclip_num_factors, self.phyclip_subspace_dim)
924
- product_feats = product_feats * alpha.exp().float().view(1, -1, 1)
925
- return exp_map0(product_feats, self._kappa().float().view(1, -1, 1))
926
-
927
- def _detached_kappa_logs(self, kappa: torch.Tensor) -> dict[str, torch.Tensor]:
928
- detached = kappa.detach()
929
- if detached.numel() == 1:
930
- return {"kappa": detached.reshape(())}
931
- return {
932
- "kappa": detached.mean(),
933
- "kappa_min": detached.min(),
934
- "kappa_max": detached.max(),
935
- }
936
-
937
- def _entail_weight_scale(self, step: int | None, device: torch.device) -> torch.Tensor:
938
- if self.uncha_entailment_warmup_steps <= 0 or step is None:
939
- return torch.ones((), device=device)
940
- scale = min(1.0, float(step + 1) / float(self.uncha_entailment_warmup_steps))
941
- return torch.tensor(scale, device=device)
942
-
943
-
944
- def _projection_head(input_dim: int, output_dim: int, hidden_dim: int | None) -> nn.Module:
945
- if hidden_dim is None:
946
- return nn.Linear(input_dim, output_dim)
947
- return nn.Sequential(
948
- nn.Linear(input_dim, hidden_dim),
949
- nn.ReLU(),
950
- nn.Linear(hidden_dim, output_dim),
951
- )
952
-
953
-
954
- def _pad_sequence_dim(tensor: torch.Tensor, target_length: int, value: int) -> torch.Tensor:
955
- pad = target_length - tensor.shape[1]
956
- if pad <= 0:
957
- return tensor
958
- return F.pad(tensor, (0, pad), value=value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/lorentz.py DELETED
@@ -1,265 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
-
5
- import torch
6
- from torch import Tensor
7
-
8
-
9
- def lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
10
- """Compute batched Lorentzian inner product for matching rows."""
11
- x = x.float()
12
- y = y.float()
13
- return -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
14
-
15
-
16
- def pairwise_lorentz_inner(x: Tensor, y: Tensor) -> Tensor:
17
- """Compute all-pairs Lorentzian inner products."""
18
- x = x.float()
19
- y = y.float()
20
- time = -x[:, :1] @ y[:, :1].T
21
- space = x[:, 1:] @ y[:, 1:].T
22
- return time + space
23
-
24
-
25
- def exp_map0(u: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
26
- """Exponential map at the origin from tangent space to hyperboloid."""
27
- u = u.float()
28
- kappa = kappa.float()
29
- sqrt_k = torch.sqrt(kappa)
30
- norm_u = torch.linalg.norm(u, dim=-1, keepdim=True).clamp_min(eps)
31
- scaled = sqrt_k * norm_u
32
- clipped_scaled = scaled.clamp_max(math.asinh(2**15))
33
- time = torch.cosh(clipped_scaled) / sqrt_k
34
- space = torch.sinh(clipped_scaled) * u / scaled.clamp_min(eps)
35
- return torch.cat([time, space], dim=-1)
36
-
37
-
38
- def log_map0(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
39
- """Logarithmic map at the origin from hyperboloid to tangent space.
40
-
41
- Inverts ``exp_map0`` for points on the Lorentz model hyperboloid. Returns
42
- vectors in the Euclidean tangent space at the origin (no time coordinate).
43
- """
44
- x = x.float()
45
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
46
- kappa = kappa.to(dtype=torch.float32).flatten()
47
-
48
- if x.dim() == 2:
49
- if kappa.numel() != 1:
50
- raise ValueError("log_map0 expects scalar kappa for non-product embeddings")
51
- sqrt_k = torch.sqrt(kappa.reshape(()))
52
- alpha = torch.acosh((sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps))
53
- coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
54
- return x[:, 1:] * coef.unsqueeze(-1)
55
-
56
- if x.dim() == 3:
57
- if kappa.numel() == 1:
58
- kappa = kappa.expand(x.shape[1])
59
- if kappa.numel() != x.shape[1]:
60
- raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
61
- sqrt_k = torch.sqrt(kappa).view(1, -1)
62
- alpha = torch.acosh((sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps))
63
- coef = alpha / torch.sinh(alpha).clamp_min(dist_eps)
64
- return x[..., 1:] * coef.unsqueeze(-1)
65
-
66
- raise ValueError("log_map0 expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
67
-
68
-
69
- def pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
70
- """Pairwise geodesic distance on the Lorentz model."""
71
- kappa = kappa.float()
72
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
73
- prod = (-kappa) * pairwise_lorentz_inner(x, y)
74
- prod = prod.clamp_min(1.0 + dist_eps)
75
- return torch.acosh(prod) / torch.sqrt(kappa)
76
-
77
-
78
- def product_pairwise_dist(
79
- x: Tensor,
80
- y: Tensor,
81
- kappa: Tensor,
82
- metric: str = "l1",
83
- eps: float = 1e-8,
84
- ) -> Tensor:
85
- """Pairwise distance in an l1/l2 product of Lorentz factors.
86
-
87
- Inputs have shape ``[batch, factors, dim + 1]``. For ``metric="l1"``, this
88
- matches the official PHyCLIP implementation's mean distance over factors.
89
- """
90
- if x.dim() != 3 or y.dim() != 3:
91
- raise ValueError("product_pairwise_dist expects [batch, factors, dim + 1] tensors")
92
- if x.shape[1] != y.shape[1] or x.shape[2] != y.shape[2]:
93
- raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
94
- kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
95
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
96
- x = x.float()
97
- y = y.float()
98
- inner = -x[:, None, :, 0] * y[None, :, :, 0] + torch.einsum("bkd,nkd->bnk", x[..., 1:], y[..., 1:])
99
- prod = (-kappa.view(1, 1, -1)) * inner
100
- dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, 1, -1)
101
- if metric == "l1":
102
- return dist.mean(dim=-1)
103
- if metric == "l2":
104
- return dist.square().mean(dim=-1).sqrt()
105
- raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
106
-
107
-
108
- def metric_pairwise_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
109
- """Pairwise distance for either a single Lorentz space or a product space."""
110
- if x.dim() == 3 or y.dim() == 3:
111
- return product_pairwise_dist(x, y, kappa, metric=product_metric)
112
- return pairwise_dist(x, y, kappa)
113
-
114
-
115
- def paired_dist(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1", eps: float = 1e-8) -> Tensor:
116
- """Row-wise distance for either a single Lorentz space or a product space."""
117
- if x.dim() == 3 or y.dim() == 3:
118
- if x.shape != y.shape:
119
- raise ValueError("Product paired_dist expects matching tensor shapes")
120
- kappa = _product_kappa(kappa, x.shape[1], x.device).to(dtype=torch.float32)
121
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
122
- x = x.float()
123
- y = y.float()
124
- inner = -x[..., 0] * y[..., 0] + (x[..., 1:] * y[..., 1:]).sum(dim=-1)
125
- prod = (-kappa.view(1, -1)) * inner
126
- dist = torch.acosh(prod.clamp_min(1.0 + dist_eps)) / torch.sqrt(kappa).view(1, -1)
127
- if product_metric == "l1":
128
- return dist.mean(dim=-1)
129
- if product_metric == "l2":
130
- return dist.square().mean(dim=-1).sqrt()
131
- raise ValueError(f"Unsupported product metric {product_metric!r}; expected 'l1' or 'l2'")
132
- kappa = kappa.float()
133
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
134
- prod = (-kappa) * lorentz_inner(x, y)
135
- prod = prod.clamp_min(1.0 + dist_eps)
136
- return torch.acosh(prod) / torch.sqrt(kappa)
137
-
138
-
139
- def radial_distance(x: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
140
- """Geodesic distance from the origin.
141
-
142
- For points on the hyperboloid, the time coordinate satisfies
143
- ``x0 = cosh(sqrt(kappa) * r) / sqrt(kappa)``, so we can recover the radial
144
- distance via ``r = arcosh(sqrt(kappa) * x0) / sqrt(kappa)``.
145
- """
146
- dist_eps = max(eps, 16.0 * torch.finfo(x.dtype).eps)
147
- x = x.float()
148
- kappa = kappa.to(dtype=torch.float32).flatten()
149
- if x.dim() == 2:
150
- if kappa.numel() != 1:
151
- raise ValueError("radial_distance expects scalar kappa for non-product embeddings")
152
- sqrt_k = torch.sqrt(kappa.reshape(()))
153
- arg = (sqrt_k * x[:, 0]).clamp_min(1.0 + dist_eps)
154
- return torch.acosh(arg) / sqrt_k
155
- if x.dim() == 3:
156
- if kappa.numel() == 1:
157
- kappa = kappa.expand(x.shape[1])
158
- if kappa.numel() != x.shape[1]:
159
- raise ValueError(f"Expected {x.shape[1]} curvatures for product space, got {kappa.numel()}")
160
- sqrt_k = torch.sqrt(kappa).view(1, -1)
161
- arg = (sqrt_k * x[..., 0]).clamp_min(1.0 + dist_eps)
162
- dist = torch.acosh(arg) / sqrt_k
163
- return dist.mean(dim=-1)
164
- raise ValueError("radial_distance expects [batch, dim + 1] or [batch, factors, dim + 1] tensors")
165
-
166
-
167
- def metric_similarity(x: Tensor, y: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
168
- """Retrieval/classification similarity for single-space and PHyCLIP-style models."""
169
- if x.dim() == 3 or y.dim() == 3:
170
- return -product_pairwise_dist(x, y, kappa, metric=product_metric)
171
- return pairwise_lorentz_inner(x, y)
172
-
173
-
174
- def half_aperture(general: Tensor, kappa: Tensor, min_radius: float = 0.1, eps: float = 1e-8) -> Tensor:
175
- """Cone half-aperture for entailment cone centered at general concept."""
176
- general = general.float()
177
- kappa = kappa.float()
178
- aperture_eps = max(eps, 16.0 * torch.finfo(general.dtype).eps)
179
- general_norm = torch.linalg.norm(general[:, 1:], dim=-1)
180
- ratio = (2.0 * min_radius) / (general_norm * torch.sqrt(kappa) + aperture_eps)
181
- ratio = ratio.clamp(max=1.0 - aperture_eps)
182
- return torch.asin(ratio)
183
-
184
-
185
- def oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
186
- """Exterior angle between specific point and entailment cone at general point."""
187
- specific = specific.float()
188
- general = general.float()
189
- kappa = kappa.float()
190
- angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
191
- inner = lorentz_inner(specific, general)
192
- numerator = specific[:, 0] + kappa * inner * general[:, 0]
193
- general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
194
- denom_term = (kappa * inner).pow(2) - 1.0
195
- denom = general_norm * torch.sqrt(denom_term.clamp_min(angle_eps))
196
- cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
197
- return torch.acos(cosine)
198
-
199
-
200
- def pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, eps: float = 1e-8) -> Tensor:
201
- """All-pairs exterior angle between specific points and entailment cones at general points."""
202
- specific = specific.float()
203
- general = general.float()
204
- kappa = kappa.to(dtype=torch.float32).flatten()
205
- if kappa.numel() != 1:
206
- raise ValueError("pairwise_oxy_angle expects scalar kappa for non-product embeddings")
207
- kappa_scalar = kappa.reshape(())
208
- angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
209
- inner = -specific[:, None, 0] * general[None, :, 0] + torch.einsum("nd,md->nm", specific[:, 1:], general[:, 1:])
210
- numerator = specific[:, None, 0] + kappa_scalar * inner * general[None, :, 0]
211
- general_norm = torch.linalg.norm(general[:, 1:], dim=-1).clamp_min(angle_eps)
212
- denom_term = (kappa_scalar * inner).pow(2) - 1.0
213
- denom = general_norm[None, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
214
- cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
215
- return torch.acos(cosine)
216
-
217
-
218
- def product_pairwise_oxy_angle(
219
- specific: Tensor,
220
- general: Tensor,
221
- kappa: Tensor,
222
- metric: str = "l1",
223
- eps: float = 1e-8,
224
- ) -> Tensor:
225
- """All-pairs exterior angle in an l1/l2 product of Lorentz factors."""
226
- if specific.dim() != 3 or general.dim() != 3:
227
- raise ValueError("product_pairwise_oxy_angle expects [batch, factors, dim + 1] tensors")
228
- if specific.shape[1] != general.shape[1] or specific.shape[2] != general.shape[2]:
229
- raise ValueError("Product Lorentz tensors must have matching factor and feature dimensions")
230
- kappa = _product_kappa(kappa, specific.shape[1], specific.device).to(dtype=torch.float32)
231
- angle_eps = max(eps, 16.0 * torch.finfo(specific.dtype).eps)
232
- specific = specific.float()
233
- general = general.float()
234
- inner = -specific[:, None, :, 0] * general[None, :, :, 0] + torch.einsum(
235
- "nkd,mkd->nmk",
236
- specific[..., 1:],
237
- general[..., 1:],
238
- )
239
- numerator = specific[:, None, :, 0] + (kappa.view(1, 1, -1) * inner) * general[None, :, :, 0]
240
- general_norm = torch.linalg.norm(general[..., 1:], dim=-1).clamp_min(angle_eps)
241
- denom_term = (kappa.view(1, 1, -1) * inner).pow(2) - 1.0
242
- denom = general_norm[None, :, :] * torch.sqrt(denom_term.clamp_min(angle_eps))
243
- cosine = (numerator / denom).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
244
- angles = torch.acos(cosine)
245
- if metric == "l1":
246
- return angles.mean(dim=-1)
247
- if metric == "l2":
248
- return angles.square().mean(dim=-1).sqrt()
249
- raise ValueError(f"Unsupported product metric {metric!r}; expected 'l1' or 'l2'")
250
-
251
-
252
- def metric_pairwise_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
253
- """All-pairs oxy-angle for either a single Lorentz space or a product space."""
254
- if specific.dim() == 3 or general.dim() == 3:
255
- return product_pairwise_oxy_angle(specific, general, kappa, metric=product_metric)
256
- return pairwise_oxy_angle(specific, general, kappa)
257
-
258
-
259
- def _product_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
260
- kappa = kappa.to(device=device, dtype=torch.float32).flatten()
261
- if kappa.numel() == 1:
262
- return kappa.expand(num_factors)
263
- if kappa.numel() != num_factors:
264
- raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
265
- return kappa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/losses.py DELETED
@@ -1,1400 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
-
5
- import torch
6
- from torch import Tensor
7
- import torch.nn.functional as F
8
-
9
- from hyper3_clip.models.lorentz import (
10
- half_aperture,
11
- metric_pairwise_dist,
12
- metric_pairwise_oxy_angle,
13
- oxy_angle,
14
- paired_dist,
15
- radial_distance,
16
- )
17
-
18
-
19
- def contrastive_ce(logits: Tensor, targets: Tensor | None = None, weights: Tensor | None = None) -> Tensor:
20
- if targets is None:
21
- targets = torch.arange(logits.size(0), device=logits.device)
22
- losses = F.cross_entropy(logits, targets, reduction="none")
23
- return weighted_mean(losses, weights)
24
-
25
-
26
- def contrastive_sigmoid(
27
- logits: Tensor,
28
- targets: Tensor | None = None,
29
- weights: Tensor | None = None,
30
- negative_weight: float = 1.0,
31
- ) -> Tensor:
32
- if targets is None:
33
- targets = torch.arange(logits.size(0), device=logits.device)
34
- labels = torch.zeros_like(logits)
35
- labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
36
- losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
37
- if negative_weight != 1.0:
38
- element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
39
- losses = losses * element_weights
40
- losses = losses.mean(dim=1)
41
- return weighted_mean(losses, weights)
42
-
43
-
44
- def contrastive_siglip(
45
- logits: Tensor,
46
- targets: Tensor | None = None,
47
- weights: Tensor | None = None,
48
- negative_weight: float = 1.0,
49
- ) -> Tensor:
50
- """SigLIP pairwise sigmoid loss (Zhai et al., ICCV 2023).
51
-
52
- Uses labels in {+1, -1} with a per-row sum (not mean) over pairs:
53
- L_i = sum_j softplus(- y_ij * logit_ij)
54
- """
55
- if logits.ndim != 2:
56
- raise ValueError("contrastive_siglip expects a [batch, classes] logit matrix")
57
- if targets is None:
58
- targets = torch.arange(logits.size(0), device=logits.device)
59
- labels = logits.new_full(logits.shape, -1.0)
60
- labels[torch.arange(logits.size(0), device=logits.device), targets] = 1.0
61
- losses = F.softplus(-(labels * logits))
62
- if negative_weight != 1.0:
63
- element_weights = torch.where(labels > 0.0, torch.ones_like(labels), logits.new_full((), negative_weight))
64
- losses = losses * element_weights
65
- row_losses = losses.sum(dim=1)
66
- return weighted_mean(row_losses, weights)
67
-
68
-
69
- def weighted_mean(values: Tensor, weights: Tensor | None = None) -> Tensor:
70
- if weights is None:
71
- return values.mean()
72
- weights = weights.to(device=values.device, dtype=values.dtype)
73
- while weights.dim() < values.dim():
74
- weights = weights.unsqueeze(-1)
75
- return (values * weights).sum() / weights.sum().clamp_min(torch.finfo(values.dtype).eps)
76
-
77
-
78
- def gramian_volume_loss(vectors: Tensor, weights: Tensor | None = None, eps: float = 1e-4) -> Tensor:
79
- """GRAM-style volume loss for sets of vectors.
80
-
81
- ``vectors`` is expected to have shape ``[batch, k, dim]``. Each set of k
82
- vectors is L2-normalized along ``dim``, then we compute the Gramian
83
- ``G = V V^T`` and return ``sqrt(det(G + eps I))`` averaged over the batch.
84
- """
85
- if vectors.ndim != 3:
86
- raise ValueError("gramian_volume_loss expects a [batch, k, dim] tensor")
87
- if eps <= 0.0:
88
- raise ValueError("gramian_volume_loss eps must be positive")
89
-
90
- vectors = F.normalize(vectors.float(), dim=-1, eps=1e-8)
91
- gram = vectors @ vectors.transpose(-1, -2)
92
- k = gram.size(-1)
93
- gram = gram + eps * torch.eye(k, device=gram.device, dtype=gram.dtype)
94
- sign, logabsdet = torch.linalg.slogdet(gram)
95
- volume = torch.exp(0.5 * logabsdet)
96
- volume = torch.where(sign > 0, volume, volume.new_ones(volume.shape))
97
- return weighted_mean(volume, weights)
98
-
99
-
100
- def radius_order_hinge(
101
- specific: Tensor,
102
- general: Tensor,
103
- kappa: Tensor,
104
- margin: float,
105
- weights: Tensor | None = None,
106
- ) -> Tensor:
107
- if specific.shape[0] != general.shape[0]:
108
- raise ValueError("radius_order_hinge expects matching batch dimensions")
109
- if margin < 0.0:
110
- raise ValueError("radius_order_hinge margin must be non-negative")
111
- specific_radius = radial_distance(specific, kappa)
112
- general_radius = radial_distance(general, kappa)
113
- losses = F.relu(float(margin) + general_radius - specific_radius)
114
- return weighted_mean(losses, weights)
115
-
116
-
117
- def soft_contrastive_ce(logits: Tensor, target_weights: Tensor, weights: Tensor | None = None) -> Tensor:
118
- if logits.ndim != 2 or target_weights.ndim != 2:
119
- raise ValueError("soft_contrastive_ce expects [batch, classes] tensors")
120
- if logits.shape != target_weights.shape:
121
- raise ValueError("soft_contrastive_ce requires logits and target_weights to have matching shapes")
122
- log_probs = F.log_softmax(logits, dim=1)
123
- losses = -(target_weights.to(dtype=log_probs.dtype) * log_probs).sum(dim=1)
124
- return weighted_mean(losses, weights)
125
-
126
-
127
- def beta_cal_loss(
128
- logits: Tensor,
129
- *,
130
- targets: Tensor,
131
- group_ids: Tensor,
132
- all_group_ids: Tensor,
133
- beta: float,
134
- variant: str,
135
- weights: Tensor | None = None,
136
- ) -> Tensor:
137
- if beta < 0.0:
138
- raise ValueError("beta_cal_loss beta must be non-negative")
139
- if variant not in {"ce", "bce"}:
140
- raise ValueError("beta_cal_loss variant must be 'ce' or 'bce'")
141
- if logits.ndim != 2:
142
- raise ValueError("beta_cal_loss expects a [batch, classes] logit matrix")
143
- if targets.shape != (logits.size(0),):
144
- raise ValueError("beta_cal_loss targets must have shape [batch]")
145
- if group_ids.shape != (logits.size(0),):
146
- raise ValueError("beta_cal_loss group_ids must have shape [batch]")
147
- if all_group_ids.shape != (logits.size(1),):
148
- raise ValueError("beta_cal_loss all_group_ids must have shape [classes]")
149
-
150
- same_group = group_ids[:, None] == all_group_ids[None, :]
151
- same_pair = targets[:, None] == torch.arange(logits.size(1), device=logits.device)[None, :]
152
-
153
- if variant == "ce":
154
- target_weights = logits.new_zeros(logits.shape)
155
- target_weights = torch.where(same_pair, logits.new_ones(()), target_weights)
156
- target_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), target_weights)
157
- target_weights = target_weights / target_weights.sum(dim=1, keepdim=True).clamp_min(
158
- torch.finfo(target_weights.dtype).eps
159
- )
160
- return soft_contrastive_ce(logits, target_weights, weights)
161
-
162
- labels = same_group.to(dtype=logits.dtype)
163
- element_weights = logits.new_ones(logits.shape)
164
- element_weights = torch.where(same_group & ~same_pair, logits.new_full((), float(beta)), element_weights)
165
- element_losses = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") * element_weights
166
- row_losses = element_losses.mean(dim=1)
167
- return weighted_mean(row_losses, weights)
168
-
169
- def compositional_contrastive_loss(
170
- image_feats: Tensor,
171
- text_feats: Tensor,
172
- box_image_feats: Tensor,
173
- box_text_feats: Tensor,
174
- kappa: Tensor,
175
- logit_scale: Tensor,
176
- all_image_feats: Tensor | None = None,
177
- all_text_feats: Tensor | None = None,
178
- targets: Tensor | None = None,
179
- ) -> Tensor:
180
- scale = logit_scale.exp().clamp(max=100.0)
181
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
182
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
183
-
184
- logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
185
- logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
186
- logits_bi_t = -metric_pairwise_dist(box_image_feats, all_text_feats, kappa) * scale
187
- logits_bt_i = -metric_pairwise_dist(box_text_feats, all_image_feats, kappa) * scale
188
-
189
- return 0.25 * (
190
- contrastive_ce(logits_i_t, targets)
191
- + contrastive_ce(logits_t_i, targets)
192
- + contrastive_ce(logits_bi_t, targets)
193
- + contrastive_ce(logits_bt_i, targets)
194
- )
195
-
196
-
197
- def multi_part_contrastive_loss(
198
- image_feats: Tensor,
199
- text_feats: Tensor,
200
- part_image_feats: Tensor,
201
- part_text_feats: Tensor,
202
- part_mask: Tensor,
203
- kappa: Tensor,
204
- logit_scale: Tensor,
205
- all_image_feats: Tensor | None = None,
206
- all_text_feats: Tensor | None = None,
207
- targets: Tensor | None = None,
208
- ) -> Tensor:
209
- scale = logit_scale.exp().clamp(max=100.0)
210
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
211
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
212
- if targets is None:
213
- targets = torch.arange(image_feats.size(0), device=image_feats.device)
214
-
215
- part_image_flat, part_text_flat, part_targets = _flatten_valid_parts(part_image_feats, part_text_feats, part_mask, targets)
216
-
217
- logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
218
- logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
219
- logits_pi_t = -metric_pairwise_dist(part_image_flat, all_text_feats, kappa) * scale
220
- logits_pt_i = -metric_pairwise_dist(part_text_flat, all_image_feats, kappa) * scale
221
-
222
- return 0.25 * (
223
- contrastive_ce(logits_i_t, targets)
224
- + contrastive_ce(logits_t_i, targets)
225
- + contrastive_ce(logits_pi_t, part_targets)
226
- + contrastive_ce(logits_pt_i, part_targets)
227
- )
228
-
229
-
230
- def packed_part_contrastive_loss(
231
- image_feats: Tensor,
232
- text_feats: Tensor,
233
- part_image_feats: Tensor,
234
- part_text_feats: Tensor,
235
- part_owner: Tensor,
236
- kappa: Tensor,
237
- logit_scale: Tensor,
238
- all_image_feats: Tensor | None = None,
239
- all_text_feats: Tensor | None = None,
240
- targets: Tensor | None = None,
241
- ) -> Tensor:
242
- scale = logit_scale.exp().clamp(max=100.0)
243
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
244
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
245
- if targets is None:
246
- targets = torch.arange(image_feats.size(0), device=image_feats.device)
247
-
248
- logits_i_t = -metric_pairwise_dist(image_feats, all_text_feats, kappa) * scale
249
- logits_t_i = -metric_pairwise_dist(text_feats, all_image_feats, kappa) * scale
250
- global_loss = 0.5 * (contrastive_ce(logits_i_t, targets) + contrastive_ce(logits_t_i, targets))
251
-
252
- if part_image_feats.numel() == 0:
253
- return global_loss
254
-
255
- part_targets = targets[part_owner]
256
- logits_pi_t = -metric_pairwise_dist(part_image_feats, all_text_feats, kappa) * scale
257
- logits_pt_i = -metric_pairwise_dist(part_text_feats, all_image_feats, kappa) * scale
258
- part_loss = 0.5 * (contrastive_ce(logits_pi_t, part_targets) + contrastive_ce(logits_pt_i, part_targets))
259
- return 0.5 * (global_loss + part_loss)
260
-
261
-
262
- def factor_oxy_angle(specific: Tensor, general: Tensor, kappa: Tensor) -> Tensor:
263
- if specific.dim() != 3:
264
- return oxy_angle(specific=specific, general=general, kappa=kappa)
265
- batch_size, num_factors, feature_dim = specific.shape
266
- kappa = _factor_kappa(kappa, num_factors, specific.device)
267
- factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
268
- return oxy_angle(
269
- specific=specific.reshape(batch_size * num_factors, feature_dim),
270
- general=general.reshape(batch_size * num_factors, feature_dim),
271
- kappa=factor_kappa,
272
- ).reshape(batch_size, num_factors)
273
-
274
-
275
- def factor_half_aperture(general: Tensor, kappa: Tensor) -> Tensor:
276
- if general.dim() != 3:
277
- return half_aperture(general=general, kappa=kappa)
278
- batch_size, num_factors, feature_dim = general.shape
279
- kappa = _factor_kappa(kappa, num_factors, general.device)
280
- factor_kappa = kappa.view(1, num_factors).expand(batch_size, num_factors).reshape(-1)
281
- return half_aperture(
282
- general=general.reshape(batch_size * num_factors, feature_dim),
283
- kappa=factor_kappa,
284
- ).reshape(batch_size, num_factors)
285
-
286
-
287
- def _factor_kappa(kappa: Tensor, num_factors: int, device: torch.device) -> Tensor:
288
- kappa = kappa.to(device=device, dtype=torch.float32).flatten()
289
- if kappa.numel() == 1:
290
- return kappa.expand(num_factors)
291
- if kappa.numel() != num_factors:
292
- raise ValueError(f"Expected {num_factors} curvatures for product space, got {kappa.numel()}")
293
- return kappa
294
-
295
-
296
- def entailment_residual(
297
- specific: Tensor,
298
- general: Tensor,
299
- kappa: Tensor,
300
- aperture_scale: float,
301
- ) -> Tensor:
302
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
303
- apertures = factor_half_aperture(general=general, kappa=kappa)
304
- return torch.clamp(angles - (aperture_scale * apertures), min=0.0).mean()
305
-
306
-
307
- def weighted_entailment_residual(
308
- specific: Tensor,
309
- general: Tensor,
310
- kappa: Tensor,
311
- aperture_scale: float,
312
- weights: Tensor | None = None,
313
- ) -> Tensor:
314
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
315
- apertures = factor_half_aperture(general=general, kappa=kappa)
316
- residuals = torch.clamp(angles - (aperture_scale * apertures), min=0.0)
317
- if residuals.dim() == 2:
318
- residuals = residuals.mean(dim=-1)
319
- return weighted_mean(residuals, weights)
320
-
321
-
322
- def compositional_entailment_loss(
323
- image_feats: Tensor,
324
- text_feats: Tensor,
325
- box_image_feats: Tensor,
326
- box_text_feats: Tensor,
327
- kappa: Tensor,
328
- inter_aperture_scale: float,
329
- intra_aperture_scale: float,
330
- ) -> Tensor:
331
- text_to_image = entailment_residual(
332
- specific=image_feats,
333
- general=text_feats,
334
- kappa=kappa,
335
- aperture_scale=inter_aperture_scale,
336
- )
337
- box_text_to_box_image = entailment_residual(
338
- specific=box_image_feats,
339
- general=box_text_feats,
340
- kappa=kappa,
341
- aperture_scale=inter_aperture_scale,
342
- )
343
- box_image_to_image = entailment_residual(
344
- specific=image_feats,
345
- general=box_image_feats,
346
- kappa=kappa,
347
- aperture_scale=intra_aperture_scale,
348
- )
349
- box_text_to_text = entailment_residual(
350
- specific=text_feats,
351
- general=box_text_feats,
352
- kappa=kappa,
353
- aperture_scale=intra_aperture_scale,
354
- )
355
-
356
- return 0.5 * (text_to_image + box_text_to_box_image + box_image_to_image + box_text_to_text)
357
-
358
-
359
- def multi_part_entailment_loss(
360
- image_feats: Tensor,
361
- text_feats: Tensor,
362
- part_image_feats: Tensor,
363
- part_text_feats: Tensor,
364
- part_mask: Tensor,
365
- kappa: Tensor,
366
- inter_aperture_scale: float,
367
- intra_aperture_scale: float,
368
- ) -> Tensor:
369
- part_image_flat = part_image_feats[part_mask]
370
- part_text_flat = part_text_feats[part_mask]
371
- image_for_parts = image_feats[:, None, :].expand_as(part_image_feats)[part_mask]
372
- text_for_parts = text_feats[:, None, :].expand_as(part_text_feats)[part_mask]
373
-
374
- text_to_image = entailment_residual(
375
- specific=image_feats,
376
- general=text_feats,
377
- kappa=kappa,
378
- aperture_scale=inter_aperture_scale,
379
- )
380
- part_text_to_part_image = entailment_residual(
381
- specific=part_image_flat,
382
- general=part_text_flat,
383
- kappa=kappa,
384
- aperture_scale=inter_aperture_scale,
385
- )
386
- part_image_to_image = entailment_residual(
387
- specific=image_for_parts,
388
- general=part_image_flat,
389
- kappa=kappa,
390
- aperture_scale=intra_aperture_scale,
391
- )
392
- part_text_to_text = entailment_residual(
393
- specific=text_for_parts,
394
- general=part_text_flat,
395
- kappa=kappa,
396
- aperture_scale=intra_aperture_scale,
397
- )
398
-
399
- return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
400
-
401
-
402
- def packed_part_entailment_loss(
403
- image_feats: Tensor,
404
- text_feats: Tensor,
405
- part_image_feats: Tensor,
406
- part_text_feats: Tensor,
407
- part_owner: Tensor,
408
- kappa: Tensor,
409
- inter_aperture_scale: float,
410
- intra_aperture_scale: float,
411
- ) -> Tensor:
412
- text_to_image = entailment_residual(
413
- specific=image_feats,
414
- general=text_feats,
415
- kappa=kappa,
416
- aperture_scale=inter_aperture_scale,
417
- )
418
- if part_image_feats.numel() == 0:
419
- return text_to_image
420
-
421
- image_for_parts = image_feats[part_owner]
422
- text_for_parts = text_feats[part_owner]
423
- part_text_to_part_image = entailment_residual(
424
- specific=part_image_feats,
425
- general=part_text_feats,
426
- kappa=kappa,
427
- aperture_scale=inter_aperture_scale,
428
- )
429
- part_image_to_image = entailment_residual(
430
- specific=image_for_parts,
431
- general=part_image_feats,
432
- kappa=kappa,
433
- aperture_scale=intra_aperture_scale,
434
- )
435
- part_text_to_text = entailment_residual(
436
- specific=text_for_parts,
437
- general=part_text_feats,
438
- kappa=kappa,
439
- aperture_scale=intra_aperture_scale,
440
- )
441
-
442
- return 0.5 * (text_to_image + part_text_to_part_image + part_image_to_image + part_text_to_text)
443
-
444
-
445
- def uncha_contrastive_losses(
446
- image_feats: Tensor,
447
- text_feats: Tensor,
448
- part_image_flat: Tensor,
449
- part_text_flat: Tensor,
450
- image_for_parts: Tensor,
451
- text_for_parts: Tensor,
452
- kappa: Tensor,
453
- global_logit_scale: Tensor,
454
- local_logit_scale: Tensor,
455
- global_local_logit_scale: Tensor,
456
- image_euc_feats: Tensor | None = None,
457
- text_euc_feats: Tensor | None = None,
458
- part_image_euc_flat: Tensor | None = None,
459
- part_text_euc_flat: Tensor | None = None,
460
- image_for_parts_euc: Tensor | None = None,
461
- text_for_parts_euc: Tensor | None = None,
462
- all_image_feats: Tensor | None = None,
463
- all_text_feats: Tensor | None = None,
464
- all_part_image_feats: Tensor | None = None,
465
- all_part_text_feats: Tensor | None = None,
466
- all_image_for_parts: Tensor | None = None,
467
- all_text_for_parts: Tensor | None = None,
468
- all_image_euc_feats: Tensor | None = None,
469
- all_text_euc_feats: Tensor | None = None,
470
- all_part_image_euc_feats: Tensor | None = None,
471
- all_part_text_euc_feats: Tensor | None = None,
472
- all_image_for_parts_euc: Tensor | None = None,
473
- all_text_for_parts_euc: Tensor | None = None,
474
- global_targets: Tensor | None = None,
475
- part_targets: Tensor | None = None,
476
- part_weights: Tensor | None = None,
477
- product_metric: str = "l1",
478
- loss_type: str = "ce",
479
- contrastive_global_weight: float = 1.0,
480
- contrastive_local_weight: float = 1.0,
481
- contrastive_global_local_weight: float = 1.0,
482
- beta_cal_beta: float = 0.0,
483
- beta_cal_variant: str = "ce",
484
- beta_cal_weight: float = 0.0,
485
- part_group_ids: Tensor | None = None,
486
- all_part_group_ids: Tensor | None = None,
487
- global_logit_bias: Tensor | None = None,
488
- local_logit_bias: Tensor | None = None,
489
- global_local_logit_bias: Tensor | None = None,
490
- sigmoid_negative_weight: float = 1.0,
491
- global_local_mode: str = "repeat",
492
- global_local_metric: str = "distance",
493
- global_local_angle_aux_weight: float = 0.0,
494
- global_local_angle_aux_mode: str = "contrastive",
495
- global_local_angle_aux_scale: float = 5.5,
496
- global_local_angle_aux_aperture_scale: float = 1.0,
497
- ) -> dict[str, Tensor]:
498
- if loss_type not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
499
- raise ValueError(
500
- f"Unsupported contrastive loss {loss_type!r}; expected 'ce', 'sigmoid', 'siglip', or 'siglip_metric'"
501
- )
502
- if global_local_mode not in {"repeat", "inbatch"}:
503
- raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
504
- if global_local_metric not in {"distance", "angle"}:
505
- raise ValueError("global_local_metric must be 'distance' or 'angle'")
506
- if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
507
- raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
508
- if global_local_angle_aux_weight < 0.0:
509
- raise ValueError("global_local_angle_aux_weight must be non-negative")
510
- if global_local_angle_aux_scale <= 0.0:
511
- raise ValueError("global_local_angle_aux_scale must be positive")
512
- if global_local_angle_aux_aperture_scale <= 0.0:
513
- raise ValueError("global_local_angle_aux_aperture_scale must be positive")
514
- all_image_feats = image_feats if all_image_feats is None else all_image_feats
515
- all_text_feats = text_feats if all_text_feats is None else all_text_feats
516
- all_part_image_feats = part_image_flat if all_part_image_feats is None else all_part_image_feats
517
- all_part_text_feats = part_text_flat if all_part_text_feats is None else all_part_text_feats
518
- all_image_for_parts = image_for_parts if all_image_for_parts is None else all_image_for_parts
519
- all_text_for_parts = text_for_parts if all_text_for_parts is None else all_text_for_parts
520
- if global_targets is None:
521
- global_targets = torch.arange(image_feats.size(0), device=image_feats.device)
522
- if part_targets is None:
523
- part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device)
524
-
525
- global_scale = global_logit_scale.exp().clamp(max=100.0)
526
- local_scale = local_logit_scale.exp().clamp(max=100.0)
527
- global_local_scale = global_local_logit_scale.exp().clamp(max=100.0)
528
-
529
- if loss_type == "siglip":
530
- if image_euc_feats is None or text_euc_feats is None:
531
- raise ValueError("siglip contrastive requires image_euc_feats and text_euc_feats")
532
- if image_feats.dim() != 2 or text_feats.dim() != 2:
533
- raise ValueError("siglip contrastive is only supported for non-product features")
534
- all_image_euc_feats = image_euc_feats if all_image_euc_feats is None else all_image_euc_feats
535
- all_text_euc_feats = text_euc_feats if all_text_euc_feats is None else all_text_euc_feats
536
- zimg = F.normalize(image_euc_feats.float(), dim=-1)
537
- ztxt = F.normalize(text_euc_feats.float(), dim=-1)
538
- zimg_all = F.normalize(all_image_euc_feats.float(), dim=-1)
539
- ztxt_all = F.normalize(all_text_euc_feats.float(), dim=-1)
540
- image_logits = (zimg @ ztxt_all.T) * global_scale
541
- text_logits = (ztxt @ zimg_all.T) * global_scale
542
- else:
543
- image_logits = -metric_pairwise_dist(image_feats, all_text_feats, kappa, product_metric=product_metric) * global_scale
544
- text_logits = -metric_pairwise_dist(text_feats, all_image_feats, kappa, product_metric=product_metric) * global_scale
545
-
546
- if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
547
- bias = image_logits.new_zeros(()) if global_logit_bias is None else global_logit_bias.to(image_logits.device)
548
- image_logits = image_logits + bias
549
- text_logits = text_logits + bias
550
- global_contrastive = 0.5 * (
551
- _contrastive_loss(image_logits, global_targets, None, loss_type, sigmoid_negative_weight)
552
- + _contrastive_loss(text_logits, global_targets, None, loss_type, sigmoid_negative_weight)
553
- )
554
-
555
- if part_image_flat.numel() == 0:
556
- zero = image_feats.new_zeros(())
557
- contrastive = contrastive_global_weight * global_contrastive
558
- return {
559
- "contrastive_loss": contrastive,
560
- "global_contrastive_loss": global_contrastive,
561
- "local_contrastive_loss": zero,
562
- "global_local_contrastive_loss": zero,
563
- "global_local_angle_aux_loss": zero,
564
- "beta_cal_loss": zero,
565
- }
566
-
567
- if loss_type == "siglip":
568
- if part_image_euc_flat is None or part_text_euc_flat is None:
569
- raise ValueError("siglip contrastive requires part_image_euc_flat and part_text_euc_flat when parts exist")
570
- all_part_image_euc_feats = part_image_euc_flat if all_part_image_euc_feats is None else all_part_image_euc_feats
571
- all_part_text_euc_feats = part_text_euc_flat if all_part_text_euc_feats is None else all_part_text_euc_feats
572
- zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
573
- zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
574
- zpi_all = F.normalize(all_part_image_euc_feats.float(), dim=-1)
575
- zpt_all = F.normalize(all_part_text_euc_feats.float(), dim=-1)
576
- part_image_logits = (zpi @ zpt_all.T) * local_scale
577
- part_text_logits = (zpt @ zpi_all.T) * local_scale
578
- else:
579
- part_image_logits = -metric_pairwise_dist(part_image_flat, all_part_text_feats, kappa, product_metric=product_metric) * local_scale
580
- part_text_logits = -metric_pairwise_dist(part_text_flat, all_part_image_feats, kappa, product_metric=product_metric) * local_scale
581
-
582
- if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
583
- bias = part_image_logits.new_zeros(()) if local_logit_bias is None else local_logit_bias.to(part_image_logits.device)
584
- part_image_logits = part_image_logits + bias
585
- part_text_logits = part_text_logits + bias
586
- local_contrastive = 0.5 * (
587
- _contrastive_loss(part_image_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
588
- + _contrastive_loss(part_text_logits, part_targets, part_weights, loss_type, sigmoid_negative_weight)
589
- )
590
-
591
- global_local_contrastive = image_feats.new_zeros(())
592
- global_local_angle_aux = image_feats.new_zeros(())
593
- if contrastive_global_local_weight != 0.0:
594
- if global_local_mode == "inbatch":
595
- if part_group_ids is None:
596
- raise ValueError("inbatch global-local contrastive requires part_group_ids to be provided")
597
- global_local_targets = part_group_ids
598
- all_text_for_global_local = all_text_feats
599
- all_image_for_global_local = all_image_feats
600
- all_text_for_global_local_euc = all_text_euc_feats
601
- all_image_for_global_local_euc = all_image_euc_feats
602
- else:
603
- global_local_targets = part_targets
604
- all_text_for_global_local = all_text_for_parts
605
- all_image_for_global_local = all_image_for_parts
606
- all_text_for_global_local_euc = all_text_for_parts_euc
607
- all_image_for_global_local_euc = all_image_for_parts_euc
608
-
609
- image_uncertainty = embedding_uncertainty(part_image_flat).detach()
610
- text_uncertainty = embedding_uncertainty(part_text_flat).detach()
611
- image_temp = torch.exp(-0.5 * image_uncertainty).clamp(min=0.1, max=10.0)
612
- text_temp = torch.exp(-0.5 * text_uncertainty).clamp(min=0.1, max=10.0)
613
-
614
- if loss_type == "siglip":
615
- if part_image_euc_flat is None or part_text_euc_flat is None:
616
- raise ValueError("siglip global-local contrastive requires part_image_euc_flat/part_text_euc_flat")
617
- if all_text_for_global_local_euc is None or all_image_for_global_local_euc is None:
618
- raise ValueError("siglip global-local contrastive requires all_image_euc_feats/all_text_euc_feats")
619
- zpi = F.normalize(part_image_euc_flat.float(), dim=-1)
620
- zpt = F.normalize(part_text_euc_flat.float(), dim=-1)
621
- zimg_all = F.normalize(all_image_for_global_local_euc.float(), dim=-1)
622
- ztxt_all = F.normalize(all_text_for_global_local_euc.float(), dim=-1)
623
- part_image_to_whole_text = (zpi @ ztxt_all.T) * image_temp[:, None] * global_local_scale
624
- part_text_to_whole_image = (zpt @ zimg_all.T) * text_temp[:, None] * global_local_scale
625
- else:
626
- if global_local_metric == "angle":
627
- part_image_to_whole_text = -metric_pairwise_oxy_angle(
628
- part_image_flat,
629
- all_text_for_global_local,
630
- kappa,
631
- product_metric=product_metric,
632
- )
633
- part_text_to_whole_image = -metric_pairwise_oxy_angle(
634
- part_text_flat,
635
- all_image_for_global_local,
636
- kappa,
637
- product_metric=product_metric,
638
- )
639
- else:
640
- part_image_to_whole_text = -metric_pairwise_dist(
641
- part_image_flat, all_text_for_global_local, kappa, product_metric=product_metric
642
- )
643
- part_text_to_whole_image = -metric_pairwise_dist(
644
- part_text_flat, all_image_for_global_local, kappa, product_metric=product_metric
645
- )
646
- part_image_to_whole_text = part_image_to_whole_text * image_temp[:, None] * global_local_scale
647
- part_text_to_whole_image = part_text_to_whole_image * text_temp[:, None] * global_local_scale
648
-
649
- if loss_type in {"sigmoid", "siglip", "siglip_metric"}:
650
- bias = (
651
- part_image_to_whole_text.new_zeros(())
652
- if global_local_logit_bias is None
653
- else global_local_logit_bias.to(part_image_to_whole_text.device)
654
- )
655
- part_image_to_whole_text = part_image_to_whole_text + bias
656
- part_text_to_whole_image = part_text_to_whole_image + bias
657
-
658
- global_local_contrastive = 0.5 * (
659
- _contrastive_loss(part_image_to_whole_text, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
660
- + _contrastive_loss(part_text_to_whole_image, global_local_targets, part_weights, loss_type, sigmoid_negative_weight)
661
- )
662
-
663
- if global_local_angle_aux_weight > 0.0:
664
- if global_local_angle_aux_mode == "positive_hinge":
665
- positive_text = all_text_for_global_local.index_select(0, global_local_targets)
666
- positive_image = all_image_for_global_local.index_select(0, global_local_targets)
667
- global_local_angle_aux = 0.5 * (
668
- weighted_entailment_residual(
669
- specific=part_image_flat,
670
- general=positive_text,
671
- kappa=kappa,
672
- aperture_scale=global_local_angle_aux_aperture_scale,
673
- weights=part_weights,
674
- )
675
- + weighted_entailment_residual(
676
- specific=part_text_flat,
677
- general=positive_image,
678
- kappa=kappa,
679
- aperture_scale=global_local_angle_aux_aperture_scale,
680
- weights=part_weights,
681
- )
682
- )
683
- elif loss_type != "siglip":
684
- angle_scale = part_image_flat.new_tensor(float(global_local_angle_aux_scale))
685
- part_image_to_whole_text_angle = -metric_pairwise_oxy_angle(
686
- part_image_flat,
687
- all_text_for_global_local,
688
- kappa,
689
- product_metric=product_metric,
690
- ) * image_temp[:, None] * angle_scale
691
- part_text_to_whole_image_angle = -metric_pairwise_oxy_angle(
692
- part_text_flat,
693
- all_image_for_global_local,
694
- kappa,
695
- product_metric=product_metric,
696
- ) * text_temp[:, None] * angle_scale
697
- if loss_type in {"sigmoid", "siglip_metric"}:
698
- bias = (
699
- part_image_to_whole_text_angle.new_zeros(())
700
- if global_local_logit_bias is None
701
- else global_local_logit_bias.to(part_image_to_whole_text_angle.device)
702
- )
703
- part_image_to_whole_text_angle = part_image_to_whole_text_angle + bias
704
- part_text_to_whole_image_angle = part_text_to_whole_image_angle + bias
705
- global_local_angle_aux = 0.5 * (
706
- _contrastive_loss(
707
- part_image_to_whole_text_angle,
708
- global_local_targets,
709
- part_weights,
710
- loss_type,
711
- sigmoid_negative_weight,
712
- )
713
- + _contrastive_loss(
714
- part_text_to_whole_image_angle,
715
- global_local_targets,
716
- part_weights,
717
- loss_type,
718
- sigmoid_negative_weight,
719
- )
720
- )
721
-
722
- beta_cal = image_feats.new_zeros(())
723
- if beta_cal_weight > 0.0 and beta_cal_beta > 0.0:
724
- if part_group_ids is None or all_part_group_ids is None:
725
- raise ValueError("beta_cal requires part_group_ids and all_part_group_ids to be provided")
726
- beta_cal = 0.5 * (
727
- beta_cal_loss(
728
- part_image_logits,
729
- targets=part_targets,
730
- group_ids=part_group_ids,
731
- all_group_ids=all_part_group_ids,
732
- beta=beta_cal_beta,
733
- variant=beta_cal_variant,
734
- weights=part_weights,
735
- )
736
- + beta_cal_loss(
737
- part_text_logits,
738
- targets=part_targets,
739
- group_ids=part_group_ids,
740
- all_group_ids=all_part_group_ids,
741
- beta=beta_cal_beta,
742
- variant=beta_cal_variant,
743
- weights=part_weights,
744
- )
745
- )
746
-
747
- contrastive = (
748
- contrastive_global_weight * global_contrastive
749
- + contrastive_local_weight * local_contrastive
750
- + contrastive_global_local_weight * global_local_contrastive
751
- + global_local_angle_aux_weight * global_local_angle_aux
752
- + beta_cal_weight * beta_cal
753
- )
754
- return {
755
- "contrastive_loss": contrastive,
756
- "global_contrastive_loss": global_contrastive,
757
- "local_contrastive_loss": local_contrastive,
758
- "global_local_contrastive_loss": global_local_contrastive,
759
- "global_local_angle_aux_loss": global_local_angle_aux,
760
- "beta_cal_loss": beta_cal,
761
- }
762
-
763
-
764
- def _contrastive_loss(
765
- logits: Tensor,
766
- targets: Tensor,
767
- weights: Tensor | None,
768
- loss_type: str,
769
- sigmoid_negative_weight: float,
770
- ) -> Tensor:
771
- if loss_type == "ce":
772
- return contrastive_ce(logits, targets, weights)
773
- if loss_type == "sigmoid":
774
- return contrastive_sigmoid(logits, targets, weights, negative_weight=sigmoid_negative_weight)
775
- if loss_type in {"siglip", "siglip_metric"}:
776
- return contrastive_siglip(logits, targets, weights, negative_weight=sigmoid_negative_weight)
777
- raise ValueError(f"Unsupported contrastive loss {loss_type!r}")
778
-
779
-
780
- def uncha_entailment_losses(
781
- image_feats: Tensor,
782
- text_feats: Tensor,
783
- part_image_flat: Tensor,
784
- part_text_flat: Tensor,
785
- image_for_parts: Tensor,
786
- text_for_parts: Tensor,
787
- kappa: Tensor,
788
- inter_aperture_scale: float,
789
- intra_aperture_scale: float,
790
- piecewise_factor: float = 0.1,
791
- calibration_alpha: float = 10.0,
792
- stop_grad_calibration: bool = True,
793
- geometry: str = "lorentz",
794
- part_weights: Tensor | None = None,
795
- ) -> dict[str, Tensor]:
796
- text_image = piecewise_entailment_residual(
797
- specific=image_feats,
798
- general=text_feats,
799
- kappa=kappa,
800
- aperture_scale=inter_aperture_scale,
801
- factor=piecewise_factor,
802
- geometry=geometry,
803
- )
804
- text_image_entailment = 0.5 * text_image.mean()
805
-
806
- if part_image_flat.numel() == 0:
807
- zero = image_feats.new_zeros(())
808
- return {
809
- "entailment_loss": text_image_entailment,
810
- "text_image_entailment_loss": text_image_entailment,
811
- "part_text_image_entailment_loss": zero,
812
- "cross_image_entailment_loss": zero,
813
- "cross_text_entailment_loss": zero,
814
- "cross_image_calibration_loss": zero,
815
- "cross_text_calibration_loss": zero,
816
- }
817
-
818
- part_text_image = piecewise_entailment_residual(
819
- specific=part_image_flat,
820
- general=part_text_flat,
821
- kappa=kappa,
822
- aperture_scale=inter_aperture_scale,
823
- factor=piecewise_factor,
824
- geometry=geometry,
825
- )
826
- cross_image = piecewise_entailment_residual(
827
- specific=image_for_parts,
828
- general=part_image_flat,
829
- kappa=kappa,
830
- aperture_scale=intra_aperture_scale,
831
- factor=piecewise_factor,
832
- geometry=geometry,
833
- )
834
- cross_text = piecewise_entailment_residual(
835
- specific=text_for_parts,
836
- general=part_text_flat,
837
- kappa=kappa,
838
- aperture_scale=intra_aperture_scale,
839
- factor=piecewise_factor,
840
- geometry=geometry,
841
- )
842
-
843
- part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
844
- cross_image_entailment, cross_image_calibration = uncertainty_calibrated_entailment_loss(
845
- cross_image,
846
- embedding_uncertainty(part_image_flat),
847
- alpha=calibration_alpha,
848
- stop_grad=stop_grad_calibration,
849
- weights=part_weights,
850
- )
851
- cross_text_entailment, cross_text_calibration = uncertainty_calibrated_entailment_loss(
852
- cross_text,
853
- embedding_uncertainty(part_text_flat),
854
- alpha=calibration_alpha,
855
- stop_grad=stop_grad_calibration,
856
- weights=part_weights,
857
- )
858
-
859
- entailment = (
860
- text_image_entailment
861
- + part_text_image_entailment
862
- + 0.5 * (cross_image_entailment + cross_text_entailment)
863
- + cross_image_calibration
864
- + cross_text_calibration
865
- )
866
- return {
867
- "entailment_loss": entailment,
868
- "text_image_entailment_loss": text_image_entailment,
869
- "part_text_image_entailment_loss": part_text_image_entailment,
870
- "cross_image_entailment_loss": cross_image_entailment,
871
- "cross_text_entailment_loss": cross_text_entailment,
872
- "cross_image_calibration_loss": cross_image_calibration,
873
- "cross_text_calibration_loss": cross_text_calibration,
874
- }
875
-
876
-
877
- def uncha_argent_entailment_losses(
878
- image_feats: Tensor,
879
- text_feats: Tensor,
880
- part_image_flat: Tensor,
881
- part_text_flat: Tensor,
882
- image_for_parts: Tensor,
883
- text_for_parts: Tensor,
884
- kappa: Tensor,
885
- beta: float = 1.0,
886
- part_weights: Tensor | None = None,
887
- product_metric: str = "l1",
888
- aggregation: str = "uncha",
889
- ) -> dict[str, Tensor]:
890
- if aggregation not in {"uncha", "equal"}:
891
- raise ValueError("aggregation must be 'uncha' or 'equal'")
892
- text_image = argent_adaptive_entailment_residual(
893
- specific=image_feats,
894
- general=text_feats,
895
- kappa=kappa,
896
- adaptive_weight=False,
897
- beta=beta,
898
- product_metric=product_metric,
899
- )
900
- text_image_entailment = 0.5 * text_image.mean()
901
-
902
- if part_image_flat.numel() == 0:
903
- zero = image_feats.new_zeros(())
904
- norm_regularization = argent_norm_regularization_loss(image_feats, text_feats)
905
- return {
906
- "entailment_loss": text_image_entailment,
907
- "text_image_entailment_loss": text_image_entailment,
908
- "part_text_image_entailment_loss": zero,
909
- "cross_image_entailment_loss": zero,
910
- "cross_text_entailment_loss": zero,
911
- "cross_image_calibration_loss": zero,
912
- "cross_text_calibration_loss": zero,
913
- "norm_regularization_loss": norm_regularization,
914
- }
915
-
916
- part_text_image = argent_adaptive_entailment_residual(
917
- specific=part_image_flat,
918
- general=part_text_flat,
919
- kappa=kappa,
920
- adaptive_weight=False,
921
- beta=beta,
922
- product_metric=product_metric,
923
- )
924
- cross_image = argent_adaptive_entailment_residual(
925
- specific=image_for_parts,
926
- general=part_image_flat,
927
- kappa=kappa,
928
- adaptive_weight=True,
929
- beta=beta,
930
- product_metric=product_metric,
931
- )
932
- cross_text = argent_adaptive_entailment_residual(
933
- specific=text_for_parts,
934
- general=part_text_flat,
935
- kappa=kappa,
936
- adaptive_weight=True,
937
- beta=beta,
938
- product_metric=product_metric,
939
- )
940
-
941
- part_text_image_entailment = 0.5 * weighted_mean(part_text_image, part_weights)
942
- cross_image_entailment = 0.5 * weighted_mean(cross_image, part_weights)
943
- cross_text_entailment = 0.5 * weighted_mean(cross_text, part_weights)
944
- norm_regularization = argent_norm_regularization_loss(image_feats, text_feats, part_image_flat, part_text_flat)
945
- if aggregation == "equal":
946
- entailment = text_image_entailment + part_text_image_entailment + cross_image_entailment + cross_text_entailment
947
- else:
948
- entailment = text_image_entailment + part_text_image_entailment + 0.5 * (
949
- cross_image_entailment + cross_text_entailment
950
- )
951
- diagnostics = argent_entailment_diagnostics(
952
- image_feats=image_feats,
953
- text_feats=text_feats,
954
- part_image_flat=part_image_flat,
955
- part_text_flat=part_text_flat,
956
- image_for_parts=image_for_parts,
957
- text_for_parts=text_for_parts,
958
- kappa=kappa,
959
- product_metric=product_metric,
960
- )
961
-
962
- return {
963
- "entailment_loss": entailment,
964
- "text_image_entailment_loss": text_image_entailment,
965
- "part_text_image_entailment_loss": part_text_image_entailment,
966
- "cross_image_entailment_loss": cross_image_entailment,
967
- "cross_text_entailment_loss": cross_text_entailment,
968
- "cross_image_calibration_loss": image_feats.new_zeros(()),
969
- "cross_text_calibration_loss": image_feats.new_zeros(()),
970
- "norm_regularization_loss": norm_regularization,
971
- **diagnostics,
972
- }
973
-
974
-
975
- def hierarchical_beta_argent_entailment_losses(
976
- image_feats: Tensor,
977
- text_feats: Tensor,
978
- part_image_flat: Tensor,
979
- part_text_flat: Tensor,
980
- image_for_parts: Tensor,
981
- text_for_parts: Tensor,
982
- beta_query_image_feats: Tensor,
983
- beta_query_text_feats: Tensor,
984
- beta_query_owner: Tensor,
985
- beta_query_parent: Tensor,
986
- beta_query_weight: Tensor,
987
- kappa: Tensor,
988
- beta_query_source_part: Tensor | None = None,
989
- beta: float = 1.0,
990
- part_weights: Tensor | None = None,
991
- product_metric: str = "l1",
992
- aggregation: str = "uncha",
993
- ) -> dict[str, Tensor]:
994
- base = uncha_argent_entailment_losses(
995
- image_feats=image_feats,
996
- text_feats=text_feats,
997
- part_image_flat=part_image_flat,
998
- part_text_flat=part_text_flat,
999
- image_for_parts=image_for_parts,
1000
- text_for_parts=text_for_parts,
1001
- kappa=kappa,
1002
- beta=beta,
1003
- part_weights=part_weights,
1004
- product_metric=product_metric,
1005
- aggregation=aggregation,
1006
- )
1007
- if beta_query_image_feats.numel() == 0:
1008
- return {
1009
- **base,
1010
- "hier_beta_query_text_entailment_loss": image_feats.new_zeros(()),
1011
- "hier_beta_visual_entailment_loss": image_feats.new_zeros(()),
1012
- "hier_beta_text_entailment_loss": image_feats.new_zeros(()),
1013
- "hier_beta_sourcepart_visual_entailment_loss": image_feats.new_zeros(()),
1014
- "hier_beta_sourcepart_text_entailment_loss": image_feats.new_zeros(()),
1015
- "hier_beta_query_count": beta_query_owner.new_tensor(0),
1016
- "hier_beta_sourcepart_query_count": beta_query_owner.new_tensor(0),
1017
- }
1018
-
1019
- query_owner = beta_query_owner.to(device=image_feats.device, dtype=torch.long)
1020
- query_weights = beta_query_weight.to(device=image_feats.device, dtype=torch.float32).clamp_min(0.0)
1021
- if query_weights.numel() != beta_query_image_feats.size(0):
1022
- raise ValueError("beta_query_weight must have one value per beta query")
1023
- query_weights = query_weights / query_weights.mean().clamp_min(torch.finfo(query_weights.dtype).eps)
1024
-
1025
- query_text = argent_adaptive_entailment_residual(
1026
- specific=beta_query_image_feats,
1027
- general=beta_query_text_feats,
1028
- kappa=kappa,
1029
- adaptive_weight=False,
1030
- beta=beta,
1031
- product_metric=product_metric,
1032
- )
1033
- visual_hierarchy = argent_adaptive_entailment_residual(
1034
- specific=image_feats.index_select(0, query_owner),
1035
- general=beta_query_image_feats,
1036
- kappa=kappa,
1037
- adaptive_weight=True,
1038
- beta=beta,
1039
- product_metric=product_metric,
1040
- )
1041
- query_text_entailment = 0.5 * weighted_mean(query_text, query_weights)
1042
- visual_entailment = 0.5 * weighted_mean(visual_hierarchy, query_weights)
1043
-
1044
- parent = beta_query_parent.to(device=image_feats.device, dtype=torch.long)
1045
- parent_mask = (parent >= 0) & (parent < beta_query_text_feats.size(0)) & (query_weights > 0.0)
1046
- if bool(parent_mask.any()):
1047
- child_text = beta_query_text_feats[parent_mask]
1048
- parent_text = beta_query_text_feats[parent[parent_mask]]
1049
- text_hierarchy = argent_adaptive_entailment_residual(
1050
- specific=parent_text,
1051
- general=child_text,
1052
- kappa=kappa,
1053
- adaptive_weight=True,
1054
- beta=beta,
1055
- product_metric=product_metric,
1056
- )
1057
- text_entailment = 0.5 * weighted_mean(text_hierarchy, query_weights[parent_mask])
1058
- else:
1059
- text_entailment = image_feats.new_zeros(())
1060
-
1061
- sourcepart_visual_entailment = image_feats.new_zeros(())
1062
- sourcepart_text_entailment = image_feats.new_zeros(())
1063
- sourcepart_query_count = beta_query_owner.new_tensor(0)
1064
- if beta_query_source_part is not None and part_image_flat.numel() > 0:
1065
- source_part = beta_query_source_part.to(device=image_feats.device, dtype=torch.long)
1066
- if source_part.numel() != beta_query_image_feats.size(0):
1067
- raise ValueError("beta_query_source_part must have one value per beta query")
1068
- source_mask = (
1069
- (source_part >= 0)
1070
- & (source_part < part_image_flat.size(0))
1071
- & (query_weights > 0.0)
1072
- )
1073
- if bool(source_mask.any()):
1074
- source_indices = source_part[source_mask]
1075
- sourcepart_visual = argent_adaptive_entailment_residual(
1076
- specific=part_image_flat.index_select(0, source_indices),
1077
- general=beta_query_image_feats[source_mask],
1078
- kappa=kappa,
1079
- adaptive_weight=True,
1080
- beta=beta,
1081
- product_metric=product_metric,
1082
- )
1083
- sourcepart_text = argent_adaptive_entailment_residual(
1084
- specific=part_text_flat.index_select(0, source_indices),
1085
- general=beta_query_text_feats[source_mask],
1086
- kappa=kappa,
1087
- adaptive_weight=True,
1088
- beta=beta,
1089
- product_metric=product_metric,
1090
- )
1091
- source_weights = query_weights[source_mask]
1092
- sourcepart_visual_entailment = 0.5 * weighted_mean(sourcepart_visual, source_weights)
1093
- sourcepart_text_entailment = 0.5 * weighted_mean(sourcepart_text, source_weights)
1094
- sourcepart_query_count = beta_query_owner.new_tensor(int(source_mask.sum().item()))
1095
-
1096
- norm_regularization = argent_norm_regularization_loss(
1097
- image_feats,
1098
- text_feats,
1099
- part_image_flat,
1100
- part_text_flat,
1101
- beta_query_image_feats,
1102
- beta_query_text_feats,
1103
- )
1104
- sourcepart_entailment = 0.5 * (sourcepart_visual_entailment + sourcepart_text_entailment)
1105
- query_entailment = query_text_entailment + 0.5 * (visual_entailment + text_entailment) + sourcepart_entailment
1106
- return {
1107
- **base,
1108
- "entailment_loss": base["entailment_loss"] + query_entailment,
1109
- "norm_regularization_loss": norm_regularization,
1110
- "hier_beta_query_text_entailment_loss": query_text_entailment,
1111
- "hier_beta_visual_entailment_loss": visual_entailment,
1112
- "hier_beta_text_entailment_loss": text_entailment,
1113
- "hier_beta_sourcepart_visual_entailment_loss": sourcepart_visual_entailment,
1114
- "hier_beta_sourcepart_text_entailment_loss": sourcepart_text_entailment,
1115
- "hier_beta_query_count": beta_query_owner.new_tensor(beta_query_owner.numel()),
1116
- "hier_beta_sourcepart_query_count": sourcepart_query_count,
1117
- }
1118
-
1119
-
1120
- def argent_entailment_diagnostics(
1121
- image_feats: Tensor,
1122
- text_feats: Tensor,
1123
- part_image_flat: Tensor,
1124
- part_text_flat: Tensor,
1125
- image_for_parts: Tensor,
1126
- text_for_parts: Tensor,
1127
- kappa: Tensor,
1128
- product_metric: str = "l1",
1129
- ) -> dict[str, Tensor]:
1130
- zero = image_feats.new_zeros(())
1131
-
1132
- def angle_mean(specific: Tensor, general: Tensor) -> Tensor:
1133
- if specific.numel() == 0:
1134
- return zero
1135
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1136
- if angles.dim() == 2:
1137
- angles = angles.mean(dim=-1)
1138
- return angles.detach().mean()
1139
-
1140
- def pent_mean(specific: Tensor, general: Tensor) -> Tensor:
1141
- if specific.numel() == 0:
1142
- return zero
1143
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1144
- if angles.dim() == 2:
1145
- angles = angles.mean(dim=-1)
1146
- scores = torch.clamp(1.0 - (2.0 * angles / math.pi), min=0.0, max=1.0)
1147
- return scores.detach().mean()
1148
-
1149
- def distance_mean(specific: Tensor, general: Tensor) -> Tensor:
1150
- if specific.numel() == 0:
1151
- return zero
1152
- return lorentz_dist(specific, general, kappa, product_metric=product_metric).detach().mean()
1153
-
1154
- def adaptive_weight_mean(specific: Tensor, general: Tensor) -> Tensor:
1155
- if specific.numel() == 0:
1156
- return zero
1157
- weights = 1.0 - torch.exp(-lorentz_dist(specific, general, kappa, product_metric=product_metric))
1158
- return weights.detach().mean()
1159
-
1160
- def space_norm_mean(embedding: Tensor) -> Tensor:
1161
- if embedding.numel() == 0:
1162
- return zero
1163
- return torch.linalg.norm(_space_components(embedding).float(), dim=-1).detach().mean()
1164
-
1165
- return {
1166
- "argent_text_image_angle_mean": angle_mean(image_feats, text_feats),
1167
- "argent_text_image_pent_mean": pent_mean(image_feats, text_feats),
1168
- "argent_part_text_image_angle_mean": angle_mean(part_image_flat, part_text_flat),
1169
- "argent_part_text_image_pent_mean": pent_mean(part_image_flat, part_text_flat),
1170
- "argent_cross_image_angle_mean": angle_mean(image_for_parts, part_image_flat),
1171
- "argent_cross_image_pent_mean": pent_mean(image_for_parts, part_image_flat),
1172
- "argent_cross_image_distance_mean": distance_mean(image_for_parts, part_image_flat),
1173
- "argent_cross_image_adaptive_weight_mean": adaptive_weight_mean(image_for_parts, part_image_flat),
1174
- "argent_cross_text_angle_mean": angle_mean(text_for_parts, part_text_flat),
1175
- "argent_cross_text_pent_mean": pent_mean(text_for_parts, part_text_flat),
1176
- "argent_cross_text_distance_mean": distance_mean(text_for_parts, part_text_flat),
1177
- "argent_cross_text_adaptive_weight_mean": adaptive_weight_mean(text_for_parts, part_text_flat),
1178
- "argent_image_space_norm_mean": space_norm_mean(image_feats),
1179
- "argent_text_space_norm_mean": space_norm_mean(text_feats),
1180
- "argent_part_image_space_norm_mean": space_norm_mean(part_image_flat),
1181
- "argent_part_text_space_norm_mean": space_norm_mean(part_text_flat),
1182
- }
1183
-
1184
-
1185
- def part_quality_weights(
1186
- image_for_parts: Tensor,
1187
- text_for_parts: Tensor,
1188
- part_image_flat: Tensor,
1189
- part_text_flat: Tensor,
1190
- part_owner: Tensor,
1191
- batch_size: int,
1192
- kappa: Tensor,
1193
- mode: str,
1194
- topk: int = 5,
1195
- temperature: float = 4.0,
1196
- product_metric: str = "l1",
1197
- ) -> tuple[Tensor | None, Tensor, Tensor]:
1198
- if mode not in {"none", "soft", "topk"}:
1199
- raise ValueError(f"Unsupported part quality mode {mode!r}; expected 'none', 'soft', or 'topk'")
1200
- if mode == "none" or part_image_flat.numel() == 0:
1201
- empty = part_image_flat.new_zeros((part_image_flat.size(0),))
1202
- return None, empty, empty
1203
-
1204
- with torch.no_grad():
1205
- image_parent = torch.exp(-lorentz_dist(part_image_flat, image_for_parts, kappa, product_metric=product_metric))
1206
- text_parent = torch.exp(-lorentz_dist(part_text_flat, text_for_parts, kappa, product_metric=product_metric))
1207
- image_text = torch.exp(-lorentz_dist(part_image_flat, part_text_flat, kappa, product_metric=product_metric))
1208
- scores = torch.stack([image_parent, text_parent, image_text]).mean(dim=0).clamp_min(0.0)
1209
-
1210
- if mode == "soft":
1211
- weights = _owner_softmax_weights(scores, part_owner, batch_size, temperature)
1212
- else:
1213
- weights = _owner_topk_weights(scores, part_owner, batch_size, topk)
1214
- weights = weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
1215
- return weights, scores, (weights > 0.0).to(dtype=scores.dtype)
1216
-
1217
-
1218
- def _owner_softmax_weights(scores: Tensor, part_owner: Tensor, batch_size: int, temperature: float) -> Tensor:
1219
- weights = torch.zeros_like(scores)
1220
- for owner in range(batch_size):
1221
- mask = part_owner == owner
1222
- if not bool(mask.any()):
1223
- continue
1224
- owner_scores = scores[mask]
1225
- owner_weights = torch.softmax(owner_scores * temperature, dim=0) * owner_scores.numel()
1226
- weights[mask] = owner_weights
1227
- return weights
1228
-
1229
-
1230
- def _owner_topk_weights(scores: Tensor, part_owner: Tensor, batch_size: int, topk: int) -> Tensor:
1231
- if topk <= 0:
1232
- raise ValueError("topk must be positive for top-k part quality weighting")
1233
- weights = torch.zeros_like(scores)
1234
- for owner in range(batch_size):
1235
- indices = torch.nonzero(part_owner == owner, as_tuple=False).flatten()
1236
- if indices.numel() == 0:
1237
- continue
1238
- keep = min(topk, indices.numel())
1239
- selected = indices[scores[indices].topk(k=keep).indices]
1240
- weights[selected] = 1.0
1241
- return weights
1242
-
1243
-
1244
- def argent_adaptive_entailment_residual(
1245
- specific: Tensor,
1246
- general: Tensor,
1247
- kappa: Tensor,
1248
- adaptive_weight: bool,
1249
- beta: float = 1.0,
1250
- product_metric: str = "l1",
1251
- ) -> Tensor:
1252
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1253
- if angles.dim() == 2:
1254
- angles = angles.mean(dim=-1)
1255
- if adaptive_weight:
1256
- weights = 1.0 - torch.exp(
1257
- -lorentz_dist(specific=specific, general=general, kappa=kappa, product_metric=product_metric)
1258
- )
1259
- angles = angles * weights
1260
- return F.huber_loss(angles, torch.zeros_like(angles), delta=beta, reduction="none")
1261
-
1262
-
1263
- def lorentz_dist(specific: Tensor, general: Tensor, kappa: Tensor, product_metric: str = "l1") -> Tensor:
1264
- return paired_dist(specific, general, kappa, product_metric=product_metric)
1265
-
1266
-
1267
- def argent_norm_regularization_loss(*embeddings: Tensor, eps: float = 1e-6) -> Tensor:
1268
- losses = []
1269
- for embedding in embeddings:
1270
- if embedding.numel() == 0:
1271
- continue
1272
- space = _space_components(embedding)
1273
- space_norm = torch.linalg.norm(space.float(), dim=-1).clamp_min(eps)
1274
- losses.append((space_norm.square() - torch.log(space_norm)).mean())
1275
- if not losses:
1276
- raise ValueError("argent_norm_regularization_loss requires at least one non-empty embedding tensor")
1277
- return torch.stack(losses).mean()
1278
-
1279
-
1280
- def piecewise_entailment_residual(
1281
- specific: Tensor,
1282
- general: Tensor,
1283
- kappa: Tensor,
1284
- aperture_scale: float,
1285
- factor: float = 0.1,
1286
- geometry: str = "lorentz",
1287
- ) -> Tensor:
1288
- if geometry == "lorentz":
1289
- angles = factor_oxy_angle(specific=specific, general=general, kappa=kappa)
1290
- apertures = factor_half_aperture(general=general, kappa=kappa)
1291
- elif geometry == "euclidean":
1292
- angles = euclidean_angle(specific=specific, general=general)
1293
- apertures = euclidean_half_aperture(general=general, aperture_scale=aperture_scale)
1294
- aperture_scale = 1.0
1295
- else:
1296
- raise ValueError(f"Unsupported entailment geometry {geometry!r}; expected 'lorentz' or 'euclidean'")
1297
- residual = angles - aperture_scale * apertures
1298
- loss = torch.where(residual > 0.0, residual + factor * angles, factor * angles)
1299
- return loss.mean(dim=-1) if loss.dim() == 2 else loss
1300
-
1301
-
1302
- def euclidean_angle(specific: Tensor, general: Tensor, eps: float = 1e-6) -> Tensor:
1303
- specific_space = _space_components(specific).float()
1304
- general_space = _space_components(general).float()
1305
- numerator = (specific_space * general_space).sum(dim=-1)
1306
- denominator = torch.linalg.norm(specific_space, dim=-1) * torch.linalg.norm(general_space, dim=-1)
1307
- dtype_eps = torch.finfo(specific_space.dtype).eps
1308
- angle_eps = max(eps, 16.0 * dtype_eps)
1309
- cosine = (numerator / denominator.clamp_min(angle_eps)).clamp(min=-1.0 + angle_eps, max=1.0 - angle_eps)
1310
- return torch.acos(cosine)
1311
-
1312
-
1313
- def euclidean_half_aperture(general: Tensor, aperture_scale: float, eps: float = 1e-8) -> Tensor:
1314
- general_norm = torch.linalg.norm(_space_components(general).float(), dim=-1).clamp_min(eps)
1315
- return torch.atan(torch.as_tensor(aperture_scale, device=general.device, dtype=general.dtype) / general_norm)
1316
-
1317
-
1318
- def aggregate_part_consistency_loss(
1319
- image_feats: Tensor,
1320
- text_feats: Tensor,
1321
- part_image_flat: Tensor,
1322
- part_text_flat: Tensor,
1323
- part_owner: Tensor,
1324
- part_weights: Tensor | None = None,
1325
- ) -> Tensor:
1326
- if part_image_flat.numel() == 0:
1327
- return image_feats.new_zeros(())
1328
-
1329
- batch_size = image_feats.size(0)
1330
- image_space = _space_components(image_feats).reshape(batch_size, -1).float()
1331
- text_space = _space_components(text_feats).reshape(batch_size, -1).float()
1332
- part_image_space = _space_components(part_image_flat).reshape(part_image_flat.size(0), -1).float()
1333
- part_text_space = _space_components(part_text_flat).reshape(part_text_flat.size(0), -1).float()
1334
- if part_weights is None:
1335
- counts = torch.bincount(part_owner, minlength=batch_size).to(device=image_feats.device, dtype=image_space.dtype)
1336
- denom = counts
1337
- valid = counts > 0
1338
- weights = part_image_space.new_ones((part_image_space.size(0),))
1339
- else:
1340
- weights = part_weights.to(device=image_feats.device, dtype=image_space.dtype).flatten()
1341
- if weights.numel() != part_owner.numel():
1342
- raise ValueError("part_weights must have the same number of elements as part_owner when provided")
1343
- denom = torch.zeros(batch_size, device=image_feats.device, dtype=image_space.dtype)
1344
- denom.index_add_(0, part_owner, weights)
1345
- valid = denom > 0
1346
-
1347
- image_agg = image_space.new_zeros(image_space.shape)
1348
- text_agg = text_space.new_zeros(text_space.shape)
1349
- image_agg.index_add_(0, part_owner, part_image_space * weights[:, None])
1350
- text_agg.index_add_(0, part_owner, part_text_space * weights[:, None])
1351
- image_agg = image_agg[valid] / denom[valid, None].clamp_min(1.0)
1352
- text_agg = text_agg[valid] / denom[valid, None].clamp_min(1.0)
1353
-
1354
- image_space = image_space[valid]
1355
- text_space = text_space[valid]
1356
- return 0.25 * (
1357
- cosine_residual(image_agg, image_space)
1358
- + cosine_residual(text_agg, text_space)
1359
- + cosine_residual(image_agg, text_space)
1360
- + cosine_residual(text_agg, image_space)
1361
- )
1362
-
1363
-
1364
- def cosine_residual(x: Tensor, y: Tensor) -> Tensor:
1365
- return (1.0 - F.cosine_similarity(x, y, dim=-1)).mean()
1366
-
1367
-
1368
- def uncertainty_calibrated_entailment_loss(
1369
- entail_residual: Tensor,
1370
- log_uncertainty: Tensor,
1371
- alpha: float = 10.0,
1372
- stop_grad: bool = True,
1373
- weights: Tensor | None = None,
1374
- ) -> tuple[Tensor, Tensor]:
1375
- mean_loss = 0.5 * entail_residual
1376
- uncertainty = torch.exp(log_uncertainty).clamp(min=1e-6, max=1e6)
1377
- residual = entail_residual.detach() if stop_grad else entail_residual
1378
- scaled_entail = residual / (uncertainty + 1e-6)
1379
- calibration_term = 0.5 * scaled_entail + 0.5 * log_uncertainty
1380
- prob = torch.softmax(log_uncertainty.flatten(), dim=0)
1381
- entropy = -(prob * torch.log(prob + 1e-8)).sum()
1382
- calibration_loss = alpha * (calibration_term + entropy)
1383
- return weighted_mean(mean_loss, weights), weighted_mean(calibration_loss, weights)
1384
-
1385
-
1386
- def embedding_uncertainty(x: Tensor) -> Tensor:
1387
- space = _space_components(x)
1388
- norm = torch.linalg.norm(space.float(), dim=-1)
1389
- if norm.dim() > 1:
1390
- norm = norm.mean(dim=-1)
1391
- return F.softplus(-norm)
1392
-
1393
-
1394
- def _space_components(x: Tensor) -> Tensor:
1395
- return x[..., 1:] if x.shape[-1] > 1 else x
1396
-
1397
-
1398
- def _flatten_valid_parts(part_image_feats: Tensor, part_text_feats: Tensor, part_mask: Tensor, targets: Tensor) -> tuple[Tensor, Tensor, Tensor]:
1399
- part_targets = targets[:, None].expand_as(part_mask)[part_mask]
1400
- return part_image_feats[part_mask], part_text_feats[part_mask], part_targets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/objectives.py DELETED
@@ -1,580 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections.abc import Mapping
4
-
5
- import torch
6
- from torch import Tensor, nn
7
-
8
- from hyper3_clip.models.lorentz import log_map0, metric_pairwise_dist
9
- from hyper3_clip.models.losses import (
10
- aggregate_part_consistency_loss,
11
- contrastive_ce,
12
- gramian_volume_loss,
13
- hierarchical_beta_argent_entailment_losses,
14
- packed_part_contrastive_loss,
15
- packed_part_entailment_loss,
16
- part_quality_weights,
17
- radius_order_hinge,
18
- uncha_argent_entailment_losses,
19
- uncha_contrastive_losses,
20
- uncha_entailment_losses,
21
- )
22
- from hyper3_clip.training.distributed import gather_variable_many_with_grad, gather_variable_no_grad, get_rank
23
-
24
-
25
- class HyCoCLIPObjective(nn.Module):
26
- def __init__(
27
- self,
28
- entail_weight: float,
29
- inter_aperture_scale: float,
30
- intra_aperture_scale: float,
31
- product_metric: str = "l1",
32
- ) -> None:
33
- super().__init__()
34
- self.entail_weight = entail_weight
35
- self.inter_aperture_scale = inter_aperture_scale
36
- self.intra_aperture_scale = intra_aperture_scale
37
- self.product_metric = product_metric
38
-
39
- def forward(self, embeddings: Mapping[str, Tensor], logit_scale: Tensor) -> dict[str, Tensor]:
40
- part_owner = embeddings["part_owner"].long()
41
- part_count = part_owner.new_tensor(part_owner.numel())
42
- contrastive = packed_part_contrastive_loss(
43
- image_feats=embeddings["image_feats"],
44
- text_feats=embeddings["text_feats"],
45
- part_image_feats=embeddings["part_image_feats"],
46
- part_text_feats=embeddings["part_text_feats"],
47
- part_owner=part_owner,
48
- kappa=embeddings["kappa"],
49
- logit_scale=logit_scale,
50
- all_image_feats=embeddings.get("all_image_feats"),
51
- all_text_feats=embeddings.get("all_text_feats"),
52
- targets=embeddings.get("targets"),
53
- )
54
- entailment = packed_part_entailment_loss(
55
- image_feats=embeddings["image_feats"],
56
- text_feats=embeddings["text_feats"],
57
- part_image_feats=embeddings["part_image_feats"],
58
- part_text_feats=embeddings["part_text_feats"],
59
- part_owner=part_owner,
60
- kappa=embeddings["kappa"],
61
- inter_aperture_scale=self.inter_aperture_scale,
62
- intra_aperture_scale=self.intra_aperture_scale,
63
- )
64
- total = contrastive + self.entail_weight * entailment
65
- return {
66
- "loss": total,
67
- "contrastive_loss": contrastive,
68
- "entailment_loss": entailment,
69
- "part_count": part_count,
70
- }
71
-
72
-
73
- class UNCHAObjective(nn.Module):
74
- def __init__(
75
- self,
76
- entail_weight: float,
77
- inter_aperture_scale: float,
78
- intra_aperture_scale: float,
79
- piecewise_factor: float = 0.1,
80
- calibration_alpha: float = 10.0,
81
- stop_grad_calibration: bool = True,
82
- entailment_geometry: str = "lorentz",
83
- aggregate_weight: float = 0.0,
84
- entailment_loss: str = "piecewise",
85
- argent_beta: float = 1.0,
86
- argent_norm_weight: float = 0.0,
87
- argent_aux_weight: float = 0.5,
88
- argent_aggregation: str = "uncha",
89
- part_weight_power: float = 0.0,
90
- product_metric: str = "l1",
91
- contrastive_loss: str = "ce",
92
- sigmoid_negative_weight: float = 1.0,
93
- part_quality_mode: str = "none",
94
- part_quality_topk: int = 5,
95
- part_quality_temperature: float = 4.0,
96
- contrastive_global_weight: float = 1.0,
97
- contrastive_local_weight: float = 1.0,
98
- contrastive_global_local_weight: float = 1.0,
99
- beta_cal_beta: float = 0.0,
100
- beta_cal_variant: str = "ce",
101
- beta_cal_weight: float = 0.0,
102
- himo_component_weight: float = 0.0,
103
- global_local_mode: str = "repeat",
104
- global_local_metric: str = "distance",
105
- global_local_angle_aux_weight: float = 0.0,
106
- global_local_angle_aux_mode: str = "contrastive",
107
- global_local_angle_aux_scale: float = 5.5,
108
- global_local_angle_aux_aperture_scale: float = 1.0,
109
- radius_order_weight: float = 0.0,
110
- radius_order_margin: float = 0.0,
111
- gramian_align_weight: float = 0.0,
112
- ) -> None:
113
- super().__init__()
114
- if entailment_loss not in {
115
- "piecewise",
116
- "argent",
117
- "piecewise_argent",
118
- "hier_beta_argent",
119
- "hier_beta_sourcepart_argent",
120
- }:
121
- raise ValueError(
122
- f"Unsupported UNCHA entailment loss {entailment_loss!r}; "
123
- "expected 'piecewise', 'argent', 'piecewise_argent', 'hier_beta_argent', "
124
- "or 'hier_beta_sourcepart_argent'"
125
- )
126
- if contrastive_loss not in {"ce", "sigmoid", "siglip", "siglip_metric"}:
127
- raise ValueError("contrastive_loss must be 'ce', 'sigmoid', 'siglip', or 'siglip_metric'")
128
- if beta_cal_variant not in {"ce", "bce"}:
129
- raise ValueError("beta_cal_variant must be 'ce' or 'bce'")
130
- if argent_aggregation not in {"uncha", "equal"}:
131
- raise ValueError("argent_aggregation must be 'uncha' or 'equal'")
132
- if part_quality_mode not in {"none", "soft", "topk"}:
133
- raise ValueError("part_quality_mode must be 'none', 'soft', or 'topk'")
134
- if global_local_mode not in {"repeat", "inbatch"}:
135
- raise ValueError("global_local_mode must be 'repeat' or 'inbatch'")
136
- if global_local_metric not in {"distance", "angle"}:
137
- raise ValueError("global_local_metric must be 'distance' or 'angle'")
138
- if global_local_angle_aux_mode not in {"contrastive", "positive_hinge"}:
139
- raise ValueError("global_local_angle_aux_mode must be 'contrastive' or 'positive_hinge'")
140
- if global_local_angle_aux_weight < 0.0:
141
- raise ValueError("global_local_angle_aux_weight must be non-negative")
142
- if global_local_angle_aux_scale <= 0.0:
143
- raise ValueError("global_local_angle_aux_scale must be positive")
144
- if global_local_angle_aux_aperture_scale <= 0.0:
145
- raise ValueError("global_local_angle_aux_aperture_scale must be positive")
146
- if part_quality_topk <= 0:
147
- raise ValueError("part_quality_topk must be positive")
148
- self.entail_weight = entail_weight
149
- self.inter_aperture_scale = inter_aperture_scale
150
- self.intra_aperture_scale = intra_aperture_scale
151
- self.piecewise_factor = piecewise_factor
152
- self.calibration_alpha = calibration_alpha
153
- self.stop_grad_calibration = stop_grad_calibration
154
- self.entailment_geometry = entailment_geometry
155
- self.aggregate_weight = aggregate_weight
156
- self.entailment_loss = entailment_loss
157
- self.argent_beta = argent_beta
158
- self.argent_norm_weight = argent_norm_weight
159
- self.argent_aux_weight = argent_aux_weight
160
- self.argent_aggregation = argent_aggregation
161
- self.part_weight_power = part_weight_power
162
- self.product_metric = product_metric
163
- self.contrastive_loss = contrastive_loss
164
- self.sigmoid_negative_weight = sigmoid_negative_weight
165
- self.part_quality_mode = part_quality_mode
166
- self.part_quality_topk = part_quality_topk
167
- self.part_quality_temperature = part_quality_temperature
168
- self.contrastive_global_weight = float(contrastive_global_weight)
169
- self.contrastive_local_weight = float(contrastive_local_weight)
170
- self.contrastive_global_local_weight = float(contrastive_global_local_weight)
171
- self.beta_cal_beta = float(beta_cal_beta)
172
- self.beta_cal_variant = beta_cal_variant
173
- self.beta_cal_weight = float(beta_cal_weight)
174
- self.himo_component_weight = float(himo_component_weight)
175
- self.global_local_mode = global_local_mode
176
- self.global_local_metric = global_local_metric
177
- self.global_local_angle_aux_weight = float(global_local_angle_aux_weight)
178
- self.global_local_angle_aux_mode = global_local_angle_aux_mode
179
- self.global_local_angle_aux_scale = float(global_local_angle_aux_scale)
180
- self.global_local_angle_aux_aperture_scale = float(global_local_angle_aux_aperture_scale)
181
- self.radius_order_weight = float(radius_order_weight)
182
- self.radius_order_margin = float(radius_order_margin)
183
- self.gramian_align_weight = float(gramian_align_weight)
184
-
185
- def forward(self, embeddings: Mapping[str, Tensor], logit_scales: Mapping[str, Tensor]) -> dict[str, Tensor]:
186
- part_owner = embeddings["part_owner"].long()
187
- part_count = part_owner.new_tensor(part_owner.numel())
188
- part_image_flat = embeddings["part_image_feats"]
189
- part_text_flat = embeddings["part_text_feats"]
190
- image_feats = embeddings["image_feats"]
191
- text_feats = embeddings["text_feats"]
192
-
193
- if part_owner.numel() == 0:
194
- image_for_parts = image_feats.new_zeros((0, image_feats.size(-1)))
195
- text_for_parts = text_feats.new_zeros((0, text_feats.size(-1)))
196
- else:
197
- image_for_parts = image_feats[part_owner]
198
- text_for_parts = text_feats[part_owner]
199
- count_part_weights = _part_weights(part_owner, image_feats.size(0), self.part_weight_power)
200
- quality_part_weights, quality_scores, quality_keep = part_quality_weights(
201
- image_for_parts=image_for_parts,
202
- text_for_parts=text_for_parts,
203
- part_image_flat=part_image_flat,
204
- part_text_flat=part_text_flat,
205
- part_owner=part_owner,
206
- batch_size=image_feats.size(0),
207
- kappa=embeddings["kappa"],
208
- mode=self.part_quality_mode,
209
- topk=self.part_quality_topk,
210
- temperature=self.part_quality_temperature,
211
- product_metric=self.product_metric,
212
- )
213
- part_weights = _combine_part_weights(count_part_weights, quality_part_weights)
214
-
215
- needs_repeated_global_local = self.global_local_mode == "repeat" and self.contrastive_global_local_weight != 0.0
216
- part_feature_tensors = [part_image_flat, part_text_flat]
217
- if needs_repeated_global_local:
218
- part_feature_tensors.extend([image_for_parts, text_for_parts])
219
- gathered_part_features, part_counts = gather_variable_many_with_grad(part_feature_tensors)
220
- all_part_image_feats = gathered_part_features[0]
221
- all_part_text_feats = gathered_part_features[1]
222
- all_image_for_parts = gathered_part_features[2] if needs_repeated_global_local else None
223
- all_text_for_parts = gathered_part_features[3] if needs_repeated_global_local else None
224
- image_euc_feats = embeddings.get("image_euc_feats")
225
- text_euc_feats = embeddings.get("text_euc_feats")
226
- part_image_euc_flat = embeddings.get("part_image_euc_feats")
227
- part_text_euc_flat = embeddings.get("part_text_euc_feats")
228
- image_for_parts_euc = None
229
- text_for_parts_euc = None
230
- all_part_image_euc_feats = None
231
- all_part_text_euc_feats = None
232
- all_image_for_parts_euc = None
233
- all_text_for_parts_euc = None
234
- if (
235
- image_euc_feats is not None
236
- and text_euc_feats is not None
237
- and part_owner.numel() > 0
238
- and needs_repeated_global_local
239
- ):
240
- image_for_parts_euc = image_euc_feats[part_owner]
241
- text_for_parts_euc = text_euc_feats[part_owner]
242
- if part_image_euc_flat is not None and part_text_euc_flat is not None:
243
- euc_feature_tensors = [part_image_euc_flat, part_text_euc_flat]
244
- if image_for_parts_euc is not None and text_for_parts_euc is not None:
245
- euc_feature_tensors.extend([image_for_parts_euc, text_for_parts_euc])
246
- gathered_euc_features, _ = gather_variable_many_with_grad(euc_feature_tensors)
247
- all_part_image_euc_feats = gathered_euc_features[0]
248
- all_part_text_euc_feats = gathered_euc_features[1]
249
- if image_for_parts_euc is not None and text_for_parts_euc is not None:
250
- all_image_for_parts_euc = gathered_euc_features[2]
251
- all_text_for_parts_euc = gathered_euc_features[3]
252
- if "targets" not in embeddings:
253
- raise ValueError("UNCHAObjective requires 'targets' to compute group-aware losses")
254
- global_targets = embeddings["targets"]
255
- part_group_ids = global_targets[part_owner] if part_owner.numel() > 0 else part_owner.new_zeros((0,))
256
- all_part_group_ids = None
257
- if self.beta_cal_weight > 0.0 and self.beta_cal_beta > 0.0:
258
- all_part_group_ids, _ = gather_variable_no_grad(part_group_ids)
259
- part_offset = part_counts[: get_rank()].sum() if part_counts.numel() > 1 else part_counts.new_zeros(())
260
- part_targets = torch.arange(part_image_flat.size(0), device=part_image_flat.device) + part_offset
261
-
262
- contrastive = uncha_contrastive_losses(
263
- image_feats=image_feats,
264
- text_feats=text_feats,
265
- part_image_flat=part_image_flat,
266
- part_text_flat=part_text_flat,
267
- image_for_parts=image_for_parts,
268
- text_for_parts=text_for_parts,
269
- image_euc_feats=image_euc_feats,
270
- text_euc_feats=text_euc_feats,
271
- part_image_euc_flat=part_image_euc_flat,
272
- part_text_euc_flat=part_text_euc_flat,
273
- image_for_parts_euc=image_for_parts_euc,
274
- text_for_parts_euc=text_for_parts_euc,
275
- kappa=embeddings["kappa"],
276
- global_logit_scale=logit_scales["global"],
277
- local_logit_scale=logit_scales["local"],
278
- global_local_logit_scale=logit_scales["global_local"],
279
- all_image_feats=embeddings.get("all_image_feats"),
280
- all_text_feats=embeddings.get("all_text_feats"),
281
- all_part_image_feats=all_part_image_feats,
282
- all_part_text_feats=all_part_text_feats,
283
- all_image_for_parts=all_image_for_parts,
284
- all_text_for_parts=all_text_for_parts,
285
- all_image_euc_feats=embeddings.get("all_image_euc_feats"),
286
- all_text_euc_feats=embeddings.get("all_text_euc_feats"),
287
- all_part_image_euc_feats=all_part_image_euc_feats,
288
- all_part_text_euc_feats=all_part_text_euc_feats,
289
- all_image_for_parts_euc=all_image_for_parts_euc,
290
- all_text_for_parts_euc=all_text_for_parts_euc,
291
- global_targets=global_targets,
292
- part_targets=part_targets,
293
- part_weights=part_weights,
294
- product_metric=self.product_metric,
295
- loss_type=self.contrastive_loss,
296
- contrastive_global_weight=self.contrastive_global_weight,
297
- contrastive_local_weight=self.contrastive_local_weight,
298
- contrastive_global_local_weight=self.contrastive_global_local_weight,
299
- beta_cal_beta=self.beta_cal_beta,
300
- beta_cal_variant=self.beta_cal_variant,
301
- beta_cal_weight=self.beta_cal_weight,
302
- part_group_ids=part_group_ids,
303
- all_part_group_ids=all_part_group_ids,
304
- global_logit_bias=logit_scales.get("global_bias"),
305
- local_logit_bias=logit_scales.get("local_bias"),
306
- global_local_logit_bias=logit_scales.get("global_local_bias"),
307
- sigmoid_negative_weight=self.sigmoid_negative_weight,
308
- global_local_mode=self.global_local_mode,
309
- global_local_metric=self.global_local_metric,
310
- global_local_angle_aux_weight=self.global_local_angle_aux_weight,
311
- global_local_angle_aux_mode=self.global_local_angle_aux_mode,
312
- global_local_angle_aux_scale=self.global_local_angle_aux_scale,
313
- global_local_angle_aux_aperture_scale=self.global_local_angle_aux_aperture_scale,
314
- )
315
- himo_component_loss = image_feats.new_zeros(())
316
- if self.himo_component_weight > 0.0 and embeddings.get("himo_text_feats") is not None:
317
- himo_text_feats = embeddings["himo_text_feats"]
318
- all_himo_text_feats = embeddings.get("all_himo_text_feats")
319
- if all_himo_text_feats is None:
320
- raise ValueError("himo_text_feats requires all_himo_text_feats for distributed contrastive loss")
321
- scale = logit_scales["global"].exp().clamp(max=100.0)
322
- logits_i_t = -metric_pairwise_dist(image_feats, all_himo_text_feats, embeddings["kappa"], product_metric=self.product_metric) * scale
323
- logits_t_i = -metric_pairwise_dist(himo_text_feats, embeddings["all_image_feats"], embeddings["kappa"], product_metric=self.product_metric) * scale
324
- himo_component_loss = 0.5 * (contrastive_ce(logits_i_t, global_targets) + contrastive_ce(logits_t_i, global_targets))
325
- if self.entailment_loss == "argent":
326
- entailment = uncha_argent_entailment_losses(
327
- image_feats=image_feats,
328
- text_feats=text_feats,
329
- part_image_flat=part_image_flat,
330
- part_text_flat=part_text_flat,
331
- image_for_parts=image_for_parts,
332
- text_for_parts=text_for_parts,
333
- kappa=embeddings["kappa"],
334
- beta=self.argent_beta,
335
- part_weights=part_weights,
336
- product_metric=self.product_metric,
337
- aggregation=self.argent_aggregation,
338
- )
339
- elif self.entailment_loss in {"hier_beta_argent", "hier_beta_sourcepart_argent"}:
340
- required = (
341
- "beta_query_image_feats",
342
- "beta_query_text_feats",
343
- "beta_query_owner",
344
- "beta_query_parent",
345
- "beta_query_weight",
346
- )
347
- if self.entailment_loss == "hier_beta_sourcepart_argent":
348
- required = (*required, "beta_query_source_part")
349
- missing = [key for key in required if embeddings.get(key) is None]
350
- if missing:
351
- raise ValueError(f"{self.entailment_loss} requires beta query embeddings: missing {missing}")
352
- entailment = hierarchical_beta_argent_entailment_losses(
353
- image_feats=image_feats,
354
- text_feats=text_feats,
355
- part_image_flat=part_image_flat,
356
- part_text_flat=part_text_flat,
357
- image_for_parts=image_for_parts,
358
- text_for_parts=text_for_parts,
359
- beta_query_image_feats=embeddings["beta_query_image_feats"],
360
- beta_query_text_feats=embeddings["beta_query_text_feats"],
361
- beta_query_owner=embeddings["beta_query_owner"],
362
- beta_query_parent=embeddings["beta_query_parent"],
363
- beta_query_weight=embeddings["beta_query_weight"],
364
- beta_query_source_part=embeddings.get("beta_query_source_part")
365
- if self.entailment_loss == "hier_beta_sourcepart_argent"
366
- else None,
367
- kappa=embeddings["kappa"],
368
- beta=self.argent_beta,
369
- part_weights=part_weights,
370
- product_metric=self.product_metric,
371
- aggregation=self.argent_aggregation,
372
- )
373
- else:
374
- piecewise_entailment = uncha_entailment_losses(
375
- image_feats=image_feats,
376
- text_feats=text_feats,
377
- part_image_flat=part_image_flat,
378
- part_text_flat=part_text_flat,
379
- image_for_parts=image_for_parts,
380
- text_for_parts=text_for_parts,
381
- kappa=embeddings["kappa"],
382
- inter_aperture_scale=self.inter_aperture_scale,
383
- intra_aperture_scale=self.intra_aperture_scale,
384
- piecewise_factor=self.piecewise_factor,
385
- calibration_alpha=self.calibration_alpha,
386
- stop_grad_calibration=self.stop_grad_calibration,
387
- geometry=self.entailment_geometry,
388
- part_weights=part_weights,
389
- )
390
- if self.entailment_loss == "piecewise_argent":
391
- argent_entailment = uncha_argent_entailment_losses(
392
- image_feats=image_feats,
393
- text_feats=text_feats,
394
- part_image_flat=part_image_flat,
395
- part_text_flat=part_text_flat,
396
- image_for_parts=image_for_parts,
397
- text_for_parts=text_for_parts,
398
- kappa=embeddings["kappa"],
399
- beta=self.argent_beta,
400
- part_weights=part_weights,
401
- product_metric=self.product_metric,
402
- aggregation=self.argent_aggregation,
403
- )
404
- entailment = {
405
- **piecewise_entailment,
406
- "entailment_loss": piecewise_entailment["entailment_loss"]
407
- + self.argent_aux_weight * argent_entailment["entailment_loss"],
408
- "piecewise_entailment_loss": piecewise_entailment["entailment_loss"],
409
- "argent_entailment_loss": argent_entailment["entailment_loss"],
410
- "norm_regularization_loss": argent_entailment["norm_regularization_loss"],
411
- }
412
- else:
413
- entailment = piecewise_entailment
414
- aggregate = aggregate_part_consistency_loss(
415
- image_feats=image_feats,
416
- text_feats=text_feats,
417
- part_image_flat=part_image_flat,
418
- part_text_flat=part_text_flat,
419
- part_owner=part_owner,
420
- part_weights=part_weights,
421
- )
422
- radius_order = image_feats.new_zeros(())
423
- if self.radius_order_weight > 0.0:
424
- radius_order = (
425
- radius_order_hinge(image_feats, text_feats, embeddings["kappa"], self.radius_order_margin)
426
- + radius_order_hinge(part_image_flat, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
427
- + radius_order_hinge(image_for_parts, part_image_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
428
- + radius_order_hinge(text_for_parts, part_text_flat, embeddings["kappa"], self.radius_order_margin, part_weights)
429
- )
430
- gramian_align = image_feats.new_zeros(())
431
- if self.gramian_align_weight > 0.0 and part_owner.numel() > 0:
432
- def _tangent_flat(x: Tensor) -> Tensor:
433
- tangent = log_map0(x, embeddings["kappa"])
434
- return tangent.reshape(tangent.size(0), -1) if tangent.dim() == 3 else tangent
435
-
436
- gramian_vectors = torch.stack(
437
- [
438
- _tangent_flat(image_for_parts),
439
- _tangent_flat(text_for_parts),
440
- _tangent_flat(part_image_flat),
441
- _tangent_flat(part_text_flat),
442
- ],
443
- dim=1,
444
- )
445
- gramian_align = gramian_volume_loss(gramian_vectors, part_weights)
446
- entail_weight_scale = embeddings.get("entail_weight_scale", image_feats.new_ones(()))
447
- total = (
448
- contrastive["contrastive_loss"]
449
- + self.himo_component_weight * himo_component_loss
450
- + self.entail_weight * entail_weight_scale * entailment["entailment_loss"]
451
- + self.aggregate_weight * aggregate
452
- + self.radius_order_weight * radius_order
453
- + self.gramian_align_weight * gramian_align
454
- + self.argent_norm_weight * entailment.get(
455
- "norm_regularization_loss",
456
- image_feats.new_zeros(()),
457
- )
458
- )
459
- return {
460
- "loss": total,
461
- **contrastive,
462
- "himo_component_contrastive_loss": himo_component_loss,
463
- **entailment,
464
- "aggregate_consistency_loss": aggregate,
465
- "radius_order_loss": radius_order,
466
- "gramian_align_loss": gramian_align,
467
- "part_count": part_count,
468
- "entail_weight_scale": entail_weight_scale.detach(),
469
- "part_quality_mean": (
470
- image_feats.new_zeros(()) if quality_scores.numel() == 0 else quality_scores.mean().detach()
471
- ),
472
- "part_quality_keep_fraction": (
473
- image_feats.new_zeros(()) if quality_keep.numel() == 0 else quality_keep.mean().detach()
474
- ),
475
- }
476
-
477
-
478
- def build_objective(
479
- objective: str,
480
- entail_weight: float,
481
- inter_aperture_scale: float,
482
- intra_aperture_scale: float,
483
- uncha_piecewise_factor: float = 0.1,
484
- uncha_calibration_alpha: float = 10.0,
485
- uncha_stop_grad_calibration: bool = True,
486
- uncha_entailment_geometry: str = "lorentz",
487
- uncha_aggregate_weight: float = 0.0,
488
- uncha_entailment_loss: str = "piecewise",
489
- uncha_argent_beta: float = 1.0,
490
- uncha_argent_norm_weight: float = 0.0,
491
- uncha_argent_aux_weight: float = 0.5,
492
- uncha_argent_aggregation: str = "uncha",
493
- uncha_part_weight_power: float = 0.0,
494
- uncha_contrastive_loss: str = "ce",
495
- uncha_sigmoid_negative_weight: float = 1.0,
496
- uncha_part_quality_mode: str = "none",
497
- uncha_part_quality_topk: int = 5,
498
- uncha_part_quality_temperature: float = 4.0,
499
- uncha_contrastive_global_weight: float = 1.0,
500
- uncha_contrastive_local_weight: float = 1.0,
501
- uncha_contrastive_global_local_weight: float = 1.0,
502
- uncha_beta_cal_beta: float = 0.0,
503
- uncha_beta_cal_variant: str = "ce",
504
- uncha_beta_cal_weight: float = 0.0,
505
- uncha_himo_component_weight: float = 0.0,
506
- uncha_global_local_mode: str = "repeat",
507
- uncha_global_local_metric: str = "distance",
508
- uncha_global_local_angle_aux_weight: float = 0.0,
509
- uncha_global_local_angle_aux_mode: str = "contrastive",
510
- uncha_global_local_angle_aux_scale: float = 5.5,
511
- uncha_global_local_angle_aux_aperture_scale: float = 1.0,
512
- uncha_radius_order_weight: float = 0.0,
513
- uncha_radius_order_margin: float = 0.0,
514
- uncha_gramian_align_weight: float = 0.0,
515
- product_metric: str = "l1",
516
- ) -> nn.Module:
517
- if objective == "hycoclip":
518
- return HyCoCLIPObjective(
519
- entail_weight=entail_weight,
520
- inter_aperture_scale=inter_aperture_scale,
521
- intra_aperture_scale=intra_aperture_scale,
522
- product_metric=product_metric,
523
- )
524
- if objective == "uncha":
525
- return UNCHAObjective(
526
- entail_weight=entail_weight,
527
- inter_aperture_scale=inter_aperture_scale,
528
- intra_aperture_scale=intra_aperture_scale,
529
- piecewise_factor=uncha_piecewise_factor,
530
- calibration_alpha=uncha_calibration_alpha,
531
- stop_grad_calibration=uncha_stop_grad_calibration,
532
- entailment_geometry=uncha_entailment_geometry,
533
- aggregate_weight=uncha_aggregate_weight,
534
- entailment_loss=uncha_entailment_loss,
535
- argent_beta=uncha_argent_beta,
536
- argent_norm_weight=uncha_argent_norm_weight,
537
- argent_aux_weight=uncha_argent_aux_weight,
538
- argent_aggregation=uncha_argent_aggregation,
539
- part_weight_power=uncha_part_weight_power,
540
- product_metric=product_metric,
541
- contrastive_loss=uncha_contrastive_loss,
542
- sigmoid_negative_weight=uncha_sigmoid_negative_weight,
543
- part_quality_mode=uncha_part_quality_mode,
544
- part_quality_topk=uncha_part_quality_topk,
545
- part_quality_temperature=uncha_part_quality_temperature,
546
- contrastive_global_weight=uncha_contrastive_global_weight,
547
- contrastive_local_weight=uncha_contrastive_local_weight,
548
- contrastive_global_local_weight=uncha_contrastive_global_local_weight,
549
- beta_cal_beta=uncha_beta_cal_beta,
550
- beta_cal_variant=uncha_beta_cal_variant,
551
- beta_cal_weight=uncha_beta_cal_weight,
552
- himo_component_weight=uncha_himo_component_weight,
553
- global_local_mode=uncha_global_local_mode,
554
- global_local_metric=uncha_global_local_metric,
555
- global_local_angle_aux_weight=uncha_global_local_angle_aux_weight,
556
- global_local_angle_aux_mode=uncha_global_local_angle_aux_mode,
557
- global_local_angle_aux_scale=uncha_global_local_angle_aux_scale,
558
- global_local_angle_aux_aperture_scale=uncha_global_local_angle_aux_aperture_scale,
559
- radius_order_weight=uncha_radius_order_weight,
560
- radius_order_margin=uncha_radius_order_margin,
561
- gramian_align_weight=uncha_gramian_align_weight,
562
- )
563
- raise ValueError(f"Unsupported objective {objective!r}; expected 'hycoclip' or 'uncha'")
564
-
565
-
566
- def _part_weights(part_owner: Tensor, batch_size: int, power: float) -> Tensor | None:
567
- if power <= 0.0 or part_owner.numel() == 0:
568
- return None
569
- counts = torch.bincount(part_owner, minlength=batch_size).to(dtype=torch.float32, device=part_owner.device)
570
- weights = counts[part_owner].clamp_min(1.0).pow(-power)
571
- return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
572
-
573
-
574
- def _combine_part_weights(count_weights: Tensor | None, quality_weights: Tensor | None) -> Tensor | None:
575
- if count_weights is None:
576
- return quality_weights
577
- if quality_weights is None:
578
- return count_weights
579
- weights = count_weights * quality_weights
580
- return weights / weights.mean().clamp_min(torch.finfo(weights.dtype).eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/models/tren.py DELETED
@@ -1,255 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import Tensor, nn
8
-
9
-
10
- class FourierPositionEncoding2D(nn.Module):
11
- def __init__(self, dim: int, scale: float = 1.0) -> None:
12
- super().__init__()
13
- if dim <= 0 or dim % 2 != 0:
14
- raise ValueError("FourierPositionEncoding2D dim must be a positive even integer")
15
- if scale <= 0.0:
16
- raise ValueError("FourierPositionEncoding2D scale must be positive")
17
- generator = torch.Generator()
18
- generator.manual_seed(42)
19
- self.register_buffer("gaussian_matrix", scale * torch.randn((2, dim // 2), generator=generator))
20
-
21
- def forward(self, coords: Tensor) -> Tensor:
22
- projected = (2.0 * coords.float() - 1.0) @ self.gaussian_matrix
23
- projected = 2.0 * math.pi * projected
24
- return torch.cat([torch.sin(projected), torch.cos(projected)], dim=-1)
25
-
26
-
27
- class _MLPBlock(nn.Module):
28
- def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
29
- super().__init__()
30
- self.net = nn.Sequential(
31
- nn.Linear(dim, hidden_dim),
32
- nn.GELU(),
33
- nn.Dropout(dropout),
34
- nn.Linear(hidden_dim, dim),
35
- )
36
-
37
- def forward(self, x: Tensor) -> Tensor:
38
- return self.net(x)
39
-
40
-
41
- class _AttentionLayer(nn.Module):
42
- def __init__(
43
- self,
44
- q_dim: int,
45
- kv_dim: int,
46
- hidden_dim: int,
47
- *,
48
- num_heads: int,
49
- dropout: float,
50
- use_bias: bool = False,
51
- use_v_proj: bool = True,
52
- use_out_proj: bool = True,
53
- ) -> None:
54
- super().__init__()
55
- if hidden_dim % num_heads != 0:
56
- raise ValueError("hidden_dim must be divisible by num_heads")
57
- if not use_v_proj and kv_dim != hidden_dim:
58
- raise ValueError("kv_dim must equal hidden_dim when value projection is disabled")
59
- self.hidden_dim = hidden_dim
60
- self.num_heads = num_heads
61
- self.head_dim = hidden_dim // num_heads
62
- self.q_proj = nn.Linear(q_dim, hidden_dim, bias=use_bias)
63
- self.k_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias)
64
- self.v_proj = nn.Linear(kv_dim, hidden_dim, bias=use_bias) if use_v_proj else nn.Identity()
65
- self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=use_bias) if use_out_proj else nn.Identity()
66
- self.q_norm = nn.LayerNorm(self.head_dim)
67
- self.k_norm = nn.LayerNorm(self.head_dim)
68
- self.dropout = nn.Dropout(dropout)
69
- self.scale = self.head_dim**-0.5
70
-
71
- nn.init.kaiming_normal_(self.q_proj.weight, mode="fan_in", nonlinearity="linear")
72
- nn.init.kaiming_normal_(self.k_proj.weight, mode="fan_in", nonlinearity="linear")
73
- if isinstance(self.v_proj, nn.Linear):
74
- nn.init.kaiming_normal_(self.v_proj.weight, mode="fan_in", nonlinearity="linear")
75
- if isinstance(self.out_proj, nn.Linear):
76
- nn.init.kaiming_normal_(self.out_proj.weight, mode="fan_in", nonlinearity="linear")
77
-
78
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
79
- batch_size, q_len, _ = q.shape
80
- _, kv_len, _ = k.shape
81
- query = self.q_proj(q).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
82
- key = self.k_proj(k).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
83
- value = self.v_proj(v).view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
84
-
85
- query = self.q_norm(query)
86
- key = self.k_norm(key)
87
- attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
88
- attn_weights = self.dropout(F.softmax(attn_scores, dim=-1))
89
- out = torch.matmul(attn_weights, value)
90
- out = out.transpose(1, 2).contiguous().view(batch_size, q_len, self.hidden_dim)
91
- return self.out_proj(out), attn_weights
92
-
93
-
94
- class _CrossAttentionBlock(nn.Module):
95
- def __init__(self, dim: int, *, num_heads: int, dropout: float) -> None:
96
- super().__init__()
97
- self.query_norm = nn.LayerNorm(dim)
98
- self.cross_attn = _AttentionLayer(dim, dim, dim, num_heads=num_heads, dropout=dropout)
99
- self.dropout = nn.Dropout(dropout)
100
- self.mlp_norm = nn.LayerNorm(dim)
101
- self.mlp = _MLPBlock(dim, 2 * dim, dropout)
102
- self.out_norm = nn.LayerNorm(dim)
103
-
104
- def forward(self, query: Tensor, context: Tensor) -> Tensor:
105
- x, _ = self.cross_attn(self.query_norm(query), context, context)
106
- x = query + self.dropout(x)
107
- return self.out_norm(x + self.mlp(self.mlp_norm(x)))
108
-
109
-
110
- class TRENRegionEncoder(nn.Module):
111
- """T-REN-style point-prompted region token encoder.
112
-
113
- The module follows the public T-REN architecture: learned k-per-prompt
114
- query tokens, Fourier 2D prompt/patch position encodings, alternating
115
- cross-attention and per-prompt self-attention, then final single-head
116
- attention that pools unprojected patch tokens into region tokens.
117
- """
118
-
119
- def __init__(
120
- self,
121
- vision_dim: int,
122
- text_dim: int,
123
- *,
124
- hidden_dim: int | None = None,
125
- num_region_tokens: int = 3,
126
- num_decoder_layers: int = 2,
127
- num_attention_heads: int = 8,
128
- prompt_grid_size: int = 7,
129
- dropout: float = 0.1,
130
- ) -> None:
131
- super().__init__()
132
- if num_region_tokens <= 0:
133
- raise ValueError("num_region_tokens must be positive")
134
- if num_decoder_layers <= 0:
135
- raise ValueError("num_decoder_layers must be positive")
136
- if prompt_grid_size <= 0:
137
- raise ValueError("prompt_grid_size must be positive")
138
- hidden_dim = int(hidden_dim or vision_dim)
139
- if hidden_dim != vision_dim:
140
- raise ValueError("TRENRegionEncoder currently requires hidden_dim == vision_dim")
141
- if hidden_dim % 2 != 0:
142
- raise ValueError("TRENRegionEncoder hidden_dim must be even for Fourier features")
143
- if hidden_dim % num_attention_heads != 0:
144
- raise ValueError("TRENRegionEncoder hidden_dim must be divisible by num_attention_heads")
145
-
146
- self.vision_dim = vision_dim
147
- self.text_dim = text_dim
148
- self.hidden_dim = hidden_dim
149
- self.num_region_tokens = num_region_tokens
150
- self.prompt_grid_size = prompt_grid_size
151
- self.position_encoder = FourierPositionEncoding2D(hidden_dim)
152
- self.region_token_embeddings = nn.Embedding(num_region_tokens, hidden_dim)
153
- nn.init.normal_(self.region_token_embeddings.weight, std=0.02)
154
- self.region_attention_layers = nn.ModuleList(
155
- [_CrossAttentionBlock(hidden_dim, num_heads=num_attention_heads, dropout=dropout) for _ in range(num_decoder_layers)]
156
- )
157
- self.region_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
158
- self.prompt_attention_layers = nn.ModuleList(
159
- [
160
- _AttentionLayer(
161
- hidden_dim,
162
- hidden_dim,
163
- hidden_dim,
164
- num_heads=num_attention_heads,
165
- dropout=dropout,
166
- )
167
- for _ in range(num_decoder_layers)
168
- ]
169
- )
170
- self.prompt_attention_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_decoder_layers)])
171
- self.token_prediction_head = _AttentionLayer(
172
- hidden_dim,
173
- hidden_dim,
174
- hidden_dim,
175
- num_heads=1,
176
- dropout=0.0,
177
- use_v_proj=False,
178
- use_out_proj=False,
179
- )
180
- self.text_alignment_block = nn.Sequential(
181
- nn.Linear(hidden_dim, 2 * hidden_dim),
182
- nn.GELU(),
183
- nn.Dropout(dropout),
184
- nn.Linear(2 * hidden_dim, text_dim),
185
- )
186
-
187
- def forward(self, image_tokens: Tensor) -> dict[str, Tensor]:
188
- patch_tokens, patch_grid = _patch_tokens_and_grid(image_tokens)
189
- batch_size, patch_count, _ = patch_tokens.shape
190
- patch_coords = _grid_coords(patch_grid, patch_grid, patch_tokens.device)
191
- prompt_coords = _grid_coords(self.prompt_grid_size, self.prompt_grid_size, patch_tokens.device)
192
- prompt_count = prompt_coords.size(0)
193
-
194
- feature_pos = self.position_encoder(patch_coords).to(dtype=patch_tokens.dtype)
195
- prompt_pos = self.position_encoder(prompt_coords).to(dtype=patch_tokens.dtype)
196
- kv = patch_tokens + feature_pos.unsqueeze(0)
197
- prompt_pos = prompt_pos.view(1, prompt_count, 1, self.hidden_dim)
198
-
199
- q = self.region_token_embeddings.weight.to(dtype=patch_tokens.dtype)
200
- q = q.view(1, 1, self.num_region_tokens, self.hidden_dim).expand(
201
- batch_size,
202
- prompt_count,
203
- self.num_region_tokens,
204
- self.hidden_dim,
205
- )
206
- for region_layer, region_norm, prompt_layer, prompt_norm in zip(
207
- self.region_attention_layers,
208
- self.region_attention_norms,
209
- self.prompt_attention_layers,
210
- self.prompt_attention_norms,
211
- strict=True,
212
- ):
213
- q = q + prompt_pos
214
- q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
215
- q = region_layer(q, kv)
216
- q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
217
- q = region_norm(q)
218
- q = q.reshape(batch_size * prompt_count, self.num_region_tokens, self.hidden_dim)
219
- q, _ = prompt_layer(q, q, q)
220
- q = prompt_norm(q)
221
- q = q.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
222
-
223
- flat_q = q.reshape(batch_size, prompt_count * self.num_region_tokens, self.hidden_dim)
224
- visual_tokens, attn_weights = self.token_prediction_head(flat_q, kv, patch_tokens)
225
- visual_tokens = visual_tokens.reshape(batch_size, prompt_count, self.num_region_tokens, self.hidden_dim)
226
- attn_weights = attn_weights.squeeze(1).reshape(batch_size, prompt_count, self.num_region_tokens, patch_count)
227
- region_masks = attn_weights / attn_weights.amax(dim=-1, keepdim=True).clamp_min(torch.finfo(attn_weights.dtype).eps)
228
- region_masks = region_masks.reshape(batch_size, prompt_count, self.num_region_tokens, patch_grid, patch_grid)
229
- text_aligned_tokens = self.text_alignment_block(visual_tokens)
230
- return {
231
- "visual_tokens": visual_tokens,
232
- "text_aligned_tokens": text_aligned_tokens,
233
- "region_masks": region_masks,
234
- "prompt_coords": prompt_coords,
235
- }
236
-
237
-
238
- def _patch_tokens_and_grid(tokens: Tensor) -> tuple[Tensor, int]:
239
- if tokens.ndim != 3:
240
- raise ValueError("TRENRegionEncoder expects image tokens with shape [batch, tokens, dim]")
241
- token_count = tokens.size(1)
242
- grid = int(math.isqrt(token_count))
243
- if grid * grid == token_count:
244
- return tokens, grid
245
- grid = int(math.isqrt(token_count - 1))
246
- if grid * grid == token_count - 1:
247
- return tokens[:, 1:, :], grid
248
- raise ValueError(f"Cannot infer a square patch grid from {token_count} image tokens")
249
-
250
-
251
- def _grid_coords(height: int, width: int, device: torch.device) -> Tensor:
252
- y = torch.linspace(0.5 / height, 1.0 - 0.5 / height, height, device=device)
253
- x = torch.linspace(0.5 / width, 1.0 - 0.5 / width, width, device=device)
254
- yy, xx = torch.meshgrid(y, x, indexing="ij")
255
- return torch.stack([xx, yy], dim=-1).reshape(-1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip/training/__init__.py DELETED
@@ -1 +0,0 @@
1
- __all__: list[str] = []
 
 
hyper3_clip/training/distributed.py DELETED
@@ -1,149 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from collections.abc import Sequence
4
- import os
5
-
6
- import torch
7
- import torch.distributed as dist
8
- from torch.distributed.nn import all_gather as differentiable_all_gather
9
- from torch import Tensor
10
-
11
-
12
- def init_distributed() -> None:
13
- if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not dist.is_initialized():
14
- backend = "nccl" if torch.cuda.is_available() else "gloo"
15
- if torch.cuda.is_available():
16
- torch.cuda.set_device(get_local_rank())
17
- dist.init_process_group(backend=backend)
18
-
19
-
20
- def is_distributed() -> bool:
21
- return dist.is_available() and dist.is_initialized()
22
-
23
-
24
- def barrier() -> None:
25
- if is_distributed():
26
- dist.barrier()
27
-
28
-
29
- def destroy_distributed() -> None:
30
- if is_distributed():
31
- dist.destroy_process_group()
32
-
33
-
34
- def get_rank() -> int:
35
- return dist.get_rank() if is_distributed() else 0
36
-
37
-
38
- def get_world_size() -> int:
39
- return dist.get_world_size() if is_distributed() else 1
40
-
41
-
42
- def get_local_rank() -> int:
43
- return int(os.environ.get("LOCAL_RANK", "0"))
44
-
45
-
46
- def is_main_process() -> bool:
47
- return get_rank() == 0
48
-
49
-
50
- def gather_with_grad(tensor: Tensor) -> Tensor:
51
- world_size = get_world_size()
52
- if world_size == 1:
53
- return tensor
54
- return torch.cat(list(differentiable_all_gather(tensor.contiguous())), dim=0)
55
-
56
-
57
- def gather_variable_with_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
58
- """Gather tensors with variable first-dimension lengths across ranks."""
59
- count_tensor, max_count, keep = _variable_gather_metadata(tensor)
60
- if get_world_size() == 1:
61
- return tensor, count_tensor
62
- return _gather_variable_from_metadata(tensor, max_count, keep), count_tensor
63
-
64
-
65
- def gather_variable_many_with_grad(tensors: Sequence[Tensor]) -> tuple[list[Tensor], Tensor]:
66
- """Gather same-length variable tensors while sharing count metadata.
67
-
68
- Tensors with matching dtype/rank/trailing shape are packed along the last
69
- dimension so a single differentiable all-gather can serve several feature
70
- tensors with the same variable first dimension.
71
- """
72
- if not tensors:
73
- raise ValueError("gather_variable_many_with_grad requires at least one tensor")
74
- first = tensors[0]
75
- for tensor in tensors:
76
- if tensor.device != first.device:
77
- raise ValueError("all tensors must be on the same device")
78
- if tensor.shape[0] != first.shape[0]:
79
- raise ValueError("all tensors must have the same first dimension")
80
- count_tensor, max_count, keep = _variable_gather_metadata(first)
81
- if get_world_size() == 1:
82
- return list(tensors), count_tensor
83
-
84
- gathered: list[Tensor | None] = [None] * len(tensors)
85
- groups: dict[tuple[torch.dtype, torch.Size, int], list[int]] = {}
86
- for index, tensor in enumerate(tensors):
87
- if tensor.dim() == 0:
88
- raise ValueError("variable gather tensors must have at least one dimension")
89
- key = (tensor.dtype, tensor.shape[1:-1], tensor.dim()) if tensor.dim() > 1 else (tensor.dtype, torch.Size(), 1)
90
- groups.setdefault(key, []).append(index)
91
-
92
- for indices in groups.values():
93
- group_tensors = [tensors[index] for index in indices]
94
- if len(group_tensors) == 1 or group_tensors[0].dim() == 1:
95
- for index, tensor in zip(indices, group_tensors, strict=True):
96
- gathered[index] = _gather_variable_from_metadata(tensor, max_count, keep)
97
- continue
98
- widths = [tensor.shape[-1] for tensor in group_tensors]
99
- packed = torch.cat(group_tensors, dim=-1)
100
- gathered_packed = _gather_variable_from_metadata(packed, max_count, keep)
101
- for index, chunk in zip(indices, gathered_packed.split(widths, dim=-1), strict=True):
102
- gathered[index] = chunk
103
-
104
- if any(tensor is None for tensor in gathered):
105
- raise RuntimeError("internal error while gathering variable tensors")
106
- return [tensor for tensor in gathered if tensor is not None], count_tensor
107
-
108
-
109
- def gather_variable_no_grad(tensor: Tensor) -> tuple[Tensor, Tensor]:
110
- """Gather variable-length tensors that do not require autograd."""
111
- count_tensor, max_count, keep = _variable_gather_metadata(tensor)
112
- if get_world_size() == 1:
113
- return tensor, count_tensor
114
- padded = tensor.new_zeros((max_count, *tensor.shape[1:]))
115
- padded[: tensor.shape[0]] = tensor
116
- gathered = [torch.zeros_like(padded) for _ in range(get_world_size())]
117
- dist.all_gather(gathered, padded.contiguous())
118
- return torch.cat(gathered, dim=0)[keep], count_tensor
119
-
120
-
121
- def _variable_gather_metadata(tensor: Tensor) -> tuple[Tensor, int, Tensor]:
122
- world_size = get_world_size()
123
- local_count = torch.tensor([tensor.shape[0]], device=tensor.device, dtype=torch.long)
124
- if world_size == 1:
125
- keep = torch.ones(tensor.shape[0], device=tensor.device, dtype=torch.bool)
126
- return local_count, tensor.shape[0], keep
127
-
128
- counts = [torch.zeros_like(local_count) for _ in range(world_size)]
129
- dist.all_gather(counts, local_count)
130
- count_tensor = torch.cat(counts)
131
- max_count = int(count_tensor.max().item())
132
- keep = torch.zeros(world_size * max_count, device=tensor.device, dtype=torch.bool)
133
- for rank, count in enumerate(count_tensor.tolist()):
134
- start = rank * max_count
135
- keep[start : start + count] = True
136
- return count_tensor, max_count, keep
137
-
138
-
139
- def _gather_variable_from_metadata(tensor: Tensor, max_count: int, keep: Tensor) -> Tensor:
140
- padded_shape = (max_count, *tensor.shape[1:])
141
- padded = tensor.new_zeros(padded_shape)
142
- padded[: tensor.shape[0]] = tensor
143
-
144
- gathered = torch.cat(list(differentiable_all_gather(padded.contiguous())), dim=0)
145
- return gathered[keep]
146
-
147
-
148
- def local_target_indices(batch_size: int, device: torch.device) -> Tensor:
149
- return torch.arange(batch_size, device=device) + batch_size * get_rank()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyper3_clip_provider.py DELETED
@@ -1,115 +0,0 @@
1
- """HyperView embedding provider for the Hyper3-CLIP v0.5 HF checkpoint."""
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- import numpy as np
10
- import torch
11
- import yaml
12
- from huggingface_hub import snapshot_download
13
- from lancedb.embeddings import EmbeddingFunction
14
- from pydantic import PrivateAttr
15
- from safetensors.torch import load_file
16
-
17
-
18
- class Hyper3ClipEmbeddings(EmbeddingFunction):
19
- """Image embeddings from Hyper3-CLIP v0.5 in Lorentz/hyperboloid space."""
20
-
21
- name: str = "hyper3labs/hyper3-clip-v0.5"
22
- batch_size: int = 8
23
- device: str = "cpu"
24
-
25
- _model: Any = PrivateAttr(default=None)
26
- _transform: Any = PrivateAttr(default=None)
27
-
28
- @property
29
- def geometry(self) -> str:
30
- return "hyperboloid"
31
-
32
- @property
33
- def curvature(self) -> float:
34
- self._ensure_model()
35
- return float(self._model._kappa().detach().cpu().reshape(-1)[0].item())
36
-
37
- def ndims(self) -> int:
38
- return 513
39
-
40
- def _ensure_model(self) -> None:
41
- if self._model is not None:
42
- return
43
-
44
- from hyper3_clip import Hyper3CLIP
45
- from torchvision import transforms
46
-
47
- token = os.environ.get("HF_TOKEN")
48
- local_dir = snapshot_download(
49
- self.name,
50
- allow_patterns=["config.yaml", "model.safetensors"],
51
- token=token,
52
- )
53
- root = Path(local_dir)
54
- config = yaml.safe_load((root / "config.yaml").read_text(encoding="utf-8"))
55
-
56
- model = Hyper3CLIP(**config["model"])
57
- state = load_file(root / "model.safetensors", device="cpu")
58
- model.load_state_dict(state)
59
- model.to(torch.device(self.device))
60
- model.eval()
61
-
62
- self._model = model
63
- image_size = int(config.get("data", {}).get("image_size", 224))
64
- self._transform = transforms.Compose(
65
- [
66
- transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BICUBIC),
67
- transforms.CenterCrop(image_size),
68
- transforms.ToTensor(),
69
- transforms.Normalize(
70
- mean=(0.485, 0.456, 0.406),
71
- std=(0.229, 0.224, 0.225),
72
- ),
73
- ]
74
- )
75
-
76
- def compute_source_embeddings(
77
- self,
78
- inputs: Any,
79
- *args: Any,
80
- **kwargs: Any,
81
- ) -> list[np.ndarray | None]:
82
- from PIL import Image
83
- from hyperview.core.sample import Sample
84
-
85
- self._ensure_model()
86
- device = torch.device(self.device)
87
- images = []
88
- for item in self.sanitize_input(inputs):
89
- if isinstance(item, Sample):
90
- with item.load_image() as img:
91
- images.append(img.convert("RGB"))
92
- elif isinstance(item, str):
93
- with Image.open(item) as img:
94
- images.append(img.convert("RGB"))
95
- elif isinstance(item, Image.Image):
96
- images.append(item.convert("RGB"))
97
- else:
98
- raise TypeError(f"Unsupported input type: {type(item)}")
99
-
100
- outputs: list[np.ndarray | None] = []
101
- with torch.inference_mode():
102
- for start in range(0, len(images), self.batch_size):
103
- batch = images[start:start + self.batch_size]
104
- tensor = torch.stack([self._transform(image) for image in batch]).to(device)
105
- encoded = self._model.encode_image(tensor).detach().cpu().numpy().astype(np.float32)
106
- outputs.extend(encoded)
107
- return outputs
108
-
109
- def compute_query_embeddings(
110
- self,
111
- query: Any,
112
- *args: Any,
113
- **kwargs: Any,
114
- ) -> list[np.ndarray | None]:
115
- return self.compute_source_embeddings([query], *args, **kwargs)