| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import numpy as np |
| import cv2 |
| import requests |
| import pickle |
| from tensorflow.keras.models import load_model, Model |
|
|
| app = FastAPI( |
| title="Embryo Quality Classifier & Ranker API", |
| description="Classify and rank multiple embryos by viability score", |
| version="2.0.0" |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| print("Loading models...") |
| full_model = load_model("efficientnet_embryo_model.h5") |
| efficientnet_feature_extractor = Model( |
| inputs=full_model.input, |
| outputs=full_model.layers[-3].output, |
| ) |
| fusion_model = load_model("dual_branch_embryo_model.keras") |
| with open("morph_scaler.pkl", "rb") as f: |
| scaler = pickle.load(f) |
| print("All models loaded!") |
|
|
|
|
| |
| def download_image(url: str): |
| try: |
| resp = requests.get(url, timeout=10) |
| resp.raise_for_status() |
| arr = np.frombuffer(resp.content, np.uint8) |
| return cv2.imdecode(arr, cv2.IMREAD_COLOR) |
| except Exception as e: |
| print(f"Download error: {e}") |
| return None |
|
|
|
|
| def extract_efficientnet_features(img: np.ndarray) -> np.ndarray: |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img_resized = cv2.resize(img_rgb, (224, 224)) / 255.0 |
| features = efficientnet_feature_extractor.predict( |
| np.expand_dims(img_resized, axis=0), verbose=0 |
| ) |
| return features.flatten() |
|
|
|
|
| def extract_morphological_features(img: np.ndarray) -> np.ndarray: |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| blur = cv2.GaussianBlur(gray, (5, 5), 0) |
| _, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| centroids = [] |
| for cnt in contours: |
| M = cv2.moments(cnt) |
| if M["m00"] != 0: |
| centroids.append((int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"]))) |
|
|
| symmetry_score = 0.0 |
| if len(centroids) > 1: |
| distances = [ |
| np.linalg.norm(np.array(centroids[i]) - np.array(centroids[j])) |
| for i in range(len(centroids)) |
| for j in range(i + 1, len(centroids)) |
| ] |
| symmetry_score = float(np.mean(distances)) |
|
|
| embryo_area = int(np.sum(thresh == 255)) |
| fragmented_area = sum(cv2.contourArea(c) for c in contours if cv2.contourArea(c) < 500) |
| fragmentation_ratio = fragmented_area / embryo_area if embryo_area > 0 else 0.0 |
|
|
| return np.array([symmetry_score, fragmentation_ratio]) |
|
|
|
|
| def analyze_single_image(img: np.ndarray) -> dict: |
| """Run full pipeline on one image, return raw scores.""" |
| deep_features = extract_efficientnet_features(img) |
| morph_raw = extract_morphological_features(img) |
| morph_scaled = scaler.transform([morph_raw])[0] |
|
|
| combined = np.expand_dims(np.concatenate([deep_features, morph_scaled]), axis=0) |
| prediction = fusion_model.predict(combined, verbose=0)[0] |
|
|
| class_id = int(np.argmax(prediction)) |
| good_prob = float(prediction[1]) |
| poor_prob = float(prediction[0]) |
|
|
| return { |
| "class_id": class_id, |
| "label": "Good Quality Embryo" if class_id == 1 else "Poor Quality Embryo", |
| "confidence": round(float(np.max(prediction)), 4), |
| "viability_score_percent": round(good_prob * 100, 2), |
| "good_probability": round(good_prob, 4), |
| "poor_probability": round(poor_prob, 4), |
| "symmetry_score": round(float(morph_raw[0]), 4), |
| "fragmentation_ratio": round(float(morph_raw[1]), 6), |
| } |
|
|
|
|
| |
| class SingleRequest(BaseModel): |
| image_url: str |
|
|
| class RankRequest(BaseModel): |
| embryos: list[dict] |
|
|
| class EmbryoResult(BaseModel): |
| rank: int |
| id: str |
| label: str |
| viability_score_percent: float |
| confidence: float |
| good_probability: float |
| poor_probability: float |
| symmetry_score: float |
| fragmentation_ratio: float |
| recommendation: str |
|
|
| class RankResponse(BaseModel): |
| total_analyzed: int |
| best_embryo_id: str |
| ranked_embryos: list[EmbryoResult] |
|
|
|
|
| |
| @app.get("/") |
| def root(): |
| return { |
| "message": "Embryo Quality Classifier & Ranker", |
| "endpoints": { |
| "POST /predict": "Analyze a single embryo image", |
| "POST /rank": "Rank multiple embryos by viability score", |
| }, |
| "docs": "/docs" |
| } |
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/predict") |
| def predict_single(request: SingleRequest): |
| """Analyze one embryo image from a URL.""" |
| img = download_image(request.image_url) |
| if img is None: |
| raise HTTPException(status_code=400, detail="Could not download image.") |
| return analyze_single_image(img) |
|
|
|
|
| @app.post("/rank", response_model=RankResponse) |
| def rank_embryos(request: RankRequest): |
| """ |
| Rank multiple embryos from a list of image URLs. |
| |
| Request body example: |
| { |
| "embryos": [ |
| {"id": "E1", "image_url": "https://..."}, |
| {"id": "E2", "image_url": "https://..."}, |
| {"id": "E3", "image_url": "https://..."} |
| ] |
| } |
| |
| Returns all embryos ranked #1 (best viability) to #N (worst). |
| Each embryo gets a viability_score_percent, label, and transfer recommendation. |
| """ |
| if len(request.embryos) < 1: |
| raise HTTPException(status_code=400, detail="Provide at least 1 embryo.") |
| if len(request.embryos) > 20: |
| raise HTTPException(status_code=400, detail="Maximum 20 embryos per request.") |
|
|
| results = [] |
| for i, embryo in enumerate(request.embryos): |
| embryo_id = embryo.get("id") or f"Embryo_{i+1}" |
| image_url = embryo.get("image_url") |
|
|
| if not image_url: |
| raise HTTPException(status_code=400, detail=f"Missing image_url for '{embryo_id}'.") |
|
|
| img = download_image(image_url) |
| if img is None: |
| raise HTTPException(status_code=400, detail=f"Could not download image for '{embryo_id}'.") |
|
|
| analysis = analyze_single_image(img) |
| results.append({"id": embryo_id, **analysis}) |
|
|
| |
| results.sort(key=lambda x: x["viability_score_percent"], reverse=True) |
|
|
| ranked = [] |
| for rank_pos, r in enumerate(results, start=1): |
| score = r["viability_score_percent"] |
|
|
| if score >= 80: |
| rec = "Highly recommended for transfer" |
| elif score >= 60: |
| rec = "Suitable for transfer" |
| elif score >= 40: |
| rec = "Marginal quality β use only if no better option available" |
| else: |
| rec = "Poor quality β not recommended for transfer" |
|
|
| ranked.append(EmbryoResult( |
| rank=rank_pos, |
| id=r["id"], |
| label=r["label"], |
| viability_score_percent=r["viability_score_percent"], |
| confidence=r["confidence"], |
| good_probability=r["good_probability"], |
| poor_probability=r["poor_probability"], |
| symmetry_score=r["symmetry_score"], |
| fragmentation_ratio=r["fragmentation_ratio"], |
| recommendation=rec, |
| )) |
|
|
| return RankResponse( |
| total_analyzed=len(ranked), |
| best_embryo_id=ranked[0].id, |
| ranked_embryos=ranked, |
| ) |