Spaces:
Running
Running
| import torch | |
| import open_clip | |
| import json | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| from typing import List, Dict, Tuple, Optional | |
| from .preprocessing import preprocess_file, preprocess_image | |
| # Constants | |
| MODEL_NAME = "numansaeed/fetalclip-model" | |
| INPUT_SIZE = 224 | |
| TOP_N_PROBS = 15 | |
| # GA Text prompts - view-specific prompts for brain, abdomen, and femur | |
| GA_TEXT_PROMPTS = { | |
| "brain": [ | |
| "Ultrasound image at {weeks} weeks and {day} days gestation focusing on the fetal brain, highlighting anatomical structures with a pixel spacing of {pixel_spacing} mm/pixel.", | |
| "Fetal ultrasound image at {weeks} weeks, {day} days of gestation, focusing on the developing brain, with a pixel spacing of {pixel_spacing} mm/pixel, highlighting the structures of the fetal brain.", | |
| "Fetal ultrasound image at {weeks} weeks and {day} days gestational age, highlighting the developing brain structures with a pixel spacing of {pixel_spacing} mm/pixel, providing important visual insights for ongoing prenatal assessments.", | |
| "Ultrasound image at {weeks} weeks and {day} days gestation, highlighting the fetal brain structures with a pixel spacing of {pixel_spacing} mm/pixel.", | |
| "Fetal ultrasound at {weeks} weeks and {day} days, showing a clear view of the developing brain, with an image pixel spacing of {pixel_spacing} mm/pixel." | |
| ], | |
| "abdomen": [ | |
| "Fetal ultrasound at {weeks} weeks and {day} days gestation, focusing on the abdominal area, highlighting structural development with a pixel spacing of {pixel_spacing} mm/pixel.", | |
| "Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the fetal abdomen, with pixel spacing of {pixel_spacing} mm/pixel, highlighting the structural development in this stage of gestation.", | |
| "Ultrasound image of the fetal abdomen at {weeks} weeks and {day} days gestational age, highlighting anatomical structures with a pixel spacing of {pixel_spacing} mm/pixel.", | |
| "Ultrasound image of the fetal abdomen at {weeks} weeks and {day} days gestational age, highlighting the development of abdominal structures, with a pixel spacing of {pixel_spacing} mm/pixel.", | |
| "Fetal ultrasound image at {weeks} weeks and {day} days gestational age, focusing on the abdomen with a pixel spacing of {pixel_spacing} mm/pixel." | |
| ], | |
| "femur": [ | |
| "Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the developing fetal femur, with a pixel spacing of {pixel_spacing} mm/pixel, highlighting bone length and structure.", | |
| "The ultrasound image highlights the fetal femur at {weeks} weeks and {day} days of gestation, with a pixel spacing of {pixel_spacing} mm/pixel, providing a detailed view of the developing bone.", | |
| "Ultrasound image at {weeks} weeks and {day} days gestation, focusing on the fetal femur, highlighting skeletal development at a pixel spacing of {pixel_spacing} mm/pixel.", | |
| "Fetal ultrasound image at {weeks} weeks and {day} days gestation, highlighting the femur with a pixel spacing of {pixel_spacing} mm/pixel, providing a detailed view of bone development.", | |
| "Ultrasound image at {weeks} weeks and {day} days gestation, highlighting the fetal femur with a pixel spacing of {pixel_spacing} mm/pixel." | |
| ] | |
| } | |
| LIST_GA_IN_DAYS = [weeks * 7 + days for weeks in range(14, 39) for days in range(0, 7)] | |
| class FetalCLIPService: | |
| _instance = None | |
| _initialized = False | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| if FetalCLIPService._initialized: | |
| return | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.preprocess = None | |
| self.tokenizer = None | |
| self.text_features = None | |
| self.list_plane = [] | |
| FetalCLIPService._initialized = True | |
| def load_model(self, assets_dir: Path): | |
| """Load the FetalCLIP model and precompute text features.""" | |
| config_path = assets_dir / "FetalCLIP_config.json" | |
| prompts_path = assets_dir / "prompt_fetal_view.json" | |
| # Load config | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| open_clip.factory._MODEL_CONFIGS["FetalCLIP"] = config | |
| # Download weights | |
| weights_path = hf_hub_download( | |
| repo_id=MODEL_NAME, | |
| filename="FetalCLIP_weights.pt" | |
| ) | |
| # Create model | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| "FetalCLIP", | |
| pretrained=weights_path | |
| ) | |
| self.tokenizer = open_clip.get_tokenizer("FetalCLIP") | |
| self.model = self.model.float() | |
| self.model.eval() | |
| self.model.to(self.device) | |
| # Load text prompts and compute features | |
| with open(prompts_path, 'r') as f: | |
| text_prompts = json.load(f) | |
| list_text_features = [] | |
| self.list_plane = [] | |
| with torch.no_grad(): | |
| for plane, prompts in text_prompts.items(): | |
| self.list_plane.append(plane) | |
| tokens = self.tokenizer(prompts).to(self.device) | |
| features = self.model.encode_text(tokens) | |
| features /= features.norm(dim=-1, keepdim=True) | |
| features = features.mean(dim=0).unsqueeze(0) | |
| features /= features.norm(dim=-1, keepdim=True) | |
| list_text_features.append(features) | |
| self.text_features = torch.stack(list_text_features)[:, 0] | |
| print(f"✓ FetalCLIP model loaded on {self.device}") | |
| return True | |
| def classify_view(self, image: Image.Image, top_k: int = 5) -> List[Dict]: | |
| """Classify fetal ultrasound view from preprocessed image.""" | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| top_k = min(top_k, len(self.list_plane)) | |
| # Apply model preprocessing (resize to 224, normalize) | |
| img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(img_tensor) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| # Compute similarity | |
| similarity = (99.2198 * image_features @ self.text_features.T).softmax(dim=-1) | |
| values, indices = similarity[0].topk(top_k) | |
| results = [] | |
| for idx, val in zip(indices, values): | |
| results.append({ | |
| "label": self.list_plane[idx], | |
| "confidence": round(val.item() * 100, 2) | |
| }) | |
| return results | |
| def classify_from_file(self, file_bytes: bytes, filename: str, top_k: int = 5) -> Tuple[List[Dict], Dict]: | |
| """ | |
| Classify from raw file bytes with automatic preprocessing. | |
| Returns: | |
| Tuple of (predictions, preprocessing_info) | |
| """ | |
| # Preprocess based on file type | |
| processed_image, preprocessing_info = preprocess_file(file_bytes, filename) | |
| # Classify | |
| predictions = self.classify_view(processed_image, top_k) | |
| return predictions, preprocessing_info | |
| def _get_ga_text_features(self, template: str, pixel_spacing: float) -> torch.Tensor: | |
| """Generate text features for GA estimation.""" | |
| prompts = [] | |
| for weeks in range(14, 39): | |
| for days in range(0, 7): | |
| prompt = template.format( | |
| weeks=weeks, | |
| day=days, | |
| pixel_spacing=f"{pixel_spacing:.2f}" | |
| ) | |
| prompts.append(prompt) | |
| with torch.no_grad(): | |
| tokens = self.tokenizer(prompts).to(self.device) | |
| features = self.model.encode_text(tokens) | |
| features /= features.norm(dim=-1, keepdim=True) | |
| return features | |
| def _get_unnormalized_dot_products(self, image_features: torch.Tensor, list_text_features: List[torch.Tensor]) -> torch.Tensor: | |
| """Compute dot products between image and text features.""" | |
| text_features = torch.cat(list_text_features, dim=0) | |
| text_dot_prods = (100.0 * image_features @ text_features.T) | |
| n_prompts = len(list_text_features) | |
| n_days = len(list_text_features[0]) | |
| text_dot_prods = text_dot_prods.view(image_features.shape[0], n_prompts, n_days) | |
| text_dot_prods = text_dot_prods.mean(dim=1) | |
| return text_dot_prods | |
| def _find_median_from_top_n(self, text_dot_prods: np.ndarray, n: int) -> int: | |
| """Find median index from top N predictions.""" | |
| tmp = [[i, t] for i, t in enumerate(text_dot_prods)] | |
| tmp = sorted(tmp, key=lambda x: x[1], reverse=True)[:n] | |
| tmp = sorted(tmp, key=lambda x: x[0]) | |
| return tmp[n // 2][0] | |
| def _get_biometry_from_ga(self, ga_days: int, biometry_type: str, percentile: str = '0.5') -> float: | |
| """ | |
| Calculate expected fetal biometry from gestational age using WHO coefficients. | |
| Formula: measurement = exp(b0 + b1*GA + b2*GA² + b3*GA³ + b4*GA⁴) | |
| where GA is in weeks. | |
| Args: | |
| ga_days: Gestational age in days | |
| biometry_type: 'HC', 'AC', or 'FL' | |
| percentile: '0.025', '0.5', or '0.975' | |
| Returns: | |
| Expected measurement in mm | |
| """ | |
| ga_weeks = ga_days / 7 | |
| # WHO Fetal Growth Coefficients (from coefficientsGlobalV3.csv) | |
| WHO_COEFFICIENTS = { | |
| # Head Circumference (mm) | |
| 'HC': { | |
| '0.025': [1.59317517131532e+0, 2.9459800552433e-1, -7.3860372566707e-3, 6.56951770216148e-5, 0e+0], | |
| '0.5': [2.09924879247164e+0, 2.53373656106037e-1, -6.05647816678282e-3, 5.14256072059917e-5, 0e+0], | |
| '0.975': [2.50074069629423e+0, 2.20067854715719e-1, -4.93623111462443e-3, 3.89066000946519e-5, 0e+0], | |
| }, | |
| # Abdominal Circumference (mm) | |
| 'AC': { | |
| '0.025': [1.19202778944614e+0, 3.14756681991964e-1, -8.01581308902169e-3, 7.51395976546808e-5, 0e+0], | |
| '0.5': [1.58552931028045e+0, 2.89936781915424e-1, -7.32651929135797e-3, 6.9261631643994e-5, 0e+0], | |
| '0.975': [2.03674472691951e+0, 2.57138461817474e-1, -6.34918788914223e-3, 6.0053745113196e-5, 0e+0], | |
| }, | |
| # Femur Length (mm) - uses all 5 coefficients | |
| 'FL': { | |
| '0.025': [-7.27187176976836e+0, 1.28298928826162e+0, -5.80601892487905e-2, 1.21314319801879e-3, -9.60171505470123e-6], | |
| '0.5': [-5.54922620776446e+0, 1.09559990166124e+0, -5.01310925949098e-2, 1.0678072569586e-3, -8.63970606288493e-6], | |
| '0.975': [-3.64483930811801e+0, 8.57028131514986e-1, -3.84005685481303e-2, 8.12062784461527e-4, -6.55932416998498e-6], | |
| }, | |
| } | |
| if biometry_type not in WHO_COEFFICIENTS: | |
| raise ValueError(f"Unknown biometry type: {biometry_type}") | |
| if percentile not in WHO_COEFFICIENTS[biometry_type]: | |
| raise ValueError(f"Unknown percentile: {percentile}") | |
| b0, b1, b2, b3, b4 = WHO_COEFFICIENTS[biometry_type][percentile] | |
| return np.exp(b0 + b1*ga_weeks + b2*ga_weeks**2 + b3*ga_weeks**3 + b4*ga_weeks**4) | |
| def estimate_gestational_age(self, image: Image.Image, pixel_size: float, view: str = "brain") -> Dict: | |
| """Estimate gestational age from preprocessed fetal ultrasound.""" | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| # Calculate effective pixel spacing | |
| pixel_spacing = max(image.size) / INPUT_SIZE * pixel_size | |
| # Apply model preprocessing | |
| img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(img_tensor) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| # Get text features for all prompts for the specified view | |
| view_prompts = GA_TEXT_PROMPTS.get(view, GA_TEXT_PROMPTS["brain"]) | |
| text_features_list = [ | |
| self._get_ga_text_features(template, pixel_spacing) | |
| for template in view_prompts | |
| ] | |
| text_dot_prods = self._get_unnormalized_dot_products(image_features, text_features_list) | |
| # Compute prediction | |
| text_dot_prod = text_dot_prods.detach().cpu().numpy()[0] | |
| med_idx = self._find_median_from_top_n(text_dot_prod, TOP_N_PROBS) | |
| pred_day = LIST_GA_IN_DAYS[med_idx] | |
| pred_weeks = pred_day // 7 | |
| pred_days = pred_day % 7 | |
| # Map view to biometry type | |
| VIEW_TO_BIOMETRY = { | |
| "brain": "HC", | |
| "abdomen": "AC", | |
| "femur": "FL" | |
| } | |
| biometry_type = VIEW_TO_BIOMETRY.get(view, "HC") | |
| # Compute view-specific biometry percentiles using WHO formulas | |
| q025 = self._get_biometry_from_ga(pred_day, biometry_type, '0.025') | |
| q500 = self._get_biometry_from_ga(pred_day, biometry_type, '0.5') | |
| q975 = self._get_biometry_from_ga(pred_day, biometry_type, '0.975') | |
| # Biometry labels for response | |
| BIOMETRY_LABELS = { | |
| "HC": "head_circumference", | |
| "AC": "abdominal_circumference", | |
| "FL": "femur_length" | |
| } | |
| biometry_key = BIOMETRY_LABELS.get(biometry_type, "head_circumference") | |
| # Biometry units | |
| BIOMETRY_UNITS = { | |
| "HC": "mm", | |
| "AC": "mm", | |
| "FL": "mm" | |
| } | |
| return { | |
| "gestational_age": { | |
| "weeks": pred_weeks, | |
| "days": pred_days, | |
| "total_days": pred_day | |
| }, | |
| "view": view, | |
| biometry_key: { | |
| "p2_5": round(q025, 2), | |
| "p50": round(q500, 2), | |
| "p97_5": round(q975, 2) | |
| } | |
| } | |
| def estimate_ga_from_file(self, file_bytes: bytes, filename: str, pixel_size: float, view: str = "brain") -> Tuple[Dict, Dict]: | |
| """ | |
| Estimate GA from raw file bytes with automatic preprocessing. | |
| Returns: | |
| Tuple of (ga_results, preprocessing_info) | |
| """ | |
| # Preprocess based on file type | |
| processed_image, preprocessing_info = preprocess_file(file_bytes, filename) | |
| # Use pixel spacing from DICOM if available | |
| if preprocessing_info["type"] == "dicom": | |
| pixel_size = preprocessing_info["metadata"].get("pixel_spacing", pixel_size) | |
| # Estimate GA with the specified view | |
| ga_results = self.estimate_gestational_age(processed_image, pixel_size, view) | |
| return ga_results, preprocessing_info | |
| # Singleton instance | |
| model_service = FetalCLIPService() | |