Pcos-Detect / README.md
Dehsahk-AI's picture
Update README.md
c1d3249 verified
---
license: mit
tags:
- medical-imaging
- pcos-detection
- explainable-ai
- grad-cam
- ultrasound
- image Classification
- pcos
language: en
metrics:
- accuracy
library_name: tensorflow
---
# PCOS Detection with Explainable AI
A deep learning model for **Polycystic Ovary Syndrome (PCOS)** detection from ultrasound images with **Grad-CAM** visualization for clinical interpretability.
## Model Overview
- **Architecture**: Dual-path CNN with multi-head attention
- **Input**: 224Γ—224 RGB ultrasound images
- **Output**: Binary classification (PCOS-positive / Healthy)
- **Accuracy**: ~95%+ on test set
- **XAI**: Grad-CAM heatmaps for interpretability
## πŸš€ Quick Start
```bash
pip install tensorflow opencv-python matplotlib numpy requests huggingface-hub
```
### Complete Working Example
```python
# ============================================================
# πŸ” PCOS Prediction + Grad-CAM (HF VERSION)
# ============================================================
import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import (
Conv2D, MaxPooling2D, Flatten, Dense,
Lambda, Reshape, Concatenate,
MultiHeadAttention, GlobalAveragePooling1D
)
import requests
from huggingface_hub import hf_hub_download
# ============================================================
# Config
# ============================================================
IMG_SIZE = (224, 224)
HF_MODEL_REPO = "Dehsahk-AI/Pcos-Detect"
MODEL_FILENAME = "best_pcos_model.h5"
IMAGE_URL = "https://example.com/ultrasound.jpg" # Your image URL
CLASS_NAMES = ["infected", "noninfected"]
# ============================================================
# Download model from HF
# ============================================================
MODEL_PATH = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME)
print(f" Model downloaded to: {MODEL_PATH}")
# ============================================================
# Custom Lambda Functions
# ============================================================
def split_image(image):
upper = image[:, :IMG_SIZE[0]//2, :, :]
lower = image[:, IMG_SIZE[0]//2:, :, :]
return upper, lower
def flip_lower(lower_half):
return tf.image.flip_left_right(lower_half)
# ============================================================
# Rebuild Model Architecture
# ============================================================
input_layer = Input(shape=(224,224,3))
upper_half, lower_half = Lambda(split_image)(input_layer)
lower_half = Lambda(flip_lower)(lower_half)
# Upper CNN
u = Conv2D(32, 3, activation="relu", padding="same")(upper_half)
u = MaxPooling2D(2)(u)
u = Conv2D(64, 3, activation="relu", padding="same")(u)
u = MaxPooling2D(2)(u)
u = Conv2D(128, 3, activation="relu", padding="same", name="upper_last_conv")(u)
u = MaxPooling2D(2)(u)
u = Flatten()(u)
# Lower CNN
l = Conv2D(32, 3, activation="relu", padding="same")(lower_half)
l = MaxPooling2D(2)(l)
l = Conv2D(64, 3, activation="relu", padding="same")(l)
l = MaxPooling2D(2)(l)
l = Conv2D(128, 3, activation="relu", padding="same", name="lower_last_conv")(l)
l = MaxPooling2D(2)(l)
l = Flatten()(l)
u_dense = Dense(512, activation="relu")(u)
l_dense = Dense(512, activation="relu")(l)
u_r = Reshape((1,512))(u_dense)
l_r = Reshape((1,512))(l_dense)
concat = Concatenate(axis=1)([u_r, l_r])
att = MultiHeadAttention(num_heads=4, key_dim=64)(concat, concat)
att = GlobalAveragePooling1D()(att)
fc = Dense(256, activation="relu")(att)
fc = Dense(128, activation="relu")(fc)
# Logits for Grad-CAM
logits = Dense(2, name="logits")(fc)
output = tf.keras.layers.Activation('softmax', name='softmax')(logits)
model = Model(input_layer, output)
model.load_weights(MODEL_PATH)
print(" Weights loaded successfully")
# ============================================================
# Load & Preprocess Image
# ============================================================
response = requests.get(IMAGE_URL)
img_array_raw = np.asarray(bytearray(response.content), dtype=np.uint8)
img = cv2.imdecode(img_array_raw, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, IMG_SIZE)
img = img.astype(np.float32) / 255.0
img_array = np.expand_dims(img, axis=0)
# ============================================================
# Prediction
# ============================================================
pred = model.predict(img_array, verbose=0)[0]
pred_class = np.argmax(pred)
confidence = pred[pred_class]
print(f"\n Prediction: {CLASS_NAMES[pred_class]}")
print(f" Confidence: {confidence:.2%}")
# ============================================================
# Grad-CAM
# ============================================================
def gradcam(img_array, model, layer_name, pred_index):
logits_layer = model.get_layer('logits')
grad_model = Model(
model.input,
[model.get_layer(layer_name).output, logits_layer.output]
)
with tf.GradientTape() as tape:
conv_out, logits = grad_model(img_array)
loss = logits[:, pred_index]
grads = tape.gradient(loss, conv_out)
pooled = tf.reduce_mean(grads, axis=(0,1,2))
conv_out = conv_out[0]
heatmap = conv_out @ pooled[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0)
if tf.reduce_max(heatmap) > 0:
heatmap /= tf.reduce_max(heatmap)
return heatmap.numpy()
upper = gradcam(img_array, model, "upper_last_conv", pred_class)
lower = gradcam(img_array, model, "lower_last_conv", pred_class)
h = IMG_SIZE[0] // 2
upper = cv2.resize(upper, (IMG_SIZE[1], h))
lower = cv2.resize(lower, (IMG_SIZE[1], h))
lower = cv2.flip(lower, 1)
heatmap = np.vstack([upper, lower])
heatmap_color = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) / 255.0
overlay = 0.5 * heatmap_color + 0.5 * img
# ============================================================
# Visualization
# ============================================================
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.imshow(img)
plt.title("Original")
plt.axis("off")
plt.subplot(1,3,2)
plt.imshow(heatmap, cmap="jet")
plt.title("Grad-CAM")
plt.axis("off")
plt.subplot(1,3,3)
plt.imshow(overlay)
plt.title(f"{CLASS_NAMES[pred_class]} ({confidence:.2%})")
plt.axis("off")
plt.tight_layout()
plt.show()
```
### Load from Local File
```python
# Replace URL loading with:
img = cv2.imread('path/to/ultrasound.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, IMG_SIZE)
img = img.astype(np.float32) / 255.0
img_array = np.expand_dims(img, axis=0)
```
## Understanding Grad-CAM Output
- **Red/Hot regions**: High importance for prediction (follicles, cysts)
- **Blue/Cool regions**: Low influence on decision
- **Dual visualization**: Separate heatmaps for upper and lower ovarian regions
## Model Architecture
```
Input (224Γ—224Γ—3)
β”œβ”€β”€ Split horizontally (upper/lower)
β”œβ”€β”€ Upper Path: Conv32 β†’ Conv64 β†’ Conv128 β†’ Dense512
β”œβ”€β”€ Lower Path: Conv32 β†’ Conv64 β†’ Conv128 β†’ Dense512
β”œβ”€β”€ Multi-Head Attention (4 heads, dim=64)
└── Classification: Dense256 β†’ Dense128 β†’ Dense2
```
**Key Features:**
- Dual-path CNN for separate ovarian region analysis
- Lower region flipped for symmetry normalization
- Multi-head attention for feature fusion
- Logits-based Grad-CAM (fixes saturated softmax gradients)
## Dataset
- **Total**: 11,784 ultrasound images
- **PCOS-positive**: 6,784 images (57.5%)
- **Healthy**: 5,000 images (42.5%)
- **Source**: 3 clinics (2018-2022), expert-annotated
- **Dataset**: [PCOS XAI Ultrasound](https://www.kaggle.com/datasets/ibadeus/pcos-xai-ultrasound-dataset)
## Important Notes
**Clinical Use:**
- Research purposes only - NOT FDA approved
- Not a diagnostic tool - requires professional validation
- Must be validated on local datasets before clinical deployment
**Technical:**
- Fixed 224Γ—224 input size required
- RGB images only
- Model performance may vary across different ultrasound machines
## Citation
```bibtex
@misc{pcos_xai_2024,
title={PCOS Detection with Explainable AI},
author={Dehsahk-AI},
year={2025},
url={https://huggingface.co/Dehsahk-AI/Pcos-Detect}
}
```
## License
MIT License - See LICENSE file for details.
---
**Model Version**: 1.0 | **Last Updated**: December 2025
license: MIT
---