Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import io | |
| from models import ResNet9 | |
| app = FastAPI(title="CropGuard - Plant Disease Detection") | |
| CLASS_NAMES = [ | |
| 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', | |
| 'Blueberry___healthy', 'Cherry_(including_sour)__Powdery_mildew', 'Cherry(including_sour)__healthy', | |
| 'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)_Common_rust', | |
| 'Corn(maize)__Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape___Black_rot', | |
| 'Grape___Esca(Black_Measles)', 'Grape___Leaf_blight(Isariopsis_Leaf_Spot)', 'Grape___healthy', | |
| 'Orange___Haunglongbing(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', | |
| 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', | |
| 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', | |
| 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', | |
| 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', | |
| 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', | |
| 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy' | |
| ] | |
| model = None | |
| def load_model(): | |
| global model | |
| if model is None: | |
| model = ResNet9(3, len(CLASS_NAMES)) | |
| state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| load_model() | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| img_tensor = transform(image) | |
| if isinstance(img_tensor, torch.Tensor) and img_tensor.ndimension() == 3: | |
| img_tensor = img_tensor.unsqueeze(0) | |
| global model | |
| if model is None: | |
| load_model() | |
| if model is None: | |
| raise RuntimeError("Model failed to load.") | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probabilities = F.softmax(outputs[0], dim=0) | |
| top5_prob, top5_indices = torch.topk(probabilities, 5) | |
| results = {} | |
| for prob, idx in zip(top5_prob, top5_indices): | |
| class_name = CLASS_NAMES[int(idx.item())] | |
| clean_name = class_name.replace('___', ' - ').replace('_', ' ') | |
| results[clean_name] = float(prob) | |
| return JSONResponse(content={"predictions": results}) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |
| def root(): | |
| return {"message": "CropGuard FastAPI is running. Use /predict to POST an image."} | |