Chronos-MSK Bone Age Regressor
A LoRA-fine-tuned MedSigLIP-448 model for pediatric bone age estimation from hand/wrist X-rays. Achieves 8.81-month MAE on the RSNA Bone Age validation set (Pearson r = 0.963), surpassing typical human expert inter-reader variability of 10-13 months.
Part of the Chronos-MSK multi-agent bone age assessment system, built for the Google HAI-DEF competition.
Model Description
This model predicts bone age in months from a single left hand/wrist radiograph. Instead of regressing to a single number, it uses Deep Label Distribution Learning (DLDL) β predicting a full probability distribution over 228 monthly age bins. The final age estimate is the expected value of this distribution, which naturally captures prediction uncertainty.
Biological sex is incorporated as a conditioning signal through a learned embedding, reflecting the known differences in skeletal maturation rates between males and females.
Architecture
Input: Hand/Wrist X-Ray (448Γ448, letterbox padded)
β
βΌ
MedSigLIP-448 Vision Encoder (frozen base + DoRA adapters)
β
βΌ last_hidden_state (sequence of patch embeddings)
β
βΌ
Global Average Pooling + LayerNorm β 1152-D visual features
β
βββ Sex Input (0/1) β Linear(1,64) β GELU β Linear(64,128) β 128-D
β
βΌ
Concatenate [visual_1152; sex_128] β 1280-D
β
βΌ
LayerNorm β Linear(1280,512) β GELU β Dropout(0.1) β Linear(512,228)
β
βΌ
Softmax β Probability distribution over 228 months
β
βΌ
Expected Value: Γ’ = Ξ£(k Γ P(k)) for k = 0..227
Key Design Choices
- DoRA (Weight-Decomposed LoRA): rank=16, alpha=32 β adapts the frozen MedSigLIP encoder with only ~2M trainable parameters while preserving medical domain knowledge
- DLDL Loss: KL divergence between predicted distribution and a Gaussian target centered on the true age (sigma=12 months) β smoother gradients than MSE regression
- Sex Conditioning: Learned embedding concatenated with visual features, enabling sex-specific age estimation without separate models
- Test-Time Augmentation: Average predictions from original and horizontally-flipped image for improved accuracy
Results
Evaluated on 1,425 held-out RSNA Bone Age validation cases (strict no-leakage protocol):
| Metric | Value |
|---|---|
| Mean Absolute Error | 8.81 months (0.73 years) |
| Median Absolute Error | 7.27 months |
| Root Mean Square Error | 11.39 months |
| Pearson r | 0.963 |
| RΒ² | 0.927 |
| Within Β±6 months | 42.9% |
| Within Β±12 months | 73.2% |
| Within Β±24 months | 95.7% |
Age-Stratified Performance
| Age Range | Cases | MAE |
|---|---|---|
| 0-5 years | 94 | 9.54 months |
| 5-10 years | 394 | 10.04 months |
| 10-15 years | 809 | 8.07 months |
| 15-19 years | 124 | 8.45 months |
Sex-Stratified Performance
| Sex | Cases | MAE |
|---|---|---|
| Male | 773 | 8.26 months |
| Female | 652 | 9.47 months |
Usage
Installation
pip install torch transformers peft numpy opencv-python huggingface_hub
Inference
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import json
import numpy as np
from transformers import SiglipVisionModel
from peft import PeftModel
from huggingface_hub import hf_hub_download, snapshot_download
# ββ Model Components ββ
class GlobalAveragePooling(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x):
return self.norm(x.mean(dim=1))
def letterbox_resize(image, size=448):
"""Resize preserving aspect ratio with zero-padding."""
h, w = image.shape[:2]
scale = size / max(h, w)
nh, nw = int(h * scale), int(w * scale)
resized = cv2.resize(image, (nw, nh))
padded = np.zeros((size, size, 3), dtype=np.uint8)
top, left = (size - nh) // 2, (size - nw) // 2
padded[top:top + nh, left:left + nw] = resized
return padded
# ββ Load Model ββ
def load_regressor(repo_id="rohitrajesh/chronos-msk-regressor", device="cuda"):
"""Load the complete bone age regressor from HuggingFace."""
# Download all files
config_path = hf_hub_download(repo_id, "config.json")
heads_path = hf_hub_download(repo_id, "heads.pth")
adapter_dir = snapshot_download(repo_id, allow_patterns=["adapter/*"])
# Load config
with open(config_path) as f:
config = json.load(f)
# Load MedSigLIP base model with LoRA adapter
base_model = SiglipVisionModel.from_pretrained(
config["model_id"], torch_dtype=torch.float32
)
backbone = PeftModel.from_pretrained(base_model, f"{adapter_dir}/adapter")
hidden = config["hidden_size"]
num_bins = config["num_bins"]
# Build prediction heads
pooler = GlobalAveragePooling(hidden)
gender_embed = nn.Sequential(
nn.Linear(1, 64), nn.GELU(), nn.Linear(64, 128)
)
classifier = nn.Sequential(
nn.LayerNorm(hidden + 128),
nn.Linear(hidden + 128, 512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, num_bins),
)
# Load trained head weights
state = torch.load(heads_path, map_location=device)
pooler.load_state_dict(state["pooler"])
gender_embed.load_state_dict(state["gender_embed"])
classifier.load_state_dict(state["classifier"])
# Move to device and set eval mode
backbone.to(device).eval()
pooler.to(device).eval()
gender_embed.to(device).eval()
classifier.to(device).eval()
return backbone, pooler, gender_embed, classifier, config
# ββ Predict ββ
def predict_bone_age(image_path, is_male, backbone, pooler,
gender_embed, classifier, config, device="cuda"):
"""
Predict bone age from a hand/wrist X-ray.
Args:
image_path: Path to X-ray image (any common format)
is_male: Boolean, biological sex
Returns:
age_months: Predicted bone age in months (float)
"""
size = config["image_size"]
# Load and preprocess
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Cannot read image: {image_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = letterbox_resize(img, size)
# Normalize: (pixel / 255 - 0.5) / 0.5
norm = (img.astype(np.float32) / 255.0 - 0.5) / 0.5
t_img = torch.from_numpy(norm.transpose(2, 0, 1)).unsqueeze(0).to(device)
t_sex = torch.tensor([[1.0 if is_male else 0.0]], device=device)
with torch.no_grad():
# Forward pass
out = backbone(pixel_values=t_img)
vis = pooler(out.last_hidden_state)
gen = gender_embed(t_sex)
logits = classifier(torch.cat([vis, gen], dim=-1))
# Expected value from probability distribution
probs = F.softmax(logits, dim=-1)
ages = torch.arange(probs.shape[-1], device=device, dtype=torch.float32)
predicted_months = torch.sum(probs * ages).item()
return predicted_months
def predict_with_tta(image_path, is_male, backbone, pooler,
gender_embed, classifier, config, device="cuda"):
"""
Predict with Test-Time Augmentation (original + horizontal flip).
This is the recommended inference method.
"""
# Original prediction
pred_orig = predict_bone_age(
image_path, is_male, backbone, pooler,
gender_embed, classifier, config, device
)
# Flipped prediction
img = cv2.imread(image_path)
img_flipped = cv2.flip(img, 1) # Horizontal flip
tmp_path = "/tmp/chronos_flip.png"
cv2.imwrite(tmp_path, img_flipped)
pred_flip = predict_bone_age(
tmp_path, is_male, backbone, pooler,
gender_embed, classifier, config, device
)
# Average
return (pred_orig + pred_flip) / 2.0
# ββ Example ββ
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model (downloads from HuggingFace on first run)
print("Loading model...")
backbone, pooler, gender_embed, classifier, config = load_regressor(device=device)
print("Model loaded!")
# Single prediction
age = predict_bone_age(
"hand_xray.png",
is_male=True,
backbone=backbone,
pooler=pooler,
gender_embed=gender_embed,
classifier=classifier,
config=config,
device=device,
)
print(f"Predicted bone age: {age:.1f} months ({age/12:.1f} years)")
# With TTA (recommended)
age_tta = predict_with_tta(
"hand_xray.png",
is_male=True,
backbone=backbone,
pooler=pooler,
gender_embed=gender_embed,
classifier=classifier,
config=config,
device=device,
)
print(f"Predicted bone age (TTA): {age_tta:.1f} months ({age_tta/12:.1f} years)")
Training Details
Data
- Dataset: RSNA Pediatric Bone Age (14,236 images)
- Split: 12,611 train / 1,625 validation (stratified by age bins, no leakage)
- Labels: Bone age in months (0-228), converted to Gaussian distributions for DLDL
Hyperparameters
| Parameter | Value |
|---|---|
| Base model | google/medsiglip-448 |
| Adaptation | DoRA (use_dora=True) |
| LoRA rank | 16 |
| LoRA alpha | 32 |
| LoRA dropout | 0.05 |
| LoRA targets | q_proj, k_proj, v_proj, out_proj, fc1, fc2 |
| DLDL sigma | 12.0 months |
| Label smoothing | 0.01 |
| Optimizer | AdamW (fused) |
| Learning rate | 5e-5 |
| LR schedule | Cosine with warmup |
| Warmup ratio | 0.05 |
| Weight decay | 0.01 |
| Batch size | 8 Γ 2 gradient accumulation = 16 effective |
| Epochs | 20 |
| Precision | bfloat16 |
| Gradient checkpointing | Enabled |
Augmentation
- Rotation: Β±20Β° with constant border padding
- Random brightness/contrast: Β±10%, probability 0.3
- Letterbox resize to 448Γ448 (aspect ratio preserved)
Hardware
- GPU: NVIDIA RTX 4090 (24 GB VRAM)
- Training time: ~2 hours
- Peak VRAM usage: ~18 GB
Loss Function
KL divergence between the predicted softmax distribution and a Gaussian target:
Target: N(age_true, ΟΒ²) discretized over 228 bins, label-smoothed
Loss: KL(predicted || target) = Ξ£ target[k] Γ log(target[k] / predicted[k])
Saved Artifacts
| File | Description | Size |
|---|---|---|
config.json |
Model configuration (image size, hidden dim, num bins) | <1 KB |
heads.pth |
Trained pooler, gender_embed, and classifier weights | ~2 MB |
adapter/adapter_config.json |
LoRA/DoRA adapter configuration | <1 KB |
adapter/adapter_model.safetensors |
LoRA adapter weights | ~5 MB |
The base MedSigLIP-448 model (~1.2 GB) is downloaded automatically from google/medsiglip-448 on first use.
Limitations
- Trained primarily on the RSNA dataset, which lacks explicit racial/ethnic metadata. Cross-population generalization should be validated before clinical use.
- The model assumes a standard left hand PA radiograph. Performance may degrade on non-standard views, right hand images, or heavily rotated inputs.
- This is a research tool, not a medical device. All predictions require interpretation by qualified clinical professionals.
Ethical Considerations
Bone age assessment has significant implications in pediatric endocrinology (growth disorder diagnosis), forensic anthropology, and legal age determination for asylum cases. Automated systems must be transparent about their accuracy limitations and should never be used as the sole basis for consequential decisions.
Citation
@misc{chronosmsk2025,
title={Chronos-MSK: Bias-Aware Skeletal Maturity Assessment at the Edge},
author={Rohit Rajesh},
year={2025},
note={Google HAI-DEF Competition Submission}
}
Links
- Full Pipeline: github.com/04RR/chronos-msk
- Model Weights (Kaggle): kaggle.com/datasets/rohitrajesh/model-checkpoints
- Base Model: google/medsiglip-448
- Dataset: RSNA Bone Age
- Downloads last month
- -
Model tree for 04RR/chronos-msk-regressor
Base model
google/medsiglip-448