Spaces:
Sleeping
Sleeping
| """ | |
| Project Phoenix - Cervical Cancer Cell Classification | |
| Gradio application for running inference on ConvNeXt V2 model from Hugging Face | |
| with explainability features (GRAD-CAM). | |
| Deployed on Hugging Face Spaces. | |
| """ | |
| import os | |
| import numpy as np | |
| import cv2 | |
| from typing import Dict, Tuple, Optional | |
| # Deep Learning | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| # Transformers | |
| from transformers import ( | |
| ConvNextV2ForImageClassification, | |
| AutoImageProcessor | |
| ) | |
| # GRAD-CAM variants | |
| from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, LayerCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| # Gradio | |
| import gradio as gr | |
| # ========== CONFIGURATION ========== | |
| # Model directory - model files are in the root directory of the Space | |
| MODEL_DIR = os.path.dirname(__file__) # Current directory where app.py is located | |
| # Class names | |
| CLASS_NAMES = [ | |
| 'im_Dyskeratotic', | |
| 'im_Koilocytotic', | |
| 'im_Metaplastic', | |
| 'im_Parabasal', | |
| 'im_Superficial-Intermediate' | |
| ] | |
| # Display names (cleaner for UI) | |
| DISPLAY_NAMES = [ | |
| 'Dyskeratotic', | |
| 'Koilocytotic', | |
| 'Metaplastic', | |
| 'Parabasal', | |
| 'Superficial-Intermediate' | |
| ] | |
| # Device | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # ========== MODEL LOADING ========== | |
| print("Loading model from local directory...") | |
| print(f"Model directory: {MODEL_DIR}") | |
| print(f"Device: {DEVICE}") | |
| # Load image processor | |
| processor = AutoImageProcessor.from_pretrained(MODEL_DIR) | |
| print("β Processor loaded") | |
| # Load model | |
| model = ConvNextV2ForImageClassification.from_pretrained(MODEL_DIR) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| print("β Model loaded and set to evaluation mode") | |
| print(f"Model configuration:") | |
| print(f" - Number of classes: {model.config.num_labels}") | |
| print(f" - Image size: {model.config.image_size}") | |
| print(f" - Total parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| # ========== HELPER FUNCTIONS ========== | |
| def preprocess_image(image: Image.Image) -> Tuple[torch.Tensor, np.ndarray]: | |
| """ | |
| Preprocess image for model input. | |
| Returns: | |
| Tuple of (preprocessed_tensor, original_image_array) | |
| """ | |
| # Store original for visualization | |
| original_image = np.array(image.convert('RGB')) | |
| # Preprocess using the model's processor | |
| inputs = processor(images=image, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(DEVICE) | |
| return pixel_values, original_image | |
| class ConvNeXtGradCAMWrapper(nn.Module): | |
| """Wrapper for ConvNeXtV2ForImageClassification to make it compatible with GRAD-CAM.""" | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def forward(self, x): | |
| outputs = self.model(pixel_values=x) | |
| return outputs.logits | |
| def get_target_layers(model): | |
| """Get the target layers for GRAD-CAM from ConvNeXt model.""" | |
| return [model.convnextv2.encoder.stages[-1].layers[-1]] | |
| def apply_cam_methods( | |
| pixel_values: torch.Tensor, | |
| original_image: np.ndarray, | |
| target_class: Optional[int] = None | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, float]: | |
| """ | |
| Apply GRAD-CAM, GRAD-CAM++, and LayerCAM to visualize model attention. | |
| Args: | |
| pixel_values: Preprocessed image tensor | |
| original_image: Original image as numpy array | |
| target_class: Target class index (None for predicted class) | |
| Returns: | |
| Tuple of (gradcam_viz, gradcam_pp_viz, layercam_viz, predicted_class, confidence) | |
| """ | |
| # Wrap the model | |
| wrapped_model = ConvNeXtGradCAMWrapper(model) | |
| # Get target layers | |
| target_layers = get_target_layers(model) | |
| # Initialize all CAM methods | |
| gradcam = GradCAM(model=wrapped_model, target_layers=target_layers) | |
| gradcam_pp = GradCAMPlusPlus(model=wrapped_model, target_layers=target_layers) | |
| layercam = LayerCAM(model=wrapped_model, target_layers=target_layers) | |
| # Get prediction | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| predicted_class = logits.argmax(-1).item() | |
| probabilities = F.softmax(logits, dim=-1)[0] | |
| # Use predicted class if target not specified | |
| if target_class is None: | |
| target_class = predicted_class | |
| # Create target for CAM methods | |
| targets = [ClassifierOutputTarget(target_class)] | |
| # Generate all CAM visualizations | |
| grayscale_gradcam = gradcam(input_tensor=pixel_values, targets=targets)[0, :] | |
| grayscale_gradcam_pp = gradcam_pp(input_tensor=pixel_values, targets=targets)[0, :] | |
| grayscale_layercam = layercam(input_tensor=pixel_values, targets=targets)[0, :] | |
| # Resize original image to match CAM dimensions | |
| cam_h, cam_w = grayscale_gradcam.shape | |
| rgb_image_for_overlay = cv2.resize(original_image, (cam_w, cam_h)).astype(np.float32) / 255.0 | |
| # Create visualizations for all methods | |
| viz_gradcam = show_cam_on_image( | |
| rgb_image_for_overlay, | |
| grayscale_gradcam, | |
| use_rgb=True, | |
| colormap=cv2.COLORMAP_JET | |
| ) | |
| viz_gradcam_pp = show_cam_on_image( | |
| rgb_image_for_overlay, | |
| grayscale_gradcam_pp, | |
| use_rgb=True, | |
| colormap=cv2.COLORMAP_JET | |
| ) | |
| viz_layercam = show_cam_on_image( | |
| rgb_image_for_overlay, | |
| grayscale_layercam, | |
| use_rgb=True, | |
| colormap=cv2.COLORMAP_JET | |
| ) | |
| return viz_gradcam, viz_gradcam_pp, viz_layercam, predicted_class, float(probabilities[predicted_class].item()) | |
| # ========== GRADIO INTERFACE FUNCTIONS ========== | |
| def predict_basic(image): | |
| """ | |
| Basic prediction without explainability. | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| Dictionary with class probabilities for Gradio Label component | |
| """ | |
| if image is None: | |
| return None | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Preprocess | |
| pixel_values, _ = preprocess_image(image) | |
| # Predict | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| probabilities = F.softmax(logits, dim=-1)[0] | |
| # Format for Gradio Label component | |
| return {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))} | |
| except Exception as e: | |
| print(f"Error in prediction: {e}") | |
| return None | |
| def predict_with_explainability(image): | |
| """ | |
| Prediction with multiple CAM explainability methods. | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| Tuple of (probabilities_dict, gradcam_image, gradcam_pp_image, layercam_image, info_text) | |
| """ | |
| if image is None: | |
| return None, None, None, None, "Please upload an image." | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Preprocess | |
| pixel_values, original_image = preprocess_image(image) | |
| # Predict | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| logits = outputs.logits | |
| probabilities = F.softmax(logits, dim=-1)[0] | |
| predicted_class = logits.argmax(-1).item() | |
| # Apply all CAM methods | |
| viz_gradcam, viz_gradcam_pp, viz_layercam, pred_class, confidence = apply_cam_methods( | |
| pixel_values, original_image | |
| ) | |
| # Format probabilities for Gradio | |
| probs_dict = {DISPLAY_NAMES[i]: float(probabilities[i]) for i in range(len(DISPLAY_NAMES))} | |
| # Create info text | |
| info_text = f"**Predicted Class:** {DISPLAY_NAMES[predicted_class]}\n\n" | |
| info_text += f"**Confidence:** {confidence*100:.2f}%\n\n" | |
| info_text += "The heatmaps show regions the model focused on for classification using different visualization methods." | |
| return probs_dict, viz_gradcam, viz_gradcam_pp, viz_layercam, info_text | |
| except Exception as e: | |
| print(f"Error in prediction with explainability: {e}") | |
| return None, None, None, None, f"Error: {str(e)}" | |
| # ========== GRADIO INTERFACE ========== | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| .header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| """ | |
| # Create Gradio Blocks interface | |
| with gr.Blocks(css=custom_css, title="Project Phoenix - Cervical Cancer Cell Classification") as demo: | |
| gr.Markdown(""" | |
| # π¬ Project Phoenix - Cervical Cancer Cell Classification | |
| ConvNeXt V2 model for automated classification of cervical cancer cells into 5 categories: | |
| - **Dyskeratotic**: Abnormal keratinization | |
| - **Koilocytotic**: HPV-infected cells | |
| - **Metaplastic**: Transitional cells | |
| - **Parabasal**: Immature cells | |
| - **Superficial-Intermediate**: Mature cells | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Basic Prediction | |
| with gr.TabItem("π― Basic Prediction"): | |
| gr.Markdown("Upload an image to classify the cervical cell type.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image_basic = gr.Image(type="pil", label="Upload Cell Image") | |
| predict_btn_basic = gr.Button("Classify", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_label_basic = gr.Label(label="Classification Results", num_top_classes=5) | |
| predict_btn_basic.click( | |
| fn=predict_basic, | |
| inputs=input_image_basic, | |
| outputs=output_label_basic, | |
| api_name="predict_basic", | |
| queue=False | |
| ) | |
| # Tab 2: Prediction with Explainability | |
| with gr.TabItem("π Prediction + Explainability (CAM Methods)"): | |
| gr.Markdown("Upload an image to classify and visualize model attention using GRAD-CAM, GRAD-CAM++, and LayerCAM.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image_explain = gr.Image(type="pil", label="Upload Cell Image") | |
| predict_btn_explain = gr.Button("Classify with Explainability", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_label_explain = gr.Label(label="Classification Results", num_top_classes=5) | |
| with gr.Row(): | |
| output_gradcam = gr.Image(label="GRAD-CAM") | |
| output_gradcam_pp = gr.Image(label="GRAD-CAM++") | |
| output_layercam = gr.Image(label="LayerCAM") | |
| output_info = gr.Markdown(label="Analysis") | |
| predict_btn_explain.click( | |
| fn=predict_with_explainability, | |
| inputs=input_image_explain, | |
| outputs=[output_label_explain, output_gradcam, output_gradcam_pp, output_layercam, output_info], | |
| api_name="predict_with_explainability", | |
| queue=False | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| ### π About the Model | |
| This model is a fine-tuned **ConvNeXt V2** neural network trained on the SIPaKMeD dataset | |
| for cervical cancer cell classification. The model achieves high accuracy in distinguishing | |
| between different cell types, which is crucial for early cancer detection and diagnosis. | |
| **GRAD-CAM** (Gradient-weighted Class Activation Mapping) provides visual explanations by | |
| highlighting the regions in the image that were most important for the model's decision. | |
| π **Model**: [Meet2304/convnextv2-cervical-cell-classification](https://huggingface.co/Meet2304/convnextv2-cervical-cell-classification) | |
| """) | |
| # ========== LAUNCH ========== | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |