Feature Extraction
Transformers
Safetensors
jolia
medical
radiology
ct
3d
vision
foundation-model
self-supervised
custom_code
Instructions to use raidium/Jolia with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use raidium/Jolia with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="raidium/Jolia", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("raidium/Jolia", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Jolia: a self-contained Atlas vision backbone with named organ queries. | |
| ``JoliaModel`` is the public Hugging Face entry point. It wraps a vendored | |
| ``MultiModalAtlas`` 3D backbone, a per-scale bank of organ-query | |
| cross-attention poolers, and a CLIP-style text-projection head used for | |
| zero-shot classification. | |
| Feature views: | |
| * :meth:`forward` / ``__call__`` — the pooled global (CLS-equivalent) embedding. | |
| * :meth:`encode_image` — L2-normalized image embedding in the shared CLIP space. | |
| * :meth:`encode_text` — L2-normalized text embedding in the shared CLIP space | |
| (input: pooled Qwen3 text features, output: 576-d). | |
| * :meth:`zero_shot_logits` / :meth:`zero_shot` — image-vs-text similarity with | |
| the trained temperature + bias. | |
| * :meth:`encode_organs` — per-organ embeddings keyed by organ **name**. | |
| * :meth:`extract_flat_feature` — the normalized ``[cls ⊕ organs]`` vector used | |
| for linear probing. | |
| Load it with:: | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("raidium/Jolia", trust_remote_code=True).eval() | |
| For zero-shot, pair it with the paired text encoder | |
| (see :class:`text_encoder_jolia.JoliaTextEncoder`). | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.utils import ModelOutput | |
| from .configuration_jolia import JoliaConfig | |
| from .jolia_multimodal_atlas import MultiModalAtlas, MultiModalAtlasConfig | |
| from .jolia_organ_query_attention import OrganQueryAttention | |
| # HF's trust_remote_code loader copies only the *direct* relative imports of this | |
| # entry module. The lines below force every transitively-vendored backbone file | |
| # (and the text-encoder helper) into that copy set so the repo loads from the | |
| # hub without missing modules. | |
| from .jolia_atlas_encoders import ChestCTEmbed3D as _ensure_atlas_encoders # noqa: E402,F401 | |
| from .jolia_multimodal_msa import AtlasStage as _ensure_multimodal_msa # noqa: E402,F401 | |
| from .jolia_shim import BaseModel as _ensure_shim # noqa: E402,F401 | |
| # The text encoder is opt-in (loading Qwen3 is heavy); we still touch the | |
| # module so it's bundled in the dynamic-module copy set for snapshot_download | |
| # + sys.path.append users. | |
| from .text_encoder_jolia import JoliaTextEncoder as _ensure_text_encoder # noqa: E402,F401 | |
| class JoliaOutput(ModelOutput): | |
| """Output of :meth:`JoliaModel.forward`. | |
| Attributes: | |
| pooler_output: ``(B, embed_dim)`` global embedding. | |
| organ_queries: ``(B, num_organs, query_dim)`` per-organ embeddings, or | |
| ``None`` when ``output_organ_queries=False`` / no organ-query head. | |
| """ | |
| pooler_output: torch.FloatTensor | None = None | |
| organ_queries: torch.FloatTensor | None = None | |
| class JoliaModel(PreTrainedModel): | |
| config_class = JoliaConfig | |
| base_model_prefix = "jolia" | |
| main_input_name = "image" | |
| def __init__(self, config: JoliaConfig) -> None: | |
| super().__init__(config) | |
| atlas_config = MultiModalAtlasConfig( | |
| embed_dim=config.embed_dim, | |
| num_classes=0, | |
| multiscale_feats=config.multiscale_feats, | |
| atlas_config=config.atlas_config, | |
| ) | |
| self.vision_model = MultiModalAtlas(atlas_config) | |
| if config.has_queries: | |
| self.organ_query_attn_scales: torch.nn.ModuleList | None = torch.nn.ModuleList( | |
| OrganQueryAttention( | |
| num_organs=config.num_organs, | |
| query_dim=config.patch_embed_dim, | |
| num_heads=config.num_heads, | |
| ) | |
| for _ in range(config.num_scales) | |
| ) | |
| else: | |
| self.organ_query_attn_scales = None | |
| # CLIP-style text head — small (~10 MB) projection from the paired text | |
| # encoder's hidden size to the shared embedding space, plus the trained | |
| # temperature (`logit_scale`) and additive `bias` used by `zero_shot_logits`. | |
| if config.has_text_head: | |
| self.text_projection: nn.Linear | None = nn.Linear( | |
| config.text_embed_dim, config.embed_dim, bias=True | |
| ) | |
| self.logit_scale = nn.Parameter(torch.zeros([])) | |
| self.text_bias = nn.Parameter(torch.zeros([])) | |
| else: | |
| self.text_projection = None | |
| self.register_parameter("logit_scale", None) | |
| self.register_parameter("text_bias", None) | |
| # ParallelOrganCLIP text head — used for organ-routed zero-shot (a text | |
| # prompt compared against a *specific* organ-query embedding). Same | |
| # shape as the global text head but trained against per-organ findings. | |
| if config.has_text_head and config.has_queries: | |
| self.organ_text_projection: nn.Linear | None = nn.Linear( | |
| config.text_embed_dim, config.embed_dim, bias=True | |
| ) | |
| self.organ_logit_scale = nn.Parameter(torch.zeros(config.num_organs)) | |
| self.organ_text_bias = nn.Parameter(torch.zeros(config.num_organs)) | |
| else: | |
| self.organ_text_projection = None | |
| self.register_parameter("organ_logit_scale", None) | |
| self.register_parameter("organ_text_bias", None) | |
| self.post_init() | |
| # ------------------------------------------------------------------ | |
| # Capability / naming helpers | |
| # ------------------------------------------------------------------ | |
| def embed_dim(self) -> int: | |
| return int(self.config.embed_dim) | |
| def query_dim(self) -> int: | |
| return int(self.config.query_dim) | |
| def num_organs(self) -> int: | |
| return int(self.config.num_organs) | |
| def has_queries(self) -> bool: | |
| return self.organ_query_attn_scales is not None | |
| def organ_slot_names(self) -> list[str]: | |
| """Names for organ-query slots ``0 .. len-1`` (trailing slots unused).""" | |
| return list(self.config.organ_slot_names) | |
| def has_text_head(self) -> bool: | |
| return self.text_projection is not None | |
| def has_organ_text_head(self) -> bool: | |
| return self.organ_text_projection is not None | |
| def text_encoder_id(self) -> str: | |
| """HuggingFace id of the paired text encoder (e.g. ``Qwen/Qwen3-Embedding-8B``).""" | |
| return self.config.text_encoder_id | |
| # ------------------------------------------------------------------ | |
| # Forward views | |
| # ------------------------------------------------------------------ | |
| def forward( | |
| self, | |
| image: torch.Tensor, | |
| output_organ_queries: bool = False, | |
| return_dict: bool = True, | |
| ) -> JoliaOutput | tuple: | |
| """Return the pooled global embedding (and optionally organ queries).""" | |
| if output_organ_queries: | |
| cls, organ = self.forward_with_queries(image) | |
| else: | |
| cls, organ = self.vision_model(image), None | |
| if not return_dict: | |
| return (cls,) if organ is None else (cls, organ) | |
| return JoliaOutput(pooler_output=cls, organ_queries=organ) | |
| def forward_with_patch_tokens(self, image: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: | |
| """Return ``(cls, multi_scale_patch_tokens)`` from the Atlas backbone.""" | |
| cls, patch_scales = self.vision_model(image, with_patch_tokens=True) | |
| return cls, patch_scales | |
| def forward_with_queries(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Return ``(cls, organ_queries)`` — ``(B, embed_dim)``, ``(B, num_organs, query_dim)``.""" | |
| self._require_queries() | |
| cls, patch_scales = self.forward_with_patch_tokens(image) | |
| scale_outputs = [ | |
| attn(tokens)[:, : self.num_organs, :] | |
| for attn, tokens in zip(self.organ_query_attn_scales, patch_scales) # type: ignore[arg-type] | |
| ] | |
| organ = torch.cat(scale_outputs, dim=-1) | |
| if self.config.use_cls_residual: | |
| organ = organ + cls.unsqueeze(1).expand(-1, self.num_organs, -1) | |
| return cls, organ | |
| def extract_flat_feature(self, image: torch.Tensor) -> torch.Tensor: | |
| """Normalized ``[F.normalize(cls) ⊕ F.normalize(organ).flatten(1)]`` feature.""" | |
| cls, organ = self.forward_with_queries(image) | |
| cls_n = F.normalize(cls.float(), dim=-1, eps=1e-6) | |
| organ_n = F.normalize(organ.float(), dim=-1, eps=1e-6) | |
| return torch.cat([cls_n, organ_n.reshape(organ_n.size(0), -1)], dim=-1) | |
| # ------------------------------------------------------------------ | |
| # Zero-shot CLIP — image / text embeddings in a shared space | |
| # ------------------------------------------------------------------ | |
| def encode_image(self, image: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |
| """Image embedding ready for zero-shot — ``(B, embed_dim)``. | |
| The released checkpoint uses CLS-only (no vision projection) before | |
| L2-normalization, matching ``MultimodalCLSZeroShotCLIP`` in the | |
| training repo. | |
| """ | |
| cls = self.vision_model(image) | |
| return F.normalize(cls.float(), dim=-1, eps=1e-6) if normalize else cls | |
| def encode_text(self, text_features: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |
| """Project pooled text features into the shared embedding space. | |
| Args: | |
| text_features: ``(N, text_embed_dim)`` last-token-pooled features | |
| produced by the paired text encoder (use | |
| :class:`text_encoder_jolia.JoliaTextEncoder` to obtain them). | |
| normalize: L2-normalize the projected vectors (default). | |
| """ | |
| self._require_text_head() | |
| projected = self.text_projection(text_features.float()) # type: ignore[misc] | |
| return F.normalize(projected, dim=-1, eps=1e-6) if normalize else projected | |
| def zero_shot_logits(self, image_emb: torch.Tensor, text_emb: torch.Tensor) -> torch.Tensor: | |
| """Calibrated image-vs-text logits with the trained temperature + bias. | |
| ``image_emb`` and ``text_emb`` must come from :meth:`encode_image` and | |
| :meth:`encode_text` (both L2-normalized). The clamp matches the | |
| training-time ``max_logit_scale=ln(100)``. | |
| """ | |
| self._require_text_head() | |
| scale = torch.clamp(self.logit_scale.float(), max=math.log(100.0)).exp() | |
| return image_emb @ text_emb.t() * scale + self.text_bias | |
| def zero_shot( | |
| self, | |
| image: torch.Tensor, | |
| text_features: torch.Tensor, | |
| calibrated: bool = True, | |
| ) -> torch.Tensor: | |
| """One-call zero-shot scoring: image volume + pooled text -> ``(B, N)``. | |
| Args: | |
| image: Preprocessed volume ``(B, 11, 192, 192, 192)``. | |
| text_features: Pooled text features ``(N, text_embed_dim)`` from | |
| the paired text encoder. | |
| calibrated: When ``True`` (default), returns the trained CLIP | |
| logits ``cosine * exp(logit_scale) + bias`` — same output as | |
| ``MultimodalCLSZeroShotCLIP.get_logits_per_image`` in rarm, | |
| and what you want for ``torch.sigmoid(...)`` / | |
| ``torch.softmax(...)``. Set ``calibrated=False`` for raw | |
| cosine similarity in ``[-1, 1]`` (handy when you only need | |
| ranking and don't want the bias offset). | |
| """ | |
| img = self.encode_image(image) | |
| txt = self.encode_text(text_features) | |
| if calibrated: | |
| return self.zero_shot_logits(img, txt) | |
| return img @ txt.t() | |
| def _require_text_head(self) -> None: | |
| if self.text_projection is None: | |
| raise RuntimeError( | |
| "This Jolia checkpoint has no text head — zero-shot is unavailable. " | |
| "Re-export the model with text_embed_dim>0 from a CLIPObjective checkpoint." | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Per-organ (query-routed) zero-shot | |
| # ------------------------------------------------------------------ | |
| def encode_organ_text(self, text_features: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |
| """Project pooled text features through the ParallelOrganCLIP text head. | |
| Use this for organ-routed zero-shot: the resulting embedding lives in | |
| the same space as the per-organ image queries (different head than the | |
| global :meth:`encode_text`). | |
| Args: | |
| text_features: ``(N, text_embed_dim)`` last-token-pooled Qwen3 features. | |
| normalize: L2-normalize the projection (default). | |
| """ | |
| self._require_organ_text_head() | |
| projected = self.organ_text_projection(text_features.float()) # type: ignore[misc] | |
| return F.normalize(projected, dim=-1, eps=1e-6) if normalize else projected | |
| def zero_shot_organ( | |
| self, | |
| image: torch.Tensor, | |
| text_features: torch.Tensor, | |
| organ: str, | |
| calibrated: bool = True, | |
| ) -> torch.Tensor: | |
| """Score text prompts against a single organ's query embedding. | |
| Routes the image through the per-organ cross-attention pooler for | |
| ``organ`` and contrasts that query embedding with text features | |
| projected by the ParallelOrganCLIP text head. Returns ``(B, N)``. | |
| Args: | |
| image: ``(B, 11, 192, 192, 192)`` preprocessed CT volume. | |
| text_features: ``(N, text_embed_dim)`` pooled Qwen3 features. | |
| organ: Organ name (must be in :attr:`organ_slot_names`). | |
| calibrated: When ``True`` (default), applies this organ's | |
| trained temperature + bias. Set ``False`` for raw cosine. | |
| """ | |
| self._require_organ_text_head() | |
| organ_emb = self.encode_organs(image, organs=[organ], normalize=True)[organ] # (B, 576) | |
| txt_emb = self.encode_organ_text(text_features) # (N, 576) | |
| cosine = organ_emb @ txt_emb.t() # (B, N) | |
| if not calibrated: | |
| return cosine | |
| idx = self.organ_slot_names.index(organ) | |
| scale = self.organ_logit_scale[idx].float().exp() | |
| bias = self.organ_text_bias[idx].float() | |
| return cosine * scale + bias | |
| def zero_shot_organs( | |
| self, | |
| image: torch.Tensor, | |
| text_features: torch.Tensor, | |
| organs: list[str] | None = None, | |
| calibrated: bool = True, | |
| ) -> dict[str, torch.Tensor]: | |
| """Per-organ zero-shot for many organs at once. | |
| Returns ``{organ_name: (B, N)}``. ``organs=None`` runs every named slot. | |
| ``calibrated`` defaults to ``True`` (per-organ temperature + bias applied). | |
| """ | |
| self._require_organ_text_head() | |
| organ_embeds = self.encode_organs(image, organs=organs, normalize=True) | |
| txt_emb = self.encode_organ_text(text_features) | |
| names = self.organ_slot_names | |
| out: dict[str, torch.Tensor] = {} | |
| for name, emb in organ_embeds.items(): | |
| cosine = emb @ txt_emb.t() | |
| if calibrated: | |
| idx = names.index(name) | |
| scale = self.organ_logit_scale[idx].float().exp() | |
| bias = self.organ_text_bias[idx].float() | |
| cosine = cosine * scale + bias | |
| out[name] = cosine | |
| return out | |
| def _require_organ_text_head(self) -> None: | |
| if self.organ_text_projection is None: | |
| raise RuntimeError( | |
| "This Jolia checkpoint has no per-organ text head — organ-routed zero-shot " | |
| "is unavailable. The base text head (encode_text / zero_shot) may still work." | |
| ) | |
| # ------------------------------------------------------------------ | |
| # The easy, named organ-query API | |
| # ------------------------------------------------------------------ | |
| def encode_organs( | |
| self, | |
| image: torch.Tensor, | |
| organs: list[str] | None = None, | |
| normalize: bool = False, | |
| ) -> dict[str, torch.Tensor]: | |
| """Per-organ embeddings keyed by organ name. | |
| Args: | |
| image: Preprocessed volume ``(B, 11, 192, 192, 192)``. | |
| organs: Subset of organ names to return. ``None`` returns every | |
| named slot. Unknown names raise ``KeyError`` (with the valid | |
| names listed). | |
| normalize: L2-normalize each organ embedding (cosine-ready). | |
| Returns: | |
| ``{organ_name: (B, query_dim)}``. If the model has no organ-slot | |
| names, keys fall back to ``"slot_<i>"``. | |
| """ | |
| self._require_queries() | |
| _, organ = self.forward_with_queries(image) # (B, num_organs, query_dim) | |
| names = self.organ_slot_names or [f"slot_{i}" for i in range(self.num_organs)] | |
| name_to_idx = {name: i for i, name in enumerate(names)} | |
| if organs is None: | |
| wanted = list(name_to_idx.items()) | |
| else: | |
| missing = [o for o in organs if o not in name_to_idx] | |
| if missing: | |
| raise KeyError( | |
| f"Unknown organ(s) {missing}. Available organs: {sorted(name_to_idx)}" | |
| ) | |
| wanted = [(o, name_to_idx[o]) for o in organs] | |
| out = {name: organ[:, idx, :] for name, idx in wanted} | |
| if normalize: | |
| out = {name: F.normalize(vec.float(), dim=-1, eps=1e-6) for name, vec in out.items()} | |
| return out | |
| def organ_similarity(self, image: torch.Tensor, organs: list[str] | None = None) -> torch.Tensor: | |
| """Cosine-similarity matrix between organ embeddings (batch-averaged). | |
| Returns ``(N, N)`` for the ``N`` requested organs — handy for probing | |
| which organs the model represents similarly. | |
| """ | |
| embeds = self.encode_organs(image, organs=organs, normalize=True) | |
| names = list(embeds) | |
| mat = torch.stack([embeds[n].mean(0) for n in names], dim=0) # (N, query_dim) | |
| return mat @ mat.t() | |
| def _require_queries(self) -> None: | |
| if self.organ_query_attn_scales is None: | |
| raise RuntimeError( | |
| "This Jolia checkpoint has no organ-query attention " | |
| "(num_organs=0); organ-level methods are unavailable." | |
| ) | |