File size: 3,061 Bytes
2c0eb09 163c547 9a7fc09 d1f56a5 9a81fea d1f56a5 9a81fea 0c6c8ff 9a81fea 2c0eb09 9a81fea 381e2e8 9a81fea 78539e2 9a81fea 9a7fc09 381e2e8 9a81fea 381e2e8 2c0eb09 9a7fc09 2c0eb09 9a81fea 2c0eb09 9a81fea 2c0eb09 381e2e8 2c0eb09 9a81fea 2c0eb09 9a81fea 2c0eb09 9a81fea 42d431e 9a81fea d1f56a5 9a81fea 2c0eb09 9a81fea dd2886a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import torch.nn as nn
from torchvision.models import resnet50
from torchvision import transforms
from PIL import Image
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import io
# ----------------------------------
# App
# ----------------------------------
app = FastAPI(title="Crop Disease Classification API")
# ----------------------------------
# Labels
# ----------------------------------
CLASS_LABELS = [
'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
'Rice_Brown_Spot', 'Rice_Healthy', 'Rice_Leaf_Blast', 'Rice_Neck_Blast',
'Wheat_Brown_Rust', 'Wheat_Healthy', 'Wheat_Yellow_Rust',
'Sugarcane_Red_Rot', 'Sugarcane_Healthy', 'Sugarcane_Bacterial_Blight'
]
NUM_CLASSES = len(CLASS_LABELS)
# ----------------------------------
# Model
# ----------------------------------
class ResNetPlantDisease(nn.Module):
def __init__(self, num_classes=17):
super().__init__()
self.backbone = resnet50(weights=None)
self.backbone.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.backbone(x)
model = ResNetPlantDisease(NUM_CLASSES)
model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu"))
model.eval()
# ----------------------------------
# Transform
# ----------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ----------------------------------
# API Endpoint
# ----------------------------------
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
probs = torch.softmax(outputs, dim=1)[0]
best_idx = torch.argmax(probs).item()
confidence = float(probs[best_idx])
if confidence < 0.6:
return JSONResponse({
"message": "Please upload a clearer leaf image"
})
return {
"model": "ResNet50 Classification",
"disease": CLASS_LABELS[best_idx],
"confidence": round(confidence, 4)
}
except Exception as e:
return JSONResponse({"error": str(e)})
import gradio as gr
import uvicorn
import threading
def run_api():
uvicorn.run(app, host="0.0.0.0", port=8000)
threading.Thread(target=run_api).start()
gr.Markdown("""
# 🌾 Crop Disease Classification API
### Endpoint:
POST `/predict`
Use Postman / frontend to send image.
""").launch(server_name="0.0.0.0", server_port=7860)
|