--- license: mit tags: - medical - bone-age - radiology - medsiglip - lora - pediatric - skeletal-maturity - hai-def - dldl - age-estimation datasets: - kmader/rsna-bone-age base_model: google/medsiglip-448 pipeline_tag: image-classification library_name: peft --- # Chronos-MSK Bone Age Regressor **Writeup here**: [Download PDF](https://github.com/04RR/chronos-msk/blob/master/ChronosMSK_Writeup.pdf) **Techinal paper here**: [Download PDF](https://github.com/04RR/chronos-msk/blob/master/ChronosMSK_technical_paper.pdf) A LoRA-fine-tuned [MedSigLIP-448](https://huggingface.co/google/medsiglip-448) model for pediatric bone age estimation from hand/wrist X-rays. Achieves **8.81-month MAE** on the RSNA Bone Age validation set (Pearson r = 0.963), surpassing typical human expert inter-reader variability of 10-13 months. Part of the [Chronos-MSK](https://github.com/04RR/chronos-msk) multi-agent bone age assessment system, built for the Google HAI-DEF competition. ## Model Description This model predicts bone age in months from a single left hand/wrist radiograph. Instead of regressing to a single number, it uses **Deep Label Distribution Learning (DLDL)** — predicting a full probability distribution over 228 monthly age bins. The final age estimate is the expected value of this distribution, which naturally captures prediction uncertainty. Biological sex is incorporated as a conditioning signal through a learned embedding, reflecting the known differences in skeletal maturation rates between males and females. ### Architecture ``` Input: Hand/Wrist X-Ray (448×448, letterbox padded) │ ▼ MedSigLIP-448 Vision Encoder (frozen base + DoRA adapters) │ ▼ last_hidden_state (sequence of patch embeddings) │ ▼ Global Average Pooling + LayerNorm → 1152-D visual features │ ├── Sex Input (0/1) → Linear(1,64) → GELU → Linear(64,128) → 128-D │ ▼ Concatenate [visual_1152; sex_128] → 1280-D │ ▼ LayerNorm → Linear(1280,512) → GELU → Dropout(0.1) → Linear(512,228) │ ▼ Softmax → Probability distribution over 228 months │ ▼ Expected Value: â = Σ(k × P(k)) for k = 0..227 ``` ### Key Design Choices - **DoRA (Weight-Decomposed LoRA)**: rank=16, alpha=32 — adapts the frozen MedSigLIP encoder with only ~2M trainable parameters while preserving medical domain knowledge - **DLDL Loss**: KL divergence between predicted distribution and a Gaussian target centered on the true age (sigma=12 months) — smoother gradients than MSE regression - **Sex Conditioning**: Learned embedding concatenated with visual features, enabling sex-specific age estimation without separate models - **Test-Time Augmentation**: Average predictions from original and horizontally-flipped image for improved accuracy ## Results Evaluated on 1,425 held-out RSNA Bone Age validation cases (strict no-leakage protocol): | Metric | Value | |---|---| | **Mean Absolute Error** | **8.81 months (0.73 years)** | | Median Absolute Error | 7.27 months | | Root Mean Square Error | 11.39 months | | Pearson r | 0.963 | | R² | 0.927 | | Within ±6 months | 42.9% | | Within ±12 months | 73.2% | | Within ±24 months | 95.7% | ### Age-Stratified Performance | Age Range | Cases | MAE | |---|---|---| | 0-5 years | 94 | 9.54 months | | 5-10 years | 394 | 10.04 months | | 10-15 years | 809 | 8.07 months | | 15-19 years | 124 | 8.45 months | ### Sex-Stratified Performance | Sex | Cases | MAE | |---|---|---| | Male | 773 | 8.26 months | | Female | 652 | 9.47 months | ## Usage ### Installation ```bash pip install torch transformers peft numpy opencv-python huggingface_hub ``` ### Inference ```python import torch import torch.nn as nn import torch.nn.functional as F import cv2 import json import numpy as np from transformers import SiglipVisionModel from peft import PeftModel from huggingface_hub import hf_hub_download, snapshot_download # ── Model Components ── class GlobalAveragePooling(nn.Module): def __init__(self, hidden_size): super().__init__() self.norm = nn.LayerNorm(hidden_size) def forward(self, x): return self.norm(x.mean(dim=1)) def letterbox_resize(image, size=448): """Resize preserving aspect ratio with zero-padding.""" h, w = image.shape[:2] scale = size / max(h, w) nh, nw = int(h * scale), int(w * scale) resized = cv2.resize(image, (nw, nh)) padded = np.zeros((size, size, 3), dtype=np.uint8) top, left = (size - nh) // 2, (size - nw) // 2 padded[top:top + nh, left:left + nw] = resized return padded # ── Load Model ── def load_regressor(repo_id="rohitrajesh/chronos-msk-regressor", device="cuda"): """Load the complete bone age regressor from HuggingFace.""" # Download all files config_path = hf_hub_download(repo_id, "config.json") heads_path = hf_hub_download(repo_id, "heads.pth") adapter_dir = snapshot_download(repo_id, allow_patterns=["adapter/*"]) # Load config with open(config_path) as f: config = json.load(f) # Load MedSigLIP base model with LoRA adapter base_model = SiglipVisionModel.from_pretrained( config["model_id"], torch_dtype=torch.float32 ) backbone = PeftModel.from_pretrained(base_model, f"{adapter_dir}/adapter") hidden = config["hidden_size"] num_bins = config["num_bins"] # Build prediction heads pooler = GlobalAveragePooling(hidden) gender_embed = nn.Sequential( nn.Linear(1, 64), nn.GELU(), nn.Linear(64, 128) ) classifier = nn.Sequential( nn.LayerNorm(hidden + 128), nn.Linear(hidden + 128, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, num_bins), ) # Load trained head weights state = torch.load(heads_path, map_location=device) pooler.load_state_dict(state["pooler"]) gender_embed.load_state_dict(state["gender_embed"]) classifier.load_state_dict(state["classifier"]) # Move to device and set eval mode backbone.to(device).eval() pooler.to(device).eval() gender_embed.to(device).eval() classifier.to(device).eval() return backbone, pooler, gender_embed, classifier, config # ── Predict ── def predict_bone_age(image_path, is_male, backbone, pooler, gender_embed, classifier, config, device="cuda"): """ Predict bone age from a hand/wrist X-ray. Args: image_path: Path to X-ray image (any common format) is_male: Boolean, biological sex Returns: age_months: Predicted bone age in months (float) """ size = config["image_size"] # Load and preprocess img = cv2.imread(image_path) if img is None: raise ValueError(f"Cannot read image: {image_path}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = letterbox_resize(img, size) # Normalize: (pixel / 255 - 0.5) / 0.5 norm = (img.astype(np.float32) / 255.0 - 0.5) / 0.5 t_img = torch.from_numpy(norm.transpose(2, 0, 1)).unsqueeze(0).to(device) t_sex = torch.tensor([[1.0 if is_male else 0.0]], device=device) with torch.no_grad(): # Forward pass out = backbone(pixel_values=t_img) vis = pooler(out.last_hidden_state) gen = gender_embed(t_sex) logits = classifier(torch.cat([vis, gen], dim=-1)) # Expected value from probability distribution probs = F.softmax(logits, dim=-1) ages = torch.arange(probs.shape[-1], device=device, dtype=torch.float32) predicted_months = torch.sum(probs * ages).item() return predicted_months def predict_with_tta(image_path, is_male, backbone, pooler, gender_embed, classifier, config, device="cuda"): """ Predict with Test-Time Augmentation (original + horizontal flip). This is the recommended inference method. """ # Original prediction pred_orig = predict_bone_age( image_path, is_male, backbone, pooler, gender_embed, classifier, config, device ) # Flipped prediction img = cv2.imread(image_path) img_flipped = cv2.flip(img, 1) # Horizontal flip tmp_path = "/tmp/chronos_flip.png" cv2.imwrite(tmp_path, img_flipped) pred_flip = predict_bone_age( tmp_path, is_male, backbone, pooler, gender_embed, classifier, config, device ) # Average return (pred_orig + pred_flip) / 2.0 # ── Example ── if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load model (downloads from HuggingFace on first run) print("Loading model...") backbone, pooler, gender_embed, classifier, config = load_regressor(device=device) print("Model loaded!") # Single prediction age = predict_bone_age( "hand_xray.png", is_male=True, backbone=backbone, pooler=pooler, gender_embed=gender_embed, classifier=classifier, config=config, device=device, ) print(f"Predicted bone age: {age:.1f} months ({age/12:.1f} years)") # With TTA (recommended) age_tta = predict_with_tta( "hand_xray.png", is_male=True, backbone=backbone, pooler=pooler, gender_embed=gender_embed, classifier=classifier, config=config, device=device, ) print(f"Predicted bone age (TTA): {age_tta:.1f} months ({age_tta/12:.1f} years)") ``` ## Training Details ### Data - **Dataset**: [RSNA Pediatric Bone Age](https://www.kaggle.com/datasets/kmader/rsna-bone-age) (14,236 images) - **Split**: 12,611 train / 1,625 validation (stratified by age bins, no leakage) - **Labels**: Bone age in months (0-228), converted to Gaussian distributions for DLDL ### Hyperparameters | Parameter | Value | |---|---| | Base model | google/medsiglip-448 | | Adaptation | DoRA (use_dora=True) | | LoRA rank | 16 | | LoRA alpha | 32 | | LoRA dropout | 0.05 | | LoRA targets | q_proj, k_proj, v_proj, out_proj, fc1, fc2 | | DLDL sigma | 12.0 months | | Label smoothing | 0.01 | | Optimizer | AdamW (fused) | | Learning rate | 5e-5 | | LR schedule | Cosine with warmup | | Warmup ratio | 0.05 | | Weight decay | 0.01 | | Batch size | 8 × 2 gradient accumulation = 16 effective | | Epochs | 20 | | Precision | bfloat16 | | Gradient checkpointing | Enabled | ### Augmentation - Rotation: ±20° with constant border padding - Random brightness/contrast: ±10%, probability 0.3 - Letterbox resize to 448×448 (aspect ratio preserved) ### Hardware - GPU: NVIDIA RTX 4090 (24 GB VRAM) - Training time: ~2 hours - Peak VRAM usage: ~18 GB ### Loss Function KL divergence between the predicted softmax distribution and a Gaussian target: ``` Target: N(age_true, σ²) discretized over 228 bins, label-smoothed Loss: KL(predicted || target) = Σ target[k] × log(target[k] / predicted[k]) ``` ## Saved Artifacts | File | Description | Size | |---|---|---| | `config.json` | Model configuration (image size, hidden dim, num bins) | <1 KB | | `heads.pth` | Trained pooler, gender_embed, and classifier weights | ~2 MB | | `adapter/adapter_config.json` | LoRA/DoRA adapter configuration | <1 KB | | `adapter/adapter_model.safetensors` | LoRA adapter weights | ~5 MB | The base MedSigLIP-448 model (~1.2 GB) is downloaded automatically from `google/medsiglip-448` on first use. ## Limitations - Trained primarily on the RSNA dataset, which lacks explicit racial/ethnic metadata. Cross-population generalization should be validated before clinical use. - The model assumes a standard left hand PA radiograph. Performance may degrade on non-standard views, right hand images, or heavily rotated inputs. - This is a research tool, **not a medical device**. All predictions require interpretation by qualified clinical professionals. ## Ethical Considerations Bone age assessment has significant implications in pediatric endocrinology (growth disorder diagnosis), forensic anthropology, and legal age determination for asylum cases. Automated systems must be transparent about their accuracy limitations and should never be used as the sole basis for consequential decisions. ## Citation ```bibtex @misc{chronosmsk2025, title={Chronos-MSK: Bias-Aware Skeletal Maturity Assessment at the Edge}, author={Rohit Rajesh}, year={2025}, note={Google HAI-DEF Competition Submission} } ``` ## Links - **Full Pipeline**: [github.com/04RR/chronos-msk](https://github.com/04RR/chronos-msk) - **Model Weights (Kaggle)**: [kaggle.com/datasets/rohitrajesh/model-checkpoints](https://www.kaggle.com/datasets/rohitrajesh/model-checkpoints/) - **Base Model**: [google/medsiglip-448](https://huggingface.co/google/medsiglip-448) - **Dataset**: [RSNA Bone Age](https://www.kaggle.com/datasets/kmader/rsna-bone-age)