Chronos-MSK Demographic Projector

A SOTA metric learning model that projects MedSigLIP-448 embeddings (1152-D) into a compact 256-D space optimized for age-aware, demographic-stratified bone age retrieval. Enables "Visual Twin" matching β€” finding atlas cases with similar skeletal maturity from biologically relevant populations.

Part of the Chronos-MSK multi-agent bone age assessment system, built for the Google HAI-DEF competition.

Model Description

Bone age varies by sex and ethnicity, yet most retrieval systems use a single embedding space that conflates demographic variation with skeletal maturity differences. This projector creates a structured embedding space where:

  1. Proximity correlates with age similarity β€” closer embeddings mean closer bone ages
  2. Demographic structure is preserved β€” separate clusters for each sex/race combination
  3. Distance is a calibrated quality signal β€” retrieval distance correlates with prediction error (r = +0.26)

The model was trained on the USC Digital Hand Atlas, which provides explicit demographic metadata across four racial groups (Asian, Black, Caucasian, Hispanic) and both sexes.

Architecture

Input: 1152-D MedSigLIP-448 pooled embedding (L2-normalized)
  β”‚
  β”œβ”€β”€ Trunk (main pathway):
  β”‚     Linear(1152 β†’ 768) β†’ LayerNorm β†’ GELU β†’ Dropout(0.1)
  β”‚     β†’ Linear(768 β†’ 768) β†’ LayerNorm β†’ GELU β†’ Dropout(0.1)
  β”‚
  β”œβ”€β”€ Skip (residual connection):
  β”‚     Linear(1152 β†’ 768)
  β”‚
  └── Add: Trunk output + Skip output β†’ 768-D hidden representation
       β”‚
       β”œβ”€β”€ Projection Head:
       β”‚     Linear(768 β†’ 256) β†’ L2 Normalize
       β”‚     Output: 256-D unit-normalized embedding for FAISS retrieval
       β”‚
       β”œβ”€β”€ Age Head (auxiliary, multi-task):
       β”‚     Linear(768 β†’ 256) β†’ GELU β†’ Linear(256 β†’ 1)
       β”‚     Output: Predicted age in months (training signal only)
       β”‚
       └── Proxies (learnable):
             8 Γ— 256-D vectors (one per demographic class)
             Used by Proxy-NCA loss during training

Training: Multi-Objective SOTA Approach

The projector was trained with four complementary loss functions:

1. Multi-Similarity Loss (Wang et al., CVPR 2019)

  • Mines ALL informative positive and negative pairs in each batch
  • Pairs are defined by both demographic class AND age proximity
  • Positive: same class, age gap < threshold
  • Negative: different class OR same class with large age gap
  • Far more efficient than random triplet mining

2. Proxy-NCA Loss (Movshovitz-Attias et al., ICCV 2017)

  • Learns a proxy centroid for each of 8 demographic classes
  • O(BΓ—C) complexity instead of O(BΒ²) β€” faster and more stable gradients
  • Provides clean demographic cluster structure

3. Age-Continuous Soft Contrastive Loss

  • Instead of binary same/different labels, weights pair attraction by age distance
  • Pair weight: w_ij = exp(-|age_i - age_j| / Οƒ), where Οƒ = 12 months
  • Creates smooth age gradients within each demographic cluster

4. Auxiliary Age Regression (SmoothL1 Loss)

  • Multi-task signal from the age prediction head
  • Ensures the hidden representation encodes age information
  • Does not affect the projection head output at inference time

Curriculum Learning:

  • Age thresholds for positive/negative pair mining progressively tighten over training
  • Start: positives within 36 months, negatives beyond 108 months (easy)
  • End: positives within 6 months, negatives beyond 18 months (hard)
  • Transition: cosine schedule over 100 epochs

Performance

Metric Value
Retrieval MAE (Archivist-only prediction) 19.13 months
Distance ↔ Error Correlation r = +0.26 (p < 1e-23)
Confidence Calibration r = +0.20 (monotonic across tiers)
Recall@1 (same class, age within 12m) 10.8%
Recall@5 21.6%
Retrieval Age MAE (val set) 15.8 months

Confidence Calibration

When used in the Chronos-MSK pipeline, retrieval distance enables calibrated confidence tiers:

Tier Cases Pipeline MAE Within Β±12m
HIGH 535 (37.5%) 6.96 months 83.9%
MODERATE 644 (45.2%) 9.69 months 67.2%
LOW 245 (17.2%) 10.71 months 64.5%

HIGH-confidence predictions achieve 6.96-month MAE β€” this calibration is the projector's primary value in the pipeline.

Usage

Installation

pip install torch numpy faiss-cpu huggingface_hub

Load the Projector

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from huggingface_hub import hf_hub_download


class DemographicProjectorSOTA(nn.Module):
    """
    Projects 1152-D MedSigLIP embeddings to 256-D metric space.
    Must match the training architecture exactly.
    """
    def __init__(self, input_dim=1152, output_dim=256,
                 hidden_dim=768, num_classes=8):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
        )
        self.skip = nn.Linear(input_dim, hidden_dim)
        self.projector = nn.Sequential(nn.Linear(hidden_dim, output_dim))
        self.age_head = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.GELU(),
            nn.Linear(256, 1),
        )
        self.proxies = nn.Parameter(torch.randn(num_classes, output_dim) * 0.1)

    def forward(self, x, return_age=False):
        h = self.trunk(x) + self.skip(x)
        emb = self.projector(h)
        emb = F.normalize(emb, p=2, dim=1)
        if return_age:
            age_pred = self.age_head(h).squeeze(-1)
            return emb, age_pred
        return emb


def load_projector(repo_id="rohitrajesh/chronos-msk-projector", device="cpu"):
    """Load the trained projector from HuggingFace."""
    weights_path = hf_hub_download(repo_id, "projector_sota.pth")
    
    model = DemographicProjectorSOTA(
        input_dim=1152,
        output_dim=256,
        hidden_dim=768,
        num_classes=8,
    )
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.to(device).eval()
    
    return model


# Load
device = "cuda" if torch.cuda.is_available() else "cpu"
projector = load_projector(device=device)
print("Projector loaded!")

Project an Embedding

# Assume you have a 1152-D MedSigLIP embedding (from any image)
raw_embedding = np.random.randn(1152).astype(np.float32)  # Replace with real embedding

with torch.no_grad():
    tensor = torch.from_numpy(raw_embedding).unsqueeze(0).to(device)
    projected = projector(tensor)  # Shape: (1, 256), L2-normalized

print(f"Input shape:  {tensor.shape}")       # [1, 1152]
print(f"Output shape: {projected.shape}")     # [1, 256]
print(f"L2 norm:      {projected.norm():.4f}")  # ~1.0

End-to-End: Image β†’ MedSigLIP β†’ Projector β†’ FAISS Search

import cv2
import faiss
from transformers import SiglipVisionModel, AutoProcessor


def letterbox_resize(image, size=448):
    h, w = image.shape[:2]
    scale = size / max(h, w)
    nh, nw = int(h * scale), int(w * scale)
    resized = cv2.resize(image, (nw, nh))
    padded = np.zeros((size, size, 3), dtype=np.uint8)
    top, left = (size - nh) // 2, (size - nw) // 2
    padded[top:top + nh, left:left + nw] = resized
    return padded


# Step 1: Load MedSigLIP encoder
encoder = SiglipVisionModel.from_pretrained("google/medsiglip-448").to(device).eval()
processor = AutoProcessor.from_pretrained("google/medsiglip-448")

# Step 2: Embed an X-ray image
img = cv2.imread("hand_xray.png")
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_resized = letterbox_resize(img_rgb, 448)

inputs = processor(images=img_resized, return_tensors="pt").to(device)
with torch.no_grad():
    raw_emb = encoder(**inputs).pooler_output                # [1, 1152]
    raw_emb = raw_emb / raw_emb.norm(p=2, dim=-1, keepdim=True)

# Step 3: Project to 256-D retrieval space
with torch.no_grad():
    projected = projector(raw_emb)                           # [1, 256]

# Step 4: Search a FAISS index
# (Build from atlas embeddings, or download pre-built from Kaggle)
# index = faiss.read_index("Male_Caucasian.index")
# distances, indices = index.search(projected.cpu().numpy(), k=5)
# print(f"Top 5 matches: {indices[0]}, distances: {distances[0]}")

Get Age Prediction from Auxiliary Head

# The projector also has an auxiliary age prediction head (used during training)
with torch.no_grad():
    projected_emb, age_pred = projector(raw_emb, return_age=True)

print(f"Projected embedding: {projected_emb.shape}")  # [1, 256]
print(f"Predicted age: {age_pred.item():.0f} months")  # Rough estimate
# Note: The age head is auxiliary β€” use the main regressor for accurate predictions

Batch Processing

# Project a batch of embeddings efficiently
batch_embeddings = np.random.randn(100, 1152).astype(np.float32)

with torch.no_grad():
    batch_tensor = torch.from_numpy(batch_embeddings).to(device)
    batch_projected = projector(batch_tensor)  # [100, 256]

print(f"Batch projected: {batch_projected.shape}")
print(f"All unit-normalized: {batch_projected.norm(dim=1).mean():.4f}")

Training Details

Data

  • Dataset: USC Digital Hand Atlas (1,390 images)
  • Demographics: Evenly distributed across 8 partitions
Partition Count
Male Asian 167
Male Black 184
Male Caucasian 167
Male Hispanic 182
Female Asian 167
Female Black 174
Female Caucasian 166
Female Hispanic 183
  • Age Range: 0-18 years (0-216 months)
  • Split: 90% train (1,251), 10% validation (139)
  • Embeddings: Pre-computed with frozen MedSigLIP-448 (cached to disk)

Hyperparameters

Parameter Value
Input dimension 1152
Output dimension 256
Hidden dimension 768
Trainable parameters 2,760,705
Optimizer AdamW
Learning rate 1e-3
Weight decay 1e-4
LR schedule Cosine with linear warmup (20 epochs)
Minimum LR 1e-6
Batch size 256
Sampling Age-balanced weighted random sampling
Epochs 146 (early stopping, patience=50)
Gradient clipping max norm = 1.0

Loss Weights

Loss Weight
Multi-Similarity 1.0
Proxy-NCA 1.0
Soft Contrastive 0.5
Age Regression 0.5

Curriculum Schedule

Phase Epoch Range Positive Threshold Negative Threshold
Easy 0-25 36 months 108 months
Medium 25-75 21 months 63 months
Hard 75-100 9 months 27 months
Final 100+ 6 months 18 months

Training Trajectory

Epoch   1: Loss=56.07, R@1=0.036, Age MAE=22.2m
Epoch  50: Loss=11.66, R@1=0.050, Age MAE=18.9m
Epoch 100: Loss= 9.01, R@1=0.108, Age MAE=17.5m
Epoch 146: Early stopping (best R@1=0.108, Age MAE=15.8m)

Pre-Built FAISS Indices

Pre-built indices for all 8 demographic partitions are available on Kaggle:

Download: kaggle.com/datasets/rohitrajesh/indices-projected-256d

Each partition contains:

  • {Sex}_{Race}.index β€” FAISS IndexFlatL2 with 256-D projected vectors
  • {Sex}_{Race}_meta.json β€” Aligned metadata with age_months, age_years, image paths

Design Rationale

Why Not Search in Raw 1152-D Space?

Raw MedSigLIP embeddings encode general medical image features, not age-specific information. The projector:

  • Reduces dimensionality (1152 β†’ 256) for faster FAISS search
  • Creates age-aware clusters through contrastive training
  • Produces distance that correlates with retrieval quality (r = +0.26)
  • Enables calibrated confidence estimation

Why Demographic Partitioning?

Skeletal maturation rates differ by sex and ethnicity. A 12-year-old male Asian hand looks different from a 12-year-old female Hispanic hand. Partitioned indices ensure "Visual Twins" come from biologically relevant reference populations, addressing the Caucasian bias inherent in the traditional Greulich-Pyle atlas.

Why Not Use Retrieval for Prediction?

At 19.13-month MAE, the retrieval system is too noisy to improve the regressor (8.81-month MAE) numerically. Its value is:

  1. Explainability β€” show clinicians similar atlas cases
  2. Confidence β€” distance-based calibration identifies when predictions are reliable
  3. Fairness auditing β€” demographic partitions make bias visible

Limitations

  • The atlas contains only 1,390 reference cases. Retrieval quality scales with atlas size β€” larger atlases would improve both MAE and confidence calibration.
  • Training was limited by the small dataset size (1,251 training samples across 8 classes). This constrains the complexity of learnable representations.
  • The projector was trained on the USC Digital Hand Atlas, which may not fully represent all ethnic populations globally.

Citation

@misc{chronosmsk2025,
    title={Chronos-MSK: Bias-Aware Skeletal Maturity Assessment at the Edge},
    author={Rohit Rajesh},
    year={2025},
    note={Google HAI-DEF Competition Submission}
}

References

  • Wang, X., et al. Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning. CVPR, 2019.
  • Movshovitz-Attias, Y., et al. No Fuss Distance Metric Learning Using Proxies. ICCV, 2017.

Links

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for 04RR/chronos-msk-projector

Finetuned
(41)
this model

Collection including 04RR/chronos-msk-projector