"""Jolia: a self-contained Atlas vision backbone with named organ queries. ``JoliaModel`` is the public Hugging Face entry point. It wraps a vendored ``MultiModalAtlas`` 3D backbone, a per-scale bank of organ-query cross-attention poolers, and a CLIP-style text-projection head used for zero-shot classification. Feature views: * :meth:`forward` / ``__call__`` — the pooled global (CLS-equivalent) embedding. * :meth:`encode_image` — L2-normalized image embedding in the shared CLIP space. * :meth:`encode_text` — L2-normalized text embedding in the shared CLIP space (input: pooled Qwen3 text features, output: 576-d). * :meth:`zero_shot_logits` / :meth:`zero_shot` — image-vs-text similarity with the trained temperature + bias. * :meth:`encode_organs` — per-organ embeddings keyed by organ **name**. * :meth:`extract_flat_feature` — the normalized ``[cls ⊕ organs]`` vector used for linear probing. Load it with:: from transformers import AutoModel model = AutoModel.from_pretrained("raidium/Jolia", trust_remote_code=True).eval() For zero-shot, pair it with the paired text encoder (see :class:`text_encoder_jolia.JoliaTextEncoder`). """ from __future__ import annotations import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.utils import ModelOutput from .configuration_jolia import JoliaConfig from .jolia_multimodal_atlas import MultiModalAtlas, MultiModalAtlasConfig from .jolia_organ_query_attention import OrganQueryAttention # HF's trust_remote_code loader copies only the *direct* relative imports of this # entry module. The lines below force every transitively-vendored backbone file # (and the text-encoder helper) into that copy set so the repo loads from the # hub without missing modules. from .jolia_atlas_encoders import ChestCTEmbed3D as _ensure_atlas_encoders # noqa: E402,F401 from .jolia_multimodal_msa import AtlasStage as _ensure_multimodal_msa # noqa: E402,F401 from .jolia_shim import BaseModel as _ensure_shim # noqa: E402,F401 # The text encoder is opt-in (loading Qwen3 is heavy); we still touch the # module so it's bundled in the dynamic-module copy set for snapshot_download # + sys.path.append users. from .text_encoder_jolia import JoliaTextEncoder as _ensure_text_encoder # noqa: E402,F401 @dataclass class JoliaOutput(ModelOutput): """Output of :meth:`JoliaModel.forward`. Attributes: pooler_output: ``(B, embed_dim)`` global embedding. organ_queries: ``(B, num_organs, query_dim)`` per-organ embeddings, or ``None`` when ``output_organ_queries=False`` / no organ-query head. """ pooler_output: torch.FloatTensor | None = None organ_queries: torch.FloatTensor | None = None class JoliaModel(PreTrainedModel): config_class = JoliaConfig base_model_prefix = "jolia" main_input_name = "image" def __init__(self, config: JoliaConfig) -> None: super().__init__(config) atlas_config = MultiModalAtlasConfig( embed_dim=config.embed_dim, num_classes=0, multiscale_feats=config.multiscale_feats, atlas_config=config.atlas_config, ) self.vision_model = MultiModalAtlas(atlas_config) if config.has_queries: self.organ_query_attn_scales: torch.nn.ModuleList | None = torch.nn.ModuleList( OrganQueryAttention( num_organs=config.num_organs, query_dim=config.patch_embed_dim, num_heads=config.num_heads, ) for _ in range(config.num_scales) ) else: self.organ_query_attn_scales = None # CLIP-style text head — small (~10 MB) projection from the paired text # encoder's hidden size to the shared embedding space, plus the trained # temperature (`logit_scale`) and additive `bias` used by `zero_shot_logits`. if config.has_text_head: self.text_projection: nn.Linear | None = nn.Linear( config.text_embed_dim, config.embed_dim, bias=True ) self.logit_scale = nn.Parameter(torch.zeros([])) self.text_bias = nn.Parameter(torch.zeros([])) else: self.text_projection = None self.register_parameter("logit_scale", None) self.register_parameter("text_bias", None) # ParallelOrganCLIP text head — used for organ-routed zero-shot (a text # prompt compared against a *specific* organ-query embedding). Same # shape as the global text head but trained against per-organ findings. if config.has_text_head and config.has_queries: self.organ_text_projection: nn.Linear | None = nn.Linear( config.text_embed_dim, config.embed_dim, bias=True ) self.organ_logit_scale = nn.Parameter(torch.zeros(config.num_organs)) self.organ_text_bias = nn.Parameter(torch.zeros(config.num_organs)) else: self.organ_text_projection = None self.register_parameter("organ_logit_scale", None) self.register_parameter("organ_text_bias", None) self.post_init() # ------------------------------------------------------------------ # Capability / naming helpers # ------------------------------------------------------------------ @property def embed_dim(self) -> int: return int(self.config.embed_dim) @property def query_dim(self) -> int: return int(self.config.query_dim) @property def num_organs(self) -> int: return int(self.config.num_organs) @property def has_queries(self) -> bool: return self.organ_query_attn_scales is not None @property def organ_slot_names(self) -> list[str]: """Names for organ-query slots ``0 .. len-1`` (trailing slots unused).""" return list(self.config.organ_slot_names) @property def has_text_head(self) -> bool: return self.text_projection is not None @property def has_organ_text_head(self) -> bool: return self.organ_text_projection is not None @property def text_encoder_id(self) -> str: """HuggingFace id of the paired text encoder (e.g. ``Qwen/Qwen3-Embedding-8B``).""" return self.config.text_encoder_id # ------------------------------------------------------------------ # Forward views # ------------------------------------------------------------------ def forward( self, image: torch.Tensor, output_organ_queries: bool = False, return_dict: bool = True, ) -> JoliaOutput | tuple: """Return the pooled global embedding (and optionally organ queries).""" if output_organ_queries: cls, organ = self.forward_with_queries(image) else: cls, organ = self.vision_model(image), None if not return_dict: return (cls,) if organ is None else (cls, organ) return JoliaOutput(pooler_output=cls, organ_queries=organ) def forward_with_patch_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: """Return ``(cls, multi_scale_patch_tokens)`` from the Atlas backbone.""" cls, patch_scales = self.vision_model(image, with_patch_tokens=True) return cls, patch_scales def forward_with_queries(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return ``(cls, organ_queries)`` — ``(B, embed_dim)``, ``(B, num_organs, query_dim)``.""" self._require_queries() cls, patch_scales = self.forward_with_patch_tokens(image) scale_outputs = [ attn(tokens)[:, : self.num_organs, :] for attn, tokens in zip(self.organ_query_attn_scales, patch_scales) # type: ignore[arg-type] ] organ = torch.cat(scale_outputs, dim=-1) if self.config.use_cls_residual: organ = organ + cls.unsqueeze(1).expand(-1, self.num_organs, -1) return cls, organ def extract_flat_feature(self, image: torch.Tensor) -> torch.Tensor: """Normalized ``[F.normalize(cls) ⊕ F.normalize(organ).flatten(1)]`` feature.""" cls, organ = self.forward_with_queries(image) cls_n = F.normalize(cls.float(), dim=-1, eps=1e-6) organ_n = F.normalize(organ.float(), dim=-1, eps=1e-6) return torch.cat([cls_n, organ_n.reshape(organ_n.size(0), -1)], dim=-1) # ------------------------------------------------------------------ # Zero-shot CLIP — image / text embeddings in a shared space # ------------------------------------------------------------------ def encode_image(self, image: torch.Tensor, normalize: bool = True) -> torch.Tensor: """Image embedding ready for zero-shot — ``(B, embed_dim)``. The released checkpoint uses CLS-only (no vision projection) before L2-normalization, matching ``MultimodalCLSZeroShotCLIP`` in the training repo. """ cls = self.vision_model(image) return F.normalize(cls.float(), dim=-1, eps=1e-6) if normalize else cls def encode_text(self, text_features: torch.Tensor, normalize: bool = True) -> torch.Tensor: """Project pooled text features into the shared embedding space. Args: text_features: ``(N, text_embed_dim)`` last-token-pooled features produced by the paired text encoder (use :class:`text_encoder_jolia.JoliaTextEncoder` to obtain them). normalize: L2-normalize the projected vectors (default). """ self._require_text_head() projected = self.text_projection(text_features.float()) # type: ignore[misc] return F.normalize(projected, dim=-1, eps=1e-6) if normalize else projected def zero_shot_logits(self, image_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor: """Calibrated image-vs-text logits with the trained temperature + bias. ``image_emb`` and ``text_emb`` must come from :meth:`encode_image` and :meth:`encode_text` (both L2-normalized). The clamp matches the training-time ``max_logit_scale=ln(100)``. """ self._require_text_head() scale = torch.clamp(self.logit_scale.float(), max=math.log(100.0)).exp() return image_emb @ text_emb.t() * scale + self.text_bias @torch.no_grad() def zero_shot( self, image: torch.Tensor, text_features: torch.Tensor, calibrated: bool = True, ) -> torch.Tensor: """One-call zero-shot scoring: image volume + pooled text -> ``(B, N)``. Args: image: Preprocessed volume ``(B, 11, 192, 192, 192)``. text_features: Pooled text features ``(N, text_embed_dim)`` from the paired text encoder. calibrated: When ``True`` (default), returns the trained CLIP logits ``cosine * exp(logit_scale) + bias`` — same output as ``MultimodalCLSZeroShotCLIP.get_logits_per_image`` in rarm, and what you want for ``torch.sigmoid(...)`` / ``torch.softmax(...)``. Set ``calibrated=False`` for raw cosine similarity in ``[-1, 1]`` (handy when you only need ranking and don't want the bias offset). """ img = self.encode_image(image) txt = self.encode_text(text_features) if calibrated: return self.zero_shot_logits(img, txt) return img @ txt.t() def _require_text_head(self) -> None: if self.text_projection is None: raise RuntimeError( "This Jolia checkpoint has no text head — zero-shot is unavailable. " "Re-export the model with text_embed_dim>0 from a CLIPObjective checkpoint." ) # ------------------------------------------------------------------ # Per-organ (query-routed) zero-shot # ------------------------------------------------------------------ def encode_organ_text(self, text_features: torch.Tensor, normalize: bool = True) -> torch.Tensor: """Project pooled text features through the ParallelOrganCLIP text head. Use this for organ-routed zero-shot: the resulting embedding lives in the same space as the per-organ image queries (different head than the global :meth:`encode_text`). Args: text_features: ``(N, text_embed_dim)`` last-token-pooled Qwen3 features. normalize: L2-normalize the projection (default). """ self._require_organ_text_head() projected = self.organ_text_projection(text_features.float()) # type: ignore[misc] return F.normalize(projected, dim=-1, eps=1e-6) if normalize else projected @torch.no_grad() def zero_shot_organ( self, image: torch.Tensor, text_features: torch.Tensor, organ: str, calibrated: bool = True, ) -> torch.Tensor: """Score text prompts against a single organ's query embedding. Routes the image through the per-organ cross-attention pooler for ``organ`` and contrasts that query embedding with text features projected by the ParallelOrganCLIP text head. Returns ``(B, N)``. Args: image: ``(B, 11, 192, 192, 192)`` preprocessed CT volume. text_features: ``(N, text_embed_dim)`` pooled Qwen3 features. organ: Organ name (must be in :attr:`organ_slot_names`). calibrated: When ``True`` (default), applies this organ's trained temperature + bias. Set ``False`` for raw cosine. """ self._require_organ_text_head() organ_emb = self.encode_organs(image, organs=[organ], normalize=True)[organ] # (B, 576) txt_emb = self.encode_organ_text(text_features) # (N, 576) cosine = organ_emb @ txt_emb.t() # (B, N) if not calibrated: return cosine idx = self.organ_slot_names.index(organ) scale = self.organ_logit_scale[idx].float().exp() bias = self.organ_text_bias[idx].float() return cosine * scale + bias @torch.no_grad() def zero_shot_organs( self, image: torch.Tensor, text_features: torch.Tensor, organs: list[str] | None = None, calibrated: bool = True, ) -> dict[str, torch.Tensor]: """Per-organ zero-shot for many organs at once. Returns ``{organ_name: (B, N)}``. ``organs=None`` runs every named slot. ``calibrated`` defaults to ``True`` (per-organ temperature + bias applied). """ self._require_organ_text_head() organ_embeds = self.encode_organs(image, organs=organs, normalize=True) txt_emb = self.encode_organ_text(text_features) names = self.organ_slot_names out: dict[str, torch.Tensor] = {} for name, emb in organ_embeds.items(): cosine = emb @ txt_emb.t() if calibrated: idx = names.index(name) scale = self.organ_logit_scale[idx].float().exp() bias = self.organ_text_bias[idx].float() cosine = cosine * scale + bias out[name] = cosine return out def _require_organ_text_head(self) -> None: if self.organ_text_projection is None: raise RuntimeError( "This Jolia checkpoint has no per-organ text head — organ-routed zero-shot " "is unavailable. The base text head (encode_text / zero_shot) may still work." ) # ------------------------------------------------------------------ # The easy, named organ-query API # ------------------------------------------------------------------ def encode_organs( self, image: torch.Tensor, organs: list[str] | None = None, normalize: bool = False, ) -> dict[str, torch.Tensor]: """Per-organ embeddings keyed by organ name. Args: image: Preprocessed volume ``(B, 11, 192, 192, 192)``. organs: Subset of organ names to return. ``None`` returns every named slot. Unknown names raise ``KeyError`` (with the valid names listed). normalize: L2-normalize each organ embedding (cosine-ready). Returns: ``{organ_name: (B, query_dim)}``. If the model has no organ-slot names, keys fall back to ``"slot_"``. """ self._require_queries() _, organ = self.forward_with_queries(image) # (B, num_organs, query_dim) names = self.organ_slot_names or [f"slot_{i}" for i in range(self.num_organs)] name_to_idx = {name: i for i, name in enumerate(names)} if organs is None: wanted = list(name_to_idx.items()) else: missing = [o for o in organs if o not in name_to_idx] if missing: raise KeyError( f"Unknown organ(s) {missing}. Available organs: {sorted(name_to_idx)}" ) wanted = [(o, name_to_idx[o]) for o in organs] out = {name: organ[:, idx, :] for name, idx in wanted} if normalize: out = {name: F.normalize(vec.float(), dim=-1, eps=1e-6) for name, vec in out.items()} return out def organ_similarity(self, image: torch.Tensor, organs: list[str] | None = None) -> torch.Tensor: """Cosine-similarity matrix between organ embeddings (batch-averaged). Returns ``(N, N)`` for the ``N`` requested organs — handy for probing which organs the model represents similarly. """ embeds = self.encode_organs(image, organs=organs, normalize=True) names = list(embeds) mat = torch.stack([embeds[n].mean(0) for n in names], dim=0) # (N, query_dim) return mat @ mat.t() def _require_queries(self) -> None: if self.organ_query_attn_scales is None: raise RuntimeError( "This Jolia checkpoint has no organ-query attention " "(num_organs=0); organ-level methods are unavailable." )