ItzRoBeerT commited on
Commit
6bc48aa
·
1 Parent(s): 03f76bc

Added leaves classifier

Browse files
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- gradio==6
 
 
1
+ gradio==6
2
+ transformers
src/pictures/185161-004-EAF28842.jpg ADDED
src/ui/tabs/leaf_tab.py CHANGED
@@ -5,7 +5,53 @@ UI component for analyzing individual leaves.
5
  """
6
 
7
  import gradio as gr
8
- from ...utils.leaf_analysis import analyze_leaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def create_leaf_tab():
 
5
  """
6
 
7
  import gradio as gr
8
+ from src.utils.leaf_classifier import predict as classify_image
9
+
10
+
11
+ def analyze_leaf(image):
12
+ """
13
+ Analyze a leaf image to detect diseases.
14
+
15
+ Args:
16
+ image: PIL.Image from gr.Image component
17
+
18
+ Returns:
19
+ str: Result formatted as Markdown
20
+ """
21
+ if image is None:
22
+ return "⚠️ Please upload an image of a leaf."
23
+
24
+ # Call classifier
25
+ result = classify_image(image)
26
+
27
+ # Handle error
28
+ if not result["success"]:
29
+ return f"❌ Error: {result['error']}"
30
+
31
+ # Format result as Markdown
32
+ emoji = "✅" if result["is_healthy"] else "⚠️"
33
+ status = "🌿 Healthy Plant" if result["is_healthy"] else "🦠 Disease Detected"
34
+
35
+ output = f"""
36
+ ## 🔬 Analysis Result
37
+
38
+ ### Main Diagnosis
39
+ - **Prediction:** {emoji} {result["prediction"]}
40
+ - **Confidence:** {result["confidence"]}%
41
+ - **Status:** {status}
42
+
43
+ ### Details
44
+ - **Plant:** {result["plant"]}
45
+ - **Condition:** {result["disease"]}
46
+
47
+ ### Other Possibilities
48
+ """
49
+
50
+ # Add top-k alternatives (skip first one, it's the main prediction)
51
+ for i, alt in enumerate(result["top_k"][1:], start=2):
52
+ output += f"{i}. {alt['plant']} - {alt['disease']} ({alt['confidence']}%)\n"
53
+
54
+ return output
55
 
56
 
57
  def create_leaf_tab():
src/utils/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- """
2
- Utils Package
3
- =============
4
- Utility functions for analysis and processing.
5
- """
6
-
7
- from .leaf_analysis import analyze_leaf
8
- from .farm_analysis import analyze_farm
9
- from .weather_alerts import get_weather_alerts
10
-
11
- __all__ = ["analyze_leaf", "analyze_farm", "get_weather_alerts"]
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/leaf_analysis.py DELETED
@@ -1,19 +0,0 @@
1
- """
2
- Leaf Analysis Module
3
- ====================
4
- Analyze individual leaves to detect diseases.
5
- """
6
-
7
-
8
- def analyze_leaf(image):
9
- """Analyze leaf to detect diseases."""
10
- if image is None:
11
- return "⚠️ Please upload an image."
12
-
13
- # TODO: Implement real classification
14
- # - Load AI model (ResNet, EfficientNet, etc.)
15
- # - Preprocess image
16
- # - Run inference
17
- # - Return disease classification with confidence
18
-
19
- return "✅ Image received. Analysis pending implementation."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/leaf_classifier.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Plant Disease Classifier
3
+ =========================
4
+
5
+ Classifies plant leaf diseases using MobileNetV2.
6
+
7
+ Model: linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification
8
+ - 38 classes (26 diseases + 12 healthy plants)
9
+ - 99.47% accuracy on PlantVillage dataset
10
+ - Input: 224x224 RGB image
11
+
12
+ Usage:
13
+ from src.classifier import predict
14
+
15
+ result = predict(pil_image)
16
+ print(result["prediction"]) # "Tomato - Late Blight"
17
+ """
18
+
19
+ import torch
20
+ from PIL import Image
21
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
22
+
23
+
24
+ # ============================================================
25
+ # CONFIGURATION
26
+ # ============================================================
27
+
28
+ MODEL_NAME = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
29
+
30
+
31
+ # ============================================================
32
+ # MODULE STATE
33
+ # ============================================================
34
+
35
+ _model = None
36
+ _processor = None
37
+ _device = None
38
+
39
+
40
+ # ============================================================
41
+ # PRIVATE FUNCTIONS
42
+ # ============================================================
43
+
44
+ def _load_model():
45
+ """
46
+ Load model and processor from HuggingFace.
47
+
48
+ Executes only ONCE (lazy loading).
49
+ Subsequent calls return cached objects.
50
+
51
+ Returns:
52
+ tuple: (model, processor, device)
53
+ """
54
+ global _model, _processor, _device
55
+
56
+ # Return cached if already loaded
57
+ if _model is not None:
58
+ return _model, _processor, _device
59
+
60
+ print("🌱 Loading classification model...")
61
+
62
+ # Determine device (GPU or CPU)
63
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
64
+ print(f" Device: {_device}")
65
+
66
+ # Load processor (prepares images for model)
67
+ _processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
68
+
69
+ # Load model
70
+ _model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
71
+ _model.to(_device)
72
+ _model.eval() # Set to evaluation mode
73
+
74
+ print(f"✅ Model loaded: {len(_model.config.id2label)} classes")
75
+
76
+ return _model, _processor, _device
77
+
78
+
79
+ def _parse_label(raw_label: str) -> tuple:
80
+ """
81
+ Parse raw model label into (plant, disease).
82
+
83
+ Args:
84
+ raw_label: Model label, e.g. "Tomato___Late_blight"
85
+
86
+ Returns:
87
+ tuple: (plant, disease)
88
+ e.g. ("Tomato", "Late blight")
89
+ """
90
+ try:
91
+ # Split by triple underscore
92
+ parts = raw_label.split("___")
93
+ plant = parts[0].replace("_", " ").replace("(", "").replace(")", "").strip()
94
+
95
+ if len(parts) > 1:
96
+ disease = parts[1].replace("_", " ").strip()
97
+ # Capitalize properly
98
+ disease = disease.title() if disease.lower() != "healthy" else "Healthy"
99
+ else:
100
+ disease = "Unknown"
101
+
102
+ return (plant, disease)
103
+ except:
104
+ return (raw_label, "Unknown")
105
+
106
+
107
+ # ============================================================
108
+ # PUBLIC FUNCTION
109
+ # ============================================================
110
+
111
+ def predict(image: Image.Image, top_k: int = 3) -> dict:
112
+ """
113
+ Predict disease in a plant leaf image.
114
+
115
+ Args:
116
+ image: PIL Image (PIL.Image.Image)
117
+ top_k: Number of alternative predictions to return
118
+
119
+ Returns:
120
+ dict with result:
121
+ {
122
+ "success": True,
123
+ "prediction": "Tomato - Late Blight",
124
+ "confidence": 95.23,
125
+ "is_healthy": False,
126
+ "plant": "Tomato",
127
+ "disease": "Late Blight",
128
+ "raw_label": "Tomato___Late_blight",
129
+ "top_k": [
130
+ {"plant": "Tomato", "disease": "Late Blight", "confidence": 95.23},
131
+ ...
132
+ ]
133
+ }
134
+
135
+ On error:
136
+ {
137
+ "success": False,
138
+ "error": "Error description"
139
+ }
140
+ """
141
+
142
+ # Validate input
143
+ if image is None:
144
+ return {
145
+ "success": False,
146
+ "error": "No image provided"
147
+ }
148
+
149
+ if not isinstance(image, Image.Image):
150
+ return {
151
+ "success": False,
152
+ "error": f"Invalid image type: {type(image)}. Expected PIL.Image"
153
+ }
154
+
155
+ try:
156
+ # Load model (only first time)
157
+ model, processor, device = _load_model()
158
+
159
+ # Preprocess image
160
+ image = image.convert("RGB")
161
+ inputs = processor(images=image, return_tensors="pt")
162
+ inputs = {k: v.to(device) for k, v in inputs.items()}
163
+
164
+ # Inference
165
+ with torch.no_grad():
166
+ outputs = model(**inputs)
167
+
168
+ # Process results
169
+ logits = outputs.logits
170
+ probs = torch.nn.functional.softmax(logits, dim=-1)
171
+
172
+ # Get top prediction
173
+ top_prob, top_idx = torch.max(probs, dim=-1)
174
+ raw_label = model.config.id2label[top_idx.item()]
175
+ confidence = round(top_prob.item() * 100, 2)
176
+
177
+ # Parse label
178
+ plant, disease = _parse_label(raw_label)
179
+ is_healthy = "healthy" in raw_label.lower()
180
+
181
+ # Get top-k predictions
182
+ top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[-1]))
183
+
184
+ top_k_results = []
185
+ for idx, prob in zip(top_k_indices[0], top_k_probs[0]):
186
+ label = model.config.id2label[idx.item()]
187
+ p, d = _parse_label(label)
188
+ top_k_results.append({
189
+ "plant": p,
190
+ "disease": d,
191
+ "confidence": round(prob.item() * 100, 2),
192
+ "raw_label": label
193
+ })
194
+
195
+ # Return structured result
196
+ return {
197
+ "success": True,
198
+ "prediction": f"{plant} - {disease}",
199
+ "confidence": confidence,
200
+ "is_healthy": is_healthy,
201
+ "plant": plant,
202
+ "disease": disease,
203
+ "raw_label": raw_label,
204
+ "top_k": top_k_results
205
+ }
206
+
207
+ except Exception as e:
208
+ return {
209
+ "success": False,
210
+ "error": str(e)
211
+ }
212
+
213
+
214
+
215
+ print("\n" + "="*50)
216
+ print("🧪 CLASSIFIER TEST")
217
+ print("="*50)
218
+
219
+ model, processor, device = _load_model()
220
+ print(f"\n📊 Available classes: {len(model.config.id2label)}")
221
+ print(f"🖥️ Device: {device}")
222
+
223
+ print("\n📋 Sample classes:")
224
+ for i, (idx, label) in enumerate(list(model.config.id2label.items())[:5]):
225
+ plant, disease = _parse_label(label)
226
+ print(f" {idx}: {plant} - {disease}")
227
+
228
+ print("\n✅ Classifier ready")
229
+ print("="*50)