|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision.models import resnet50 |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.responses import JSONResponse |
|
|
import io |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(title="Crop Disease Classification API") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLASS_LABELS = [ |
|
|
'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight', |
|
|
'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight', |
|
|
'Rice_Brown_Spot', 'Rice_Healthy', 'Rice_Leaf_Blast', 'Rice_Neck_Blast', |
|
|
'Wheat_Brown_Rust', 'Wheat_Healthy', 'Wheat_Yellow_Rust', |
|
|
'Sugarcane_Red_Rot', 'Sugarcane_Healthy', 'Sugarcane_Bacterial_Blight' |
|
|
] |
|
|
|
|
|
NUM_CLASSES = len(CLASS_LABELS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNetPlantDisease(nn.Module): |
|
|
def __init__(self, num_classes=17): |
|
|
super().__init__() |
|
|
self.backbone = resnet50(weights=None) |
|
|
self.backbone.fc = nn.Sequential( |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(2048, 512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(512, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.backbone(x) |
|
|
|
|
|
model = ResNetPlantDisease(NUM_CLASSES) |
|
|
model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu")) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(file: UploadFile = File(...)): |
|
|
try: |
|
|
image_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
img = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(img) |
|
|
probs = torch.softmax(outputs, dim=1)[0] |
|
|
|
|
|
best_idx = torch.argmax(probs).item() |
|
|
confidence = float(probs[best_idx]) |
|
|
|
|
|
if confidence < 0.6: |
|
|
return JSONResponse({ |
|
|
"message": "Please upload a clearer leaf image" |
|
|
}) |
|
|
|
|
|
return { |
|
|
"model": "ResNet50 Classification", |
|
|
"disease": CLASS_LABELS[best_idx], |
|
|
"confidence": round(confidence, 4) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse({"error": str(e)}) |
|
|
|
|
|
import gradio as gr |
|
|
import uvicorn |
|
|
import threading |
|
|
|
|
|
def run_api(): |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|
|
threading.Thread(target=run_api).start() |
|
|
|
|
|
gr.Markdown(""" |
|
|
# ๐พ Crop Disease Classification API |
|
|
|
|
|
### Endpoint: |
|
|
POST `/predict` |
|
|
|
|
|
Use Postman / frontend to send image. |
|
|
""").launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
|