Chronos-MSK Bone Age Regressor

A LoRA-fine-tuned MedSigLIP-448 model for pediatric bone age estimation from hand/wrist X-rays. Achieves 8.81-month MAE on the RSNA Bone Age validation set (Pearson r = 0.963), surpassing typical human expert inter-reader variability of 10-13 months.

Part of the Chronos-MSK multi-agent bone age assessment system, built for the Google HAI-DEF competition.

Model Description

This model predicts bone age in months from a single left hand/wrist radiograph. Instead of regressing to a single number, it uses Deep Label Distribution Learning (DLDL) β€” predicting a full probability distribution over 228 monthly age bins. The final age estimate is the expected value of this distribution, which naturally captures prediction uncertainty.

Biological sex is incorporated as a conditioning signal through a learned embedding, reflecting the known differences in skeletal maturation rates between males and females.

Architecture

Input: Hand/Wrist X-Ray (448Γ—448, letterbox padded)
  β”‚
  β–Ό
MedSigLIP-448 Vision Encoder (frozen base + DoRA adapters)
  β”‚
  β–Ό last_hidden_state (sequence of patch embeddings)
  β”‚
  β–Ό
Global Average Pooling + LayerNorm β†’ 1152-D visual features
  β”‚
  β”œβ”€β”€ Sex Input (0/1) β†’ Linear(1,64) β†’ GELU β†’ Linear(64,128) β†’ 128-D
  β”‚
  β–Ό
Concatenate [visual_1152; sex_128] β†’ 1280-D
  β”‚
  β–Ό
LayerNorm β†’ Linear(1280,512) β†’ GELU β†’ Dropout(0.1) β†’ Linear(512,228)
  β”‚
  β–Ό
Softmax β†’ Probability distribution over 228 months
  β”‚
  β–Ό
Expected Value: Γ’ = Ξ£(k Γ— P(k)) for k = 0..227

Key Design Choices

  • DoRA (Weight-Decomposed LoRA): rank=16, alpha=32 β€” adapts the frozen MedSigLIP encoder with only ~2M trainable parameters while preserving medical domain knowledge
  • DLDL Loss: KL divergence between predicted distribution and a Gaussian target centered on the true age (sigma=12 months) β€” smoother gradients than MSE regression
  • Sex Conditioning: Learned embedding concatenated with visual features, enabling sex-specific age estimation without separate models
  • Test-Time Augmentation: Average predictions from original and horizontally-flipped image for improved accuracy

Results

Evaluated on 1,425 held-out RSNA Bone Age validation cases (strict no-leakage protocol):

Metric Value
Mean Absolute Error 8.81 months (0.73 years)
Median Absolute Error 7.27 months
Root Mean Square Error 11.39 months
Pearson r 0.963
RΒ² 0.927
Within Β±6 months 42.9%
Within Β±12 months 73.2%
Within Β±24 months 95.7%

Age-Stratified Performance

Age Range Cases MAE
0-5 years 94 9.54 months
5-10 years 394 10.04 months
10-15 years 809 8.07 months
15-19 years 124 8.45 months

Sex-Stratified Performance

Sex Cases MAE
Male 773 8.26 months
Female 652 9.47 months

Usage

Installation

pip install torch transformers peft numpy opencv-python huggingface_hub

Inference

import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import json
import numpy as np
from transformers import SiglipVisionModel
from peft import PeftModel
from huggingface_hub import hf_hub_download, snapshot_download


# ── Model Components ──

class GlobalAveragePooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x):
        return self.norm(x.mean(dim=1))


def letterbox_resize(image, size=448):
    """Resize preserving aspect ratio with zero-padding."""
    h, w = image.shape[:2]
    scale = size / max(h, w)
    nh, nw = int(h * scale), int(w * scale)
    resized = cv2.resize(image, (nw, nh))
    padded = np.zeros((size, size, 3), dtype=np.uint8)
    top, left = (size - nh) // 2, (size - nw) // 2
    padded[top:top + nh, left:left + nw] = resized
    return padded


# ── Load Model ──

def load_regressor(repo_id="rohitrajesh/chronos-msk-regressor", device="cuda"):
    """Load the complete bone age regressor from HuggingFace."""
    
    # Download all files
    config_path = hf_hub_download(repo_id, "config.json")
    heads_path = hf_hub_download(repo_id, "heads.pth")
    adapter_dir = snapshot_download(repo_id, allow_patterns=["adapter/*"])
    
    # Load config
    with open(config_path) as f:
        config = json.load(f)
    
    # Load MedSigLIP base model with LoRA adapter
    base_model = SiglipVisionModel.from_pretrained(
        config["model_id"], torch_dtype=torch.float32
    )
    backbone = PeftModel.from_pretrained(base_model, f"{adapter_dir}/adapter")
    
    hidden = config["hidden_size"]
    num_bins = config["num_bins"]
    
    # Build prediction heads
    pooler = GlobalAveragePooling(hidden)
    gender_embed = nn.Sequential(
        nn.Linear(1, 64), nn.GELU(), nn.Linear(64, 128)
    )
    classifier = nn.Sequential(
        nn.LayerNorm(hidden + 128),
        nn.Linear(hidden + 128, 512),
        nn.GELU(),
        nn.Dropout(0.1),
        nn.Linear(512, num_bins),
    )
    
    # Load trained head weights
    state = torch.load(heads_path, map_location=device)
    pooler.load_state_dict(state["pooler"])
    gender_embed.load_state_dict(state["gender_embed"])
    classifier.load_state_dict(state["classifier"])
    
    # Move to device and set eval mode
    backbone.to(device).eval()
    pooler.to(device).eval()
    gender_embed.to(device).eval()
    classifier.to(device).eval()
    
    return backbone, pooler, gender_embed, classifier, config


# ── Predict ──

def predict_bone_age(image_path, is_male, backbone, pooler,
                     gender_embed, classifier, config, device="cuda"):
    """
    Predict bone age from a hand/wrist X-ray.
    
    Args:
        image_path: Path to X-ray image (any common format)
        is_male: Boolean, biological sex
        
    Returns:
        age_months: Predicted bone age in months (float)
    """
    size = config["image_size"]
    
    # Load and preprocess
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Cannot read image: {image_path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = letterbox_resize(img, size)
    
    # Normalize: (pixel / 255 - 0.5) / 0.5
    norm = (img.astype(np.float32) / 255.0 - 0.5) / 0.5
    t_img = torch.from_numpy(norm.transpose(2, 0, 1)).unsqueeze(0).to(device)
    t_sex = torch.tensor([[1.0 if is_male else 0.0]], device=device)
    
    with torch.no_grad():
        # Forward pass
        out = backbone(pixel_values=t_img)
        vis = pooler(out.last_hidden_state)
        gen = gender_embed(t_sex)
        logits = classifier(torch.cat([vis, gen], dim=-1))
        
        # Expected value from probability distribution
        probs = F.softmax(logits, dim=-1)
        ages = torch.arange(probs.shape[-1], device=device, dtype=torch.float32)
        predicted_months = torch.sum(probs * ages).item()
    
    return predicted_months


def predict_with_tta(image_path, is_male, backbone, pooler,
                     gender_embed, classifier, config, device="cuda"):
    """
    Predict with Test-Time Augmentation (original + horizontal flip).
    This is the recommended inference method.
    """
    # Original prediction
    pred_orig = predict_bone_age(
        image_path, is_male, backbone, pooler,
        gender_embed, classifier, config, device
    )
    
    # Flipped prediction
    img = cv2.imread(image_path)
    img_flipped = cv2.flip(img, 1)  # Horizontal flip
    tmp_path = "/tmp/chronos_flip.png"
    cv2.imwrite(tmp_path, img_flipped)
    
    pred_flip = predict_bone_age(
        tmp_path, is_male, backbone, pooler,
        gender_embed, classifier, config, device
    )
    
    # Average
    return (pred_orig + pred_flip) / 2.0


# ── Example ──

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Load model (downloads from HuggingFace on first run)
    print("Loading model...")
    backbone, pooler, gender_embed, classifier, config = load_regressor(device=device)
    print("Model loaded!")
    
    # Single prediction
    age = predict_bone_age(
        "hand_xray.png",
        is_male=True,
        backbone=backbone,
        pooler=pooler,
        gender_embed=gender_embed,
        classifier=classifier,
        config=config,
        device=device,
    )
    print(f"Predicted bone age: {age:.1f} months ({age/12:.1f} years)")
    
    # With TTA (recommended)
    age_tta = predict_with_tta(
        "hand_xray.png",
        is_male=True,
        backbone=backbone,
        pooler=pooler,
        gender_embed=gender_embed,
        classifier=classifier,
        config=config,
        device=device,
    )
    print(f"Predicted bone age (TTA): {age_tta:.1f} months ({age_tta/12:.1f} years)")

Training Details

Data

  • Dataset: RSNA Pediatric Bone Age (14,236 images)
  • Split: 12,611 train / 1,625 validation (stratified by age bins, no leakage)
  • Labels: Bone age in months (0-228), converted to Gaussian distributions for DLDL

Hyperparameters

Parameter Value
Base model google/medsiglip-448
Adaptation DoRA (use_dora=True)
LoRA rank 16
LoRA alpha 32
LoRA dropout 0.05
LoRA targets q_proj, k_proj, v_proj, out_proj, fc1, fc2
DLDL sigma 12.0 months
Label smoothing 0.01
Optimizer AdamW (fused)
Learning rate 5e-5
LR schedule Cosine with warmup
Warmup ratio 0.05
Weight decay 0.01
Batch size 8 Γ— 2 gradient accumulation = 16 effective
Epochs 20
Precision bfloat16
Gradient checkpointing Enabled

Augmentation

  • Rotation: Β±20Β° with constant border padding
  • Random brightness/contrast: Β±10%, probability 0.3
  • Letterbox resize to 448Γ—448 (aspect ratio preserved)

Hardware

  • GPU: NVIDIA RTX 4090 (24 GB VRAM)
  • Training time: ~2 hours
  • Peak VRAM usage: ~18 GB

Loss Function

KL divergence between the predicted softmax distribution and a Gaussian target:

Target: N(age_true, σ²) discretized over 228 bins, label-smoothed
Loss: KL(predicted || target) = Ξ£ target[k] Γ— log(target[k] / predicted[k])

Saved Artifacts

File Description Size
config.json Model configuration (image size, hidden dim, num bins) <1 KB
heads.pth Trained pooler, gender_embed, and classifier weights ~2 MB
adapter/adapter_config.json LoRA/DoRA adapter configuration <1 KB
adapter/adapter_model.safetensors LoRA adapter weights ~5 MB

The base MedSigLIP-448 model (~1.2 GB) is downloaded automatically from google/medsiglip-448 on first use.

Limitations

  • Trained primarily on the RSNA dataset, which lacks explicit racial/ethnic metadata. Cross-population generalization should be validated before clinical use.
  • The model assumes a standard left hand PA radiograph. Performance may degrade on non-standard views, right hand images, or heavily rotated inputs.
  • This is a research tool, not a medical device. All predictions require interpretation by qualified clinical professionals.

Ethical Considerations

Bone age assessment has significant implications in pediatric endocrinology (growth disorder diagnosis), forensic anthropology, and legal age determination for asylum cases. Automated systems must be transparent about their accuracy limitations and should never be used as the sole basis for consequential decisions.

Citation

@misc{chronosmsk2025,
    title={Chronos-MSK: Bias-Aware Skeletal Maturity Assessment at the Edge},
    author={Rohit Rajesh},
    year={2025},
    note={Google HAI-DEF Competition Submission}
}

Links

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for 04RR/chronos-msk-regressor

Adapter
(3)
this model

Collection including 04RR/chronos-msk-regressor