import torch import open_clip import json import numpy as np from PIL import Image from pathlib import Path from huggingface_hub import hf_hub_download from typing import List, Dict, Tuple, Optional from .preprocessing import preprocess_file, preprocess_image # Constants MODEL_NAME = "numansaeed/fetalclip-model" INPUT_SIZE = 224 TOP_N_PROBS = 15 # GA Text prompts - view-specific prompts for brain, abdomen, and femur GA_TEXT_PROMPTS = { "brain": [ "Ultrasound image at {weeks} weeks and {day} days gestation focusing on the fetal brain, highlighting anatomical structures with a pixel spacing of {pixel_spacing} mm/pixel.", "Fetal ultrasound image at {weeks} weeks, {day} days of gestation, focusing on the developing brain, with a pixel spacing of {pixel_spacing} mm/pixel, highlighting the structures of the fetal brain.", "Fetal ultrasound image at {weeks} weeks and {day} days gestational age, highlighting the developing brain structures with a pixel spacing of {pixel_spacing} mm/pixel, providing important visual insights for ongoing prenatal assessments.", "Ultrasound image at {weeks} weeks and {day} days gestation, highlighting the fetal brain structures with a pixel spacing of {pixel_spacing} mm/pixel.", "Fetal ultrasound at {weeks} weeks and {day} days, showing a clear view of the developing brain, with an image pixel spacing of {pixel_spacing} mm/pixel." ], "abdomen": [ "Fetal ultrasound at {weeks} weeks and {day} days gestation, focusing on the abdominal area, highlighting structural development with a pixel spacing of {pixel_spacing} mm/pixel.", "Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the fetal abdomen, with pixel spacing of {pixel_spacing} mm/pixel, highlighting the structural development in this stage of gestation.", "Ultrasound image of the fetal abdomen at {weeks} weeks and {day} days gestational age, highlighting anatomical structures with a pixel spacing of {pixel_spacing} mm/pixel.", "Ultrasound image of the fetal abdomen at {weeks} weeks and {day} days gestational age, highlighting the development of abdominal structures, with a pixel spacing of {pixel_spacing} mm/pixel.", "Fetal ultrasound image at {weeks} weeks and {day} days gestational age, focusing on the abdomen with a pixel spacing of {pixel_spacing} mm/pixel." ], "femur": [ "Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the developing fetal femur, with a pixel spacing of {pixel_spacing} mm/pixel, highlighting bone length and structure.", "The ultrasound image highlights the fetal femur at {weeks} weeks and {day} days of gestation, with a pixel spacing of {pixel_spacing} mm/pixel, providing a detailed view of the developing bone.", "Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the fetal femur, highlighting skeletal development at a pixel spacing of {pixel_spacing} mm/pixel.", "Fetal ultrasound image at {weeks} weeks and {day} days gestation, highlighting the femur with a pixel spacing of {pixel_spacing} mm/pixel, providing a detailed view of bone development.", "Ultrasound image at {weeks} weeks and {day} days gestation, highlighting the fetal femur with a pixel spacing of {pixel_spacing} mm/pixel." ] } LIST_GA_IN_DAYS = [weeks * 7 + days for weeks in range(14, 39) for days in range(0, 7)] class FetalCLIPService: _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if FetalCLIPService._initialized: return self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.preprocess = None self.tokenizer = None self.text_features = None self.list_plane = [] FetalCLIPService._initialized = True def load_model(self, assets_dir: Path): """Load the FetalCLIP model and precompute text features.""" config_path = assets_dir / "FetalCLIP_config.json" prompts_path = assets_dir / "prompt_fetal_view.json" # Load config with open(config_path, "r") as f: config = json.load(f) open_clip.factory._MODEL_CONFIGS["FetalCLIP"] = config # Download weights weights_path = hf_hub_download( repo_id=MODEL_NAME, filename="FetalCLIP_weights.pt" ) # Create model self.model, _, self.preprocess = open_clip.create_model_and_transforms( "FetalCLIP", pretrained=weights_path ) self.tokenizer = open_clip.get_tokenizer("FetalCLIP") self.model = self.model.float() self.model.eval() self.model.to(self.device) # Load text prompts and compute features with open(prompts_path, 'r') as f: text_prompts = json.load(f) list_text_features = [] self.list_plane = [] with torch.no_grad(): for plane, prompts in text_prompts.items(): self.list_plane.append(plane) tokens = self.tokenizer(prompts).to(self.device) features = self.model.encode_text(tokens) features /= features.norm(dim=-1, keepdim=True) features = features.mean(dim=0).unsqueeze(0) features /= features.norm(dim=-1, keepdim=True) list_text_features.append(features) self.text_features = torch.stack(list_text_features)[:, 0] print(f"✓ FetalCLIP model loaded on {self.device}") return True def classify_view(self, image: Image.Image, top_k: int = 5) -> List[Dict]: """Classify fetal ultrasound view from preprocessed image.""" if self.model is None: raise RuntimeError("Model not loaded. Call load_model() first.") top_k = min(top_k, len(self.list_plane)) # Apply model preprocessing (resize to 224, normalize) img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): image_features = self.model.encode_image(img_tensor) image_features /= image_features.norm(dim=-1, keepdim=True) # Compute similarity similarity = (99.2198 * image_features @ self.text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(top_k) results = [] for idx, val in zip(indices, values): results.append({ "label": self.list_plane[idx], "confidence": round(val.item() * 100, 2) }) return results def classify_from_file(self, file_bytes: bytes, filename: str, top_k: int = 5) -> Tuple[List[Dict], Dict]: """ Classify from raw file bytes with automatic preprocessing. Returns: Tuple of (predictions, preprocessing_info) """ # Preprocess based on file type processed_image, preprocessing_info = preprocess_file(file_bytes, filename) # Classify predictions = self.classify_view(processed_image, top_k) return predictions, preprocessing_info def _get_ga_text_features(self, template: str, pixel_spacing: float) -> torch.Tensor: """Generate text features for GA estimation.""" prompts = [] for weeks in range(14, 39): for days in range(0, 7): prompt = template.format( weeks=weeks, day=days, pixel_spacing=f"{pixel_spacing:.2f}" ) prompts.append(prompt) with torch.no_grad(): tokens = self.tokenizer(prompts).to(self.device) features = self.model.encode_text(tokens) features /= features.norm(dim=-1, keepdim=True) return features def _get_unnormalized_dot_products(self, image_features: torch.Tensor, list_text_features: List[torch.Tensor]) -> torch.Tensor: """Compute dot products between image and text features.""" text_features = torch.cat(list_text_features, dim=0) text_dot_prods = (100.0 * image_features @ text_features.T) n_prompts = len(list_text_features) n_days = len(list_text_features[0]) text_dot_prods = text_dot_prods.view(image_features.shape[0], n_prompts, n_days) text_dot_prods = text_dot_prods.mean(dim=1) return text_dot_prods def _find_median_from_top_n(self, text_dot_prods: np.ndarray, n: int) -> int: """Find median index from top N predictions.""" tmp = [[i, t] for i, t in enumerate(text_dot_prods)] tmp = sorted(tmp, key=lambda x: x[1], reverse=True)[:n] tmp = sorted(tmp, key=lambda x: x[0]) return tmp[n // 2][0] def _get_biometry_from_ga(self, ga_days: int, biometry_type: str, percentile: str = '0.5') -> float: """ Calculate expected fetal biometry from gestational age using WHO coefficients. Formula: measurement = exp(b0 + b1*GA + b2*GA² + b3*GA³ + b4*GA⁴) where GA is in weeks. Args: ga_days: Gestational age in days biometry_type: 'HC', 'AC', or 'FL' percentile: '0.025', '0.5', or '0.975' Returns: Expected measurement in mm """ ga_weeks = ga_days / 7 # WHO Fetal Growth Coefficients (from coefficientsGlobalV3.csv) WHO_COEFFICIENTS = { # Head Circumference (mm) 'HC': { '0.025': [1.59317517131532e+0, 2.9459800552433e-1, -7.3860372566707e-3, 6.56951770216148e-5, 0e+0], '0.5': [2.09924879247164e+0, 2.53373656106037e-1, -6.05647816678282e-3, 5.14256072059917e-5, 0e+0], '0.975': [2.50074069629423e+0, 2.20067854715719e-1, -4.93623111462443e-3, 3.89066000946519e-5, 0e+0], }, # Abdominal Circumference (mm) 'AC': { '0.025': [1.19202778944614e+0, 3.14756681991964e-1, -8.01581308902169e-3, 7.51395976546808e-5, 0e+0], '0.5': [1.58552931028045e+0, 2.89936781915424e-1, -7.32651929135797e-3, 6.9261631643994e-5, 0e+0], '0.975': [2.03674472691951e+0, 2.57138461817474e-1, -6.34918788914223e-3, 6.0053745113196e-5, 0e+0], }, # Femur Length (mm) - uses all 5 coefficients 'FL': { '0.025': [-7.27187176976836e+0, 1.28298928826162e+0, -5.80601892487905e-2, 1.21314319801879e-3, -9.60171505470123e-6], '0.5': [-5.54922620776446e+0, 1.09559990166124e+0, -5.01310925949098e-2, 1.0678072569586e-3, -8.63970606288493e-6], '0.975': [-3.64483930811801e+0, 8.57028131514986e-1, -3.84005685481303e-2, 8.12062784461527e-4, -6.55932416998498e-6], }, } if biometry_type not in WHO_COEFFICIENTS: raise ValueError(f"Unknown biometry type: {biometry_type}") if percentile not in WHO_COEFFICIENTS[biometry_type]: raise ValueError(f"Unknown percentile: {percentile}") b0, b1, b2, b3, b4 = WHO_COEFFICIENTS[biometry_type][percentile] return np.exp(b0 + b1*ga_weeks + b2*ga_weeks**2 + b3*ga_weeks**3 + b4*ga_weeks**4) def estimate_gestational_age(self, image: Image.Image, pixel_size: float, view: str = "brain") -> Dict: """Estimate gestational age from preprocessed fetal ultrasound.""" if self.model is None: raise RuntimeError("Model not loaded. Call load_model() first.") # Calculate effective pixel spacing pixel_spacing = max(image.size) / INPUT_SIZE * pixel_size # Apply model preprocessing img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): image_features = self.model.encode_image(img_tensor) image_features /= image_features.norm(dim=-1, keepdim=True) # Get text features for all prompts for the specified view view_prompts = GA_TEXT_PROMPTS.get(view, GA_TEXT_PROMPTS["brain"]) text_features_list = [ self._get_ga_text_features(template, pixel_spacing) for template in view_prompts ] text_dot_prods = self._get_unnormalized_dot_products(image_features, text_features_list) # Compute prediction text_dot_prod = text_dot_prods.detach().cpu().numpy()[0] med_idx = self._find_median_from_top_n(text_dot_prod, TOP_N_PROBS) pred_day = LIST_GA_IN_DAYS[med_idx] pred_weeks = pred_day // 7 pred_days = pred_day % 7 # Map view to biometry type VIEW_TO_BIOMETRY = { "brain": "HC", "abdomen": "AC", "femur": "FL" } biometry_type = VIEW_TO_BIOMETRY.get(view, "HC") # Compute view-specific biometry percentiles using WHO formulas q025 = self._get_biometry_from_ga(pred_day, biometry_type, '0.025') q500 = self._get_biometry_from_ga(pred_day, biometry_type, '0.5') q975 = self._get_biometry_from_ga(pred_day, biometry_type, '0.975') # Biometry labels for response BIOMETRY_LABELS = { "HC": "head_circumference", "AC": "abdominal_circumference", "FL": "femur_length" } biometry_key = BIOMETRY_LABELS.get(biometry_type, "head_circumference") # Biometry units BIOMETRY_UNITS = { "HC": "mm", "AC": "mm", "FL": "mm" } return { "gestational_age": { "weeks": pred_weeks, "days": pred_days, "total_days": pred_day }, "view": view, biometry_key: { "p2_5": round(q025, 2), "p50": round(q500, 2), "p97_5": round(q975, 2) } } def estimate_ga_from_file(self, file_bytes: bytes, filename: str, pixel_size: float, view: str = "brain") -> Tuple[Dict, Dict]: """ Estimate GA from raw file bytes with automatic preprocessing. Returns: Tuple of (ga_results, preprocessing_info) """ # Preprocess based on file type processed_image, preprocessing_info = preprocess_file(file_bytes, filename) # Use pixel spacing from DICOM if available if preprocessing_info["type"] == "dicom": pixel_size = preprocessing_info["metadata"].get("pixel_spacing", pixel_size) # Estimate GA with the specified view ga_results = self.estimate_gestational_age(processed_image, pixel_size, view) return ga_results, preprocessing_info # Singleton instance model_service = FetalCLIPService()