File size: 12,577 Bytes
f255e67
fba4e48
 
f255e67
fba4e48
f255e67
 
 
 
 
fba4e48
f255e67
 
 
 
 
 
 
 
 
 
 
 
 
4c590a1
 
f255e67
 
 
fba4e48
 
 
f255e67
 
b0364d3
 
f255e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d87f7c0
 
f255e67
 
ca4908b
d87f7c0
f255e67
 
ca4908b
d87f7c0
f255e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c590a1
f255e67
 
 
4c590a1
f255e67
4c590a1
f255e67
 
 
 
 
 
 
4c590a1
f255e67
 
 
 
 
 
 
4c590a1
 
 
 
f255e67
 
 
 
 
 
 
 
 
 
 
 
 
4c590a1
f255e67
 
4c590a1
 
 
 
f255e67
 
4c590a1
f255e67
 
4c590a1
 
f255e67
4c590a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f255e67
 
 
 
4c590a1
f255e67
 
fba4e48
f255e67
fba4e48
f255e67
fba4e48
 
 
 
 
f255e67
fba4e48
f255e67
fba4e48
 
 
f255e67
fba4e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f255e67
fba4e48
 
f255e67
 
fba4e48
f255e67
4c590a1
fba4e48
 
 
 
f255e67
4c590a1
f255e67
fba4e48
4c590a1
fba4e48
f255e67
fba4e48
 
 
 
 
f255e67
fba4e48
 
 
 
 
 
 
 
 
4c590a1
 
 
 
fba4e48
 
 
 
 
 
 
4c590a1
fba4e48
4c590a1
fba4e48
 
 
4c590a1
f255e67
 
fba4e48
f255e67
fba4e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca4908b
9466b22
 
fba4e48
 
 
4c590a1
 
fba4e48
 
 
 
 
 
 
 
4c590a1
 
 
 
fba4e48
 
 
 
 
4c590a1
9466b22
 
fba4e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f255e67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
"""

Project Phoenix - Cervical Cancer Cell Classification

Gradio application for running inference on ConvNeXt V2 model from Hugging Face

with explainability features (GRAD-CAM).

Deployed on Hugging Face Spaces.

"""

import os
import numpy as np
import cv2
from typing import Dict, Tuple, Optional

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

# Transformers
from transformers import (
    ConvNextV2ForImageClassification,
    AutoImageProcessor
)

# GRAD-CAM variants
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# Gradio
import gradio as gr

# ========== CONFIGURATION ==========

# Model directory - model files are in the root directory of the Space
MODEL_DIR = os.path.dirname(__file__)  # Current directory where app.py is located

# Class names
CLASS_NAMES = [
    'im_Dyskeratotic',
    'im_Koilocytotic',
    'im_Metaplastic',
    'im_Parabasal',
    'im_Superficial-Intermediate'
]

# Display names (cleaner for UI)
DISPLAY_NAMES = [
    'Dyskeratotic',
    'Koilocytotic',
    'Metaplastic',
    'Parabasal',
    'Superficial-Intermediate'
]

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ========== MODEL LOADING ==========

print("Loading model from local directory...")
print(f"Model directory: {MODEL_DIR}")
print(f"Device: {DEVICE}")

# Load image processor
processor = AutoImageProcessor.from_pretrained(MODEL_DIR)
print("βœ“ Processor loaded")

# Load model
model = ConvNextV2ForImageClassification.from_pretrained(MODEL_DIR)
model = model.to(DEVICE)
model.eval()
print("βœ“ Model loaded and set to evaluation mode")

print(f"Model configuration:")
print(f"  - Number of classes: {model.config.num_labels}")
print(f"  - Image size: {model.config.image_size}")
print(f"  - Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# ========== HELPER FUNCTIONS ==========

def preprocess_image(image: Image.Image) -> Tuple[torch.Tensor, np.ndarray]:
    """

    Preprocess image for model input.



    Returns:

        Tuple of (preprocessed_tensor, original_image_array)

    """
    # Store original for visualization
    original_image = np.array(image.convert('RGB'))

    # Preprocess using the model's processor
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs['pixel_values'].to(DEVICE)

    return pixel_values, original_image


class ConvNeXtGradCAMWrapper(nn.Module):
    """Wrapper for ConvNeXtV2ForImageClassification to make it compatible with GRAD-CAM."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        outputs = self.model(pixel_values=x)
        return outputs.logits


def get_target_layers(model):
    """Get the target layers for GRAD-CAM from ConvNeXt model."""
    return [model.convnextv2.encoder.stages[-1].layers[-1]]


def apply_cam_methods(

    pixel_values: torch.Tensor,

    original_image: np.ndarray,

    target_class: Optional[int] = None

) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, float]:
    """

    Apply GRAD-CAM, GRAD-CAM++, and LayerCAM to visualize model attention.



    Args:

        pixel_values: Preprocessed image tensor

        original_image: Original image as numpy array

        target_class: Target class index (None for predicted class)



    Returns:

        Tuple of (gradcam_viz, gradcam_pp_viz, layercam_viz, predicted_class, confidence)

    """
    # Wrap the model
    wrapped_model = ConvNeXtGradCAMWrapper(model)

    # Get target layers
    target_layers = get_target_layers(model)

    # Initialize all CAM methods
    gradcam = GradCAM(model=wrapped_model, target_layers=target_layers)
    gradcam_pp = GradCAMPlusPlus(model=wrapped_model, target_layers=target_layers)
    layercam = LayerCAM(model=wrapped_model, target_layers=target_layers)

    # Get prediction
    model.eval()
    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits
        predicted_class = logits.argmax(-1).item()
        probabilities = F.softmax(logits, dim=-1)[0]

    # Use predicted class if target not specified
    if target_class is None:
        target_class = predicted_class

    # Create target for CAM methods
    targets = [ClassifierOutputTarget(target_class)]

    # Generate all CAM visualizations
    grayscale_gradcam = gradcam(input_tensor=pixel_values, targets=targets)[0, :]
    grayscale_gradcam_pp = gradcam_pp(input_tensor=pixel_values, targets=targets)[0, :]
    grayscale_layercam = layercam(input_tensor=pixel_values, targets=targets)[0, :]

    # Resize original image to match CAM dimensions
    cam_h, cam_w = grayscale_gradcam.shape
    rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0

    # Create visualizations for all methods
    viz_gradcam = show_cam_on_image(
        rgb_image_for_overlay,
        grayscale_gradcam,
        use_rgb=True,
        colormap=cv2.COLORMAP_JET
    )
    
    viz_gradcam_pp = show_cam_on_image(
        rgb_image_for_overlay,
        grayscale_gradcam_pp,
        use_rgb=True,
        colormap=cv2.COLORMAP_JET
    )
    
    viz_layercam = show_cam_on_image(
        rgb_image_for_overlay,
        grayscale_layercam,
        use_rgb=True,
        colormap=cv2.COLORMAP_JET
    )

    return viz_gradcam, viz_gradcam_pp, viz_layercam, predicted_class, float(probabilities[predicted_class].item())


# ========== GRADIO INTERFACE FUNCTIONS ==========

def predict_basic(image):
    """

    Basic prediction without explainability.

    

    Args:

        image: PIL Image or numpy array

        

    Returns:

        Dictionary with class probabilities for Gradio Label component

    """
    if image is None:
        return None
    
    try:
        # Convert to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        # Preprocess
        pixel_values, _ = preprocess_image(image)
        
        # Predict
        model.eval()
        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits
            probabilities = F.softmax(logits, dim=-1)[0]
        
        # Format for Gradio Label component
        return {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))}
    
    except Exception as e:
        print(f"Error in prediction: {e}")
        return None


def predict_with_explainability(image):
    """

    Prediction with multiple CAM explainability methods.

    

    Args:

        image: PIL Image or numpy array

        

    Returns:

        Tuple of (probabilities_dict, gradcam_image, gradcam_pp_image, layercam_image, info_text)

    """
    if image is None:
        return None, None, None, None, "Please upload an image."
    
    try:
        # Convert to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        # Preprocess
        pixel_values, original_image = preprocess_image(image)
        
        # Predict
        model.eval()
        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits
            probabilities = F.softmax(logits, dim=-1)[0]
            predicted_class = logits.argmax(-1).item()
        
        # Apply all CAM methods
        viz_gradcam, viz_gradcam_pp, viz_layercam, pred_class, confidence = apply_cam_methods(
            pixel_values, original_image
        )
        
        # Format probabilities for Gradio
        probs_dict = {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))}
        
        # Create info text
        info_text = f"**Predicted Class:** {DISPLAY_NAMES[predicted_class]}\n\n"
        info_text += f"**Confidence:** {confidence*100:.2f}%\n\n"
        info_text += "The heatmaps show regions the model focused on for classification using different visualization methods."
        
        return probs_dict, viz_gradcam, viz_gradcam_pp, viz_layercam, info_text
    
    except Exception as e:
        print(f"Error in prediction with explainability: {e}")
        return None, None, None, None, f"Error: {str(e)}"


# ========== GRADIO INTERFACE ==========

# Custom CSS for better styling
custom_css = """

.gradio-container {

    font-family: 'Arial', sans-serif;

}

.header {

    text-align: center;

    margin-bottom: 2rem;

}

"""

# Create Gradio Blocks interface
with gr.Blocks(css=custom_css, title="Project Phoenix - Cervical Cancer Cell Classification") as demo:
    
    gr.Markdown("""

    # πŸ”¬ Project Phoenix - Cervical Cancer Cell Classification

    

    ConvNeXt V2 model for automated classification of cervical cancer cells into 5 categories:

    - **Dyskeratotic**: Abnormal keratinization

    - **Koilocytotic**: HPV-infected cells

    - **Metaplastic**: Transitional cells

    - **Parabasal**: Immature cells

    - **Superficial-Intermediate**: Mature cells

    """)
    
    with gr.Tabs():
        # Tab 1: Basic Prediction
        with gr.TabItem("🎯 Basic Prediction"):
            gr.Markdown("Upload an image to classify the cervical cell type.")
            
            with gr.Row():
                with gr.Column():
                    input_image_basic = gr.Image(type="pil", label="Upload Cell Image")
                    predict_btn_basic = gr.Button("Classify", variant="primary", size="lg")
                
                with gr.Column():
                    output_label_basic = gr.Label(label="Classification Results", num_top_classes=5)
            
            predict_btn_basic.click(
                fn=predict_basic,
                inputs=input_image_basic,
                outputs=output_label_basic,
                api_name="predict_basic",
                queue=False
            )
        
        # Tab 2: Prediction with Explainability
        with gr.TabItem("πŸ” Prediction + Explainability (CAM Methods)"):
            gr.Markdown("Upload an image to classify and visualize model attention using GRAD-CAM, GRAD-CAM++, and LayerCAM.")
            
            with gr.Row():
                with gr.Column():
                    input_image_explain = gr.Image(type="pil", label="Upload Cell Image")
                    predict_btn_explain = gr.Button("Classify with Explainability", variant="primary", size="lg")
                
                with gr.Column():
                    output_label_explain = gr.Label(label="Classification Results", num_top_classes=5)
                    with gr.Row():
                        output_gradcam = gr.Image(label="GRAD-CAM")
                        output_gradcam_pp = gr.Image(label="GRAD-CAM++")
                        output_layercam = gr.Image(label="LayerCAM")
                    output_info = gr.Markdown(label="Analysis")
            
            predict_btn_explain.click(
                fn=predict_with_explainability,
                inputs=input_image_explain,
                outputs=[output_label_explain, output_gradcam, output_gradcam_pp, output_layercam, output_info],
                api_name="predict_with_explainability",
                queue=False
            )
    
    # Footer
    gr.Markdown("""

    ---

    ### πŸ“Š About the Model

    

    This model is a fine-tuned **ConvNeXt V2** neural network trained on the SIPaKMeD dataset 

    for cervical cancer cell classification. The model achieves high accuracy in distinguishing 

    between different cell types, which is crucial for early cancer detection and diagnosis.

    

    **GRAD-CAM** (Gradient-weighted Class Activation Mapping) provides visual explanations by 

    highlighting the regions in the image that were most important for the model's decision.

    

    πŸ”— **Model**: [Meet2304/convnextv2-cervical-cell-classification](https://huggingface.co/Meet2304/convnextv2-cervical-cell-classification)

    """)

# ========== LAUNCH ==========

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )