Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from timm import create_model | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import io | |
| import os | |
| import cloudinary | |
| import cloudinary.uploader | |
| import cloudinary.api | |
| import uuid | |
| import uvicorn | |
| import requests | |
| import urllib.parse | |
| import config | |
| import time | |
| import cloudinary.utils | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # === FastAPI App === | |
| app = FastAPI( | |
| title="Crop Disease Detection API", | |
| description="API for detecting crop diseases using ResMamba model" | |
| ) | |
| # === Cloudinary Configuration === | |
| cloudinary_config = { | |
| 'cloud_name': config.CLOUDINARY_CLOUD_NAME, | |
| 'api_key': config.CLOUDINARY_API_KEY, | |
| 'api_secret': config.CLOUDINARY_API_SECRET | |
| } | |
| # Validate Cloudinary credentials | |
| if not all(cloudinary_config.values()): | |
| print("Warning: Some Cloudinary environment variables are missing!") | |
| missing = [k for k, v in cloudinary_config.items() if not v] | |
| print(f"Missing: {missing}") | |
| cloudinary.config(**cloudinary_config) | |
| # Ensure upload directory exists | |
| UPLOAD_FOLDER = 'Uploads' | |
| if not os.path.exists(UPLOAD_FOLDER): | |
| os.makedirs(UPLOAD_FOLDER) | |
| # === Device Setup === | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # === Helper Function to Generate Signed Cloudinary URL === | |
| def get_signed_cloudinary_url(public_id, resource_type='image', expires_in=300): | |
| """Generate signed URL for authenticated Cloudinary images and raw files""" | |
| try: | |
| expires_at = int(time.time()) + expires_in # UNIX timestamp in seconds | |
| url, options = cloudinary.utils.cloudinary_url( | |
| public_id, | |
| resource_type=resource_type, | |
| type="authenticated", # This is required to generate a signed URL | |
| sign_url=True, | |
| expires_at=expires_at | |
| ) | |
| return url | |
| except Exception as e: | |
| print(f"Failed to generate signed URL: {e}") | |
| return None | |
| # === Helper Function to Download Files from Cloudinary === | |
| def download_file(public_id: str, save_path: str, file_type: str = 'image/jpeg') -> bool: | |
| """Download files from Cloudinary (both image and raw files)""" | |
| try: | |
| resource_type = 'raw' if file_type == 'raw' else 'image' | |
| url = get_signed_cloudinary_url(public_id, resource_type=resource_type) | |
| if not url: | |
| print(f"Failed to generate signed URL for {public_id}") | |
| return False | |
| response = requests.get(url, headers={'Content-Type': file_type}, timeout=30) | |
| if response.status_code == 200: | |
| with open(save_path, 'wb') as f: | |
| f.write(response.content) | |
| return True | |
| else: | |
| print(f"Failed to download file. Status: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"Error downloading file: {e}") | |
| return False | |
| # === Helper Function to Load Model === | |
| def load_model_without_module(model, path, device): | |
| try: | |
| state_dict = torch.load(path, map_location=device) | |
| new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(new_state_dict) | |
| return model | |
| except Exception as e: | |
| raise Exception(f"Error loading model from {path}: {str(e)}") | |
| # === ResNet50Classifier === | |
| class ResNet50Classifier(nn.Module): | |
| def __init__(self, num_classes, pretrained=True, dropout_rate=0.3, hidden_dim=512, use_bn=True, activation="relu"): | |
| super().__init__() | |
| self.backbone = create_model('resnet50', pretrained=pretrained, num_classes=0) | |
| self.feature_dim = self.backbone.num_features | |
| act_fn = nn.ReLU(inplace=True) if activation == "relu" else nn.GELU() if activation == "gelu" else nn.SiLU() | |
| layers = [] | |
| layers.append(nn.Dropout(dropout_rate)) | |
| if hidden_dim: | |
| layers.append(nn.Linear(self.feature_dim, hidden_dim)) | |
| if use_bn: | |
| layers.append(nn.BatchNorm1d(hidden_dim)) | |
| layers.append(act_fn) | |
| layers.append(nn.Dropout(dropout_rate)) | |
| layers.append(nn.Linear(hidden_dim, num_classes)) | |
| else: | |
| layers.append(nn.Linear(self.feature_dim, num_classes)) | |
| self.classifier = nn.Sequential(*layers) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| return self.classifier(features) | |
| # === VMambaBlock === | |
| class VMambaBlock(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.ssm = nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim) | |
| self.ff = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)) | |
| def forward(self, x): | |
| x = self.norm(x) | |
| x = x + self.ssm(x.transpose(1, 2)).transpose(1, 2) | |
| return x + self.ff(x) | |
| # === VMambaClassifier === | |
| class VMambaClassifier(nn.Module): | |
| def __init__(self, num_classes, patch_size=4, embed_dim=512, img_size=224): | |
| super().__init__() | |
| self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| self.blocks = nn.Sequential(VMambaBlock(embed_dim), VMambaBlock(embed_dim)) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.avgpool = nn.AdaptiveAvgPool1d(1) | |
| self.classifier = nn.Linear(embed_dim, num_classes) | |
| def forward(self, x): | |
| x = self.patch_embed(x) | |
| B, C, H, W = x.shape | |
| x = x.flatten(2).transpose(1, 2) | |
| x = self.blocks(x) | |
| x = self.norm(x) | |
| x = x.transpose(1, 2) | |
| x = self.avgpool(x).squeeze(-1) | |
| return self.classifier(x) | |
| # === ResMamba === | |
| class ResMamba(nn.Module): | |
| def __init__(self, eff_model, vmamba_model, num_classes): | |
| super(ResMamba, self).__init__() | |
| self.resnet_feature_extractor = eff_model.backbone | |
| self.res_feat_dim = eff_model.feature_dim | |
| self.vmamba_patch = vmamba_model.patch_embed | |
| self.vmamba_blocks = vmamba_model.blocks | |
| self.vmamba_norm = vmamba_model.norm | |
| self.vmamba_pool = nn.AdaptiveAvgPool1d(1) | |
| self.vmamba_feat_dim = vmamba_model.classifier.in_features | |
| self.fusion_classifier = nn.Sequential( | |
| nn.Linear(self.res_feat_dim + self.vmamba_feat_dim, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, num_classes) | |
| ) | |
| def forward(self, x): | |
| res_feat = self.resnet_feature_extractor(x) | |
| x_vm = self.vmamba_patch(x) | |
| B, C, H, W = x_vm.shape | |
| x_vm = x_vm.flatten(2).transpose(1, 2) | |
| x_vm = self.vmamba_blocks(x_vm) | |
| x_vm = self.vmamba_norm(x_vm) | |
| x_vm = x_vm.transpose(1, 2) | |
| x_vm = self.vmamba_pool(x_vm).squeeze(-1) | |
| fused = torch.cat((res_feat, x_vm), dim=1) | |
| return self.fusion_classifier(fused) | |
| # === Setup Models === | |
| num_classes = 38 | |
| try: | |
| eff_model = ResNet50Classifier(num_classes=num_classes).to(device) | |
| load_model_without_module(eff_model, "./models/resnet50_classifier.pth", device) | |
| eff_model.eval() | |
| vmamba_model = VMambaClassifier(num_classes=num_classes).to(device) | |
| load_model_without_module(vmamba_model, "./models/vmamba_classifier.pth", device) | |
| vmamba_model.eval() | |
| resmamba_model = ResMamba(eff_model, vmamba_model, num_classes=num_classes).to(device) | |
| load_model_without_module(resmamba_model, "./models/ResMamba.pth", device) | |
| resmamba_model.eval() | |
| except Exception as e: | |
| print(f"Error initializing models: {str(e)}") | |
| raise | |
| # === Transform for Image === | |
| val_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # === Define Class Names === | |
| class_names = [ | |
| 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', | |
| 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', | |
| 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', | |
| 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', | |
| 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', | |
| 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', | |
| 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', | |
| 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', | |
| 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', | |
| 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', | |
| 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', | |
| 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', | |
| 'Tomato___healthy' | |
| ] | |
| # === Prediction Function === | |
| def predict_single_image(image: Image.Image, model, class_names, transform, device): | |
| def parse_class_name(class_name): | |
| parts = class_name.split('___') | |
| crop = parts[0].replace('_', ' ') | |
| if len(parts) > 1 and parts[1] == "healthy": | |
| disease = "No Disease Detected" | |
| else: | |
| disease = parts[1].replace('_', ' ') if len(parts) > 1 else "Unknown" | |
| return crop, disease | |
| model.eval() | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probabilities = F.softmax(output, dim=1) | |
| confidence, predicted = torch.max(probabilities, 1) | |
| class_idx = predicted.item() | |
| predicted_class = class_names[class_idx] | |
| accuracy = round(confidence.item() * 100, 2) | |
| crop, disease = parse_class_name(predicted_class) | |
| damage_status = "Not Damaged" if disease == "No Disease Detected" else "Damaged" | |
| return { | |
| "Crop_name": crop, | |
| "Disease": disease, | |
| "Accuracy": accuracy, | |
| "Damage_Report": damage_status | |
| } | |
| # === Pydantic Model for Cloudinary Request === | |
| class CloudinaryRequest(BaseModel): | |
| publicId: str | |
| fileType: str | |
| originalName: str | |
| # === Download and Predict Endpoint === | |
| async def predict_from_cloudinary(request: CloudinaryRequest): | |
| try: | |
| # Extract file extension from originalName if fileType is generic | |
| if request.fileType.lower() in ["image", "raw"] and request.originalName: | |
| file_extension = request.originalName.split('.')[-1].lower() if '.' in request.originalName else 'jpg' | |
| else: | |
| file_extension = request.fileType.lower() | |
| # Determine resource type and content type for Cloudinary | |
| if request.fileType.lower() == 'raw': | |
| resource_type = 'raw' | |
| content_type = 'raw' | |
| else: | |
| resource_type = 'image' | |
| content_type = f'image/{file_extension}' | |
| # Validate file extension - support both image and raw file types | |
| image_extensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff', 'tif', 'webp', 'svg'] | |
| raw_extensions = ['pdf', 'doc', 'docx', 'txt', 'csv', 'xlsx', 'zip', 'rar', 'mp4', 'avi', 'mov', 'mp3', 'wav'] | |
| valid_extensions = image_extensions + raw_extensions + ['raw', 'image'] | |
| if file_extension not in valid_extensions: | |
| raise HTTPException(status_code=400, detail=f"Invalid file type. Supported: {valid_extensions}") | |
| # For disease detection, we need to check if it's actually an image file that can be processed | |
| is_processable_image = file_extension in image_extensions or ( | |
| resource_type == 'image' and file_extension in ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'tiff', 'tif', 'webp'] | |
| ) | |
| if not is_processable_image: | |
| raise HTTPException(status_code=400, detail=f"File type '{file_extension}' is not a processable image format for disease detection") | |
| # Download file using the new download function | |
| local_path = os.path.join(UPLOAD_FOLDER, f"{uuid.uuid4()}.{file_extension}") | |
| try: | |
| print(f"Attempting to download {resource_type} file: {request.publicId}") | |
| success = download_file(request.publicId, local_path, content_type) | |
| if not success: | |
| raise HTTPException(status_code=500, detail="Failed to download file from Cloudinary") | |
| print(f"Successfully downloaded file to: {local_path}") | |
| # Verify file was created and has content | |
| if not os.path.exists(local_path) or os.path.getsize(local_path) == 0: | |
| raise Exception("Downloaded file is empty or doesn't exist") | |
| # Open and process image for disease detection | |
| image = Image.open(local_path).convert('RGB') | |
| print(f"Successfully opened image with size: {image.size}") | |
| except Exception as e: | |
| print(f"Error in file processing: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error downloading or opening file: {str(e)}") | |
| finally: | |
| # Clean up local file | |
| if os.path.exists(local_path): | |
| os.remove(local_path) | |
| # Perform prediction | |
| result = predict_single_image(image, resmamba_model, class_names, val_transform, device) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| # === Upload and Predict Endpoint === | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| # Validate file type | |
| if not file.content_type or not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Validate filename | |
| if not file.filename: | |
| raise HTTPException(status_code=400, detail="File must have a filename") | |
| # Generate a unique filename | |
| file_extension = file.filename.split('.')[-1].lower() | |
| unique_filename = f"{uuid.uuid4()}.{file_extension}" | |
| local_path = os.path.join(UPLOAD_FOLDER, unique_filename) | |
| # Save file locally | |
| contents = await file.read() | |
| with open(local_path, 'wb') as f: | |
| f.write(contents) | |
| # Upload to Cloudinary | |
| try: | |
| upload_result = cloudinary.uploader.upload( | |
| local_path, | |
| folder="crop_disease_images", | |
| resource_type="image" | |
| ) | |
| cloudinary_url = upload_result['secure_url'] | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Cloudinary upload failed: {str(e)}") | |
| finally: | |
| # Clean up local file | |
| if os.path.exists(local_path): | |
| os.remove(local_path) | |
| # Process image for prediction | |
| image = Image.open(io.BytesIO(contents)).convert('RGB') | |
| result = predict_single_image(image, resmamba_model, class_names, val_transform, device) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| # === Root Endpoint === | |
| async def root(): | |
| return {"message": "Welcome to the Crop Disease Detection API. Use POST /predict to upload an image or POST /api/damage_detection to process an image from Cloudinary."} | |
| # === Run the FastAPI app with Uvicorn === | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |