--- license: mit tags: - medical-imaging - pcos-detection - explainable-ai - grad-cam - ultrasound - image Classification - pcos language: en metrics: - accuracy library_name: tensorflow --- # PCOS Detection with Explainable AI A deep learning model for **Polycystic Ovary Syndrome (PCOS)** detection from ultrasound images with **Grad-CAM** visualization for clinical interpretability. ## Model Overview - **Architecture**: Dual-path CNN with multi-head attention - **Input**: 224×224 RGB ultrasound images - **Output**: Binary classification (PCOS-positive / Healthy) - **Accuracy**: ~95%+ on test set - **XAI**: Grad-CAM heatmaps for interpretability ## 🚀 Quick Start ```bash pip install tensorflow opencv-python matplotlib numpy requests huggingface-hub ``` ### Complete Working Example ```python # ============================================================ # 🔍 PCOS Prediction + Grad-CAM (HF VERSION) # ============================================================ import numpy as np import cv2 import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.keras import Model, Input from tensorflow.keras.layers import ( Conv2D, MaxPooling2D, Flatten, Dense, Lambda, Reshape, Concatenate, MultiHeadAttention, GlobalAveragePooling1D ) import requests from huggingface_hub import hf_hub_download # ============================================================ # Config # ============================================================ IMG_SIZE = (224, 224) HF_MODEL_REPO = "Dehsahk-AI/Pcos-Detect" MODEL_FILENAME = "best_pcos_model.h5" IMAGE_URL = "https://example.com/ultrasound.jpg" # Your image URL CLASS_NAMES = ["infected", "noninfected"] # ============================================================ # Download model from HF # ============================================================ MODEL_PATH = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME) print(f" Model downloaded to: {MODEL_PATH}") # ============================================================ # Custom Lambda Functions # ============================================================ def split_image(image): upper = image[:, :IMG_SIZE[0]//2, :, :] lower = image[:, IMG_SIZE[0]//2:, :, :] return upper, lower def flip_lower(lower_half): return tf.image.flip_left_right(lower_half) # ============================================================ # Rebuild Model Architecture # ============================================================ input_layer = Input(shape=(224,224,3)) upper_half, lower_half = Lambda(split_image)(input_layer) lower_half = Lambda(flip_lower)(lower_half) # Upper CNN u = Conv2D(32, 3, activation="relu", padding="same")(upper_half) u = MaxPooling2D(2)(u) u = Conv2D(64, 3, activation="relu", padding="same")(u) u = MaxPooling2D(2)(u) u = Conv2D(128, 3, activation="relu", padding="same", name="upper_last_conv")(u) u = MaxPooling2D(2)(u) u = Flatten()(u) # Lower CNN l = Conv2D(32, 3, activation="relu", padding="same")(lower_half) l = MaxPooling2D(2)(l) l = Conv2D(64, 3, activation="relu", padding="same")(l) l = MaxPooling2D(2)(l) l = Conv2D(128, 3, activation="relu", padding="same", name="lower_last_conv")(l) l = MaxPooling2D(2)(l) l = Flatten()(l) u_dense = Dense(512, activation="relu")(u) l_dense = Dense(512, activation="relu")(l) u_r = Reshape((1,512))(u_dense) l_r = Reshape((1,512))(l_dense) concat = Concatenate(axis=1)([u_r, l_r]) att = MultiHeadAttention(num_heads=4, key_dim=64)(concat, concat) att = GlobalAveragePooling1D()(att) fc = Dense(256, activation="relu")(att) fc = Dense(128, activation="relu")(fc) # Logits for Grad-CAM logits = Dense(2, name="logits")(fc) output = tf.keras.layers.Activation('softmax', name='softmax')(logits) model = Model(input_layer, output) model.load_weights(MODEL_PATH) print(" Weights loaded successfully") # ============================================================ # Load & Preprocess Image # ============================================================ response = requests.get(IMAGE_URL) img_array_raw = np.asarray(bytearray(response.content), dtype=np.uint8) img = cv2.imdecode(img_array_raw, cv2.IMREAD_COLOR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, IMG_SIZE) img = img.astype(np.float32) / 255.0 img_array = np.expand_dims(img, axis=0) # ============================================================ # Prediction # ============================================================ pred = model.predict(img_array, verbose=0)[0] pred_class = np.argmax(pred) confidence = pred[pred_class] print(f"\n Prediction: {CLASS_NAMES[pred_class]}") print(f" Confidence: {confidence:.2%}") # ============================================================ # Grad-CAM # ============================================================ def gradcam(img_array, model, layer_name, pred_index): logits_layer = model.get_layer('logits') grad_model = Model( model.input, [model.get_layer(layer_name).output, logits_layer.output] ) with tf.GradientTape() as tape: conv_out, logits = grad_model(img_array) loss = logits[:, pred_index] grads = tape.gradient(loss, conv_out) pooled = tf.reduce_mean(grads, axis=(0,1,2)) conv_out = conv_out[0] heatmap = conv_out @ pooled[..., tf.newaxis] heatmap = tf.squeeze(heatmap) heatmap = tf.maximum(heatmap, 0) if tf.reduce_max(heatmap) > 0: heatmap /= tf.reduce_max(heatmap) return heatmap.numpy() upper = gradcam(img_array, model, "upper_last_conv", pred_class) lower = gradcam(img_array, model, "lower_last_conv", pred_class) h = IMG_SIZE[0] // 2 upper = cv2.resize(upper, (IMG_SIZE[1], h)) lower = cv2.resize(lower, (IMG_SIZE[1], h)) lower = cv2.flip(lower, 1) heatmap = np.vstack([upper, lower]) heatmap_color = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET) heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) / 255.0 overlay = 0.5 * heatmap_color + 0.5 * img # ============================================================ # Visualization # ============================================================ plt.figure(figsize=(15,5)) plt.subplot(1,3,1) plt.imshow(img) plt.title("Original") plt.axis("off") plt.subplot(1,3,2) plt.imshow(heatmap, cmap="jet") plt.title("Grad-CAM") plt.axis("off") plt.subplot(1,3,3) plt.imshow(overlay) plt.title(f"{CLASS_NAMES[pred_class]} ({confidence:.2%})") plt.axis("off") plt.tight_layout() plt.show() ``` ### Load from Local File ```python # Replace URL loading with: img = cv2.imread('path/to/ultrasound.jpg') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, IMG_SIZE) img = img.astype(np.float32) / 255.0 img_array = np.expand_dims(img, axis=0) ``` ## Understanding Grad-CAM Output - **Red/Hot regions**: High importance for prediction (follicles, cysts) - **Blue/Cool regions**: Low influence on decision - **Dual visualization**: Separate heatmaps for upper and lower ovarian regions ## Model Architecture ``` Input (224×224×3) ├── Split horizontally (upper/lower) ├── Upper Path: Conv32 → Conv64 → Conv128 → Dense512 ├── Lower Path: Conv32 → Conv64 → Conv128 → Dense512 ├── Multi-Head Attention (4 heads, dim=64) └── Classification: Dense256 → Dense128 → Dense2 ``` **Key Features:** - Dual-path CNN for separate ovarian region analysis - Lower region flipped for symmetry normalization - Multi-head attention for feature fusion - Logits-based Grad-CAM (fixes saturated softmax gradients) ## Dataset - **Total**: 11,784 ultrasound images - **PCOS-positive**: 6,784 images (57.5%) - **Healthy**: 5,000 images (42.5%) - **Source**: 3 clinics (2018-2022), expert-annotated - **Dataset**: [PCOS XAI Ultrasound](https://www.kaggle.com/datasets/ibadeus/pcos-xai-ultrasound-dataset) ## Important Notes **Clinical Use:** - Research purposes only - NOT FDA approved - Not a diagnostic tool - requires professional validation - Must be validated on local datasets before clinical deployment **Technical:** - Fixed 224×224 input size required - RGB images only - Model performance may vary across different ultrasound machines ## Citation ```bibtex @misc{pcos_xai_2024, title={PCOS Detection with Explainable AI}, author={Dehsahk-AI}, year={2025}, url={https://huggingface.co/Dehsahk-AI/Pcos-Detect} } ``` ## License MIT License - See LICENSE file for details. --- **Model Version**: 1.0 | **Last Updated**: December 2025 license: MIT ---