Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI application for plant image classification using ViT-ConvNext hybrid model. | |
| Provides two endpoints: | |
| 1. /predict/upload - Upload image file directly | |
| 2. /predict/url - Provide image URL | |
| Both return top 5 predictions with confidence scores. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, HttpUrl | |
| from PIL import Image | |
| from torchvision import transforms | |
| import io | |
| import requests | |
| from typing import List, Dict | |
| import ast | |
| # ============== Model Architecture ============== | |
| class CBAMBlock(nn.Module): | |
| """CBAM (Channel + Spatial Attention) Block""" | |
| def __init__(self, channels, reduction=16, spatial_kernel=7): | |
| super(CBAMBlock, self).__init__() | |
| # Channel attention | |
| self.channel_att = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(channels, channels // reduction, 1, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(channels // reduction, channels, 1, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| # Spatial attention | |
| self.spatial_att = nn.Sequential( | |
| nn.Conv2d(2, 1, kernel_size=spatial_kernel, padding=spatial_kernel // 2, bias=False), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| # Channel attention | |
| ca = self.channel_att(x) | |
| x = x * ca | |
| # Spatial attention | |
| sa = torch.cat([ | |
| torch.mean(x, dim=1, keepdim=True), | |
| torch.max(x, dim=1, keepdim=True)[0] | |
| ], dim=1) | |
| sa = self.spatial_att(sa) | |
| x = x * sa | |
| return x | |
| class ViTCNNHybrid(nn.Module): | |
| """Hybrid model combining Swin Transformer and ConvNeXt with gated fusion""" | |
| def __init__(self, num_classes, use_cbam=True): | |
| super(ViTCNNHybrid, self).__init__() | |
| # Swin Transformer branch | |
| self.vit = timm.create_model( | |
| 'swin_tiny_patch4_window7_224', pretrained=False, num_classes=0, drop_rate=0.3 | |
| ) | |
| self.vit_out_features = 768 | |
| # ConvNeXt-Tiny branch | |
| self.cnn = timm.create_model( | |
| 'convnext_tiny', pretrained=False, num_classes=0, drop_rate=0.3, global_pool='' | |
| ) | |
| self.cnn_out_features = 768 | |
| self.cnn_pool = nn.AdaptiveAvgPool2d((7, 7)) | |
| # Gates for gated fusion | |
| self.vit_gate = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(self.vit_out_features, self.vit_out_features // 16, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(self.vit_out_features // 16, self.vit_out_features, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.cnn_gate = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(self.cnn_out_features, self.cnn_out_features // 16, 1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(self.cnn_out_features // 16, self.cnn_out_features, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.match_dim = nn.Conv2d(self.vit_out_features, self.cnn_out_features, 1) | |
| # Learnable α for dynamic fusion | |
| self.alpha_param = nn.Parameter(torch.tensor(0.5)) | |
| # Fusion layers | |
| fusion_layers = [ | |
| nn.Conv2d(self.cnn_out_features, 256, kernel_size=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3) | |
| ] | |
| if use_cbam: | |
| fusion_layers.append(CBAMBlock(256)) | |
| fusion_layers.append(nn.AdaptiveAvgPool2d((1, 1))) | |
| self.fusion = nn.Sequential(*fusion_layers) | |
| # Classification head | |
| self.fc = nn.Sequential( | |
| nn.Linear(256, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.4), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| # ViT branch | |
| vit_out = self.vit(x) | |
| vit_out = vit_out.view(-1, self.vit_out_features, 1, 1).expand(-1, -1, 7, 7) | |
| vit_out = vit_out * self.vit_gate(vit_out) | |
| # CNN branch | |
| cnn_out = self.cnn(x) | |
| cnn_out = self.cnn_pool(cnn_out) | |
| cnn_out = cnn_out * self.cnn_gate(cnn_out) | |
| # Dynamic Fusion | |
| alpha = torch.sigmoid(self.alpha_param) | |
| combined = alpha * vit_out + (1 - alpha) * cnn_out | |
| combined = self.fusion(combined) | |
| combined = combined.view(combined.size(0), -1) | |
| out = self.fc(combined) | |
| return out | |
| # ============== FastAPI Setup ============== | |
| app = FastAPI( | |
| title="Plant Classification API", | |
| description="API for classifying plant images using ViT-ConvNext hybrid model", | |
| version="1.0.0" | |
| ) | |
| # Global variables | |
| model = None | |
| class_names = None | |
| device = None | |
| transform = None | |
| class ImageURL(BaseModel): | |
| """Request model for URL-based prediction""" | |
| url: HttpUrl | |
| class Prediction(BaseModel): | |
| """Single prediction result""" | |
| class_name: str | |
| confidence: float | |
| class PredictionResponse(BaseModel): | |
| """Response model containing top 5 predictions""" | |
| predictions: List[Prediction] | |
| def load_class_names(file_path: str = "class.txt") -> List[str]: | |
| """Load class names from file""" | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # Parse the list from the file | |
| classes = ast.literal_eval(content.split('Classes: ')[1]) | |
| return classes | |
| def get_transform(): | |
| """Get image preprocessing transform matching training pipeline""" | |
| return transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| async def startup_event(): | |
| """Load model and class names on startup""" | |
| global model, class_names, device, transform | |
| print("Loading class names...") | |
| class_names = load_class_names() | |
| num_classes = len(class_names) | |
| print(f"Loaded {num_classes} classes") | |
| print("Loading model...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Initialize model | |
| model = ViTCNNHybrid(num_classes=num_classes, use_cbam=True) | |
| # Load trained weights | |
| checkpoint = torch.load("pbl6_model.pth", map_location=device) | |
| # Handle DataParallel checkpoint (remove 'module.' prefix) | |
| if list(checkpoint.keys())[0].startswith('module.'): | |
| from collections import OrderedDict | |
| new_checkpoint = OrderedDict() | |
| for k, v in checkpoint.items(): | |
| name = k[7:] # remove 'module.' prefix | |
| new_checkpoint[name] = v | |
| checkpoint = new_checkpoint | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| model.eval() | |
| # Initialize transform | |
| transform = get_transform() | |
| print("Model loaded successfully!") | |
| def predict_image(image: Image.Image) -> List[Dict[str, float]]: | |
| """ | |
| Perform prediction on a PIL Image | |
| Args: | |
| image: PIL Image object | |
| Returns: | |
| List of top 5 predictions with class names and confidence scores | |
| """ | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Preprocess image | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| # Get top 5 predictions | |
| top5_prob, top5_idx = torch.topk(probabilities, 5, dim=1) | |
| top5_prob = top5_prob.cpu().numpy()[0] | |
| top5_idx = top5_idx.cpu().numpy()[0] | |
| # Format results | |
| predictions = [] | |
| for idx, prob in zip(top5_idx, top5_prob): | |
| predictions.append({ | |
| "class_name": class_names[idx], | |
| "confidence": float(prob) | |
| }) | |
| return predictions | |
| async def predict_upload(file: UploadFile = File(...)): | |
| """ | |
| Classify a plant image uploaded as a file | |
| Args: | |
| file: Image file (JPEG, PNG, etc.) | |
| Returns: | |
| Top 5 predictions with class names and confidence scores | |
| """ | |
| try: | |
| # Read image file | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Get predictions | |
| predictions = predict_image(image) | |
| return JSONResponse(content={"predictions": predictions}) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") | |
| async def predict_url(image_url: ImageURL): | |
| """ | |
| Classify a plant image from a URL | |
| Args: | |
| image_url: JSON body containing the image URL | |
| Returns: | |
| Top 5 predictions with class names and confidence scores | |
| """ | |
| try: | |
| # Download image from URL with longer timeout for large images | |
| response = requests.get(str(image_url.url), timeout=30) | |
| response.raise_for_status() | |
| # Open image | |
| image = Image.open(io.BytesIO(response.content)) | |
| # Get predictions | |
| predictions = predict_image(image) | |
| return JSONResponse(content={"predictions": predictions}) | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=400, detail=f"Error downloading image: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error processing image: {str(e)}") | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "message": "Plant Classification API", | |
| "status": "running", | |
| "model_loaded": model is not None, | |
| "num_classes": len(class_names) if class_names else 0 | |
| } | |
| async def health(): | |
| """Detailed health check""" | |
| return { | |
| "status": "healthy", | |
| "model": "loaded" if model is not None else "not loaded", | |
| "device": str(device) if device else "unknown", | |
| "classes": len(class_names) if class_names else 0 | |
| } | |