Chronos-MSK Radiologist (TW3 Staging)

An SVM classifier trained on frozen MedSigLIP-448 embeddings for Tanner-Whitehouse 3 (TW3) skeletal maturity staging of the distal radius from hand/wrist X-rays.

Part of the Chronos-MSK multi-agent bone age assessment system, built for the Google HAI-DEF competition.

Model Description

This model classifies cropped distal radius X-ray images into TW3 maturity stages. It operates in two steps: MedSigLIP-448 extracts a 1152-dimensional embedding from the image, then a lightweight SVM classifier maps that embedding to a maturity stage.

The approach demonstrates that MedSigLIP's medical pre-training captures bone morphology features so effectively that a simple linear classifier can perform meaningful skeletal staging β€” no fine-tuning of the vision encoder required.

TW3 Maturity Stages

The Tanner-Whitehouse 3 system grades the distal radius through a series of developmental stages based on ossification center appearance, epiphyseal shape, and fusion status:

Stage Biological Description Typical Age (Male) Typical Age (Female)
B Initial ossification center appears 0-2 years 0-1.5 years
C Center enlarges, distinct shape forming 2-6 years 1-5 years
D Epiphysis widens, clear shape 4-8 years 3-7 years
E Epiphysis approaches metaphysis width 6-10 years 5-9 years
F Epiphysis equals metaphysis width 8-12 years 7-11 years
G Early fusion begins 10-14 years 9-13 years
H Partial fusion, gap narrowing 13-17 years 11-15 years
I Complete fusion (adult morphology) 16+ years 14+ years

Architecture

Input: Cropped distal radius X-ray (any size)
  β”‚
  β–Ό
MedSigLIP-448 Vision Encoder (completely frozen)
  β”‚
  β–Ό
Pooled embedding β†’ 1152-D vector
  β”‚
  β–Ό
StandardScaler (zero mean, unit variance normalization)
  β”‚
  β–Ό
Linear SVM (probability=True, class_weight="balanced")
  β”‚
  β–Ό
TW3 Stage prediction + calibrated class probabilities

Test-Time Augmentation

For improved robustness, the model averages SVM probability outputs from the original image and its horizontal flip:

P_final = (P_original + P_flipped) / 2
Stage = argmax(P_final)

The original (un-flipped) embedding is preserved for downstream retrieval tasks to avoid "mirror world" artifacts in the embedding space.

Usage

Installation

pip install torch transformers scikit-learn joblib opencv-python Pillow huggingface_hub

Quick Start

import torch
import joblib
import numpy as np
import cv2
from PIL import Image
from transformers import SiglipVisionModel, AutoProcessor
from huggingface_hub import hf_hub_download


def load_radiologist(repo_id="rohitrajesh/chronos-msk-radiologist", device="cuda"):
    """Load the TW3 staging model from HuggingFace."""
    
    # Download SVM weights
    svm_path = hf_hub_download(repo_id, "radiologist_head.pkl")
    svm = joblib.load(svm_path)
    
    # Load MedSigLIP backbone (frozen, no fine-tuning)
    processor = AutoProcessor.from_pretrained("google/medsiglip-448")
    vision_model = SiglipVisionModel.from_pretrained(
        "google/medsiglip-448"
    ).to(device).eval()
    
    return vision_model, processor, svm


def get_embedding(image, vision_model, processor, device="cuda"):
    """Extract 1152-D MedSigLIP embedding from an image.
    
    Args:
        image: PIL Image, numpy array (BGR or RGB), or file path string
    """
    # Handle different input types
    if isinstance(image, str):
        img_bgr = cv2.imread(image)
        if img_bgr is None:
            raise ValueError(f"Cannot read image: {image}")
        pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
    elif isinstance(image, np.ndarray):
        if image.shape[2] == 3:
            pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        else:
            pil_img = Image.fromarray(image)
    elif isinstance(image, Image.Image):
        pil_img = image
    else:
        raise TypeError(f"Unsupported image type: {type(image)}")
    
    inputs = processor(images=pil_img, return_tensors="pt").to(device)
    with torch.no_grad():
        embedding = vision_model(**inputs).pooler_output.cpu().numpy()[0]
    
    return embedding


