Dehsahk-AI commited on
Commit
fbd9eea
Β·
verified Β·
1 Parent(s): cdc7620

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +295 -3
README.md CHANGED
@@ -1,3 +1,295 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ ---
3
+ license: mit
4
+ tags:
5
+ - medical-imaging
6
+ - pcos-detection
7
+ - explainable-ai
8
+ - grad-cam
9
+ - ultrasound
10
+ - tensorflow
11
+ language: en
12
+ metrics:
13
+ - accuracy
14
+ library_name: tensorflow
15
+ ---
16
+
17
+ # πŸ₯ PCOS Detection with Explainable AI
18
+
19
+ A deep learning model for **Polycystic Ovary Syndrome (PCOS)** detection from ultrasound images with **Grad-CAM** visualization for clinical interpretability.
20
+
21
+ ## 🎯 Model Overview
22
+
23
+ - **Architecture**: Dual-path CNN with multi-head attention
24
+ - **Input**: 224Γ—224 RGB ultrasound images
25
+ - **Output**: Binary classification (PCOS-positive / Healthy)
26
+ - **Accuracy**: ~95%+ on test set
27
+ - **XAI**: Grad-CAM heatmaps for interpretability
28
+
29
+ ## πŸš€ Quick Start
30
+
31
+ ```bash
32
+ pip install tensorflow opencv-python matplotlib numpy requests huggingface-hub
33
+ ```
34
+
35
+ ### Complete Working Example
36
+
37
+ ```python
38
+ # ============================================================
39
+ # πŸ” PCOS Prediction + Grad-CAM (HF VERSION)
40
+ # ============================================================
41
+
42
+ import numpy as np
43
+ import cv2
44
+ import tensorflow as tf
45
+ import matplotlib.pyplot as plt
46
+ from tensorflow.keras import Model, Input
47
+ from tensorflow.keras.layers import (
48
+ Conv2D, MaxPooling2D, Flatten, Dense,
49
+ Lambda, Reshape, Concatenate,
50
+ MultiHeadAttention, GlobalAveragePooling1D
51
+ )
52
+ import requests
53
+ from huggingface_hub import hf_hub_download
54
+
55
+ # ============================================================
56
+ # Config
57
+ # ============================================================
58
+ IMG_SIZE = (224, 224)
59
+ HF_MODEL_REPO = "Dehsahk-AI/Pcos-Detect"
60
+ MODEL_FILENAME = "best_pcos_model.h5"
61
+ IMAGE_URL = "https://example.com/ultrasound.jpg" # Your image URL
62
+ CLASS_NAMES = ["infected", "noninfected"]
63
+
64
+ # ============================================================
65
+ # Download model from HF
66
+ # ============================================================
67
+ MODEL_PATH = hf_hub_download(repo_id=HF_MODEL_REPO, filename=MODEL_FILENAME)
68
+ print(f"βœ… Model downloaded to: {MODEL_PATH}")
69
+
70
+ # ============================================================
71
+ # Custom Lambda Functions
72
+ # ============================================================
73
+ def split_image(image):
74
+ upper = image[:, :IMG_SIZE[0]//2, :, :]
75
+ lower = image[:, IMG_SIZE[0]//2:, :, :]
76
+ return upper, lower
77
+
78
+ def flip_lower(lower_half):
79
+ return tf.image.flip_left_right(lower_half)
80
+
81
+ # ============================================================
82
+ # Rebuild Model Architecture
83
+ # ============================================================
84
+ input_layer = Input(shape=(224,224,3))
85
+
86
+ upper_half, lower_half = Lambda(split_image)(input_layer)
87
+ lower_half = Lambda(flip_lower)(lower_half)
88
+
89
+ # Upper CNN
90
+ u = Conv2D(32, 3, activation="relu", padding="same")(upper_half)
91
+ u = MaxPooling2D(2)(u)
92
+ u = Conv2D(64, 3, activation="relu", padding="same")(u)
93
+ u = MaxPooling2D(2)(u)
94
+ u = Conv2D(128, 3, activation="relu", padding="same", name="upper_last_conv")(u)
95
+ u = MaxPooling2D(2)(u)
96
+ u = Flatten()(u)
97
+
98
+ # Lower CNN
99
+ l = Conv2D(32, 3, activation="relu", padding="same")(lower_half)
100
+ l = MaxPooling2D(2)(l)
101
+ l = Conv2D(64, 3, activation="relu", padding="same")(l)
102
+ l = MaxPooling2D(2)(l)
103
+ l = Conv2D(128, 3, activation="relu", padding="same", name="lower_last_conv")(l)
104
+ l = MaxPooling2D(2)(l)
105
+ l = Flatten()(l)
106
+
107
+ u_dense = Dense(512, activation="relu")(u)
108
+ l_dense = Dense(512, activation="relu")(l)
109
+
110
+ u_r = Reshape((1,512))(u_dense)
111
+ l_r = Reshape((1,512))(l_dense)
112
+
113
+ concat = Concatenate(axis=1)([u_r, l_r])
114
+
115
+ att = MultiHeadAttention(num_heads=4, key_dim=64)(concat, concat)
116
+ att = GlobalAveragePooling1D()(att)
117
+
118
+ fc = Dense(256, activation="relu")(att)
119
+ fc = Dense(128, activation="relu")(fc)
120
+
121
+ # Logits for Grad-CAM
122
+ logits = Dense(2, name="logits")(fc)
123
+ output = tf.keras.layers.Activation('softmax', name='softmax')(logits)
124
+
125
+ model = Model(input_layer, output)
126
+ model.load_weights(MODEL_PATH)
127
+ print("βœ… Weights loaded successfully")
128
+
129
+ # ============================================================
130
+ # Load & Preprocess Image
131
+ # ============================================================
132
+ response = requests.get(IMAGE_URL)
133
+ img_array_raw = np.asarray(bytearray(response.content), dtype=np.uint8)
134
+ img = cv2.imdecode(img_array_raw, cv2.IMREAD_COLOR)
135
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
136
+ img = cv2.resize(img, IMG_SIZE)
137
+ img = img.astype(np.float32) / 255.0
138
+ img_array = np.expand_dims(img, axis=0)
139
+
140
+ # ============================================================
141
+ # Prediction
142
+ # ============================================================
143
+ pred = model.predict(img_array, verbose=0)[0]
144
+ pred_class = np.argmax(pred)
145
+ confidence = pred[pred_class]
146
+
147
+ print(f"\nπŸ” Prediction: {CLASS_NAMES[pred_class]}")
148
+ print(f"πŸ“Š Confidence: {confidence:.2%}")
149
+
150
+ # ============================================================
151
+ # Grad-CAM
152
+ # ============================================================
153
+ def gradcam(img_array, model, layer_name, pred_index):
154
+ logits_layer = model.get_layer('logits')
155
+ grad_model = Model(
156
+ model.input,
157
+ [model.get_layer(layer_name).output, logits_layer.output]
158
+ )
159
+
160
+ with tf.GradientTape() as tape:
161
+ conv_out, logits = grad_model(img_array)
162
+ loss = logits[:, pred_index]
163
+
164
+ grads = tape.gradient(loss, conv_out)
165
+ pooled = tf.reduce_mean(grads, axis=(0,1,2))
166
+ conv_out = conv_out[0]
167
+
168
+ heatmap = conv_out @ pooled[..., tf.newaxis]
169
+ heatmap = tf.squeeze(heatmap)
170
+ heatmap = tf.maximum(heatmap, 0)
171
+
172
+ if tf.reduce_max(heatmap) > 0:
173
+ heatmap /= tf.reduce_max(heatmap)
174
+
175
+ return heatmap.numpy()
176
+
177
+ upper = gradcam(img_array, model, "upper_last_conv", pred_class)
178
+ lower = gradcam(img_array, model, "lower_last_conv", pred_class)
179
+
180
+ h = IMG_SIZE[0] // 2
181
+ upper = cv2.resize(upper, (IMG_SIZE[1], h))
182
+ lower = cv2.resize(lower, (IMG_SIZE[1], h))
183
+ lower = cv2.flip(lower, 1)
184
+
185
+ heatmap = np.vstack([upper, lower])
186
+
187
+ heatmap_color = cv2.applyColorMap(np.uint8(255*heatmap), cv2.COLORMAP_JET)
188
+ heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) / 255.0
189
+
190
+ overlay = 0.5 * heatmap_color + 0.5 * img
191
+
192
+ # ============================================================
193
+ # Visualization
194
+ # ============================================================
195
+ plt.figure(figsize=(15,5))
196
+
197
+ plt.subplot(1,3,1)
198
+ plt.imshow(img)
199
+ plt.title("Original")
200
+ plt.axis("off")
201
+
202
+ plt.subplot(1,3,2)
203
+ plt.imshow(heatmap, cmap="jet")
204
+ plt.title("Grad-CAM")
205
+ plt.axis("off")
206
+
207
+ plt.subplot(1,3,3)
208
+ plt.imshow(overlay)
209
+ plt.title(f"{CLASS_NAMES[pred_class]} ({confidence:.2%})")
210
+ plt.axis("off")
211
+
212
+ plt.tight_layout()
213
+ plt.show()
214
+ ```
215
+
216
+ ### Load from Local File
217
+
218
+ ```python
219
+ # Replace URL loading with:
220
+ img = cv2.imread('path/to/ultrasound.jpg')
221
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
222
+ img = cv2.resize(img, IMG_SIZE)
223
+ img = img.astype(np.float32) / 255.0
224
+ img_array = np.expand_dims(img, axis=0)
225
+ ```
226
+
227
+ ## πŸ”¬ Understanding Grad-CAM Output
228
+
229
+ - **Red/Hot regions**: High importance for prediction (follicles, cysts)
230
+ - **Blue/Cool regions**: Low influence on decision
231
+ - **Dual visualization**: Separate heatmaps for upper and lower ovarian regions
232
+
233
+ ## πŸ“Š Model Architecture
234
+
235
+ ```
236
+ Input (224Γ—224Γ—3)
237
+ β”œβ”€β”€ Split horizontally (upper/lower)
238
+ β”œβ”€β”€ Upper Path: Conv32 β†’ Conv64 β†’ Conv128 β†’ Dense512
239
+ β”œβ”€β”€ Lower Path: Conv32 β†’ Conv64 β†’ Conv128 β†’ Dense512
240
+ β”œβ”€β”€ Multi-Head Attention (4 heads, dim=64)
241
+ └── Classification: Dense256 β†’ Dense128 β†’ Dense2
242
+ ```
243
+
244
+ **Key Features:**
245
+ - Dual-path CNN for separate ovarian region analysis
246
+ - Lower region flipped for symmetry normalization
247
+ - Multi-head attention for feature fusion
248
+ - Logits-based Grad-CAM (fixes saturated softmax gradients)
249
+
250
+ ## πŸ“ˆ Dataset
251
+
252
+ - **Total**: 11,784 ultrasound images
253
+ - **PCOS-positive**: 6,784 images (57.5%)
254
+ - **Healthy**: 5,000 images (42.5%)
255
+ - **Source**: 3 clinics (2018-2022), expert-annotated
256
+ - **Dataset**: [PCOS XAI Ultrasound](https://www.kaggle.com/datasets/...)
257
+
258
+ ## ⚠️ Important Notes
259
+
260
+ **Clinical Use:**
261
+ - ⚠️ Research purposes only - NOT FDA approved
262
+ - ⚠️ Not a diagnostic tool - requires professional validation
263
+ - ⚠️ Must be validated on local datasets before clinical deployment
264
+
265
+ **Technical:**
266
+ - Fixed 224Γ—224 input size required
267
+ - RGB images only
268
+ - Model performance may vary across different ultrasound machines
269
+
270
+ ## πŸ“ Citation
271
+
272
+ ```bibtex
273
+ @misc{pcos_xai_2024,
274
+ title={PCOS Detection with Explainable AI},
275
+ author={Dehsahk-AI},
276
+ year={2024},
277
+ url={https://huggingface.co/Dehsahk-AI/Pcos-Detect}
278
+ }
279
+ ```
280
+
281
+ ## πŸ“œ License
282
+
283
+ MIT License - See LICENSE file for details.
284
+
285
+ ## πŸ™ Acknowledgments
286
+
287
+ - Grad-CAM: Selvaraju et al. (ICCV 2017)
288
+ - Multi-head Attention: Vaswani et al. (NeurIPS 2017)
289
+ - Dataset from clinical retrospective studies with ethical compliance
290
+
291
+ ---
292
+
293
+ **Model Version**: 1.0 | **Last Updated**: December 2024
294
+ license: apache-2.0
295
+ ---