AdeshJain commited on
Commit
dbad62f
·
verified ·
1 Parent(s): ab25090

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -80
app.py CHANGED
@@ -1,13 +1,15 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from huggingface_hub import hf_hub_download
4
  from tensorflow import keras
5
  import numpy as np
6
  from PIL import Image
7
  import io
8
- from typing import Dict
9
  import uvicorn
10
  import os
 
 
11
 
12
  app = FastAPI(title="CNN Image Prediction API")
13
 
@@ -20,13 +22,20 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
 
 
 
 
 
 
 
 
23
  # Global variable for model
24
  model = None
25
- IMG_SIZE = (224, 224) # Adjust based on your model's input size
26
 
27
  # Class labels
28
  CLASS_LABELS = [
29
-
30
  "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___Healthy",
31
  "Blueberry___Healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___Healthy",
32
  "Corn_(maize)___Cercospora_leaf_spot_Gray_leaf_spot", "Corn_(maize)__Common_rust", "Corn_(maize)___Northern_Leaf_Blight",
@@ -34,35 +43,27 @@ CLASS_LABELS = [
34
  "Grape___Healthy", "Orange__Haunglongbing(Citrus_greening)", "Peach___Bacterial_spot", "Peach___Healthy",
35
  "Pepper,bell__Bacterial_spot", "Pepper,bell__Healthy", "Potato___Early_blight", "Potato___Late_blight", "Potato___Healthy",
36
  "Raspberry___Healthy", "Soybean___Healthy", "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___Healthy",
37
- 'cotton : bacterial_blight',
38
- ' cotton : curl_virus',
39
- ' fussarium_wilt',"Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot",
40
- ' fussarium_wilt',"Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot",
41
  "Tomato___Spider_mites_Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
42
  "Tomato___Tomato_mosaic_virus", "Tomato___Healthy"
43
  ]
44
 
45
-
46
-
47
  @app.on_event("startup")
48
  async def load_model_keras():
49
  """Load the Keras model from Hugging Face on startup"""
50
  global model
51
  try:
52
- # Replace with your Hugging Face repo details
53
- repo_id = "AdeshJain/plant-detection" # e.g., "john/plant-disease-classifier"
54
  filename = "plant_disease_model.keras"
55
 
56
- # Download model from Hugging Face Hub
57
  model_path = hf_hub_download(
58
  repo_id=repo_id,
59
  filename=filename,
60
  cache_dir="./model_cache"
61
  )
62
 
63
- # Load the model
64
  model = keras.models.load_model(model_path)
65
-
66
  print(f"Model loaded successfully from {repo_id}!")
67
  except Exception as e:
68
  print(f"Error loading model: {e}")
@@ -70,24 +71,112 @@ async def load_model_keras():
70
 
71
  def preprocess_image(image: Image.Image) -> np.ndarray:
72
  """Preprocess the image for model prediction"""
73
- # Convert to RGB if needed
74
  if image.mode != 'RGB':
75
  image = image.convert('RGB')
76
 
77
- # Resize image
78
  image = image.resize(IMG_SIZE)
79
-
80
- # Convert to numpy array
81
  img_array = np.array(image)
82
-
83
- # Normalize pixel values to [0, 1]
84
  img_array = img_array.astype('float32') / 255.0
85
-
86
- # Add batch dimension
87
  img_array = np.expand_dims(img_array, axis=0)
88
 
89
  return img_array
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @app.get("/")
92
  async def root():
93
  """Health check endpoint"""
@@ -97,39 +186,68 @@ async def root():
97
  "model_loaded": model is not None
98
  }
99
 
100
- @app.get("/test")
101
- async def test_prediction():
102
- """
103
- Test endpoint using a hardcoded image from local directory
104
- Place a test image named 'test_image.jpg' in the same directory as main.py
105
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if model is None:
107
  raise HTTPException(status_code=500, detail="Model not loaded")
108
 
109
- # Path to test image (change this to your test image path)
110
- test_image_path = "test2.jpg"
111
-
112
- if not os.path.exists(test_image_path):
113
- raise HTTPException(
114
- status_code=404,
115
- detail=f"Test image not found at {test_image_path}. Please place a test image in the directory."
116
- )
117
 
118
  try:
119
- # Load and process the test image
120
- image = Image.open(test_image_path)
121
  processed_img = preprocess_image(image)
122
-
123
- # Make prediction
124
  prediction = model.predict(processed_img)
125
 
126
- # Get predicted class and confidence
127
- # Get top-1 prediction
128
  predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
129
  confidence = float(np.max(prediction))
130
  predicted_class_name = CLASS_LABELS[predicted_class_idx]
131
-
132
- # Get top-5 predictions
133
  top_5_indices = np.argsort(prediction[0])[-5:][::-1]
134
  top_5_predictions = [
135
  {
@@ -138,79 +256,118 @@ async def test_prediction():
138
  }
139
  for idx in top_5_indices
140
  ]
141
- print("Top-5 indices & confidences:",top_5_indices
142
- )
143
-
144
- # ✅ Return or print all results
145
  return {
146
- "top_prediction": {
147
- "class": predicted_class_name,
148
- "confidence": confidence
149
- }
150
- # "top_5_predictions": top_5_predictions
 
151
  }
152
-
153
  except Exception as e:
154
- raise HTTPException(status_code=500, detail=f"Test prediction error: {str(e)}")
155
 
156
- @app.post("/predict")
157
- async def predict(file: UploadFile = File(...)) -> Dict:
 
 
 
 
 
158
  """
159
- Predict image class using the loaded CNN model
160
 
161
  Args:
162
- file: Uploaded image file
 
 
 
163
 
164
  Returns:
165
- Dictionary containing prediction results
166
  """
167
  if model is None:
168
  raise HTTPException(status_code=500, detail="Model not loaded")
169
 
170
- # Validate file type
171
  if not file.content_type.startswith('image/'):
172
  raise HTTPException(status_code=400, detail="File must be an image")
173
 
174
  try:
175
- # Read image file
176
  contents = await file.read()
177
  image = Image.open(io.BytesIO(contents))
178
-
179
- # Preprocess image
180
  processed_img = preprocess_image(image)
181
-
182
- # Make prediction
183
  prediction = model.predict(processed_img)
184
 
185
- # Get predicted class and confidence
186
  predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
187
  confidence = float(np.max(prediction))
188
  predicted_class_name = CLASS_LABELS[predicted_class_idx]
189
 
190
- # Get top 5 predictions
191
- top_5_indices = np.argsort(prediction[0])[-5:][::-1]
192
- top_5_predictions = [
193
  {
194
  "class": CLASS_LABELS[idx],
195
  "confidence": float(prediction[0][idx])
196
  }
197
- for idx in top_5_indices
198
  ]
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  return {
201
  "success": True,
202
- "predicted_class": predicted_class_name,
203
- "predicted_class_index": predicted_class_idx,
204
- "confidence": confidence,
205
- "top_5_predictions": top_5_predictions,
 
 
 
 
 
 
 
 
 
 
 
 
206
  "filename": file.filename
207
  }
208
 
 
 
209
  except Exception as e:
210
- raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
211
-
212
-
213
 
214
  if __name__ == "__main__":
215
  uvicorn.run(app, host="0.0.0.0", port=7860)
216
-
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from huggingface_hub import hf_hub_download
4
  from tensorflow import keras
5
  import numpy as np
6
  from PIL import Image
7
  import io
8
+ from typing import Dict, Optional
9
  import uvicorn
10
  import os
11
+ import httpx
12
+ from groq import Groq
13
 
14
  app = FastAPI(title="CNN Image Prediction API")
15
 
 
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Initialize Groq client
26
+ client = Groq(
27
+ api_key="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",
28
+ )
29
+
30
+ # Weather API key - get from openweathermap.org
31
+ WEATHER_API_KEY = "6af4193be6ef82758dde960743909a80"
32
+
33
  # Global variable for model
34
  model = None
35
+ IMG_SIZE = (224, 224)
36
 
37
  # Class labels
38
  CLASS_LABELS = [
 
39
  "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___Healthy",
40
  "Blueberry___Healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___Healthy",
41
  "Corn_(maize)___Cercospora_leaf_spot_Gray_leaf_spot", "Corn_(maize)__Common_rust", "Corn_(maize)___Northern_Leaf_Blight",
 
43
  "Grape___Healthy", "Orange__Haunglongbing(Citrus_greening)", "Peach___Bacterial_spot", "Peach___Healthy",
44
  "Pepper,bell__Bacterial_spot", "Pepper,bell__Healthy", "Potato___Early_blight", "Potato___Late_blight", "Potato___Healthy",
45
  "Raspberry___Healthy", "Soybean___Healthy", "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___Healthy",
46
+ 'cotton : bacterial_blight', ' cotton : curl_virus', ' fussarium_wilt', "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot",
47
+ ' fussarium_wilt', "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot",
 
 
48
  "Tomato___Spider_mites_Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
49
  "Tomato___Tomato_mosaic_virus", "Tomato___Healthy"
50
  ]
51
 
 
 
52
  @app.on_event("startup")
53
  async def load_model_keras():
54
  """Load the Keras model from Hugging Face on startup"""
55
  global model
56
  try:
57
+ repo_id = "AdeshJain/plant-detection"
 
58
  filename = "plant_disease_model.keras"
59
 
 
60
  model_path = hf_hub_download(
61
  repo_id=repo_id,
62
  filename=filename,
63
  cache_dir="./model_cache"
64
  )
65
 
 
66
  model = keras.models.load_model(model_path)
 
67
  print(f"Model loaded successfully from {repo_id}!")
68
  except Exception as e:
69
  print(f"Error loading model: {e}")
 
71
 
72
  def preprocess_image(image: Image.Image) -> np.ndarray:
73
  """Preprocess the image for model prediction"""
 
74
  if image.mode != 'RGB':
75
  image = image.convert('RGB')
76
 
 
77
  image = image.resize(IMG_SIZE)
 
 
78
  img_array = np.array(image)
 
 
79
  img_array = img_array.astype('float32') / 255.0
 
 
80
  img_array = np.expand_dims(img_array, axis=0)
81
 
82
  return img_array
83
 
84
+ async def get_weather_data(lat: float, lon: float) -> Dict:
85
+ """Fetch weather data from OpenWeather API"""
86
+ url = f"https://api.openweathermap.org/data/2.5/weather?lat={lat}&lon={lon}&appid={WEATHER_API_KEY}&units=metric"
87
+
88
+ try:
89
+ async with httpx.AsyncClient() as client_http:
90
+ response = await client_http.get(url)
91
+ if response.status_code != 200:
92
+ return None
93
+ return response.json()
94
+ except Exception as e:
95
+ print(f"Weather API error: {e}")
96
+ return None
97
+
98
+ def get_llm_remedies(disease: str, weather_data: Optional[Dict], location: str) -> Dict:
99
+ """Get disease remedies from Groq LLM"""
100
+
101
+ # Prepare weather context
102
+ weather_context = ""
103
+ if weather_data:
104
+ temp = weather_data['main']['temp']
105
+ humidity = weather_data['main']['humidity']
106
+ weather_desc = weather_data['weather'][0]['description']
107
+ weather_context = f"""
108
+ Current Weather Conditions at {location}:
109
+ - Temperature: {temp}°C
110
+ - Humidity: {humidity}%
111
+ - Conditions: {weather_desc}
112
+ """
113
+ else:
114
+ weather_context = f"Location: {location}\n(Weather data unavailable)"
115
+
116
+ prompt = f"""You are an expert agricultural consultant specializing in plant disease management. A farmer has detected the following plant disease through image analysis:
117
+
118
+ **Detected Disease: {disease}**
119
+
120
+ {weather_context}
121
+
122
+ Please provide comprehensive treatment recommendations in the following structured format:
123
+
124
+ ## 1. CHEMICAL TREATMENT METHODS
125
+ Provide specific chemical treatments including:
126
+ - Recommended fungicides/pesticides (with active ingredients)
127
+ - Application dosage and concentration
128
+ - Application frequency and timing
129
+ - Safety precautions and protective equipment needed
130
+ - Pre-harvest intervals if applicable
131
+
132
+ ## 2. SUSTAINABLE & NATURAL TREATMENT METHODS
133
+ Provide organic and eco-friendly solutions including:
134
+ - Natural/organic sprays and remedies
135
+ - Biological control methods
136
+ - Cultural practices and preventive measures
137
+ - Soil management techniques
138
+ - Plant-based solutions
139
+
140
+ ## 3. WEATHER-SPECIFIC RECOMMENDATIONS
141
+ Based on the current weather conditions:
142
+ - How weather affects disease progression
143
+ - Best time to apply treatments
144
+ - Additional precautions needed
145
+ - Environmental considerations
146
+
147
+ ## 4. PREVENTIVE MEASURES
148
+ Long-term strategies to prevent recurrence:
149
+ - Crop rotation suggestions
150
+ - Irrigation management
151
+ - Nutrient management
152
+ - Monitoring practices
153
+
154
+ Please be specific, practical, and actionable. Consider both immediate treatment and long-term disease management."""
155
+
156
+ try:
157
+ chat_completion = client.chat.completions.create(
158
+ messages=[
159
+ {
160
+ "role": "system",
161
+ "content": "You are an expert agricultural consultant with deep knowledge of plant pathology and sustainable farming practices. Provide detailed, practical advice for farmers."
162
+ },
163
+ {
164
+ "role": "user",
165
+ "content": prompt
166
+ }
167
+ ],
168
+ model="llama-3.3-70b-versatile",
169
+ temperature=0.7,
170
+ max_tokens=500
171
+ )
172
+
173
+ return {
174
+ "remedies": chat_completion.choices[0].message.content,
175
+ "model_used": "llama-3.3-70b-versatile"
176
+ }
177
+ except Exception as e:
178
+ raise HTTPException(status_code=500, detail=f"LLM error: {str(e)}")
179
+
180
  @app.get("/")
181
  async def root():
182
  """Health check endpoint"""
 
186
  "model_loaded": model is not None
187
  }
188
 
189
+ # @app.get("/test")
190
+ # async def test_prediction():
191
+ # """Test endpoint using a hardcoded image from local directory"""
192
+ # if model is None:
193
+ # raise HTTPException(status_code=500, detail="Model not loaded")
194
+
195
+ # test_image_path = "test2.jpg"
196
+
197
+ # if not os.path.exists(test_image_path):
198
+ # raise HTTPException(
199
+ # status_code=404,
200
+ # detail=f"Test image not found at {test_image_path}. Please place a test image in the directory."
201
+ # )
202
+
203
+ # try:
204
+ # image = Image.open(test_image_path)
205
+ # processed_img = preprocess_image(image)
206
+ # prediction = model.predict(processed_img)
207
+
208
+ # predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
209
+ # confidence = float(np.max(prediction))
210
+ # predicted_class_name = CLASS_LABELS[predicted_class_idx]
211
+
212
+ # top_5_indices = np.argsort(prediction[0])[-5:][::-1]
213
+ # top_5_predictions = [
214
+ # {
215
+ # "class": CLASS_LABELS[idx],
216
+ # "confidence": float(prediction[0][idx])
217
+ # }
218
+ # for idx in top_5_indices
219
+ # ]
220
+ # print("Top-5 indices & confidences:", top_5_indices)
221
+
222
+ # return {
223
+ # "top_prediction": {
224
+ # "class": predicted_class_name,
225
+ # "confidence": confidence
226
+ # }
227
+ # }
228
+
229
+ # except Exception as e:
230
+ # raise HTTPException(status_code=500, detail=f"Test prediction error: {str(e)}")
231
+
232
+ @app.post("/predict")
233
+ async def predict(file: UploadFile = File(...)) -> Dict:
234
+ """Predict image class using the loaded CNN model"""
235
  if model is None:
236
  raise HTTPException(status_code=500, detail="Model not loaded")
237
 
238
+ if not file.content_type.startswith('image/'):
239
+ raise HTTPException(status_code=400, detail="File must be an image")
 
 
 
 
 
 
240
 
241
  try:
242
+ contents = await file.read()
243
+ image = Image.open(io.BytesIO(contents))
244
  processed_img = preprocess_image(image)
 
 
245
  prediction = model.predict(processed_img)
246
 
 
 
247
  predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
248
  confidence = float(np.max(prediction))
249
  predicted_class_name = CLASS_LABELS[predicted_class_idx]
250
+
 
251
  top_5_indices = np.argsort(prediction[0])[-5:][::-1]
252
  top_5_predictions = [
253
  {
 
256
  }
257
  for idx in top_5_indices
258
  ]
259
+
 
 
 
260
  return {
261
+ "success": True,
262
+ "predicted_class": predicted_class_name,
263
+ "predicted_class_index": predicted_class_idx,
264
+ "confidence": confidence,
265
+ "top_5_predictions": top_5_predictions,
266
+ "filename": file.filename
267
  }
268
+
269
  except Exception as e:
270
+ raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
271
 
272
+ @app.post("/predict-with-remedies")
273
+ async def predict_with_remedies(
274
+ file: UploadFile = File(...),
275
+ latitude: float = Form(...),
276
+ longitude: float = Form(...),
277
+ city: Optional[str] = Form(None)
278
+ ) -> Dict:
279
  """
