File size: 3,061 Bytes
2c0eb09
 
163c547
9a7fc09
d1f56a5
9a81fea
 
 
d1f56a5
9a81fea
 
 
 
0c6c8ff
9a81fea
 
 
2c0eb09
 
 
 
 
 
 
 
 
 
9a81fea
 
 
381e2e8
9a81fea
78539e2
9a81fea
9a7fc09
381e2e8
9a81fea
381e2e8
 
 
 
2c0eb09
 
9a7fc09
2c0eb09
9a81fea
 
2c0eb09
 
9a81fea
 
 
2c0eb09
381e2e8
2c0eb09
 
 
 
 
 
 
9a81fea
 
 
 
 
 
 
 
 
2c0eb09
9a81fea
 
 
2c0eb09
9a81fea
 
42d431e
9a81fea
 
 
 
d1f56a5
9a81fea
 
 
 
 
2c0eb09
9a81fea
 
dd2886a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
from torchvision.models import resnet50
from torchvision import transforms
from PIL import Image
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import io

# ----------------------------------
# App
# ----------------------------------
app = FastAPI(title="Crop Disease Classification API")

# ----------------------------------
# Labels
# ----------------------------------
CLASS_LABELS = [
    'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
    'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
    'Rice_Brown_Spot', 'Rice_Healthy', 'Rice_Leaf_Blast', 'Rice_Neck_Blast',
    'Wheat_Brown_Rust', 'Wheat_Healthy', 'Wheat_Yellow_Rust',
    'Sugarcane_Red_Rot', 'Sugarcane_Healthy', 'Sugarcane_Bacterial_Blight'
]

NUM_CLASSES = len(CLASS_LABELS)

# ----------------------------------
# Model
# ----------------------------------
class ResNetPlantDisease(nn.Module):
    def __init__(self, num_classes=17):
        super().__init__()
        self.backbone = resnet50(weights=None)
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

model = ResNetPlantDisease(NUM_CLASSES)
model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu"))
model.eval()

# ----------------------------------
# Transform
# ----------------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# ----------------------------------
# API Endpoint
# ----------------------------------
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        img = transform(image).unsqueeze(0)

        with torch.no_grad():
            outputs = model(img)
            probs = torch.softmax(outputs, dim=1)[0]

        best_idx = torch.argmax(probs).item()
        confidence = float(probs[best_idx])

        if confidence < 0.6:
            return JSONResponse({
                "message": "Please upload a clearer leaf image"
            })

        return {
            "model": "ResNet50 Classification",
            "disease": CLASS_LABELS[best_idx],
            "confidence": round(confidence, 4)
        }

    except Exception as e:
        return JSONResponse({"error": str(e)})

import gradio as gr
import uvicorn
import threading

def run_api():
    uvicorn.run(app, host="0.0.0.0", port=8000)

threading.Thread(target=run_api).start()

gr.Markdown("""
# 🌾 Crop Disease Classification API

### Endpoint:
POST `/predict`

Use Postman / frontend to send image.
""").launch(server_name="0.0.0.0", server_port=7860)