def predict_stage(image, vision_model, processor, svm, device="cuda"):
    """
    Predict TW3 maturity stage with Test-Time Augmentation.
    
    Args:
        image: File path, PIL Image, or numpy array
        
    Returns:
        stage: Predicted TW3 stage (str, e.g., "H")
        confidence: Probability of predicted stage (float)
        all_probs: Dict mapping each stage to its probability
        embedding: 1152-D embedding (for downstream use)
    """
    # Get embeddings for original and flipped
    if isinstance(image, str):
        img_bgr = cv2.imread(image)
        pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
    elif isinstance(image, np.ndarray):
        pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    else:
        pil_img = image
    
    pil_flipped = pil_img.transpose(Image.FLIP_LEFT_RIGHT)
    
    emb_orig = get_embedding(pil_img, vision_model, processor, device)
    emb_flip = get_embedding(pil_flipped, vision_model, processor, device)
    
    # Average SVM probabilities (TTA)
    probs_orig = svm.predict_proba([emb_orig])[0]
    probs_flip = svm.predict_proba([emb_flip])[0]
    avg_probs = (probs_orig + probs_flip) / 2.0
    
    # Get prediction
    best_idx = np.argmax(avg_probs)
    stage = svm.classes_[best_idx]
    confidence = avg_probs[best_idx]
    
    # Build probability dict
    all_probs = {
        svm.classes_[i]: float(avg_probs[i])
        for i in range(len(svm.classes_))
    }
    
    return stage, confidence, all_probs, emb_orig


# ── Example Usage ──

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")
    
    # Load model (downloads from HuggingFace on first run)
    print("Loading radiologist model...")
    vision_model, processor, svm = load_radiologist(device=device)
    print(f"Loaded! Classes: {list(svm.classes_)}")
    
    # Predict on a radius crop
    stage, confidence, probs, embedding = predict_stage(
        "radius_crop.png",  # Cropped distal radius image
        vision_model, processor, svm, device
    )
    
    print(f"\nTW3 Stage: {stage}")
    print(f"Confidence: {confidence:.1%}")
    print(f"\nAll stage probabilities:")
    for s, p in sorted(probs.items()):
        bar = "β–ˆ" * int(p * 40)
        print(f"  Stage {s}: {p:6.1%} {bar}")
    
    print(f"\nEmbedding shape: {embedding.shape}")  # (1152,)

Using with the Scout Detector

In the full Chronos-MSK pipeline, the radiologist receives cropped images from the Scout agent:

from huggingface_hub import hf_hub_download
from ultralytics import YOLO

# Load Scout
scout_path = hf_hub_download("rohitrajesh/chronos-msk-scout", "best_scout.pt")
scout = YOLO(scout_path)

# Detect and crop radius
img = cv2.imread("full_hand_xray.png")
results = scout(img, verbose=False)[0]

for box in results.boxes:
    if int(box.cls[0]) == 0:
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        # Add 15% padding
        h, w = img.shape[:2]
        px, py = int((x2-x1)*0.15), int((y2-y1)*0.15)
        crop = img[max(0,y1-py):min(h,y2+py), max(0,x1-px):min(w,x2+px)]
        break

# Stage the crop
stage, confidence, probs, emb = predict_stage(
    crop, vision_model, processor, svm, device
)
print(f"Stage: {stage} ({confidence:.0%})")

Batch Processing

import os
from tqdm import tqdm

image_dir = "radius_crops/"
results = []

for filename in tqdm(os.listdir(image_dir)):
    if not filename.endswith(('.png', '.jpg')):
        continue
    
    path = os.path.join(image_dir, filename)
    stage, conf, probs, emb = predict_stage(
        path, vision_model, processor, svm, device
    )
    results.append({
        "filename": filename,
        "stage": stage,
        "confidence": conf,
        "embedding": emb,  # 1152-D, reusable for retrieval
    })

# Summary
from collections import Counter
stage_counts = Counter(r["stage"] for r in results)
print("\nStage distribution:")
for stage in sorted(stage_counts):
    print(f"  {stage}: {stage_counts[stage]}")

Extract Embeddings Only (No Classification)

The MedSigLIP embedding is useful beyond classification β€” for retrieval, clustering, or other downstream tasks:

# Just get the embedding, no SVM prediction
embedding = get_embedding("radius_crop.png", vision_model, processor, device)
print(f"Shape: {embedding.shape}")  # (1152,)
print(f"L2 norm: {np.linalg.norm(embedding):.4f}")

Training Details

Label Generation

TW3 stage labels are not available in public datasets. We used a teacher-student distillation approach:

  1. Teacher: Google Gemini 3 Pro (Vision) was prompted with strict TW3 criteria to label 14,000 RSNA images with maturity stages
  2. Validation: 88% agreement with human expert annotations on a held-out sample
  3. Student: The lightweight SVM was trained on these synthetic labels, effectively distilling the reasoning of a massive proprietary model into a deployable open-weight classifier

Training Configuration

Parameter Value
Feature extractor MedSigLIP-448 (frozen, 1152-D)
Classifier scikit-learn LinearSVC with Platt scaling
Kernel Linear
Class weighting Balanced (handles stage imbalance)
Probability calibration Enabled (probability=True)
Preprocessing StandardScaler (zero mean, unit variance)
Train/test split 90/10, stratified by stage
Training samples ~12,600
Test samples ~1,400

Why a Linear SVM?

MedSigLIP's pre-training on medical image-text pairs produces embeddings where bone morphology features are already well-separated. A linear classifier is sufficient because:

  • The 1152-D embedding space is high-dimensional enough for linear separability
  • SVMs with balanced class weights handle the natural imbalance across stages (more samples in middle stages)
  • Platt scaling provides calibrated probabilities for downstream confidence estimation
  • Training takes seconds (vs hours for fine-tuning), enabling rapid iteration
  • The model file is <5 MB (vs hundreds of MB for fine-tuned models)

Role in the Pipeline

In the full Chronos-MSK system, the Radiologist agent serves two purposes:

  1. TW3 Stage Classification: Provides a categorical maturity assessment that contextualizes the numeric age prediction (e.g., "Stage H means near-complete fusion, consistent with 14-17 years")

  2. Embedding Extraction: The same 1152-D MedSigLIP embedding used for classification is also passed to the Archivist agent for retrieval. This dual use of a single forward pass is efficient β€” one embedding serves two agents.

The stage classification is used for clinical reporting and sanity checking (e.g., if the regressor predicts 18 years but the radiologist sees Stage D, something is wrong). It does not directly influence the numeric bone age prediction.

Saved Artifacts

File Description Size
radiologist_head.pkl scikit-learn Pipeline (StandardScaler + SVM) serialized with joblib ~5 MB

The base MedSigLIP-448 model (~1.2 GB) is downloaded automatically from google/medsiglip-448 on first use.

Limitations

  • TW3 labels were generated synthetically by Gemini 3 Pro, not by human radiologists. While 88% agreement was achieved, systematic biases from the teacher model may propagate.
  • The model classifies the distal radius only. Full TW3 assessment involves multiple bones (radius, ulna, metacarpals, phalanges) β€” this is a single-bone simplification.
  • Stage boundaries are inherently fuzzy in biology. Cases near stage transitions (e.g., G/H boundary) will have lower confidence and higher classification uncertainty.
  • Performance depends on crop quality. Poor crops from the Scout detector (missed detection, wrong region) will degrade staging accuracy.

Ethical Considerations

TW3 staging is used in clinical workflows for growth disorder diagnosis and in forensic/legal contexts for age determination. While this model provides a useful automated first assessment, skeletal maturity staging should always be confirmed by qualified radiologists, especially in consequential decisions. The synthetic labeling approach (using Gemini as teacher) means the model inherits any biases present in Gemini's training data.

Citation

@misc{chronosmsk2025,
    title={Chronos-MSK: Bias-Aware Skeletal Maturity Assessment at the Edge},
    author={Rohit Rajesh},
    year={2025},
    note={Google HAI-DEF Competition Submission}
}

Links

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for 04RR/chronos-msk-radiologist

Finetuned
(41)
this model

Collection including 04RR/chronos-msk-radiologist