{ "cells": [ { "cell_type": "markdown", "id": "a3e1c6bc", "metadata": {}, "source": [ "# šŸ„ Medical Image Segmentation Demo\n", "## UW-Madison GI Tract Segmentation using SegFormer\n", "\n", "This notebook demonstrates how to use the pre-trained SegFormer model to segment medical images of the GI tract (stomach, small bowel, large bowel)." ] }, { "cell_type": "code", "execution_count": null, "id": "d82b1011", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from pathlib import Path\n", "from dataclasses import dataclass\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "import torchvision.transforms as TF\n", "from transformers import SegformerForSemanticSegmentation\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as mpatches\n", "from glob import glob\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "# Display settings\n", "plt.style.use('seaborn-v0_8-darkgrid')\n", "%matplotlib inline\n", "\n", "# Define configuration\n", "@dataclass\n", "class Configs:\n", " NUM_CLASSES: int = 4 # including background\n", " CLASSES: tuple = (\"Large bowel\", \"Small bowel\", \"Stomach\")\n", " IMAGE_SIZE: tuple = (288, 288) # W, H\n", " MEAN: tuple = (0.485, 0.456, 0.406)\n", " STD: tuple = (0.229, 0.224, 0.225)\n", " MODEL_PATH: str = os.path.join(os.getcwd(), \"segformer_trained_weights\")\n", "\n", "config = Configs()\n", "print(f\"āœ“ Configuration loaded\")\n", "print(f\" Classes: {config.CLASSES}\")\n", "print(f\" Image size: {config.IMAGE_SIZE}\")\n" ] }, { "cell_type": "markdown", "id": "61b319d6", "metadata": {}, "source": [ "## 1ļøāƒ£ Load Pre-trained SegFormer Model" ] }, { "cell_type": "code", "execution_count": null, "id": "f36ff588", "metadata": {}, "outputs": [], "source": [ "# Set device\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"šŸ–„ļø Device: {device}\")\n", "\n", "# Load model\n", "model = SegformerForSemanticSegmentation.from_pretrained(\n", " config.MODEL_PATH,\n", " num_labels=config.NUM_CLASSES,\n", " ignore_mismatched_sizes=True\n", ")\n", "model.to(device)\n", "model.eval()\n", "\n", "# Test forward pass\n", "with torch.no_grad():\n", " dummy_input = torch.randn(1, 3, *config.IMAGE_SIZE[::-1], device=device)\n", " _ = model(pixel_values=dummy_input)\n", "\n", "print(f\"āœ“ SegFormer model loaded successfully\")\n", "print(f\" Total parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f}M\")\n" ] }, { "cell_type": "markdown", "id": "55e6ec48", "metadata": {}, "source": [ "## 2ļøāƒ£ Define Image Preprocessing Pipeline" ] }, { "cell_type": "code", "execution_count": null, "id": "5fdb5622", "metadata": {}, "outputs": [], "source": [ "# Create preprocessing pipeline\n", "preprocess = TF.Compose([\n", " TF.Resize(size=config.IMAGE_SIZE[::-1]),\n", " TF.ToTensor(),\n", " TF.Normalize(config.MEAN, config.STD, inplace=True),\n", "])\n", "\n", "print(\"āœ“ Preprocessing pipeline created\")\n", "print(f\" - Resize: {config.IMAGE_SIZE}\")\n", "print(f\" - Normalize with ImageNet statistics\")\n" ] }, { "cell_type": "markdown", "id": "848798e0", "metadata": {}, "source": [ "## 3ļøāƒ£ Implement Prediction Function" ] }, { "cell_type": "code", "execution_count": null, "id": "2862b1ca", "metadata": {}, "outputs": [], "source": [ "@torch.inference_mode()\n", "def predict_segmentation(image_path):\n", " \"\"\"\n", " Predict segmentation for a medical image\n", " \n", " Args:\n", " image_path: Path to input image\n", " \n", " Returns:\n", " Dictionary with predictions and confidence scores\n", " \"\"\"\n", " # Load image\n", " image = Image.open(image_path).convert(\"RGB\")\n", " original_size = image.size[::-1] # (H, W)\n", " \n", " # Preprocess\n", " input_tensor = preprocess(image)\n", " input_tensor = input_tensor.unsqueeze(0).to(device)\n", " \n", " # Model inference\n", " with torch.no_grad():\n", " outputs = model(pixel_values=input_tensor, return_dict=True)\n", " logits = outputs.logits\n", " \n", " # Interpolate to original size\n", " predictions = F.interpolate(\n", " logits,\n", " size=original_size,\n", " mode=\"bilinear\",\n", " align_corners=False\n", " )\n", " \n", " # Get predictions and confidence\n", " probs = torch.softmax(predictions, dim=1)\n", " pred_mask = predictions.argmax(dim=1)[0].cpu().numpy()\n", " confidence_map = probs.max(dim=1)[0][0].cpu().numpy()\n", " \n", " return {\n", " 'image': image,\n", " 'pred_mask': pred_mask,\n", " 'confidence_map': confidence_map,\n", " 'original_size': original_size\n", " }\n", "\n", "print(\"āœ“ Prediction function defined\")\n" ] }, { "cell_type": "markdown", "id": "361ce3b5", "metadata": {}, "source": [ "## 4ļøāƒ£ Load and Display Sample Medical Images" ] }, { "cell_type": "code", "execution_count": null, "id": "cee6daca", "metadata": {}, "outputs": [], "source": [ "# Load sample images\n", "sample_dir = \"./samples\"\n", "sample_images = sorted(glob(os.path.join(sample_dir, \"*.png\")))[:6]\n", "\n", "print(f\"āœ“ Found {len(sample_images)} sample images\")\n", "\n", "# Display sample images\n", "fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n", "fig.suptitle(\"Sample Medical Images\", fontsize=16, fontweight='bold')\n", "\n", "for idx, (ax, img_path) in enumerate(zip(axes.flat, sample_images)):\n", " image = Image.open(img_path).convert(\"RGB\")\n", " ax.imshow(image)\n", " ax.set_title(Path(img_path).stem)\n", " ax.axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(f\"Loaded {len(sample_images)} sample images for testing\")\n" ] }, { "cell_type": "markdown", "id": "aeb2d9b7", "metadata": {}, "source": [ "## 5ļøāƒ£ Perform Segmentation and Visualize Results" ] }, { "cell_type": "code", "execution_count": null, "id": "21aed5f3", "metadata": {}, "outputs": [], "source": [ "# Color mapping for organs\n", "class2hexcolor = {\n", " \"Large bowel\": \"#FF0000\", # Red\n", " \"Small bowel\": \"#009A17\", # Green\n", " \"Stomach\": \"#007fff\" # Blue\n", "}\n", "\n", "# Predict on first 3 samples\n", "sample_predictions = [predict_segmentation(img_path) for img_path in sample_images[:3]]\n", "\n", "# Visualize predictions\n", "fig, axes = plt.subplots(3, 3, figsize=(18, 12))\n", "fig.suptitle(\"Medical Image Segmentation Results\", fontsize=16, fontweight='bold')\n", "\n", "for row, pred_result in enumerate(sample_predictions):\n", " image = pred_result['image']\n", " pred_mask = pred_result['pred_mask']\n", " confidence = pred_result['confidence_map']\n", " \n", " # Original image\n", " axes[row, 0].imshow(image)\n", " axes[row, 0].set_title(f\"Original Image\", fontweight='bold')\n", " axes[row, 0].axis('off')\n", " \n", " # Prediction mask (colored)\n", " pred_colored = np.zeros((*pred_mask.shape, 3))\n", " colors = [(1, 0, 0), (0, 0.6, 0.1), (0, 0.5, 1)] # RGB for each class\n", " for class_id, color in enumerate(colors, 1):\n", " mask = (pred_mask == class_id)\n", " pred_colored[mask] = color\n", " \n", " axes[row, 1].imshow(pred_colored)\n", " axes[row, 1].set_title(f\"Segmentation Mask\", fontweight='bold')\n", " axes[row, 1].axis('off')\n", " \n", " # Confidence map\n", " im = axes[row, 2].imshow(confidence, cmap='hot')\n", " axes[row, 2].set_title(f\"Confidence Map\", fontweight='bold')\n", " axes[row, 2].axis('off')\n", " plt.colorbar(im, ax=axes[row, 2], label='Confidence')\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"āœ“ Segmentation predictions generated successfully\")\n" ] }, { "cell_type": "markdown", "id": "3f923634", "metadata": {}, "source": [ "## 6ļøāƒ£ Create Color-Coded Segmentation Overlays" ] }, { "cell_type": "code", "execution_count": null, "id": "414a8549", "metadata": {}, "outputs": [], "source": [ "def create_overlay(image, mask, alpha=0.5):\n", " \"\"\"\n", " Create color-coded segmentation overlay\n", " \"\"\"\n", " image_array = np.array(image).astype(float) / 255.0\n", " \n", " # Create colored mask\n", " overlay = image_array.copy()\n", " \n", " # Colors for each class (RGB)\n", " colors = {\n", " 1: np.array([1.0, 0.0, 0.0]), # Large bowel - Red\n", " 2: np.array([0.0, 0.6, 0.1]), # Small bowel - Green\n", " 3: np.array([0.0, 0.5, 1.0]) # Stomach - Blue\n", " }\n", " \n", " for class_id, color in colors.items():\n", " mask_region = (mask == class_id)\n", " overlay[mask_region] = (\n", " image_array[mask_region] * (1 - alpha) + \n", " np.array(color) * alpha\n", " )\n", " \n", " return (overlay * 255).astype(np.uint8)\n", "\n", "# Create overlays for all predictions\n", "fig, axes = plt.subplots(3, 2, figsize=(15, 15))\n", "fig.suptitle(\"Original vs Segmentation Overlay\", fontsize=16, fontweight='bold')\n", "\n", "class_names = {1: \"Large bowel\", 2: \"Small bowel\", 3: \"Stomach\"}\n", "\n", "for row, pred_result in enumerate(sample_predictions):\n", " image = pred_result['image']\n", " pred_mask = pred_result['pred_mask']\n", " \n", " # Original image\n", " axes[row, 0].imshow(image)\n", " axes[row, 0].set_title(\"Original Image\", fontweight='bold')\n", " axes[row, 0].axis('off')\n", " \n", " # Overlay\n", " overlay = create_overlay(image, pred_mask, alpha=0.4)\n", " axes[row, 1].imshow(overlay)\n", " \n", " # Add detected classes info\n", " detected_classes = [class_names[i] for i in np.unique(pred_mask) if i > 0]\n", " title = f\"Overlay - Detected: {', '.join(detected_classes) if detected_classes else 'None'}\"\n", " axes[row, 1].set_title(title, fontweight='bold')\n", " axes[row, 1].axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"āœ“ Segmentation overlays created\")\n" ] }, { "cell_type": "markdown", "id": "e8da1936", "metadata": {}, "source": [ "## 7ļøāƒ£ Evaluate Model Predictions on Batch Images" ] }, { "cell_type": "code", "execution_count": null, "id": "35db0fbf", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from tqdm import tqdm\n", "\n", "# Process all sample images\n", "print(\"Processing all sample images...\")\n", "batch_results = []\n", "\n", "for img_path in tqdm(sample_images):\n", " try:\n", " pred_result = predict_segmentation(img_path)\n", " mask = pred_result['pred_mask']\n", " confidence = pred_result['confidence_map']\n", " \n", " # Get detected organs\n", " detected_organs = []\n", " organ_confidences = []\n", " \n", " for class_id, organ_name in [(1, \"Large bowel\"), (2, \"Small bowel\"), (3, \"Stomach\")]:\n", " if (mask == class_id).any():\n", " organ_mask = (mask == class_id)\n", " organ_conf = confidence[organ_mask].mean()\n", " detected_organs.append(organ_name)\n", " organ_confidences.append(f\"{organ_conf:.1%}\")\n", " \n", " batch_results.append({\n", " 'Image': Path(img_path).stem,\n", " 'Detected Organs': ', '.join(detected_organs) if detected_organs else 'None',\n", " 'Avg Confidence': f\"{confidence.mean():.1%}\",\n", " 'Max Confidence': f\"{confidence.max():.1%}\",\n", " 'Min Confidence': f\"{confidence.min():.1%}\"\n", " })\n", " except Exception as e:\n", " print(f\" Error processing {img_path}: {e}\")\n", "\n", "# Create results table\n", "results_df = pd.DataFrame(batch_results)\n", "\n", "print(\"\\n\" + \"=\"*80)\n", "print(\"šŸ“Š Batch Prediction Results\")\n", "print(\"=\"*80)\n", "display(results_df)\n", "\n", "# Summary statistics\n", "print(\"\\nšŸ“ˆ Summary Statistics:\")\n", "print(f\" Total images processed: {len(results_df)}\")\n", "print(f\" Average confidence: {results_df['Avg Confidence'].apply(lambda x: float(x.strip('%'))/100).mean():.1%}\")\n", "\n", "# Create legend\n", "print(\"\\nšŸŽØ Color Legend:\")\n", "print(\" šŸ”“ Red (#FF0000) : Large bowel\")\n", "print(\" 🟢 Green (#009A17) : Small bowel\")\n", "print(\" šŸ”µ Blue (#007fff) : Stomach\")\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }