cv_first / app.py
SoraRyuu's picture
Update app.py
dd2886a verified
import torch
import torch.nn as nn
from torchvision.models import resnet50
from torchvision import transforms
from PIL import Image
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import io
# ----------------------------------
# App
# ----------------------------------
app = FastAPI(title="Crop Disease Classification API")
# ----------------------------------
# Labels
# ----------------------------------
CLASS_LABELS = [
'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
'Rice_Brown_Spot', 'Rice_Healthy', 'Rice_Leaf_Blast', 'Rice_Neck_Blast',
'Wheat_Brown_Rust', 'Wheat_Healthy', 'Wheat_Yellow_Rust',
'Sugarcane_Red_Rot', 'Sugarcane_Healthy', 'Sugarcane_Bacterial_Blight'
]
NUM_CLASSES = len(CLASS_LABELS)
# ----------------------------------
# Model
# ----------------------------------
class ResNetPlantDisease(nn.Module):
def __init__(self, num_classes=17):
super().__init__()
self.backbone = resnet50(weights=None)
self.backbone.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.backbone(x)
model = ResNetPlantDisease(NUM_CLASSES)
model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu"))
model.eval()
# ----------------------------------
# Transform
# ----------------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ----------------------------------
# API Endpoint
# ----------------------------------
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
probs = torch.softmax(outputs, dim=1)[0]
best_idx = torch.argmax(probs).item()
confidence = float(probs[best_idx])
if confidence < 0.6:
return JSONResponse({
"message": "Please upload a clearer leaf image"
})
return {
"model": "ResNet50 Classification",
"disease": CLASS_LABELS[best_idx],
"confidence": round(confidence, 4)
}
except Exception as e:
return JSONResponse({"error": str(e)})
import gradio as gr
import uvicorn
import threading
def run_api():
uvicorn.run(app, host="0.0.0.0", port=8000)
threading.Thread(target=run_api).start()
gr.Markdown("""
# ๐ŸŒพ Crop Disease Classification API
### Endpoint:
POST `/predict`
Use Postman / frontend to send image.
""").launch(server_name="0.0.0.0", server_port=7860)