Numan Saeed
View-aware GA with WHO biometry formulas
cbd23a5
import torch
import open_clip
import json
import numpy as np
from PIL import Image
from pathlib import Path
from huggingface_hub import hf_hub_download
from typing import List, Dict, Tuple, Optional
from .preprocessing import preprocess_file, preprocess_image
# Constants
MODEL_NAME = "numansaeed/fetalclip-model"
INPUT_SIZE = 224
TOP_N_PROBS = 15
# GA Text prompts - view-specific prompts for brain, abdomen, and femur
GA_TEXT_PROMPTS = {
"brain": [
"Ultrasound image at {weeks} weeks and {day} days gestation focusing on the fetal brain, highlighting anatomical structures with a pixel spacing of {pixel_spacing} mm/pixel.",
"Fetal ultrasound image at {weeks} weeks, {day} days of gestation, focusing on the developing brain, with a pixel spacing of {pixel_spacing} mm/pixel, highlighting the structures of the fetal brain.",
"Fetal ultrasound image at {weeks} weeks and {day} days gestational age, highlighting the developing brain structures with a pixel spacing of {pixel_spacing} mm/pixel, providing important visual insights for ongoing prenatal assessments.",
"Ultrasound image at {weeks} weeks and {day} days gestation, highlighting the fetal brain structures with a pixel spacing of {pixel_spacing} mm/pixel.",
"Fetal ultrasound at {weeks} weeks and {day} days, showing a clear view of the developing brain, with an image pixel spacing of {pixel_spacing} mm/pixel."
],
"abdomen": [
"Fetal ultrasound at {weeks} weeks and {day} days gestation, focusing on the abdominal area, highlighting structural development with a pixel spacing of {pixel_spacing} mm/pixel.",
"Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the fetal abdomen, with pixel spacing of {pixel_spacing} mm/pixel, highlighting the structural development in this stage of gestation.",
"Ultrasound image of the fetal abdomen at {weeks} weeks and {day} days gestational age, highlighting anatomical structures with a pixel spacing of {pixel_spacing} mm/pixel.",
"Ultrasound image of the fetal abdomen at {weeks} weeks and {day} days gestational age, highlighting the development of abdominal structures, with a pixel spacing of {pixel_spacing} mm/pixel.",
"Fetal ultrasound image at {weeks} weeks and {day} days gestational age, focusing on the abdomen with a pixel spacing of {pixel_spacing} mm/pixel."
],
"femur": [
"Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the developing fetal femur, with a pixel spacing of {pixel_spacing} mm/pixel, highlighting bone length and structure.",
"The ultrasound image highlights the fetal femur at {weeks} weeks and {day} days of gestation, with a pixel spacing of {pixel_spacing} mm/pixel, providing a detailed view of the developing bone.",
"Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the fetal femur, highlighting skeletal development at a pixel spacing of {pixel_spacing} mm/pixel.",
"Fetal ultrasound image at {weeks} weeks and {day} days gestation, highlighting the femur with a pixel spacing of {pixel_spacing} mm/pixel, providing a detailed view of bone development.",
"Ultrasound image at {weeks} weeks and {day} days gestation, highlighting the fetal femur with a pixel spacing of {pixel_spacing} mm/pixel."
]
}
LIST_GA_IN_DAYS = [weeks * 7 + days for weeks in range(14, 39) for days in range(0, 7)]
class FetalCLIPService:
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if FetalCLIPService._initialized:
return
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.preprocess = None
self.tokenizer = None
self.text_features = None
self.list_plane = []
FetalCLIPService._initialized = True
def load_model(self, assets_dir: Path):
"""Load the FetalCLIP model and precompute text features."""
config_path = assets_dir / "FetalCLIP_config.json"
prompts_path = assets_dir / "prompt_fetal_view.json"
# Load config
with open(config_path, "r") as f:
config = json.load(f)
open_clip.factory._MODEL_CONFIGS["FetalCLIP"] = config
# Download weights
weights_path = hf_hub_download(
repo_id=MODEL_NAME,
filename="FetalCLIP_weights.pt"
)
# Create model
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
"FetalCLIP",
pretrained=weights_path
)
self.tokenizer = open_clip.get_tokenizer("FetalCLIP")
self.model = self.model.float()
self.model.eval()
self.model.to(self.device)
# Load text prompts and compute features
with open(prompts_path, 'r') as f:
text_prompts = json.load(f)
list_text_features = []
self.list_plane = []
with torch.no_grad():
for plane, prompts in text_prompts.items():
self.list_plane.append(plane)
tokens = self.tokenizer(prompts).to(self.device)
features = self.model.encode_text(tokens)
features /= features.norm(dim=-1, keepdim=True)
features = features.mean(dim=0).unsqueeze(0)
features /= features.norm(dim=-1, keepdim=True)
list_text_features.append(features)
self.text_features = torch.stack(list_text_features)[:, 0]
print(f"✓ FetalCLIP model loaded on {self.device}")
return True
def classify_view(self, image: Image.Image, top_k: int = 5) -> List[Dict]:
"""Classify fetal ultrasound view from preprocessed image."""
if self.model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
top_k = min(top_k, len(self.list_plane))
# Apply model preprocessing (resize to 224, normalize)
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
image_features = self.model.encode_image(img_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
# Compute similarity
similarity = (99.2198 * image_features @ self.text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(top_k)
results = []
for idx, val in zip(indices, values):
results.append({
"label": self.list_plane[idx],
"confidence": round(val.item() * 100, 2)
})
return results
def classify_from_file(self, file_bytes: bytes, filename: str, top_k: int = 5) -> Tuple[List[Dict], Dict]:
"""
Classify from raw file bytes with automatic preprocessing.
Returns:
Tuple of (predictions, preprocessing_info)
"""
# Preprocess based on file type
processed_image, preprocessing_info = preprocess_file(file_bytes, filename)
# Classify
predictions = self.classify_view(processed_image, top_k)
return predictions, preprocessing_info
def _get_ga_text_features(self, template: str, pixel_spacing: float) -> torch.Tensor:
"""Generate text features for GA estimation."""
prompts = []
for weeks in range(14, 39):
for days in range(0, 7):
prompt = template.format(
weeks=weeks,
day=days,
pixel_spacing=f"{pixel_spacing:.2f}"
)
prompts.append(prompt)
with torch.no_grad():
tokens = self.tokenizer(prompts).to(self.device)
features = self.model.encode_text(tokens)
features /= features.norm(dim=-1, keepdim=True)
return features
def _get_unnormalized_dot_products(self, image_features: torch.Tensor, list_text_features: List[torch.Tensor]) -> torch.Tensor:
"""Compute dot products between image and text features."""
text_features = torch.cat(list_text_features, dim=0)
text_dot_prods = (100.0 * image_features @ text_features.T)
n_prompts = len(list_text_features)
n_days = len(list_text_features[0])
text_dot_prods = text_dot_prods.view(image_features.shape[0], n_prompts, n_days)
text_dot_prods = text_dot_prods.mean(dim=1)
return text_dot_prods
def _find_median_from_top_n(self, text_dot_prods: np.ndarray, n: int) -> int:
"""Find median index from top N predictions."""
tmp = [[i, t] for i, t in enumerate(text_dot_prods)]
tmp = sorted(tmp, key=lambda x: x[1], reverse=True)[:n]
tmp = sorted(tmp, key=lambda x: x[0])
return tmp[n // 2][0]
def _get_biometry_from_ga(self, ga_days: int, biometry_type: str, percentile: str = '0.5') -> float:
"""
Calculate expected fetal biometry from gestational age using WHO coefficients.
Formula: measurement = exp(b0 + b1*GA + b2*GA² + b3*GA³ + b4*GA⁴)
where GA is in weeks.
Args:
ga_days: Gestational age in days
biometry_type: 'HC', 'AC', or 'FL'
percentile: '0.025', '0.5', or '0.975'
Returns:
Expected measurement in mm
"""
ga_weeks = ga_days / 7
# WHO Fetal Growth Coefficients (from coefficientsGlobalV3.csv)
WHO_COEFFICIENTS = {
# Head Circumference (mm)
'HC': {
'0.025': [1.59317517131532e+0, 2.9459800552433e-1, -7.3860372566707e-3, 6.56951770216148e-5, 0e+0],
'0.5': [2.09924879247164e+0, 2.53373656106037e-1, -6.05647816678282e-3, 5.14256072059917e-5, 0e+0],
'0.975': [2.50074069629423e+0, 2.20067854715719e-1, -4.93623111462443e-3, 3.89066000946519e-5, 0e+0],
},
# Abdominal Circumference (mm)
'AC': {
'0.025': [1.19202778944614e+0, 3.14756681991964e-1, -8.01581308902169e-3, 7.51395976546808e-5, 0e+0],
'0.5': [1.58552931028045e+0, 2.89936781915424e-1, -7.32651929135797e-3, 6.9261631643994e-5, 0e+0],
'0.975': [2.03674472691951e+0, 2.57138461817474e-1, -6.34918788914223e-3, 6.0053745113196e-5, 0e+0],
},
# Femur Length (mm) - uses all 5 coefficients
'FL': {
'0.025': [-7.27187176976836e+0, 1.28298928826162e+0, -5.80601892487905e-2, 1.21314319801879e-3, -9.60171505470123e-6],
'0.5': [-5.54922620776446e+0, 1.09559990166124e+0, -5.01310925949098e-2, 1.0678072569586e-3, -8.63970606288493e-6],
'0.975': [-3.64483930811801e+0, 8.57028131514986e-1, -3.84005685481303e-2, 8.12062784461527e-4, -6.55932416998498e-6],
},
}
if biometry_type not in WHO_COEFFICIENTS:
raise ValueError(f"Unknown biometry type: {biometry_type}")
if percentile not in WHO_COEFFICIENTS[biometry_type]:
raise ValueError(f"Unknown percentile: {percentile}")
b0, b1, b2, b3, b4 = WHO_COEFFICIENTS[biometry_type][percentile]
return np.exp(b0 + b1*ga_weeks + b2*ga_weeks**2 + b3*ga_weeks**3 + b4*ga_weeks**4)
def estimate_gestational_age(self, image: Image.Image, pixel_size: float, view: str = "brain") -> Dict:
"""Estimate gestational age from preprocessed fetal ultrasound."""
if self.model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
# Calculate effective pixel spacing
pixel_spacing = max(image.size) / INPUT_SIZE * pixel_size
# Apply model preprocessing
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
image_features = self.model.encode_image(img_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
# Get text features for all prompts for the specified view
view_prompts = GA_TEXT_PROMPTS.get(view, GA_TEXT_PROMPTS["brain"])
text_features_list = [
self._get_ga_text_features(template, pixel_spacing)
for template in view_prompts
]
text_dot_prods = self._get_unnormalized_dot_products(image_features, text_features_list)
# Compute prediction
text_dot_prod = text_dot_prods.detach().cpu().numpy()[0]
med_idx = self._find_median_from_top_n(text_dot_prod, TOP_N_PROBS)
pred_day = LIST_GA_IN_DAYS[med_idx]
pred_weeks = pred_day // 7
pred_days = pred_day % 7
# Map view to biometry type
VIEW_TO_BIOMETRY = {
"brain": "HC",
"abdomen": "AC",
"femur": "FL"
}
biometry_type = VIEW_TO_BIOMETRY.get(view, "HC")
# Compute view-specific biometry percentiles using WHO formulas
q025 = self._get_biometry_from_ga(pred_day, biometry_type, '0.025')
q500 = self._get_biometry_from_ga(pred_day, biometry_type, '0.5')
q975 = self._get_biometry_from_ga(pred_day, biometry_type, '0.975')
# Biometry labels for response
BIOMETRY_LABELS = {
"HC": "head_circumference",
"AC": "abdominal_circumference",
"FL": "femur_length"
}
biometry_key = BIOMETRY_LABELS.get(biometry_type, "head_circumference")
# Biometry units
BIOMETRY_UNITS = {
"HC": "mm",
"AC": "mm",
"FL": "mm"
}
return {
"gestational_age": {
"weeks": pred_weeks,
"days": pred_days,
"total_days": pred_day
},
"view": view,
biometry_key: {
"p2_5": round(q025, 2),
"p50": round(q500, 2),
"p97_5": round(q975, 2)
}
}
def estimate_ga_from_file(self, file_bytes: bytes, filename: str, pixel_size: float, view: str = "brain") -> Tuple[Dict, Dict]:
"""
Estimate GA from raw file bytes with automatic preprocessing.
Returns:
Tuple of (ga_results, preprocessing_info)
"""
# Preprocess based on file type
processed_image, preprocessing_info = preprocess_file(file_bytes, filename)
# Use pixel spacing from DICOM if available
if preprocessing_info["type"] == "dicom":
pixel_size = preprocessing_info["metadata"].get("pixel_spacing", pixel_size)
# Estimate GA with the specified view
ga_results = self.estimate_gestational_age(processed_image, pixel_size, view)
return ga_results, preprocessing_info
# Singleton instance
model_service = FetalCLIPService()