Spaces:
Sleeping
Sleeping
| """ | |
| Skin Cancer Lesion Segmentation - FastAPI Backend | |
| ================================================ | |
| This API provides endpoints for skin lesion segmentation using an | |
| Attention U-Net model with MobileNetV2 encoder. | |
| Features: | |
| - Real-time segmentation of dermoscopic images | |
| - Confidence score calculation | |
| - Lesion area analysis | |
| - Support for multiple image formats (JPEG, PNG, WebP) | |
| """ | |
| import os | |
| import io | |
| import base64 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import mobilenet_v2 | |
| import cv2 | |
| from PIL import Image | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| # ============================================= | |
| # CONFIGURATION | |
| # ============================================= | |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), "..", "models", "model.pth") | |
| IMG_SIZE = 256 | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # ============================================= | |
| # MODEL ARCHITECTURE (must match training) | |
| # ============================================= | |
| class AttentionGate(nn.Module): | |
| """Attention U-Net gate for skip connections.""" | |
| def __init__(self, F_g, F_l, F_int): | |
| super(AttentionGate, self).__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(F_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| if g1.shape[2:] != x1.shape[2:]: | |
| g1 = nn.functional.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=True) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi | |
| class BoundaryRefinementHead(nn.Module): | |
| """Refines segmentation boundaries.""" | |
| def __init__(self, in_channels, filters=64): | |
| super(BoundaryRefinementHead, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, filters, kernel_size=3, padding=1) | |
| self.relu1 = nn.ReLU(inplace=True) | |
| self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, padding=1) | |
| self.relu2 = nn.ReLU(inplace=True) | |
| self.conv3 = nn.Conv2d(filters, 1, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x = self.relu1(self.conv1(x)) | |
| x = self.relu2(self.conv2(x)) | |
| x = self.sigmoid(self.conv3(x)) | |
| return x | |
| class AttentionUNetMobileNet(nn.Module): | |
| """ | |
| Attention U-Net with MobileNetV2 encoder for skin lesion segmentation. | |
| """ | |
| def __init__(self, deep_supervision=False): | |
| super(AttentionUNetMobileNet, self).__init__() | |
| self.deep_supervision = deep_supervision | |
| # Encoder (MobileNetV2 - pretrained) | |
| mobilenet = mobilenet_v2(weights='DEFAULT') | |
| self.encoder_features = mobilenet.features | |
| # Attention Gates | |
| self.att1 = AttentionGate(F_g=256, F_l=96, F_int=64) | |
| self.att2 = AttentionGate(F_g=128, F_l=32, F_int=32) | |
| self.att3 = AttentionGate(F_g=64, F_l=24, F_int=16) | |
| self.att4 = AttentionGate(F_g=32, F_l=16, F_int=8) | |
| # Decoder | |
| self.up1 = nn.ConvTranspose2d(1280, 256, kernel_size=2, stride=2) | |
| self.dec1_conv1 = nn.Conv2d(256 + 96, 256, kernel_size=3, padding=1) | |
| self.dec1_bn1 = nn.BatchNorm2d(256) | |
| self.dec1_relu1 = nn.ReLU(inplace=True) | |
| self.dec1_conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) | |
| self.dec1_bn2 = nn.BatchNorm2d(256) | |
| self.dec1_relu2 = nn.ReLU(inplace=True) | |
| self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
| self.dec2_conv1 = nn.Conv2d(128 + 32, 128, kernel_size=3, padding=1) | |
| self.dec2_bn1 = nn.BatchNorm2d(128) | |
| self.dec2_relu1 = nn.ReLU(inplace=True) | |
| self.dec2_conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | |
| self.dec2_bn2 = nn.BatchNorm2d(128) | |
| self.dec2_relu2 = nn.ReLU(inplace=True) | |
| self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
| self.dec3_conv1 = nn.Conv2d(64 + 24, 64, kernel_size=3, padding=1) | |
| self.dec3_bn1 = nn.BatchNorm2d(64) | |
| self.dec3_relu1 = nn.ReLU(inplace=True) | |
| self.dec3_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) | |
| self.dec3_bn2 = nn.BatchNorm2d(64) | |
| self.dec3_relu2 = nn.ReLU(inplace=True) | |
| self.up4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) | |
| self.dec4_conv1 = nn.Conv2d(32 + 16, 32, kernel_size=3, padding=1) | |
| self.dec4_bn1 = nn.BatchNorm2d(32) | |
| self.dec4_relu1 = nn.ReLU(inplace=True) | |
| self.dec4_conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) | |
| self.dec4_bn2 = nn.BatchNorm2d(32) | |
| self.dec4_relu2 = nn.ReLU(inplace=True) | |
| self.up5 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2) | |
| self.dec5_conv1 = nn.Conv2d(16, 16, kernel_size=3, padding=1) | |
| self.dec5_bn1 = nn.BatchNorm2d(16) | |
| self.dec5_relu1 = nn.ReLU(inplace=True) | |
| self.dec5_conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1) | |
| self.dec5_bn2 = nn.BatchNorm2d(16) | |
| self.dec5_relu2 = nn.ReLU(inplace=True) | |
| # Main output | |
| self.seg_output = nn.Sequential( | |
| nn.Conv2d(16, 1, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| # Boundary head | |
| self.boundary_head = BoundaryRefinementHead(16, filters=64) | |
| # Deep supervision outputs (only used during training) | |
| if self.deep_supervision: | |
| self.ds_out1 = nn.Sequential(nn.Conv2d(256, 1, kernel_size=1), nn.Sigmoid()) | |
| self.ds_out2 = nn.Sequential(nn.Conv2d(128, 1, kernel_size=1), nn.Sigmoid()) | |
| self.ds_out3 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=1), nn.Sigmoid()) | |
| self.ds_out4 = nn.Sequential(nn.Conv2d(32, 1, kernel_size=1), nn.Sigmoid()) | |
| def forward(self, x): | |
| # Encoder | |
| skips = [] | |
| for idx, layer in enumerate(self.encoder_features): | |
| x = layer(x) | |
| if idx in [1, 3, 6, 13, 18]: | |
| skips.append(x) | |
| # Decoder Stage 1 | |
| d1 = self.up1(x) | |
| skip1_att = self.att1(d1, skips[3]) | |
| d1 = torch.cat([d1, skip1_att], dim=1) | |
| d1 = self.dec1_relu1(self.dec1_bn1(self.dec1_conv1(d1))) | |
| d1 = self.dec1_relu2(self.dec1_bn2(self.dec1_conv2(d1))) | |
| # Decoder Stage 2 | |
| d2 = self.up2(d1) | |
| skip2_att = self.att2(d2, skips[2]) | |
| d2 = torch.cat([d2, skip2_att], dim=1) | |
| d2 = self.dec2_relu1(self.dec2_bn1(self.dec2_conv1(d2))) | |
| d2 = self.dec2_relu2(self.dec2_bn2(self.dec2_conv2(d2))) | |
| # Decoder Stage 3 | |
| d3 = self.up3(d2) | |
| skip3_att = self.att3(d3, skips[1]) | |
| d3 = torch.cat([d3, skip3_att], dim=1) | |
| d3 = self.dec3_relu1(self.dec3_bn1(self.dec3_conv1(d3))) | |
| d3 = self.dec3_relu2(self.dec3_bn2(self.dec3_conv2(d3))) | |
| # Decoder Stage 4 | |
| d4 = self.up4(d3) | |
| skip4_att = self.att4(d4, skips[0]) | |
| d4 = torch.cat([d4, skip4_att], dim=1) | |
| d4 = self.dec4_relu1(self.dec4_bn1(self.dec4_conv1(d4))) | |
| d4 = self.dec4_relu2(self.dec4_bn2(self.dec4_conv2(d4))) | |
| # Final upsampling | |
| d5 = self.up5(d4) | |
| d5 = self.dec5_relu1(self.dec5_bn1(self.dec5_conv1(d5))) | |
| d5 = self.dec5_relu2(self.dec5_bn2(self.dec5_conv2(d5))) | |
| # Main output | |
| seg_output = self.seg_output(d5) | |
| boundary_output = self.boundary_head(d5) | |
| final_output = (seg_output + boundary_output) / 2 | |
| return final_output | |
| # ============================================= | |
| # PREPROCESSING | |
| # ============================================= | |
| def get_inference_transforms(img_size=IMG_SIZE): | |
| """Transform pipeline for inference.""" | |
| return A.Compose([ | |
| A.Resize(img_size, img_size), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ]) | |
| def preprocess_image(image: np.ndarray) -> torch.Tensor: | |
| """Preprocess image for model inference.""" | |
| # Convert to RGB if needed | |
| if len(image.shape) == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| elif image.shape[2] == 4: | |
| image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| elif image.shape[2] == 3: | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Apply transforms | |
| transform = get_inference_transforms() | |
| transformed = transform(image=image) | |
| img_tensor = transformed['image'].unsqueeze(0) | |
| return img_tensor | |
| def postprocess_mask(mask: np.ndarray, original_size: tuple) -> np.ndarray: | |
| """Postprocess model output to original image size.""" | |
| # Resize to original size | |
| mask_resized = cv2.resize(mask, (original_size[1], original_size[0]), | |
| interpolation=cv2.INTER_LINEAR) | |
| return mask_resized | |
| def mask_to_base64(mask: np.ndarray) -> str: | |
| """Convert numpy mask to base64 string.""" | |
| # Normalize to 0-255 | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| # Encode as PNG | |
| _, buffer = cv2.imencode('.png', mask_uint8) | |
| return base64.b64encode(buffer).decode('utf-8') | |
| def create_overlay(image: np.ndarray, mask: np.ndarray, alpha: float = 0.5) -> np.ndarray: | |
| """Create overlay of mask on original image.""" | |
| # Create colored mask (magenta for lesion) | |
| colored_mask = np.zeros_like(image) | |
| colored_mask[:, :, 0] = int(255 * 0.94) # R | |
| colored_mask[:, :, 1] = int(255 * 0.38) # G | |
| colored_mask[:, :, 2] = int(255 * 0.57) # B | |
| # Apply mask | |
| mask_3ch = np.stack([mask] * 3, axis=-1) | |
| overlay = image * (1 - mask_3ch * alpha) + colored_mask * (mask_3ch * alpha) | |
| return overlay.astype(np.uint8) | |
| def image_to_base64(image: np.ndarray) -> str: | |
| """Convert numpy image to base64 string.""" | |
| _, buffer = cv2.imencode('.png', cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
| return base64.b64encode(buffer).decode('utf-8') | |
| # ============================================= | |
| # FASTAPI APP | |
| # ============================================= | |
| app = FastAPI( | |
| title="Skin Lesion Segmentation API", | |
| description="AI-powered skin cancer lesion segmentation using Attention U-Net", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model variable | |
| model = None | |
| async def load_model(): | |
| """Load the model on startup.""" | |
| global model | |
| print(f"Loading model from {MODEL_PATH}...") | |
| print(f"Using device: {DEVICE}") | |
| try: | |
| model = AttentionUNetMobileNet(deep_supervision=False).to(DEVICE) | |
| if os.path.exists(MODEL_PATH): | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(state_dict, strict=False) | |
| print("✓ Model loaded successfully!") | |
| else: | |
| print(f"⚠ Model file not found at {MODEL_PATH}") | |
| print(" The API will work but predictions will use untrained weights.") | |
| model.eval() | |
| except Exception as e: | |
| print(f"✗ Error loading model: {e}") | |
| model = AttentionUNetMobileNet(deep_supervision=False).to(DEVICE) | |
| model.eval() | |
| class SegmentationResponse(BaseModel): | |
| """Response model for segmentation endpoint.""" | |
| success: bool | |
| mask_base64: Optional[str] = None | |
| overlay_base64: Optional[str] = None | |
| confidence: Optional[float] = None | |
| lesion_area_percent: Optional[float] = None | |
| message: Optional[str] = None | |
| async def read_root(): | |
| """Redirect to frontend.""" | |
| return """ | |
| <html> | |
| <head> | |
| <meta http-equiv="refresh" content="0; url=/static/index.html" /> | |
| </head> | |
| <body> | |
| <p>Redirecting to the application...</p> | |
| </body> | |
| </html> | |
| """ | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "device": str(DEVICE) | |
| } | |
| async def segment_image(file: UploadFile = File(...)): | |
| """ | |
| Segment a skin lesion from an uploaded dermoscopic image. | |
| Args: | |
| file: Image file (JPEG, PNG, WebP) | |
| Returns: | |
| SegmentationResponse with mask and overlay in base64 | |
| """ | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Validate file type | |
| allowed_types = ["image/jpeg", "image/png", "image/webp", "image/jpg"] | |
| if file.content_type not in allowed_types: | |
| return SegmentationResponse( | |
| success=False, | |
| message=f"Invalid file type. Allowed: {', '.join(allowed_types)}" | |
| ) | |
| try: | |
| # Read image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| original_size = image_rgb.shape[:2] | |
| # Preprocess | |
| img_tensor = preprocess_image(image_rgb).to(DEVICE) | |
| # Inference | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| mask = output.cpu().numpy().squeeze() | |
| # Calculate metrics | |
| confidence = float(np.mean(mask[mask > 0.5])) if np.any(mask > 0.5) else 0.0 | |
| lesion_area = float(np.mean(mask > 0.5) * 100) | |
| # Threshold mask | |
| mask_binary = (mask > 0.5).astype(np.float32) | |
| # Resize to original size | |
| mask_resized = postprocess_mask(mask_binary, original_size) | |
| # Create overlay | |
| overlay = create_overlay(image_rgb, mask_resized, alpha=0.4) | |
| return SegmentationResponse( | |
| success=True, | |
| mask_base64=mask_to_base64(mask_resized), | |
| overlay_base64=image_to_base64(overlay), | |
| confidence=round(confidence * 100, 2), | |
| lesion_area_percent=round(lesion_area, 2) | |
| ) | |
| except Exception as e: | |
| return SegmentationResponse( | |
| success=False, | |
| message=f"Error processing image: {str(e)}" | |
| ) | |
| # Mount static files (frontend) | |
| client_path = os.path.join(os.path.dirname(__file__), "..", "client") | |
| if os.path.exists(client_path): | |
| app.mount("/static", StaticFiles(directory=client_path), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |