""" Project Phoenix - Cervical Cancer Cell Classification Gradio application for running inference on ConvNeXt V2 model from Hugging Face with explainability features (GRAD-CAM). Deployed on Hugging Face Spaces. """ import os import numpy as np import cv2 from typing import Dict, Tuple, Optional # Deep Learning import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image # Transformers from transformers import ( ConvNextV2ForImageClassification, AutoImageProcessor ) # GRAD-CAM variants from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget # Gradio import gradio as gr # ========== CONFIGURATION ========== # Model directory - model files are in the root directory of the Space MODEL_DIR = os.path.dirname(__file__) # Current directory where app.py is located # Class names CLASS_NAMES = [ 'im_Dyskeratotic', 'im_Koilocytotic', 'im_Metaplastic', 'im_Parabasal', 'im_Superficial-Intermediate' ] # Display names (cleaner for UI) DISPLAY_NAMES = [ 'Dyskeratotic', 'Koilocytotic', 'Metaplastic', 'Parabasal', 'Superficial-Intermediate' ] # Device DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ========== MODEL LOADING ========== print("Loading model from local directory...") print(f"Model directory: {MODEL_DIR}") print(f"Device: {DEVICE}") # Load image processor processor = AutoImageProcessor.from_pretrained(MODEL_DIR) print("✓ Processor loaded") # Load model model = ConvNextV2ForImageClassification.from_pretrained(MODEL_DIR) model = model.to(DEVICE) model.eval() print("✓ Model loaded and set to evaluation mode") print(f"Model configuration:") print(f" - Number of classes: {model.config.num_labels}") print(f" - Image size: {model.config.image_size}") print(f" - Total parameters: {sum(p.numel() for p in model.parameters()):,}") # ========== HELPER FUNCTIONS ========== def preprocess_image(image: Image.Image) -> Tuple[torch.Tensor, np.ndarray]: """ Preprocess image for model input. Returns: Tuple of (preprocessed_tensor, original_image_array) """ # Store original for visualization original_image = np.array(image.convert('RGB')) # Preprocess using the model's processor inputs = processor(images=image, return_tensors="pt") pixel_values = inputs['pixel_values'].to(DEVICE) return pixel_values, original_image class ConvNeXtGradCAMWrapper(nn.Module): """Wrapper for ConvNeXtV2ForImageClassification to make it compatible with GRAD-CAM.""" def __init__(self, model): super().__init__() self.model = model def forward(self, x): outputs = self.model(pixel_values=x) return outputs.logits def get_target_layers(model): """Get the target layers for GRAD-CAM from ConvNeXt model.""" return [model.convnextv2.encoder.stages[-1].layers[-1]] def apply_cam_methods( pixel_values: torch.Tensor, original_image: np.ndarray, target_class: Optional[int] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, float]: """ Apply GRAD-CAM, GRAD-CAM++, and LayerCAM to visualize model attention. Args: pixel_values: Preprocessed image tensor original_image: Original image as numpy array target_class: Target class index (None for predicted class) Returns: Tuple of (gradcam_viz, gradcam_pp_viz, layercam_viz, predicted_class, confidence) """ # Wrap the model wrapped_model = ConvNeXtGradCAMWrapper(model) # Get target layers target_layers = get_target_layers(model) # Initialize all CAM methods gradcam = GradCAM(model=wrapped_model, target_layers=target_layers) gradcam_pp = GradCAMPlusPlus(model=wrapped_model, target_layers=target_layers) layercam = LayerCAM(model=wrapped_model, target_layers=target_layers) # Get prediction model.eval() with torch.no_grad(): outputs = model(pixel_values) logits = outputs.logits predicted_class = logits.argmax(-1).item() probabilities = F.softmax(logits, dim=-1)[0] # Use predicted class if target not specified if target_class is None: target_class = predicted_class # Create target for CAM methods targets = [ClassifierOutputTarget(target_class)] # Generate all CAM visualizations grayscale_gradcam = gradcam(input_tensor=pixel_values, targets=targets)[0, :] grayscale_gradcam_pp = gradcam_pp(input_tensor=pixel_values, targets=targets)[0, :] grayscale_layercam = layercam(input_tensor=pixel_values, targets=targets)[0, :] # Resize original image to match CAM dimensions cam_h, cam_w = grayscale_gradcam.shape rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0 # Create visualizations for all methods viz_gradcam = show_cam_on_image( rgb_image_for_overlay, grayscale_gradcam, use_rgb=True, colormap=cv2.COLORMAP_JET ) viz_gradcam_pp = show_cam_on_image( rgb_image_for_overlay, grayscale_gradcam_pp, use_rgb=True, colormap=cv2.COLORMAP_JET ) viz_layercam = show_cam_on_image( rgb_image_for_overlay, grayscale_layercam, use_rgb=True, colormap=cv2.COLORMAP_JET ) return viz_gradcam, viz_gradcam_pp, viz_layercam, predicted_class, float(probabilities[predicted_class].item()) # ========== GRADIO INTERFACE FUNCTIONS ========== def predict_basic(image): """ Basic prediction without explainability. Args: image: PIL Image or numpy array Returns: Dictionary with class probabilities for Gradio Label component """ if image is None: return None try: # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Preprocess pixel_values, _ = preprocess_image(image) # Predict model.eval() with torch.no_grad(): outputs = model(pixel_values) logits = outputs.logits probabilities = F.softmax(logits, dim=-1)[0] # Format for Gradio Label component return {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))} except Exception as e: print(f"Error in prediction: {e}") return None def predict_with_explainability(image): """ Prediction with multiple CAM explainability methods. Args: image: PIL Image or numpy array Returns: Tuple of (probabilities_dict, gradcam_image, gradcam_pp_image, layercam_image, info_text) """ if image is None: return None, None, None, None, "Please upload an image." try: # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Preprocess pixel_values, original_image = preprocess_image(image) # Predict model.eval() with torch.no_grad(): outputs = model(pixel_values) logits = outputs.logits probabilities = F.softmax(logits, dim=-1)[0] predicted_class = logits.argmax(-1).item() # Apply all CAM methods viz_gradcam, viz_gradcam_pp, viz_layercam, pred_class, confidence = apply_cam_methods( pixel_values, original_image ) # Format probabilities for Gradio probs_dict = {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))} # Create info text info_text = f"**Predicted Class:** {DISPLAY_NAMES[predicted_class]}\n\n" info_text += f"**Confidence:** {confidence*100:.2f}%\n\n" info_text += "The heatmaps show regions the model focused on for classification using different visualization methods." return probs_dict, viz_gradcam, viz_gradcam_pp, viz_layercam, info_text except Exception as e: print(f"Error in prediction with explainability: {e}") return None, None, None, None, f"Error: {str(e)}" # ========== GRADIO INTERFACE ========== # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'Arial', sans-serif; } .header { text-align: center; margin-bottom: 2rem; } """ # Create Gradio Blocks interface with gr.Blocks(css=custom_css, title="Project Phoenix - Cervical Cancer Cell Classification") as demo: gr.Markdown(""" # 🔬 Project Phoenix - Cervical Cancer Cell Classification ConvNeXt V2 model for automated classification of cervical cancer cells into 5 categories: - **Dyskeratotic**: Abnormal keratinization - **Koilocytotic**: HPV-infected cells - **Metaplastic**: Transitional cells - **Parabasal**: Immature cells - **Superficial-Intermediate**: Mature cells """) with gr.Tabs(): # Tab 1: Basic Prediction with gr.TabItem("🎯 Basic Prediction"): gr.Markdown("Upload an image to classify the cervical cell type.") with gr.Row(): with gr.Column(): input_image_basic = gr.Image(type="pil", label="Upload Cell Image") predict_btn_basic = gr.Button("Classify", variant="primary", size="lg") with gr.Column(): output_label_basic = gr.Label(label="Classification Results", num_top_classes=5) predict_btn_basic.click( fn=predict_basic, inputs=input_image_basic, outputs=output_label_basic, api_name="predict_basic", queue=False ) # Tab 2: Prediction with Explainability with gr.TabItem("🔍 Prediction + Explainability (CAM Methods)"): gr.Markdown("Upload an image to classify and visualize model attention using GRAD-CAM, GRAD-CAM++, and LayerCAM.") with gr.Row(): with gr.Column(): input_image_explain = gr.Image(type="pil", label="Upload Cell Image") predict_btn_explain = gr.Button("Classify with Explainability", variant="primary", size="lg") with gr.Column(): output_label_explain = gr.Label(label="Classification Results", num_top_classes=5) with gr.Row(): output_gradcam = gr.Image(label="GRAD-CAM") output_gradcam_pp = gr.Image(label="GRAD-CAM++") output_layercam = gr.Image(label="LayerCAM") output_info = gr.Markdown(label="Analysis") predict_btn_explain.click( fn=predict_with_explainability, inputs=input_image_explain, outputs=[output_label_explain, output_gradcam, output_gradcam_pp, output_layercam, output_info], api_name="predict_with_explainability", queue=False ) # Footer gr.Markdown(""" --- ### 📊 About the Model This model is a fine-tuned **ConvNeXt V2** neural network trained on the SIPaKMeD dataset for cervical cancer cell classification. The model achieves high accuracy in distinguishing between different cell types, which is crucial for early cancer detection and diagnosis. **GRAD-CAM** (Gradient-weighted Class Activation Mapping) provides visual explanations by highlighting the regions in the image that were most important for the model's decision. 🔗 **Model**: [Meet2304/convnextv2-cervical-cell-classification](https://huggingface.co/Meet2304/convnextv2-cervical-cell-classification) """) # ========== LAUNCH ========== if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )