Jolia / README.md
SovanK's picture
Upload folder using huggingface_hub
6858e35 verified
|
Raw
History Blame Contribute Delete
7.33 kB
---
license: other
library_name: transformers
pipeline_tag: feature-extraction
tags:
- medical
- radiology
- ct
- 3d
- vision
- foundation-model
- self-supervised
---
# Jolia — A 3D CT foundation model with anatomical representations
**Jolia** is a 3D CT foundation model that encodes images into vector representations
program. It encodes a whole 3D CT volume into:
- a **global embedding** (`embed_dim = 576`), and
- **per-organ embeddings** — 102 named organ slots produced by organ-query
cross-attention pooling, trained to align with per-organ report text.
## Installation
```bash
pip install torch transformers timm einops numpy safetensors
```
## Quick start
```python
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained("raidium/Jolia", trust_remote_code=True).eval()
# image: a preprocessed CT volume, shape (B, 11, 192, 192, 192) — see Preprocessing
with torch.no_grad():
cls = model(image).pooler_output # (B, 576) global embedding
```
## Preprocessing
Raw CT volumes must be brought to the Atlas input format
(`(11, 192, 192, 192)`: 1.5 mm isotropic, 192³ crop, 11 CT windowing channels).
Grab the bundled preprocessor from the repo:
```python
from huggingface_hub import snapshot_download
import sys
repo = snapshot_download("raidium/Jolia")
sys.path.append(repo)
from preprocessing_jolia import JoliaPreprocessor
pre = JoliaPreprocessor()
# volume: (H, W, D) in Hounsfield units; resolution in mm (row, col, slice)
image = pre(volume, resolution=(0.7, 0.7, 1.0)).unsqueeze(0) # (1, 11, 192, 192, 192)
```
## Working with organ queries (the easy way)
Per-organ embeddings are addressed **by name**
```python
# All 102 organs as {name: (B, 576)}
organs = model.encode_organs(image)
# A subset, L2-normalized (cosine-ready)
sub = model.encode_organs(image, organs=["liver", "spleen", "pancreas"], normalize=True)
print(model.organ_slot_names) # the 102 available organ names
```
For linear probing, the concatenated normalized feature is one call:
```python
flat = model.extract_flat_feature(image) # (B, 576 * (1 + num_organs))
```
## Zero-shot classification
Jolia ships with the CLIP text-projection head it was trained with. Pair it
with the text encoder Jolia was trained against (`Qwen/Qwen3-Embedding-8B`)
to classify a CT against arbitrary text prompts with no fine-tuning.
The text encoder is the heavy piece (~18 GB), so loading it is opt-in.
Jolia bundles a small helper, `JoliaTextEncoder`, that handles tokenization
and the (attention-mask-aware) last-token pooling the model was trained with.
```python
import sys, torch
from huggingface_hub import snapshot_download
from transformers import AutoModel
# 1) Vision: Jolia from the Hub (self-contained, ~89 MB).
jolia = AutoModel.from_pretrained("raidium/Jolia", trust_remote_code=True).eval()
# 2) Text: Qwen3-Embedding-8B + Jolia's bundled JoliaTextEncoder helper.
repo = snapshot_download("raidium/Jolia"); sys.path.append(repo)
from text_encoder_jolia import JoliaTextEncoder
text_encoder = JoliaTextEncoder.from_pretrained(
"Qwen/Qwen3-Embedding-8B",
dtype=torch.bfloat16, # ~18 GB at fp32; bf16 halves it
device_map="auto", # or .to("cuda")
).eval()
# 3) Zero-shot classification on a preprocessed CT volume.
prompts = ["a CT showing a liver lesion", "a CT showing pneumonia", "a normal abdominal CT"]
with torch.no_grad():
text_features = text_encoder(prompts) # (N, 4096) last-token-pooled
logits = jolia.zero_shot(image, text_features) # (B, N) — calibrated CLIP logits
probs = torch.sigmoid(logits) # per-pair "is this a match?" probability
# Same output as `MultimodalCLSZeroShotCLIP.get_logits_per_image` in rarm.
# Pass `calibrated=False` if you want raw cosine in [-1, 1] (ranking-only):
cosine = jolia.zero_shot(image, text_features, calibrated=False)
```
### Per-organ (query-routed) zero-shot
Jolia also ships the **ParallelOrganCLIP** text head it was trained against
the per-organ findings of each report. This routes a text prompt to one
specific organ's query embedding — useful when you want to ask
*"is there a lesion in the **liver**?"* rather than scoring against the
whole-volume CLS.
```python
text_features = text_encoder(["a lesion", "looks normal"]) # (N, 4096)
# Score N prompts against a single organ — calibrated CLIP logits (B, N)
liver_logits = jolia.zero_shot_organ(image, text_features, organ="liver")
liver_probs = torch.sigmoid(liver_logits)
# Score N prompts against many organs at once -> {organ_name: (B, N)}
scores = jolia.zero_shot_organs(
image, text_features, organs=["liver", "spleen", "kidneys", "pancreas"]
)
# Raw cosine if you only need ranking and don't want the bias offset:
cosine = jolia.zero_shot_organ(image, text_features, organ="liver", calibrated=False)
```
Each organ has its **own** trained temperature and bias (the
`(200,)`-shaped `organ_logit_scale` / `organ_text_bias`), automatically applied
when `calibrated=True`. `jolia.organ_slot_names` lists the 102 organs that can
be routed. The per-organ head uses a *different* text projection than the
global one (`encode_text` vs `encode_organ_text`), trained on per-organ
findings text.
A runnable, self-contained script is bundled as `example_zero_shot.py`.
## Model details
| | |
|---|---|
| Backbone | `MultiModalAtlas` — multi-scale 3D ViT, `dim=192`, heads `6`, stages `[2, 2, 8]` |
| Patch embed | `6×6×6`, 11 input channels (CT windowing), `merge_ratio = 4³` |
| Global embedding | 576-d |
| Organ queries | 102 slots × 192-d × 3 scales → 576-d |
| Parameters | ~22 M (89 MB `safetensors`) |
| Input | `(B, 11, 192, 192, 192)` float32 |
| Training data | INSPECT, CT-RATE, Stanford-Abdominal-CT (chest + abdomen CT) |
| Objectives | Volume–report CLIP + per-organ ParallelOrganCLIP |
| Paired text encoder | `Qwen/Qwen3-Embedding-8B` (last-token pooling, context length 512) |
| Global text projection | Linear `4096 → 576` (+ scalar temperature + bias) — global CLIP head |
| Per-organ text projection | Linear `4096 → 576` (+ per-organ temperature + bias, both `(200,)`) — ParallelOrganCLIP head |
The 102 organ-slot names are the alphabetically-sorted union of per-organ
report sections across the training datasets; slots `102–199` are unused
padding. Methods like `encode_organs` expose only the named slots.
## Outputs
`model(image)` returns a `JoliaOutput` with:
- `pooler_output``(B, 576)` global embedding,
- `organ_queries``(B, num_organs, 576)`, populated when called with
`output_organ_queries=True`.
## Intended use & limitations
> ⚠️ Research preview. Not a medical device; not for clinical use.
Jolia is a **feature extractor** for downstream radiology tasks (classification,
retrieval, per-organ analysis) via linear probing or fine-tuning. It is trained
on adult chest/abdominal CT and will not generalize to other modalities or
unusual acquisition protocols. **It does not produce diagnoses** and must not be
used for clinical decision-making.
## Citation
```bibtex
@misc{raidium_jolia,
title = {Jolia: a 3D CT Atlas foundation model with per-organ queries},
author = {Raidium},
year = {2026},
howpublished = {\url{https://huggingface.co/raidium/Jolia}}
}
```