CropGuardFastAPI / fastapi_app.py
Jude Joseph Agustino
Initial Commit: CropGuard disease detection app
ca8798e
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import io
from models import ResNet9
app = FastAPI(title="CropGuard - Plant Disease Detection")
CLASS_NAMES = [
'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
'Blueberry___healthy', 'Cherry_(including_sour)__Powdery_mildew', 'Cherry(including_sour)__healthy',
'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)_Common_rust',
'Corn(maize)__Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape___Black_rot',
'Grape___Esca(Black_Measles)', 'Grape___Leaf_blight(Isariopsis_Leaf_Spot)', 'Grape___healthy',
'Orange___Haunglongbing(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight',
'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy',
'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'
]
model = None
def load_model():
global model
if model is None:
model = ResNet9(3, len(CLASS_NAMES))
state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
load_model()
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
img_tensor = transform(image)
if isinstance(img_tensor, torch.Tensor) and img_tensor.ndimension() == 3:
img_tensor = img_tensor.unsqueeze(0)
global model
if model is None:
load_model()
if model is None:
raise RuntimeError("Model failed to load.")
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs[0], dim=0)
top5_prob, top5_indices = torch.topk(probabilities, 5)
results = {}
for prob, idx in zip(top5_prob, top5_indices):
class_name = CLASS_NAMES[int(idx.item())]
clean_name = class_name.replace('___', ' - ').replace('_', ' ')
results[clean_name] = float(prob)
return JSONResponse(content={"predictions": results})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.get("/")
def root():
return {"message": "CropGuard FastAPI is running. Use /predict to POST an image."}