Jolia / modeling_jolia.py
SovanK's picture
Upload folder using huggingface_hub
6858e35 verified
Raw
History Blame Contribute Delete
18.4 kB
"""Jolia: a self-contained Atlas vision backbone with named organ queries.
``JoliaModel`` is the public Hugging Face entry point. It wraps a vendored
``MultiModalAtlas`` 3D backbone, a per-scale bank of organ-query
cross-attention poolers, and a CLIP-style text-projection head used for
zero-shot classification.
Feature views:
* :meth:`forward` / ``__call__`` — the pooled global (CLS-equivalent) embedding.
* :meth:`encode_image` — L2-normalized image embedding in the shared CLIP space.
* :meth:`encode_text` — L2-normalized text embedding in the shared CLIP space
(input: pooled Qwen3 text features, output: 576-d).
* :meth:`zero_shot_logits` / :meth:`zero_shot` — image-vs-text similarity with
the trained temperature + bias.
* :meth:`encode_organs` — per-organ embeddings keyed by organ **name**.
* :meth:`extract_flat_feature` — the normalized ``[cls ⊕ organs]`` vector used
for linear probing.
Load it with::
from transformers import AutoModel
model = AutoModel.from_pretrained("raidium/Jolia", trust_remote_code=True).eval()
For zero-shot, pair it with the paired text encoder
(see :class:`text_encoder_jolia.JoliaTextEncoder`).
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
from .configuration_jolia import JoliaConfig
from .jolia_multimodal_atlas import MultiModalAtlas, MultiModalAtlasConfig
from .jolia_organ_query_attention import OrganQueryAttention
# HF's trust_remote_code loader copies only the *direct* relative imports of this
# entry module. The lines below force every transitively-vendored backbone file
# (and the text-encoder helper) into that copy set so the repo loads from the
# hub without missing modules.
from .jolia_atlas_encoders import ChestCTEmbed3D as _ensure_atlas_encoders # noqa: E402,F401
from .jolia_multimodal_msa import AtlasStage as _ensure_multimodal_msa # noqa: E402,F401
from .jolia_shim import BaseModel as _ensure_shim # noqa: E402,F401
# The text encoder is opt-in (loading Qwen3 is heavy); we still touch the
# module so it's bundled in the dynamic-module copy set for snapshot_download
# + sys.path.append users.
from .text_encoder_jolia import JoliaTextEncoder as _ensure_text_encoder # noqa: E402,F401
@dataclass
class JoliaOutput(ModelOutput):
"""Output of :meth:`JoliaModel.forward`.
Attributes:
pooler_output: ``(B, embed_dim)`` global embedding.
organ_queries: ``(B, num_organs, query_dim)`` per-organ embeddings, or
``None`` when ``output_organ_queries=False`` / no organ-query head.
"""
pooler_output: torch.FloatTensor | None = None
organ_queries: torch.FloatTensor | None = None
class JoliaModel(PreTrainedModel):
config_class = JoliaConfig
base_model_prefix = "jolia"
main_input_name = "image"
def __init__(self, config: JoliaConfig) -> None:
super().__init__(config)
atlas_config = MultiModalAtlasConfig(
embed_dim=config.embed_dim,
num_classes=0,
multiscale_feats=config.multiscale_feats,
atlas_config=config.atlas_config,
)
self.vision_model = MultiModalAtlas(atlas_config)
if config.has_queries:
self.organ_query_attn_scales: torch.nn.ModuleList | None = torch.nn.ModuleList(
OrganQueryAttention(
num_organs=config.num_organs,
query_dim=config.patch_embed_dim,
num_heads=config.num_heads,
)
for _ in range(config.num_scales)
)
else:
self.organ_query_attn_scales = None
# CLIP-style text head — small (~10 MB) projection from the paired text
# encoder's hidden size to the shared embedding space, plus the trained
# temperature (`logit_scale`) and additive `bias` used by `zero_shot_logits`.
if config.has_text_head:
self.text_projection: nn.Linear | None = nn.Linear(
config.text_embed_dim, config.embed_dim, bias=True
)
self.logit_scale = nn.Parameter(torch.zeros([]))
self.text_bias = nn.Parameter(torch.zeros([]))
else:
self.text_projection = None
self.register_parameter("logit_scale", None)
self.register_parameter("text_bias", None)
# ParallelOrganCLIP text head — used for organ-routed zero-shot (a text
# prompt compared against a *specific* organ-query embedding). Same
# shape as the global text head but trained against per-organ findings.
if config.has_text_head and config.has_queries:
self.organ_text_projection: nn.Linear | None = nn.Linear(
config.text_embed_dim, config.embed_dim, bias=True
)
self.organ_logit_scale = nn.Parameter(torch.zeros(config.num_organs))
self.organ_text_bias = nn.Parameter(torch.zeros(config.num_organs))
else:
self.organ_text_projection = None
self.register_parameter("organ_logit_scale", None)
self.register_parameter("organ_text_bias", None)
self.post_init()
# ------------------------------------------------------------------
# Capability / naming helpers
# ------------------------------------------------------------------
@property
def embed_dim(self) -> int:
return int(self.config.embed_dim)
@property
def query_dim(self) -> int:
return int(self.config.query_dim)
@property
def num_organs(self) -> int:
return int(self.config.num_organs)
@property
def has_queries(self) -> bool:
return self.organ_query_attn_scales is not None
@property
def organ_slot_names(self) -> list[str]:
"""Names for organ-query slots ``0 .. len-1`` (trailing slots unused)."""
return list(self.config.organ_slot_names)
@property
def has_text_head(self) -> bool:
return self.text_projection is not None
@property
def has_organ_text_head(self) -> bool:
return self.organ_text_projection is not None
@property
def text_encoder_id(self) -> str:
"""HuggingFace id of the paired text encoder (e.g. ``Qwen/Qwen3-Embedding-8B``)."""
return self.config.text_encoder_id
# ------------------------------------------------------------------
# Forward views
# ------------------------------------------------------------------
def forward(
self,
image: torch.Tensor,
output_organ_queries: bool = False,
return_dict: bool = True,
) -> JoliaOutput | tuple:
"""Return the pooled global embedding (and optionally organ queries)."""
if output_organ_queries:
cls, organ = self.forward_with_queries(image)
else:
cls, organ = self.vision_model(image), None
if not return_dict:
return (cls,) if organ is None else (cls, organ)
return JoliaOutput(pooler_output=cls, organ_queries=organ)
def forward_with_patch_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Return ``(cls, multi_scale_patch_tokens)`` from the Atlas backbone."""
cls, patch_scales = self.vision_model(image, with_patch_tokens=True)
return cls, patch_scales
def forward_with_queries(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Return ``(cls, organ_queries)`` — ``(B, embed_dim)``, ``(B, num_organs, query_dim)``."""
self._require_queries()
cls, patch_scales = self.forward_with_patch_tokens(image)
scale_outputs = [
attn(tokens)[:, : self.num_organs, :]
for attn, tokens in zip(self.organ_query_attn_scales, patch_scales) # type: ignore[arg-type]
]
organ = torch.cat(scale_outputs, dim=-1)
if self.config.use_cls_residual:
organ = organ + cls.unsqueeze(1).expand(-1, self.num_organs, -1)
return cls, organ
def extract_flat_feature(self, image: torch.Tensor) -> torch.Tensor:
"""Normalized ``[F.normalize(cls) ⊕ F.normalize(organ).flatten(1)]`` feature."""
cls, organ = self.forward_with_queries(image)
cls_n = F.normalize(cls.float(), dim=-1, eps=1e-6)
organ_n = F.normalize(organ.float(), dim=-1, eps=1e-6)
return torch.cat([cls_n, organ_n.reshape(organ_n.size(0), -1)], dim=-1)
# ------------------------------------------------------------------
# Zero-shot CLIP — image / text embeddings in a shared space
# ------------------------------------------------------------------
def encode_image(self, image: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""Image embedding ready for zero-shot — ``(B, embed_dim)``.
The released checkpoint uses CLS-only (no vision projection) before
L2-normalization, matching ``MultimodalCLSZeroShotCLIP`` in the
training repo.
"""
cls = self.vision_model(image)
return F.normalize(cls.float(), dim=-1, eps=1e-6) if normalize else cls
def encode_text(self, text_features: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""Project pooled text features into the shared embedding space.
Args:
text_features: ``(N, text_embed_dim)`` last-token-pooled features
produced by the paired text encoder (use
:class:`text_encoder_jolia.JoliaTextEncoder` to obtain them).
normalize: L2-normalize the projected vectors (default).
"""
self._require_text_head()
projected = self.text_projection(text_features.float()) # type: ignore[misc]
return F.normalize(projected, dim=-1, eps=1e-6) if normalize else projected
def zero_shot_logits(self, image_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor:
"""Calibrated image-vs-text logits with the trained temperature + bias.
``image_emb`` and ``text_emb`` must come from :meth:`encode_image` and
:meth:`encode_text` (both L2-normalized). The clamp matches the
training-time ``max_logit_scale=ln(100)``.
"""
self._require_text_head()
scale = torch.clamp(self.logit_scale.float(), max=math.log(100.0)).exp()
return image_emb @ text_emb.t() * scale + self.text_bias
@torch.no_grad()
def zero_shot(
self,
image: torch.Tensor,
text_features: torch.Tensor,
calibrated: bool = True,
) -> torch.Tensor:
"""One-call zero-shot scoring: image volume + pooled text -> ``(B, N)``.
Args:
image: Preprocessed volume ``(B, 11, 192, 192, 192)``.
text_features: Pooled text features ``(N, text_embed_dim)`` from
the paired text encoder.
calibrated: When ``True`` (default), returns the trained CLIP
logits ``cosine * exp(logit_scale) + bias`` — same output as
``MultimodalCLSZeroShotCLIP.get_logits_per_image`` in rarm,
and what you want for ``torch.sigmoid(...)`` /
``torch.softmax(...)``. Set ``calibrated=False`` for raw
cosine similarity in ``[-1, 1]`` (handy when you only need
ranking and don't want the bias offset).
"""
img = self.encode_image(image)
txt = self.encode_text(text_features)
if calibrated:
return self.zero_shot_logits(img, txt)
return img @ txt.t()
def _require_text_head(self) -> None:
if self.text_projection is None:
raise RuntimeError(
"This Jolia checkpoint has no text head — zero-shot is unavailable. "
"Re-export the model with text_embed_dim>0 from a CLIPObjective checkpoint."
)
# ------------------------------------------------------------------
# Per-organ (query-routed) zero-shot
# ------------------------------------------------------------------
def encode_organ_text(self, text_features: torch.Tensor, normalize: bool = True) -> torch.Tensor:
"""Project pooled text features through the ParallelOrganCLIP text head.
Use this for organ-routed zero-shot: the resulting embedding lives in
the same space as the per-organ image queries (different head than the
global :meth:`encode_text`).
Args:
text_features: ``(N, text_embed_dim)`` last-token-pooled Qwen3 features.
normalize: L2-normalize the projection (default).
"""
self._require_organ_text_head()
projected = self.organ_text_projection(text_features.float()) # type: ignore[misc]
return F.normalize(projected, dim=-1, eps=1e-6) if normalize else projected
@torch.no_grad()
def zero_shot_organ(
self,
image: torch.Tensor,
text_features: torch.Tensor,
organ: str,
calibrated: bool = True,
) -> torch.Tensor:
"""Score text prompts against a single organ's query embedding.
Routes the image through the per-organ cross-attention pooler for
``organ`` and contrasts that query embedding with text features
projected by the ParallelOrganCLIP text head. Returns ``(B, N)``.
Args:
image: ``(B, 11, 192, 192, 192)`` preprocessed CT volume.
text_features: ``(N, text_embed_dim)`` pooled Qwen3 features.
organ: Organ name (must be in :attr:`organ_slot_names`).
calibrated: When ``True`` (default), applies this organ's
trained temperature + bias. Set ``False`` for raw cosine.
"""
self._require_organ_text_head()
organ_emb = self.encode_organs(image, organs=[organ], normalize=True)[organ] # (B, 576)
txt_emb = self.encode_organ_text(text_features) # (N, 576)
cosine = organ_emb @ txt_emb.t() # (B, N)
if not calibrated:
return cosine
idx = self.organ_slot_names.index(organ)
scale = self.organ_logit_scale[idx].float().exp()
bias = self.organ_text_bias[idx].float()
return cosine * scale + bias
@torch.no_grad()
def zero_shot_organs(
self,
image: torch.Tensor,
text_features: torch.Tensor,
organs: list[str] | None = None,
calibrated: bool = True,
) -> dict[str, torch.Tensor]:
"""Per-organ zero-shot for many organs at once.
Returns ``{organ_name: (B, N)}``. ``organs=None`` runs every named slot.
``calibrated`` defaults to ``True`` (per-organ temperature + bias applied).
"""
self._require_organ_text_head()
organ_embeds = self.encode_organs(image, organs=organs, normalize=True)
txt_emb = self.encode_organ_text(text_features)
names = self.organ_slot_names
out: dict[str, torch.Tensor] = {}
for name, emb in organ_embeds.items():
cosine = emb @ txt_emb.t()
if calibrated:
idx = names.index(name)
scale = self.organ_logit_scale[idx].float().exp()
bias = self.organ_text_bias[idx].float()
cosine = cosine * scale + bias
out[name] = cosine
return out
def _require_organ_text_head(self) -> None:
if self.organ_text_projection is None:
raise RuntimeError(
"This Jolia checkpoint has no per-organ text head — organ-routed zero-shot "
"is unavailable. The base text head (encode_text / zero_shot) may still work."
)
# ------------------------------------------------------------------
# The easy, named organ-query API
# ------------------------------------------------------------------
def encode_organs(
self,
image: torch.Tensor,
organs: list[str] | None = None,
normalize: bool = False,
) -> dict[str, torch.Tensor]:
"""Per-organ embeddings keyed by organ name.
Args:
image: Preprocessed volume ``(B, 11, 192, 192, 192)``.
organs: Subset of organ names to return. ``None`` returns every
named slot. Unknown names raise ``KeyError`` (with the valid
names listed).
normalize: L2-normalize each organ embedding (cosine-ready).
Returns:
``{organ_name: (B, query_dim)}``. If the model has no organ-slot
names, keys fall back to ``"slot_<i>"``.
"""
self._require_queries()
_, organ = self.forward_with_queries(image) # (B, num_organs, query_dim)
names = self.organ_slot_names or [f"slot_{i}" for i in range(self.num_organs)]
name_to_idx = {name: i for i, name in enumerate(names)}
if organs is None:
wanted = list(name_to_idx.items())
else:
missing = [o for o in organs if o not in name_to_idx]
if missing:
raise KeyError(
f"Unknown organ(s) {missing}. Available organs: {sorted(name_to_idx)}"
)
wanted = [(o, name_to_idx[o]) for o in organs]
out = {name: organ[:, idx, :] for name, idx in wanted}
if normalize:
out = {name: F.normalize(vec.float(), dim=-1, eps=1e-6) for name, vec in out.items()}
return out
def organ_similarity(self, image: torch.Tensor, organs: list[str] | None = None) -> torch.Tensor:
"""Cosine-similarity matrix between organ embeddings (batch-averaged).
Returns ``(N, N)`` for the ``N`` requested organs — handy for probing
which organs the model represents similarly.
"""
embeds = self.encode_organs(image, organs=organs, normalize=True)
names = list(embeds)
mat = torch.stack([embeds[n].mean(0) for n in names], dim=0) # (N, query_dim)
return mat @ mat.t()
def _require_queries(self) -> None:
if self.organ_query_attn_scales is None:
raise RuntimeError(
"This Jolia checkpoint has no organ-query attention "
"(num_organs=0); organ-level methods are unavailable."
)