import torch import torch.nn as nn from torchvision.models import resnet50 from torchvision import transforms from PIL import Image from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse import io # ---------------------------------- # App # ---------------------------------- app = FastAPI(title="Crop Disease Classification API") # ---------------------------------- # Labels # ---------------------------------- CLASS_LABELS = [ 'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight', 'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight', 'Rice_Brown_Spot', 'Rice_Healthy', 'Rice_Leaf_Blast', 'Rice_Neck_Blast', 'Wheat_Brown_Rust', 'Wheat_Healthy', 'Wheat_Yellow_Rust', 'Sugarcane_Red_Rot', 'Sugarcane_Healthy', 'Sugarcane_Bacterial_Blight' ] NUM_CLASSES = len(CLASS_LABELS) # ---------------------------------- # Model # ---------------------------------- class ResNetPlantDisease(nn.Module): def __init__(self, num_classes=17): super().__init__() self.backbone = resnet50(weights=None) self.backbone.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) ) def forward(self, x): return self.backbone(x) model = ResNetPlantDisease(NUM_CLASSES) model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu")) model.eval() # ---------------------------------- # Transform # ---------------------------------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ---------------------------------- # API Endpoint # ---------------------------------- @app.post("/predict") async def predict(file: UploadFile = File(...)): try: image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") img = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(img) probs = torch.softmax(outputs, dim=1)[0] best_idx = torch.argmax(probs).item() confidence = float(probs[best_idx]) if confidence < 0.6: return JSONResponse({ "message": "Please upload a clearer leaf image" }) return { "model": "ResNet50 Classification", "disease": CLASS_LABELS[best_idx], "confidence": round(confidence, 4) } except Exception as e: return JSONResponse({"error": str(e)}) import gradio as gr import uvicorn import threading def run_api(): uvicorn.run(app, host="0.0.0.0", port=8000) threading.Thread(target=run_api).start() gr.Markdown(""" # 🌾 Crop Disease Classification API ### Endpoint: POST `/predict` Use Postman / frontend to send image. """).launch(server_name="0.0.0.0", server_port=7860)