Jude Joseph Agustino commited on
Commit
ca8798e
·
1 Parent(s): e0c3c75

Initial Commit: CropGuard disease detection app

Browse files
Files changed (8) hide show
  1. Dockerfile +30 -0
  2. README.md +22 -4
  3. app.py +178 -0
  4. converter.py +19 -0
  5. fastapi_app.py +71 -0
  6. models.py +60 -0
  7. plant-disease-model-state-dict.pth +3 -0
  8. requirements.txt +4 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.13.7 slim image for better performance and smaller size
2
+ FROM python:3.13.7-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements first to leverage Docker cache
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir --upgrade pip && \
17
+ pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application files
20
+ COPY . .
21
+
22
+ # Create a non-root user for security
23
+ RUN useradd --create-home --shell /bin/bash app && chown -R app:app /app
24
+ USER app
25
+
26
+ # Expose the port that Hugging Face Spaces expects
27
+ EXPOSE 7860
28
+
29
+ # Command to run the application
30
+ CMD ["uvicorn", "fastapi_app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,28 @@
1
  ---
2
- title: CropGuardFastAPI
3
- emoji: 📉
4
  colorFrom: green
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CropGuard Plant Disease Detection
3
+ emoji: 🌱
4
  colorFrom: green
5
+ colorTo: blue
6
  sdk: docker
7
  pinned: false
8
+ license: apache-2.0
9
  ---
10
 
11
+ # CropGuard - Plant Disease Detection
12
+
13
+ This is a plant disease classification system using a ResNet-9 model trained on the PlantVillage dataset.
14
+
15
+ ## Features
16
+ - Detects diseases in 38 different plant categories
17
+ - Supports multiple plant types: Apple, Corn, Tomato, Potato, Grape, and more
18
+ - Real-time prediction with confidence scores
19
+ - User-friendly web interface
20
+
21
+ ## Usage
22
+ Simply upload an image of a plant leaf, and the model will predict the disease (if any) with confidence scores.
23
+
24
+ ## Model Details
25
+ - Architecture: ResNet-9
26
+ - Dataset: PlantVillage
27
+ - Classes: 38 plant disease categories
28
+ - Input: 256x256 RGB images
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import os
7
+
8
+ # Import your model
9
+ from models import ResNet9
10
+
11
+ # Plant disease class names
12
+ CLASS_NAMES = [
13
+ 'Apple___Apple_scab',
14
+ 'Apple___Black_rot',
15
+ 'Apple___Cedar_apple_rust',
16
+ 'Apple___healthy',
17
+ 'Blueberry___healthy',
18
+ 'Cherry_(including_sour)__Powdery_mildew',
19
+ 'Cherry(including_sour)__healthy',
20
+ 'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot',
21
+ 'Corn(maize)_Common_rust',
22
+ 'Corn(maize)__Northern_Leaf_Blight',
23
+ 'Corn(maize)healthy',
24
+ 'Grape___Black_rot',
25
+ 'Grape___Esca(Black_Measles)',
26
+ 'Grape___Leaf_blight(Isariopsis_Leaf_Spot)',
27
+ 'Grape___healthy',
28
+ 'Orange___Haunglongbing(Citrus_greening)',
29
+ 'Peach___Bacterial_spot',
30
+ 'Peach___healthy',
31
+ 'Pepper,_bell___Bacterial_spot',
32
+ 'Pepper,_bell___healthy',
33
+ 'Potato___Early_blight',
34
+ 'Potato___Late_blight',
35
+ 'Potato___healthy',
36
+ 'Raspberry___healthy',
37
+ 'Soybean___healthy',
38
+ 'Squash___Powdery_mildew',
39
+ 'Strawberry___Leaf_scorch',
40
+ 'Strawberry___healthy',
41
+ 'Tomato___Bacterial_spot',
42
+ 'Tomato___Early_blight',
43
+ 'Tomato___Late_blight',
44
+ 'Tomato___Leaf_Mold',
45
+ 'Tomato___Septoria_leaf_spot',
46
+ 'Tomato___Spider_mites Two-spotted_spider_mite',
47
+ 'Tomato___Target_Spot',
48
+ 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
49
+ 'Tomato___Tomato_mosaic_virus',
50
+ 'Tomato___healthy'
51
+ ]
52
+
53
+ # Load model
54
+ model = None
55
+
56
+ def load_model():
57
+ global model
58
+ try:
59
+ model = ResNet9(3, len(CLASS_NAMES))
60
+ state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu")
61
+ model.load_state_dict(state_dict)
62
+ model.eval()
63
+ print("✅ Model loaded successfully")
64
+ return True
65
+ except Exception as e:
66
+ print(f"❌ Model load failed: {e}")
67
+ return False
68
+
69
+ def predict_disease(image):
70
+ """Predict plant disease from image"""
71
+ if model is None:
72
+ if not load_model():
73
+ return {"Error": "Model not available"}
74
+
75
+ # Transform image
76
+ transform = transforms.Compose([
77
+ transforms.Resize((256, 256)),
78
+ transforms.ToTensor()
79
+ ])
80
+
81
+ try:
82
+ # Convert and transform image
83
+ if image is None:
84
+ return {"Error": "No image provided"}
85
+
86
+ img_tensor = transform(image).unsqueeze(0)
87
+
88
+ # Make prediction
89
+ with torch.no_grad():
90
+ outputs = model(img_tensor)
91
+ probabilities = F.softmax(outputs[0], dim=0)
92
+
93
+ # Get top 5 predictions
94
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
95
+
96
+ # Format results for Gradio
97
+ results = {}
98
+ for i, (prob, idx) in enumerate(zip(top5_prob, top5_indices)):
99
+ class_name = CLASS_NAMES[idx.item()]
100
+ # Clean up class name for display
101
+ clean_name = class_name.replace('___', ' - ').replace('_', ' ')
102
+ results[clean_name] = float(prob)
103
+
104
+ return results
105
+
106
+ except Exception as e:
107
+ return {"Error": f"Prediction failed: {str(e)}"}
108
+
109
+ def format_class_info():
110
+ """Format class information for display"""
111
+ plants = {}
112
+ for class_name in CLASS_NAMES:
113
+ if '___' in class_name:
114
+ plant, condition = class_name.split('___', 1)
115
+ if plant not in plants:
116
+ plants[plant] = []
117
+ plants[plant].append(condition.replace('_', ' '))
118
+
119
+ info = "## Supported Plants and Conditions:\n\n"
120
+ for plant, conditions in sorted(plants.items()):
121
+ info += f"**{plant.replace('_', ' ')}**: {', '.join(conditions)}\n\n"
122
+
123
+ return info
124
+
125
+ # Load model on startup
126
+ load_model()
127
+
128
+ # Create Gradio interface
129
+ with gr.Blocks(title="🌱 CropGuard - Plant Disease Detection", theme=gr.themes.Soft()) as demo:
130
+ gr.Markdown("""
131
+ # 🌱 CropGuard - Plant Disease Detection
132
+
133
+ Upload an image of a plant leaf to detect diseases using our ResNet-9 model trained on the PlantVillage dataset.
134
+
135
+ **Supported formats**: JPG, PNG, JPEG
136
+ """)
137
+
138
+ with gr.Row():
139
+ with gr.Column():
140
+ image_input = gr.Image(
141
+ type="pil",
142
+ label="Upload Plant Image",
143
+ height=400
144
+ )
145
+ predict_btn = gr.Button("🔍 Analyze Disease", variant="primary", size="lg")
146
+
147
+ with gr.Column():
148
+ output = gr.Label(
149
+ label="Disease Prediction Results",
150
+ num_top_classes=5,
151
+ show_label=True
152
+ )
153
+
154
+ # Example images (you can add these later)
155
+ gr.Markdown("### 📋 Examples")
156
+ gr.Markdown("Try uploading images of plant leaves to see the disease detection in action!")
157
+
158
+ # Info section
159
+ with gr.Accordion("ℹ️ Supported Plants & Diseases", open=False):
160
+ gr.Markdown(format_class_info())
161
+
162
+ # Event handlers
163
+ predict_btn.click(
164
+ fn=predict_disease,
165
+ inputs=image_input,
166
+ outputs=output
167
+ )
168
+
169
+ # Also predict on image upload
170
+ image_input.change(
171
+ fn=predict_disease,
172
+ inputs=image_input,
173
+ outputs=output
174
+ )
175
+
176
+ # Launch the app
177
+ if __name__ == "__main__":
178
+ demo.launch()
converter.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models import ResNet9
3
+
4
+ # Load the existing model
5
+ try:
6
+ old_model = torch.load("plant-disease-model.pth", map_location="cpu", weights_only=False)
7
+
8
+ # Save only the state dict
9
+ torch.save(old_model.state_dict(), "plant-disease-model-state-dict.pth")
10
+ print("✅ Model converted to state dict format")
11
+
12
+ # Test loading the new format
13
+ new_model = ResNet9(3, 38) # 38 classes based on your CLASS_NAMES
14
+ new_model.load_state_dict(torch.load("plant-disease-model-state-dict.pth"))
15
+ new_model.eval()
16
+ print("✅ Converted model loads successfully")
17
+
18
+ except Exception as e:
19
+ print(f"❌ Conversion failed: {e}")
fastapi_app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as transforms
7
+ import io
8
+ from models import ResNet9
9
+
10
+ app = FastAPI(title="CropGuard - Plant Disease Detection")
11
+
12
+ CLASS_NAMES = [
13
+ 'Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy',
14
+ 'Blueberry___healthy', 'Cherry_(including_sour)__Powdery_mildew', 'Cherry(including_sour)__healthy',
15
+ 'Corn(maize)__Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)_Common_rust',
16
+ 'Corn(maize)__Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape___Black_rot',
17
+ 'Grape___Esca(Black_Measles)', 'Grape___Leaf_blight(Isariopsis_Leaf_Spot)', 'Grape___healthy',
18
+ 'Orange___Haunglongbing(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
19
+ 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight',
20
+ 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy',
21
+ 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy',
22
+ 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold',
23
+ 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot',
24
+ 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'
25
+ ]
26
+
27
+ model = None
28
+
29
+ def load_model():
30
+ global model
31
+ if model is None:
32
+ model = ResNet9(3, len(CLASS_NAMES))
33
+ state_dict = torch.load("plant-disease-model-state-dict.pth", map_location="cpu")
34
+ model.load_state_dict(state_dict)
35
+ model.eval()
36
+
37
+ load_model()
38
+
39
+ @app.post("/predict")
40
+ async def predict(file: UploadFile = File(...)):
41
+ try:
42
+ image_bytes = await file.read()
43
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
44
+ transform = transforms.Compose([
45
+ transforms.Resize((256, 256)),
46
+ transforms.ToTensor()
47
+ ])
48
+ img_tensor = transform(image)
49
+ if isinstance(img_tensor, torch.Tensor) and img_tensor.ndimension() == 3:
50
+ img_tensor = img_tensor.unsqueeze(0)
51
+ global model
52
+ if model is None:
53
+ load_model()
54
+ if model is None:
55
+ raise RuntimeError("Model failed to load.")
56
+ with torch.no_grad():
57
+ outputs = model(img_tensor)
58
+ probabilities = F.softmax(outputs[0], dim=0)
59
+ top5_prob, top5_indices = torch.topk(probabilities, 5)
60
+ results = {}
61
+ for prob, idx in zip(top5_prob, top5_indices):
62
+ class_name = CLASS_NAMES[int(idx.item())]
63
+ clean_name = class_name.replace('___', ' - ').replace('_', ' ')
64
+ results[clean_name] = float(prob)
65
+ return JSONResponse(content={"predictions": results})
66
+ except Exception as e:
67
+ return JSONResponse(content={"error": str(e)}, status_code=500)
68
+
69
+ @app.get("/")
70
+ def root():
71
+ return {"message": "CropGuard FastAPI is running. Use /predict to POST an image."}
models.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def ConvBlock(in_channels, out_channels, pool=False):
6
+ layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
7
+ nn.BatchNorm2d(out_channels),
8
+ nn.ReLU(inplace=True)]
9
+ if pool:
10
+ layers.append(nn.MaxPool2d(4))
11
+ return nn.Sequential(*layers)
12
+
13
+ class ImageClassificationBase(nn.Module):
14
+ def training_step(self, batch):
15
+ images, labels = batch
16
+ out = self(images)
17
+ loss = F.cross_entropy(out, labels)
18
+ return loss
19
+
20
+ def validation_step(self, batch):
21
+ images, labels = batch
22
+ out = self(images)
23
+ loss = F.cross_entropy(out, labels)
24
+ acc = (out.argmax(dim=1) == labels).float().mean()
25
+ return {"val_loss": loss.detach(), "val_accuracy": acc}
26
+
27
+ def validation_epoch_end(self, outputs):
28
+ batch_losses = [x["val_loss"] for x in outputs]
29
+ batch_accuracy = [x["val_accuracy"] for x in outputs]
30
+ epoch_loss = torch.stack(batch_losses).mean()
31
+ epoch_accuracy = torch.stack(batch_accuracy).mean()
32
+ return {"val_loss": epoch_loss, "val_accuracy": epoch_accuracy}
33
+
34
+ def epoch_end(self, epoch, result):
35
+ print(f"Epoch [{epoch}], train_loss: {result['train_loss']:.4f}, val_loss: {result['val_loss']:.4f}, val_acc: {result['val_accuracy']:.4f}")
36
+
37
+
38
+ class ResNet9(ImageClassificationBase):
39
+ def __init__(self, in_channels, num_classes):
40
+ super().__init__()
41
+ self.conv1 = ConvBlock(in_channels, 64)
42
+ self.conv2 = ConvBlock(64, 128, pool=True)
43
+ self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))
44
+ self.conv3 = ConvBlock(128, 256, pool=True)
45
+ self.conv4 = ConvBlock(256, 512, pool=True)
46
+ self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))
47
+ self.classifier = nn.Sequential(
48
+ nn.MaxPool2d(4),
49
+ nn.Flatten(),
50
+ nn.Linear(512, num_classes)
51
+ )
52
+ def forward(self, xb):
53
+ out = self.conv1(xb)
54
+ out = self.conv2(out)
55
+ out = self.res1(out) + out
56
+ out = self.conv3(out)
57
+ out = self.conv4(out)
58
+ out = self.res2(out) + out
59
+ out = self.classifier(out)
60
+ return out
plant-disease-model-state-dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5e9497ddbb41600362d95c00ad1fe96b2b3ed637ac8a27cc2803e8acff1144d
3
+ size 26397109
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ fastapi[standard]