280
+ Predict plant disease and provide AI-generated remedies with weather context
281
 
282
  Args:
283
+ file: Uploaded plant image
284
+ latitude: Location latitude for weather data
285
+ longitude: Location longitude for weather data
286
+ city: Optional city name for display
287
 
288
  Returns:
289
+ Dictionary containing prediction, weather data, and treatment recommendations
290
  """
291
  if model is None:
292
  raise HTTPException(status_code=500, detail="Model not loaded")
293
 
 
294
  if not file.content_type.startswith('image/'):
295
  raise HTTPException(status_code=400, detail="File must be an image")
296
 
297
  try:
298
+ # Step 1: Make disease prediction using CNN model
299
  contents = await file.read()
300
  image = Image.open(io.BytesIO(contents))
 
 
301
  processed_img = preprocess_image(image)
 
 
302
  prediction = model.predict(processed_img)
303
 
 
304
  predicted_class_idx = int(np.argmax(prediction, axis=1)[0])
305
  confidence = float(np.max(prediction))
306
  predicted_class_name = CLASS_LABELS[predicted_class_idx]
307
 
308
+ # Get top 3 predictions for additional context
309
+ top_3_indices = np.argsort(prediction[0])[-3:][::-1]
310
+ top_3_predictions = [
311
  {
312
  "class": CLASS_LABELS[idx],
313
  "confidence": float(prediction[0][idx])
314
  }
315
+ for idx in top_3_indices
316
  ]
317
 
318
+ # Step 2: Fetch weather data for the location
319
+ location_name = city if city else f"Lat: {latitude}, Lon: {longitude}"
320
+ weather_data = await get_weather_data(latitude, longitude)
321
+
322
+ # Step 3: Get LLM-generated remedies
323
+ # Only generate remedies if disease detected (not healthy)
324
+ remedies_data = None
325
+ if "Healthy" not in predicted_class_name:
326
+ remedies_data = get_llm_remedies(predicted_class_name, weather_data, location_name)
327
+ else:
328
+ remedies_data = {
329
+ "remedies": "🎉 Great news! Your plant appears to be healthy. No treatment needed.\n\n**Preventive Care Tips:**\n- Continue regular monitoring\n- Maintain proper watering schedule\n- Ensure adequate sunlight\n- Keep the area clean and weed-free\n- Monitor for any changes in plant appearance",
330
+ "model_used": "rule-based"
331
+ }
332
+
333
+ # Prepare weather info for response
334
+ weather_info = None
335
+ if weather_data:
336
+ weather_info = {
337
+ "temperature": weather_data['main']['temp'],
338
+ "feels_like": weather_data['main']['feels_like'],
339
+ "humidity": weather_data['main']['humidity'],
340
+ "pressure": weather_data['main']['pressure'],
341
+ "conditions": weather_data['weather'][0]['description'],
342
+ "wind_speed": weather_data['wind']['speed']
343
+ }
344
+
345
  return {
346
  "success": True,
347
+ "prediction": {
348
+ "disease": predicted_class_name,
349
+ "confidence": confidence,
350
+ "is_healthy": "Healthy" in predicted_class_name,
351
+ "top_3_predictions": top_3_predictions
352
+ },
353
+ "location": {
354
+ "name": location_name,
355
+ "latitude": latitude,
356
+ "longitude": longitude
357
+ },
358
+ "weather": weather_info,
359
+ "treatment": {
360
+ "remedies": remedies_data["remedies"],
361
+ "llm_model": remedies_data["model_used"]
362
+ },
363
  "filename": file.filename
364
  }
365
 
366
+ except HTTPException:
367
+ raise
368
  except Exception as e:
369
+ raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
 
 
370
 
371
  if __name__ == "__main__":
372
  uvicorn.run(app, host="0.0.0.0", port=7860)
373
+