| | --- |
| | license: mit |
| | tags: |
| | - medical |
| | - bone-age |
| | - radiology |
| | - medsiglip |
| | - lora |
| | - pediatric |
| | - skeletal-maturity |
| | - hai-def |
| | - dldl |
| | - age-estimation |
| | datasets: |
| | - kmader/rsna-bone-age |
| | base_model: google/medsiglip-448 |
| | pipeline_tag: image-classification |
| | library_name: peft |
| | --- |
| | |
| | # Chronos-MSK Bone Age Regressor |
| |
|
| | **Writeup here**: [Download PDF](https://github.com/04RR/chronos-msk/blob/master/ChronosMSK_Writeup.pdf) |
| | **Techinal paper here**: [Download PDF](https://github.com/04RR/chronos-msk/blob/master/ChronosMSK_technical_paper.pdf) |
| |
|
| | A LoRA-fine-tuned [MedSigLIP-448](https://huggingface.co/google/medsiglip-448) model for pediatric bone age estimation from hand/wrist X-rays. Achieves **8.81-month MAE** on the RSNA Bone Age validation set (Pearson r = 0.963), surpassing typical human expert inter-reader variability of 10-13 months. |
| |
|
| | Part of the [Chronos-MSK](https://github.com/04RR/chronos-msk) multi-agent bone age assessment system, built for the Google HAI-DEF competition. |
| |
|
| | ## Model Description |
| |
|
| | This model predicts bone age in months from a single left hand/wrist radiograph. Instead of regressing to a single number, it uses **Deep Label Distribution Learning (DLDL)** β predicting a full probability distribution over 228 monthly age bins. The final age estimate is the expected value of this distribution, which naturally captures prediction uncertainty. |
| |
|
| | Biological sex is incorporated as a conditioning signal through a learned embedding, reflecting the known differences in skeletal maturation rates between males and females. |
| |
|
| | ### Architecture |
| |
|
| | ``` |
| | Input: Hand/Wrist X-Ray (448Γ448, letterbox padded) |
| | β |
| | βΌ |
| | MedSigLIP-448 Vision Encoder (frozen base + DoRA adapters) |
| | β |
| | βΌ last_hidden_state (sequence of patch embeddings) |
| | β |
| | βΌ |
| | Global Average Pooling + LayerNorm β 1152-D visual features |
| | β |
| | βββ Sex Input (0/1) β Linear(1,64) β GELU β Linear(64,128) β 128-D |
| | β |
| | βΌ |
| | Concatenate [visual_1152; sex_128] β 1280-D |
| | β |
| | βΌ |
| | LayerNorm β Linear(1280,512) β GELU β Dropout(0.1) β Linear(512,228) |
| | β |
| | βΌ |
| | Softmax β Probability distribution over 228 months |
| | β |
| | βΌ |
| | Expected Value: Γ’ = Ξ£(k Γ P(k)) for k = 0..227 |
| | ``` |
| |
|
| | ### Key Design Choices |
| |
|
| | - **DoRA (Weight-Decomposed LoRA)**: rank=16, alpha=32 β adapts the frozen MedSigLIP encoder with only ~2M trainable parameters while preserving medical domain knowledge |
| | - **DLDL Loss**: KL divergence between predicted distribution and a Gaussian target centered on the true age (sigma=12 months) β smoother gradients than MSE regression |
| | - **Sex Conditioning**: Learned embedding concatenated with visual features, enabling sex-specific age estimation without separate models |
| | - **Test-Time Augmentation**: Average predictions from original and horizontally-flipped image for improved accuracy |
| |
|
| | ## Results |
| |
|
| | Evaluated on 1,425 held-out RSNA Bone Age validation cases (strict no-leakage protocol): |
| |
|
| | | Metric | Value | |
| | |---|---| |
| | | **Mean Absolute Error** | **8.81 months (0.73 years)** | |
| | | Median Absolute Error | 7.27 months | |
| | | Root Mean Square Error | 11.39 months | |
| | | Pearson r | 0.963 | |
| | | RΒ² | 0.927 | |
| | | Within Β±6 months | 42.9% | |
| | | Within Β±12 months | 73.2% | |
| | | Within Β±24 months | 95.7% | |
| |
|
| | ### Age-Stratified Performance |
| |
|
| | | Age Range | Cases | MAE | |
| | |---|---|---| |
| | | 0-5 years | 94 | 9.54 months | |
| | | 5-10 years | 394 | 10.04 months | |
| | | 10-15 years | 809 | 8.07 months | |
| | | 15-19 years | 124 | 8.45 months | |
| |
|
| | ### Sex-Stratified Performance |
| |
|
| | | Sex | Cases | MAE | |
| | |---|---|---| |
| | | Male | 773 | 8.26 months | |
| | | Female | 652 | 9.47 months | |
| |
|
| | ## Usage |
| |
|
| | ### Installation |
| |
|
| | ```bash |
| | pip install torch transformers peft numpy opencv-python huggingface_hub |
| | ``` |
| |
|
| | ### Inference |
| |
|
| | ```python |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import cv2 |
| | import json |
| | import numpy as np |
| | from transformers import SiglipVisionModel |
| | from peft import PeftModel |
| | from huggingface_hub import hf_hub_download, snapshot_download |
| | |
| | |
| | # ββ Model Components ββ |
| | |
| | class GlobalAveragePooling(nn.Module): |
| | def __init__(self, hidden_size): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(hidden_size) |
| | |
| | def forward(self, x): |
| | return self.norm(x.mean(dim=1)) |
| | |
| | |
| | def letterbox_resize(image, size=448): |
| | """Resize preserving aspect ratio with zero-padding.""" |
| | h, w = image.shape[:2] |
| | scale = size / max(h, w) |
| | nh, nw = int(h * scale), int(w * scale) |
| | resized = cv2.resize(image, (nw, nh)) |
| | padded = np.zeros((size, size, 3), dtype=np.uint8) |
| | top, left = (size - nh) // 2, (size - nw) // 2 |
| | padded[top:top + nh, left:left + nw] = resized |
| | return padded |
| | |
| | |
| | # ββ Load Model ββ |
| | |
| | def load_regressor(repo_id="rohitrajesh/chronos-msk-regressor", device="cuda"): |
| | """Load the complete bone age regressor from HuggingFace.""" |
| | |
| | # Download all files |
| | config_path = hf_hub_download(repo_id, "config.json") |
| | heads_path = hf_hub_download(repo_id, "heads.pth") |
| | adapter_dir = snapshot_download(repo_id, allow_patterns=["adapter/*"]) |
| | |
| | # Load config |
| | with open(config_path) as f: |
| | config = json.load(f) |
| | |
| | # Load MedSigLIP base model with LoRA adapter |
| | base_model = SiglipVisionModel.from_pretrained( |
| | config["model_id"], torch_dtype=torch.float32 |
| | ) |
| | backbone = PeftModel.from_pretrained(base_model, f"{adapter_dir}/adapter") |
| | |
| | hidden = config["hidden_size"] |
| | num_bins = config["num_bins"] |
| | |
| | # Build prediction heads |
| | pooler = GlobalAveragePooling(hidden) |
| | gender_embed = nn.Sequential( |
| | nn.Linear(1, 64), nn.GELU(), nn.Linear(64, 128) |
| | ) |
| | classifier = nn.Sequential( |
| | nn.LayerNorm(hidden + 128), |
| | nn.Linear(hidden + 128, 512), |
| | nn.GELU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(512, num_bins), |
| | ) |
| | |
| | # Load trained head weights |
| | state = torch.load(heads_path, map_location=device) |
| | pooler.load_state_dict(state["pooler"]) |
| | gender_embed.load_state_dict(state["gender_embed"]) |
| | classifier.load_state_dict(state["classifier"]) |
| | |
| | # Move to device and set eval mode |
| | backbone.to(device).eval() |
| | pooler.to(device).eval() |
| | gender_embed.to(device).eval() |
| | classifier.to(device).eval() |
| | |
| | return backbone, pooler, gender_embed, classifier, config |
| | |
| | |
| | # ββ Predict ββ |
| | |
| | def predict_bone_age(image_path, is_male, backbone, pooler, |
| | gender_embed, classifier, config, device="cuda"): |
| | """ |
| | Predict bone age from a hand/wrist X-ray. |
| | |
| | Args: |
| | image_path: Path to X-ray image (any common format) |
| | is_male: Boolean, biological sex |
| | |
| | Returns: |
| | age_months: Predicted bone age in months (float) |
| | """ |
| | size = config["image_size"] |
| | |
| | # Load and preprocess |
| | img = cv2.imread(image_path) |
| | if img is None: |
| | raise ValueError(f"Cannot read image: {image_path}") |
| | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| | img = letterbox_resize(img, size) |
| | |
| | # Normalize: (pixel / 255 - 0.5) / 0.5 |
| | norm = (img.astype(np.float32) / 255.0 - 0.5) / 0.5 |
| | t_img = torch.from_numpy(norm.transpose(2, 0, 1)).unsqueeze(0).to(device) |
| | t_sex = torch.tensor([[1.0 if is_male else 0.0]], device=device) |
| | |
| | with torch.no_grad(): |
| | # Forward pass |
| | out = backbone(pixel_values=t_img) |
| | vis = pooler(out.last_hidden_state) |
| | gen = gender_embed(t_sex) |
| | logits = classifier(torch.cat([vis, gen], dim=-1)) |
| | |
| | # Expected value from probability distribution |
| | probs = F.softmax(logits, dim=-1) |
| | ages = torch.arange(probs.shape[-1], device=device, dtype=torch.float32) |
| | predicted_months = torch.sum(probs * ages).item() |
| | |
| | return predicted_months |
| | |
| | |
| | def predict_with_tta(image_path, is_male, backbone, pooler, |
| | gender_embed, classifier, config, device="cuda"): |
| | """ |
| | Predict with Test-Time Augmentation (original + horizontal flip). |
| | This is the recommended inference method. |
| | """ |
| | # Original prediction |
| | pred_orig = predict_bone_age( |
| | image_path, is_male, backbone, pooler, |
| | gender_embed, classifier, config, device |
| | ) |
| | |
| | # Flipped prediction |
| | img = cv2.imread(image_path) |
| | img_flipped = cv2.flip(img, 1) # Horizontal flip |
| | tmp_path = "/tmp/chronos_flip.png" |
| | cv2.imwrite(tmp_path, img_flipped) |
| | |
| | pred_flip = predict_bone_age( |
| | tmp_path, is_male, backbone, pooler, |
| | gender_embed, classifier, config, device |
| | ) |
| | |
| | # Average |
| | return (pred_orig + pred_flip) / 2.0 |
| | |
| | |
| | # ββ Example ββ |
| | |
| | if __name__ == "__main__": |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Using device: {device}") |
| | |
| | # Load model (downloads from HuggingFace on first run) |
| | print("Loading model...") |
| | backbone, pooler, gender_embed, classifier, config = load_regressor(device=device) |
| | print("Model loaded!") |
| | |
| | # Single prediction |
| | age = predict_bone_age( |
| | "hand_xray.png", |
| | is_male=True, |
| | backbone=backbone, |
| | pooler=pooler, |
| | gender_embed=gender_embed, |
| | classifier=classifier, |
| | config=config, |
| | device=device, |
| | ) |
| | print(f"Predicted bone age: {age:.1f} months ({age/12:.1f} years)") |
| | |
| | # With TTA (recommended) |
| | age_tta = predict_with_tta( |
| | "hand_xray.png", |
| | is_male=True, |
| | backbone=backbone, |
| | pooler=pooler, |
| | gender_embed=gender_embed, |
| | classifier=classifier, |
| | config=config, |
| | device=device, |
| | ) |
| | print(f"Predicted bone age (TTA): {age_tta:.1f} months ({age_tta/12:.1f} years)") |
| | ``` |
| |
|
| | ## Training Details |
| |
|
| | ### Data |
| |
|
| | - **Dataset**: [RSNA Pediatric Bone Age](https://www.kaggle.com/datasets/kmader/rsna-bone-age) (14,236 images) |
| | - **Split**: 12,611 train / 1,625 validation (stratified by age bins, no leakage) |
| | - **Labels**: Bone age in months (0-228), converted to Gaussian distributions for DLDL |
| |
|
| | ### Hyperparameters |
| |
|
| | | Parameter | Value | |
| | |---|---| |
| | | Base model | google/medsiglip-448 | |
| | | Adaptation | DoRA (use_dora=True) | |
| | | LoRA rank | 16 | |
| | | LoRA alpha | 32 | |
| | | LoRA dropout | 0.05 | |
| | | LoRA targets | q_proj, k_proj, v_proj, out_proj, fc1, fc2 | |
| | | DLDL sigma | 12.0 months | |
| | | Label smoothing | 0.01 | |
| | | Optimizer | AdamW (fused) | |
| | | Learning rate | 5e-5 | |
| | | LR schedule | Cosine with warmup | |
| | | Warmup ratio | 0.05 | |
| | | Weight decay | 0.01 | |
| | | Batch size | 8 Γ 2 gradient accumulation = 16 effective | |
| | | Epochs | 20 | |
| | | Precision | bfloat16 | |
| | | Gradient checkpointing | Enabled | |
| | |
| | ### Augmentation |
| | |
| | - Rotation: Β±20Β° with constant border padding |
| | - Random brightness/contrast: Β±10%, probability 0.3 |
| | - Letterbox resize to 448Γ448 (aspect ratio preserved) |
| | |
| | ### Hardware |
| | |
| | - GPU: NVIDIA RTX 4090 (24 GB VRAM) |
| | - Training time: ~2 hours |
| | - Peak VRAM usage: ~18 GB |
| | |
| | ### Loss Function |
| | |
| | KL divergence between the predicted softmax distribution and a Gaussian target: |
| | |
| | ``` |
| | Target: N(age_true, ΟΒ²) discretized over 228 bins, label-smoothed |
| | Loss: KL(predicted || target) = Ξ£ target[k] Γ log(target[k] / predicted[k]) |
| | ``` |
| | |
| | ## Saved Artifacts |
| | |
| | | File | Description | Size | |
| | |---|---|---| |
| | | `config.json` | Model configuration (image size, hidden dim, num bins) | <1 KB | |
| | | `heads.pth` | Trained pooler, gender_embed, and classifier weights | ~2 MB | |
| | | `adapter/adapter_config.json` | LoRA/DoRA adapter configuration | <1 KB | |
| | | `adapter/adapter_model.safetensors` | LoRA adapter weights | ~5 MB | |
| | |
| | The base MedSigLIP-448 model (~1.2 GB) is downloaded automatically from `google/medsiglip-448` on first use. |
| | |
| | ## Limitations |
| | |
| | - Trained primarily on the RSNA dataset, which lacks explicit racial/ethnic metadata. Cross-population generalization should be validated before clinical use. |
| | - The model assumes a standard left hand PA radiograph. Performance may degrade on non-standard views, right hand images, or heavily rotated inputs. |
| | - This is a research tool, **not a medical device**. All predictions require interpretation by qualified clinical professionals. |
| | |
| | ## Ethical Considerations |
| | |
| | Bone age assessment has significant implications in pediatric endocrinology (growth disorder diagnosis), forensic anthropology, and legal age determination for asylum cases. Automated systems must be transparent about their accuracy limitations and should never be used as the sole basis for consequential decisions. |
| | |
| | ## Citation |
| | |
| | ```bibtex |
| | @misc{chronosmsk2025, |
| | title={Chronos-MSK: Bias-Aware Skeletal Maturity Assessment at the Edge}, |
| | author={Rohit Rajesh}, |
| | year={2025}, |
| | note={Google HAI-DEF Competition Submission} |
| | } |
| | ``` |
| | |
| | ## Links |
| |
|
| | - **Full Pipeline**: [github.com/04RR/chronos-msk](https://github.com/04RR/chronos-msk) |
| | - **Model Weights (Kaggle)**: [kaggle.com/datasets/rohitrajesh/model-checkpoints](https://www.kaggle.com/datasets/rohitrajesh/model-checkpoints/) |
| | - **Base Model**: [google/medsiglip-448](https://huggingface.co/google/medsiglip-448) |
| | - **Dataset**: [RSNA Bone Age](https://www.kaggle.com/datasets/kmader/rsna-bone-age) |