Chronos-MSK Demographic Projector
A SOTA metric learning model that projects MedSigLIP-448 embeddings (1152-D) into a compact 256-D space optimized for age-aware, demographic-stratified bone age retrieval. Enables "Visual Twin" matching β finding atlas cases with similar skeletal maturity from biologically relevant populations.
Part of the Chronos-MSK multi-agent bone age assessment system, built for the Google HAI-DEF competition.
Model Description
Bone age varies by sex and ethnicity, yet most retrieval systems use a single embedding space that conflates demographic variation with skeletal maturity differences. This projector creates a structured embedding space where:
- Proximity correlates with age similarity β closer embeddings mean closer bone ages
- Demographic structure is preserved β separate clusters for each sex/race combination
- Distance is a calibrated quality signal β retrieval distance correlates with prediction error (r = +0.26)
The model was trained on the USC Digital Hand Atlas, which provides explicit demographic metadata across four racial groups (Asian, Black, Caucasian, Hispanic) and both sexes.
Architecture
Input: 1152-D MedSigLIP-448 pooled embedding (L2-normalized)
β
βββ Trunk (main pathway):
β Linear(1152 β 768) β LayerNorm β GELU β Dropout(0.1)
β β Linear(768 β 768) β LayerNorm β GELU β Dropout(0.1)
β
βββ Skip (residual connection):
β Linear(1152 β 768)
β
βββ Add: Trunk output + Skip output β 768-D hidden representation
β
βββ Projection Head:
β Linear(768 β 256) β L2 Normalize
β Output: 256-D unit-normalized embedding for FAISS retrieval
β
βββ Age Head (auxiliary, multi-task):
β Linear(768 β 256) β GELU β Linear(256 β 1)
β Output: Predicted age in months (training signal only)
β
βββ Proxies (learnable):
8 Γ 256-D vectors (one per demographic class)
Used by Proxy-NCA loss during training
Training: Multi-Objective SOTA Approach
The projector was trained with four complementary loss functions:
1. Multi-Similarity Loss (Wang et al., CVPR 2019)
- Mines ALL informative positive and negative pairs in each batch
- Pairs are defined by both demographic class AND age proximity
- Positive: same class, age gap < threshold
- Negative: different class OR same class with large age gap
- Far more efficient than random triplet mining
2. Proxy-NCA Loss (Movshovitz-Attias et al., ICCV 2017)
- Learns a proxy centroid for each of 8 demographic classes
- O(BΓC) complexity instead of O(BΒ²) β faster and more stable gradients
- Provides clean demographic cluster structure
3. Age-Continuous Soft Contrastive Loss
- Instead of binary same/different labels, weights pair attraction by age distance
- Pair weight: w_ij = exp(-|age_i - age_j| / Ο), where Ο = 12 months
- Creates smooth age gradients within each demographic cluster
4. Auxiliary Age Regression (SmoothL1 Loss)
- Multi-task signal from the age prediction head
- Ensures the hidden representation encodes age information
- Does not affect the projection head output at inference time
Curriculum Learning:
- Age thresholds for positive/negative pair mining progressively tighten over training
- Start: positives within 36 months, negatives beyond 108 months (easy)
- End: positives within 6 months, negatives beyond 18 months (hard)
- Transition: cosine schedule over 100 epochs
Performance
| Metric | Value |
|---|---|
| Retrieval MAE (Archivist-only prediction) | 19.13 months |
| Distance β Error Correlation | r = +0.26 (p < 1e-23) |
| Confidence Calibration | r = +0.20 (monotonic across tiers) |
| Recall@1 (same class, age within 12m) | 10.8% |
| Recall@5 | 21.6% |
| Retrieval Age MAE (val set) | 15.8 months |
Confidence Calibration
When used in the Chronos-MSK pipeline, retrieval distance enables calibrated confidence tiers:
| Tier | Cases | Pipeline MAE | Within Β±12m |
|---|---|---|---|
| HIGH | 535 (37.5%) | 6.96 months | 83.9% |
| MODERATE | 644 (45.2%) | 9.69 months | 67.2% |
| LOW | 245 (17.2%) | 10.71 months | 64.5% |
HIGH-confidence predictions achieve 6.96-month MAE β this calibration is the projector's primary value in the pipeline.
Usage
Installation
pip install torch numpy faiss-cpu huggingface_hub
Load the Projector
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from huggingface_hub import hf_hub_download
class DemographicProjectorSOTA(nn.Module):
"""
Projects 1152-D MedSigLIP embeddings to 256-D metric space.
Must match the training architecture exactly.
"""
def __init__(self, input_dim=1152, output_dim=256,
hidden_dim=768, num_classes=8):
super().__init__()
self.trunk = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
)
self.skip = nn.Linear(input_dim, hidden_dim)
self.projector = nn.Sequential(nn.Linear(hidden_dim, output_dim))
self.age_head = nn.Sequential(
nn.Linear(hidden_dim, 256),
nn.GELU(),
nn.Linear(256, 1),
)
self.proxies = nn.Parameter(torch.randn(num_classes, output_dim) * 0.1)
def forward(self, x, return_age=False):
h = self.trunk(x) + self.skip(x)
emb = self.projector(h)
emb = F.normalize(emb, p=2, dim=1)
if return_age:
age_pred = self.age_head(h).squeeze(-1)
return emb, age_pred
return emb
def load_projector(repo_id="rohitrajesh/chronos-msk-projector", device="cpu"):
"""Load the trained projector from HuggingFace."""
weights_path = hf_hub_download(repo_id, "projector_sota.pth")
model = DemographicProjectorSOTA(
input_dim=1152,
output_dim=256,
hidden_dim=768,
num_classes=8,
)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device).eval()
return model
# Load
device = "cuda" if torch.cuda.is_available() else "cpu"
projector = load_projector(device=device)
print("Projector loaded!")
Project an Embedding
# Assume you have a 1152-D MedSigLIP embedding (from any image)
raw_embedding = np.random.randn(1152).astype(np.float32) # Replace with real embedding
with torch.no_grad():
tensor = torch.from_numpy(raw_embedding).unsqueeze(0).to(device)
projected = projector(tensor) # Shape: (1, 256), L2-normalized
print(f"Input shape: {tensor.shape}") # [1, 1152]
print(f"Output shape: {projected.shape}") # [1, 256]
print(f"L2 norm: {projected.norm():.4f}") # ~1.0
End-to-End: Image β MedSigLIP β Projector β FAISS Search
import cv2
import faiss
from transformers import SiglipVisionModel, AutoProcessor
def letterbox_resize(image, size=448):
h, w = image.shape[:2]
scale = size / max(h, w)
nh, nw = int(h * scale), int(w * scale)
resized = cv2.resize(image, (nw, nh))
padded = np.zeros((size, size, 3), dtype=np.uint8)
top, left = (size - nh) // 2, (size - nw) // 2
padded[top:top + nh, left:left + nw] = resized
return padded
# Step 1: Load MedSigLIP encoder
encoder = SiglipVisionModel.from_pretrained("google/medsiglip-448").to(device).eval()
processor = AutoProcessor.from_pretrained("google/medsiglip-448")
# Step 2: Embed an X-ray image
img = cv2.imread("hand_xray.png")
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = letterbox_resize(img_rgb, 448)
inputs = processor(images=img_resized, return_tensors="pt").to(device)
with torch.no_grad():
raw_emb = encoder(**inputs).pooler_output # [1, 1152]
raw_emb = raw_emb / raw_emb.norm(p=2, dim=-1, keepdim=True)
# Step 3: Project to 256-D retrieval space
with torch.no_grad():
projected = projector(raw_emb) # [1, 256]
# Step 4: Search a FAISS index
# (Build from atlas embeddings, or download pre-built from Kaggle)
# index = faiss.read_index("Male_Caucasian.index")
# distances, indices = index.search(projected.cpu().numpy(), k=5)
# print(f"Top 5 matches: {indices[0]}, distances: {distances[0]}")
Get Age Prediction from Auxiliary Head
# The projector also has an auxiliary age prediction head (used during training)
with torch.no_grad():
projected_emb, age_pred = projector(raw_emb, return_age=True)
print(f"Projected embedding: {projected_emb.shape}") # [1, 256]
print(f"Predicted age: {age_pred.item():.0f} months") # Rough estimate
# Note: The age head is auxiliary β use the main regressor for accurate predictions
Batch Processing
# Project a batch of embeddings efficiently
batch_embeddings = np.random.randn(100, 1152).astype(np.float32)
with torch.no_grad():
batch_tensor = torch.from_numpy(batch_embeddings).to(device)
batch_projected = projector(batch_tensor) # [100, 256]
print(f"Batch projected: {batch_projected.shape}")
print(f"All unit-normalized: {batch_projected.norm(dim=1).mean():.4f}")
Training Details
Data
- Dataset: USC Digital Hand Atlas (1,390 images)
- Demographics: Evenly distributed across 8 partitions
| Partition | Count |
|---|---|
| Male Asian | 167 |
| Male Black | 184 |
| Male Caucasian | 167 |
| Male Hispanic | 182 |
| Female Asian | 167 |
| Female Black | 174 |
| Female Caucasian | 166 |
| Female Hispanic | 183 |
- Age Range: 0-18 years (0-216 months)
- Split: 90% train (1,251), 10% validation (139)
- Embeddings: Pre-computed with frozen MedSigLIP-448 (cached to disk)
Hyperparameters
| Parameter | Value |
|---|---|
| Input dimension | 1152 |
| Output dimension | 256 |
| Hidden dimension | 768 |
| Trainable parameters | 2,760,705 |
| Optimizer | AdamW |
| Learning rate | 1e-3 |
| Weight decay | 1e-4 |
| LR schedule | Cosine with linear warmup (20 epochs) |
| Minimum LR | 1e-6 |
| Batch size | 256 |
| Sampling | Age-balanced weighted random sampling |
| Epochs | 146 (early stopping, patience=50) |
| Gradient clipping | max norm = 1.0 |
Loss Weights
| Loss | Weight |
|---|---|
| Multi-Similarity | 1.0 |
| Proxy-NCA | 1.0 |
| Soft Contrastive | 0.5 |
| Age Regression | 0.5 |
Curriculum Schedule
| Phase | Epoch Range | Positive Threshold | Negative Threshold |
|---|---|---|---|
| Easy | 0-25 | 36 months | 108 months |
| Medium | 25-75 | 21 months | 63 months |
| Hard | 75-100 | 9 months | 27 months |
| Final | 100+ | 6 months | 18 months |
Training Trajectory
Epoch 1: Loss=56.07, R@1=0.036, Age MAE=22.2m
Epoch 50: Loss=11.66, R@1=0.050, Age MAE=18.9m
Epoch 100: Loss= 9.01, R@1=0.108, Age MAE=17.5m
Epoch 146: Early stopping (best R@1=0.108, Age MAE=15.8m)
Pre-Built FAISS Indices
Pre-built indices for all 8 demographic partitions are available on Kaggle:
Download: kaggle.com/datasets/rohitrajesh/indices-projected-256d
Each partition contains:
{Sex}_{Race}.indexβ FAISS IndexFlatL2 with 256-D projected vectors{Sex}_{Race}_meta.jsonβ Aligned metadata with age_months, age_years, image paths
Design Rationale
Why Not Search in Raw 1152-D Space?
Raw MedSigLIP embeddings encode general medical image features, not age-specific information. The projector:
- Reduces dimensionality (1152 β 256) for faster FAISS search
- Creates age-aware clusters through contrastive training
- Produces distance that correlates with retrieval quality (r = +0.26)
- Enables calibrated confidence estimation
Why Demographic Partitioning?
Skeletal maturation rates differ by sex and ethnicity. A 12-year-old male Asian hand looks different from a 12-year-old female Hispanic hand. Partitioned indices ensure "Visual Twins" come from biologically relevant reference populations, addressing the Caucasian bias inherent in the traditional Greulich-Pyle atlas.
Why Not Use Retrieval for Prediction?
At 19.13-month MAE, the retrieval system is too noisy to improve the regressor (8.81-month MAE) numerically. Its value is:
- Explainability β show clinicians similar atlas cases
- Confidence β distance-based calibration identifies when predictions are reliable
- Fairness auditing β demographic partitions make bias visible
Limitations
- The atlas contains only 1,390 reference cases. Retrieval quality scales with atlas size β larger atlases would improve both MAE and confidence calibration.
- Training was limited by the small dataset size (1,251 training samples across 8 classes). This constrains the complexity of learnable representations.
- The projector was trained on the USC Digital Hand Atlas, which may not fully represent all ethnic populations globally.
Citation
@misc{chronosmsk2025,
title={Chronos-MSK: Bias-Aware Skeletal Maturity Assessment at the Edge},
author={Rohit Rajesh},
year={2025},
note={Google HAI-DEF Competition Submission}
}
References
- Wang, X., et al. Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning. CVPR, 2019.
- Movshovitz-Attias, Y., et al. No Fuss Distance Metric Learning Using Proxies. ICCV, 2017.
Links
- Full Pipeline: github.com/04RR/chronos-msk
- Pre-built Indices: kaggle.com/datasets/rohitrajesh/indices-projected-256d
- All Model Weights: kaggle.com/datasets/rohitrajesh/model-checkpoints
- Base Model: google/medsiglip-448
- Digital Hand Atlas: ipilab.usc.edu/research/baaweb
Model tree for 04RR/chronos-msk-projector
Base model
google/medsiglip-448