zynt31 commited on
Commit
216cda9
·
verified ·
1 Parent(s): b2121ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +369 -0
app.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import sys
3
+ import time
4
+ import logging
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from fastapi import FastAPI, UploadFile, File, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from contextlib import asynccontextmanager # Add this import
10
+ from models.rice_model import RiceDiseaseCNN
11
+ from models.rice_leaf_validator import RiceLeafValidator, is_rice_leaf
12
+ from utils.image_processing import process_image
13
+ import requests
14
+ import json
15
+ from typing import Dict, Any
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format="%(asctime)s [%(levelname)s] %(message)s",
21
+ handlers=[
22
+ logging.StreamHandler(sys.stdout)
23
+ ]
24
+ )
25
+
26
+ # Initialize model variables
27
+ model = None
28
+ leaf_validator_model = None
29
+ model_path = "rice_disease_model_final.pth"
30
+ leaf_validator_model_path = "rice_leaf_validator_model.pth"
31
+ model_loaded = False
32
+ validator_loaded = False
33
+ model_loading = False
34
+ last_error = None
35
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
+
37
+ # Define lifespan context manager for model loading
38
+ @asynccontextmanager
39
+ async def lifespan(app):
40
+ # Load models on startup
41
+ load_model()
42
+ yield
43
+ # Cleanup operations can be added here for shutdown
44
+
45
+ # Create FastAPI app with lifespan
46
+ app = FastAPI(lifespan=lifespan)
47
+
48
+ # Add CORS middleware
49
+ app.add_middleware(
50
+ CORSMiddleware,
51
+ allow_origins=["*"], # Allows all origins
52
+ allow_credentials=True,
53
+ allow_methods=["*"], # Allows all methods
54
+ allow_headers=["*"], # Allows all headers
55
+ )
56
+
57
+ # Class names matching your training data
58
+ CLASS_NAMES = ['Bacterialblight', 'Blast', 'Brownspot', 'Tungro']
59
+
60
+ # Disease display names for more specific prompts
61
+ disease_display_names = {
62
+ 'Bacterialblight': 'Bacterial Leaf Blight (BLB)',
63
+ 'Blast': 'Rice Blast Disease',
64
+ 'Brownspot': 'Brown Spot Disease',
65
+ 'Tungro': 'Rice Tungro Disease'
66
+ }
67
+
68
+ # Disease recommendations (you can expand these with more detailed information)
69
+ RECOMMENDATIONS = {
70
+ 'Bacterialblight': {
71
+ 'recommendation': 'Use resistant varieties like PSB Rc18 (ALA), NSIC Rc354, or NSIC Rc302. Apply balanced nutrients, especially nitrogen (avoid excess). Maintain good drainage and field sanitation: weed removal, plow under stubbles and straw. Destroy ratoons and volunteer seedlings. Allow fallow fields to dry to suppress pathogens. Consult your local DA technician for specific advice.',
72
+ 'details': 'Bacterial Leaf Blight (BLB) is caused by Xanthomonas oryzae pv. oryzae. Symptoms include water-soaked stripes that expand to large grayish-white lesions with wavy light brown margins. In susceptible varieties, yield loss can reach 70%. At booting stage, causes poor quality grains and high broken grain percentage. Spread by ooze droplets, strong winds, heavy rains, contaminated stubbles, and infected weeds in tropical and temperate lowlands (25-34°C, >70% relative humidity).'
73
+ },
74
+ 'Blast': {
75
+ 'recommendation': 'Use resistant varieties such as PSB Rc18 or IRRI-derived lines. Avoid high nitrogen application and split nitrogen fertilizer application. Maintain proper irrigation (avoid water stress) and field sanitation (remove stubbles, debris). Adjust planting calendar to avoid peak infection period. Apply calcium silicate to strengthen cell walls. Apply fungicides: triazoles and strobilurins at first signs, especially during tillering/booting. Burn or plow under infected straw. Consult local DA.',
76
+ 'details': 'Rice Blast is caused by Magnaporthe oryzae fungus. Symptoms include leaf blast (small, spindle-shaped spots with brown border and gray center), node blast (nodes become black and break easily), and panicle blast (base becomes black, grains remain unfilled). Occurs from seedling to reproductive stage under cool temperature, high relative humidity, continuous rain, and large day-night temperature differences. Can reduce yield significantly by reducing leaf area for grain filling.'
77
+ },
78
+ 'Brownspot': {
79
+ 'recommendation': 'Plant resistant varieties and ensure balanced fertilization, especially sufficient potassium (use complete or potash fertilizers). Conduct soil test and correct deficiencies. Improve soil health and maintain proper field drainage. Use clean, healthy seeds and practice field sanitation (remove stubbles, weeds). Split nitrogen fertilizer application. Apply calcium silicate & potassium fertilizers. If severe, apply fungicides (triazoles, strobilurins) following BPI registration and safety rules. Consult local DA.',
80
+ 'details': 'Brown Spot is caused by Cochliobolus miyabeanus. Symptoms include small, circular to oval spots with gray centers on leaves, black spots on glumes with dark brown velvety fungal spores, and discolored, shriveled grains. Historically caused the 1943 Great Bengal Famine. Yield loss ranges (5-45%) and causes seedling blight mortality of 10-58%. Fungus survives 2-4 years in infected tissues under high humidity (86-100%), 16-36°C, in nutrient-deficient or toxic soils.'
81
+ },
82
+ 'Tungro': {
83
+ 'recommendation': 'Plant tungro- or GLH-resistant varieties like IR64 if available. Practice synchronous planting (sabay-sabay) with neighbors and maintain a fallow period (at least 1 month) between crops to break vector cycles. Remove (rogue) and destroy infected plants immediately. Plow infected stubbles after harvest to destroy inoculum and GLH breeding sites. Maintain balanced nutrient management. Note: Chemical control of GLH is not effective as they move quickly and spread tungro even with short feeding. Consult local DA.',
84
+ 'details': 'Rice Tungro Disease is caused by a combination of two viruses (Rice Tungro Bacilliform Virus + Rice Tungro Spherical Virus) transmitted by green leafhoppers (GLH). Symptoms include young leaves becoming mottled, older leaves turning yellow to orange, stunted growth, reduced tillers, delayed flowering, small panicles, and high sterility. Highly destructive in South & Southeast Asia with yield loss up to (100 Percent) in susceptible varieties infected early. Most damaging during vegetative/tillering stage.'
85
+ }
86
+ }
87
+
88
+ def get_ollama_recommendation(disease_name: str, confidence: float) -> Dict[str, Any]:
89
+ """Get AI-generated recommendation using Ollama"""
90
+ try:
91
+ logging.info(f"Requesting Ollama recommendation for {disease_name} with {confidence:.2f}% confidence")
92
+
93
+ # Create a detailed prompt
94
+ prompt = f"""
95
+ As an agricultural expert specializing in rice diseases in the Philippines:
96
+
97
+ Our AI system has detected {disease_display_names.get(disease_name, disease_name)} with {confidence:.2f}% confidence in a rice plant image.
98
+
99
+ Please provide:
100
+ 1. A detailed analysis of how this disease affects rice plants
101
+ 2. Specific recommendations for treating this disease infection
102
+ 3. Preventive measures farmers should take to avoid future infections
103
+
104
+ Format your response in clear sections that can be easily read by farmers.
105
+ """
106
+
107
+ # Call local Ollama API
108
+ response = requests.post('http://localhost:11434/api/generate',
109
+ json={
110
+ "model": "mistral:7b-instruct",
111
+ "prompt": prompt,
112
+ "stream": False
113
+ },
114
+ timeout=60 # 60-second timeout
115
+ )
116
+
117
+ if response.status_code == 200:
118
+ result = response.json()
119
+ ai_recommendation = result.get("response", "")
120
+
121
+ # Add this debugging
122
+ logging.info(f"Ollama response for {disease_name} (first 100 chars): {ai_recommendation[:100]}...")
123
+
124
+ return {
125
+ "recommendation": ai_recommendation,
126
+ "source": "ai",
127
+ "details": f"AI-generated recommendation for {disease_name}"
128
+ }
129
+ else:
130
+ logging.error(f"Ollama API error: {response.status_code}")
131
+ # Fall back to static recommendation
132
+ return RECOMMENDATIONS.get(disease_name, {})
133
+
134
+ except Exception as e:
135
+ logging.error(f"Error calling Ollama: {e}")
136
+ # Fall back to static recommendation
137
+ return RECOMMENDATIONS.get(disease_name, {})
138
+
139
+ def load_model():
140
+ """Load the PyTorch models"""
141
+ global model, leaf_validator_model, model_loaded, validator_loaded, model_loading, last_error
142
+ try:
143
+ logging.info(f"Loading disease model from: {model_path}")
144
+ model_loading = True
145
+
146
+ # Initialize the disease model architecture
147
+ model = RiceDiseaseCNN(num_classes=4).to(device)
148
+
149
+ # Load the saved model weights
150
+ model.load_state_dict(torch.load(model_path, map_location=device))
151
+ model.eval() # Set to evaluation mode
152
+ model_loaded = True
153
+
154
+ # Load rice leaf validator model
155
+ logging.info(f"Loading leaf validator model from: {leaf_validator_model_path}")
156
+ leaf_validator_model = RiceLeafValidator().to(device)
157
+ leaf_validator_model.load_state_dict(torch.load(leaf_validator_model_path, map_location=device))
158
+ leaf_validator_model.eval()
159
+ validator_loaded = True
160
+
161
+ model_loading = False
162
+ logging.info("Models loaded successfully!")
163
+ return True
164
+ except Exception as e:
165
+ last_error = str(e)
166
+ model_loading = False
167
+ logging.error(f"Failed to load models: {e}")
168
+ return False
169
+
170
+ @app.get("/")
171
+ async def root():
172
+ """API root endpoint"""
173
+ return {
174
+ "message": "Rice Disease Detection API is running!",
175
+ "model_loaded": model_loaded,
176
+ "validator_loaded": validator_loaded,
177
+ "model_loading": model_loading,
178
+ "error": last_error
179
+ }
180
+
181
+ @app.get("/health")
182
+ async def health():
183
+ """Health check endpoint needed by PHP status checker"""
184
+ return {
185
+ "status": "running",
186
+ "model_status": "loaded" if model_loaded else "loading" if model_loading else "not_loaded",
187
+ "validator_status": "loaded" if validator_loaded else "not_loaded",
188
+ "model_loaded": model_loaded,
189
+ "model_loading": model_loading,
190
+ "last_error": last_error,
191
+ "device": str(device)
192
+ }
193
+
194
+ @app.get("/status")
195
+ async def status():
196
+ """Check model loading status"""
197
+ return {
198
+ "model_loaded": model_loaded,
199
+ "validator_loaded": validator_loaded,
200
+ "model_loading": model_loading,
201
+ "error": last_error,
202
+ "device": str(device)
203
+ }
204
+
205
+ @app.get("/test-ollama")
206
+ async def test_ollama():
207
+ """Test if Ollama is working"""
208
+ try:
209
+ response = requests.post('http://localhost:11434/api/generate',
210
+ json={
211
+ "model": "mistral:7b-instruct",
212
+ "prompt": "Give a short greeting to verify the API is working",
213
+ "stream": False
214
+ },
215
+ timeout=10
216
+ )
217
+
218
+ if response.status_code == 200:
219
+ result = response.json()
220
+ return {
221
+ "success": True,
222
+ "message": "Ollama is working correctly",
223
+ "response": result.get("response", "")
224
+ }
225
+ else:
226
+ return {
227
+ "success": False,
228
+ "message": f"Ollama API returned status code {response.status_code}",
229
+ "error": response.text
230
+ }
231
+ except Exception as e:
232
+ return {
233
+ "success": False,
234
+ "message": "Failed to connect to Ollama",
235
+ "error": str(e)
236
+ }
237
+
238
+ @app.post("/validate-leaf/")
239
+ async def validate_rice_leaf_image(file: UploadFile = File(...)):
240
+ """Check if an image contains a rice leaf"""
241
+ if not validator_loaded:
242
+ if model_loading:
243
+ raise HTTPException(status_code=503, detail="Models are still loading. Please try again later.")
244
+ else:
245
+ success = load_model()
246
+ if not success:
247
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {last_error}")
248
+
249
+ try:
250
+ # Save uploaded file temporarily
251
+ temp_file_path = f"temp_{file.filename}"
252
+ with open(temp_file_path, "wb") as buffer:
253
+ buffer.write(await file.read())
254
+
255
+ # Check if image contains rice leaf
256
+ is_rice, confidence = is_rice_leaf(temp_file_path, leaf_validator_model, device)
257
+
258
+ # Remove temporary file
259
+ import os
260
+ os.remove(temp_file_path)
261
+
262
+ return {
263
+ "is_rice_leaf": is_rice,
264
+ "confidence": f"{confidence * 100:.2f}%"
265
+ }
266
+
267
+ except Exception as e:
268
+ logging.error(f"Error during leaf validation: {e}")
269
+ raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
270
+
271
+ @app.post("/predict/")
272
+ async def predict_disease(file: UploadFile = File(...), use_ai_recommendation: bool = True):
273
+ """Predict rice disease from uploaded image"""
274
+ # Check if models are loaded
275
+ if not model_loaded or not validator_loaded:
276
+ if model_loading:
277
+ raise HTTPException(status_code=503, detail="Models are still loading. Please try again later.")
278
+ else:
279
+ # Try loading the models if they're not already loading
280
+ success = load_model()
281
+ if not success:
282
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {last_error}")
283
+
284
+ try:
285
+ # Read image file
286
+ start_time = time.time()
287
+
288
+ # Save uploaded file temporarily for validation
289
+ temp_file_path = f"temp_{file.filename}"
290
+ contents = await file.read()
291
+ with open(temp_file_path, "wb") as buffer:
292
+ buffer.write(contents)
293
+
294
+ logging.info(f"Image size: {len(contents)} bytes")
295
+
296
+ # First validate if it's a rice leaf
297
+ is_rice, rice_confidence = is_rice_leaf(temp_file_path, leaf_validator_model, device)
298
+
299
+ # Remove temporary file
300
+ import os
301
+ os.remove(temp_file_path)
302
+
303
+ if not is_rice:
304
+ logging.info(f"Image not recognized as rice leaf: {rice_confidence * 100:.2f}% confidence")
305
+ return {
306
+ "is_rice_leaf": False,
307
+ "confidence": f"{rice_confidence * 100:.2f}%",
308
+ "message": "The uploaded image does not appear to be a rice leaf. Please upload a clear image of a rice plant leaf."
309
+ }
310
+
311
+ # Process image for disease detection
312
+ input_tensor = process_image(contents).to(device)
313
+
314
+ # Make prediction
315
+ with torch.no_grad():
316
+ outputs = model(input_tensor)
317
+ probabilities = F.softmax(outputs, dim=1)[0]
318
+
319
+ # Get top predictions with probabilities
320
+ top_probs, top_classes = torch.topk(probabilities, len(CLASS_NAMES))
321
+
322
+ predictions = []
323
+ for i in range(len(top_classes)):
324
+ class_idx = top_classes[i].item()
325
+ predictions.append({
326
+ "label": CLASS_NAMES[class_idx],
327
+ "confidence": f"{top_probs[i].item()*100:.2f}%"
328
+ })
329
+
330
+ # Get primary prediction (highest confidence)
331
+ predicted_class = top_classes[0].item()
332
+ disease_name = CLASS_NAMES[predicted_class]
333
+ confidence = float(top_probs[0].item()) * 100
334
+
335
+ # Get recommendation - either AI or static
336
+ if use_ai_recommendation:
337
+ try:
338
+ recommendation_data = get_ollama_recommendation(disease_name, confidence)
339
+ except Exception as e:
340
+ logging.error(f"Failed to get AI recommendation: {e}")
341
+ recommendation_data = RECOMMENDATIONS.get(disease_name, {}) # Fallback
342
+ else:
343
+ recommendation_data = RECOMMENDATIONS.get(disease_name, {})
344
+
345
+ # Log prediction time
346
+ processing_time = time.time() - start_time
347
+ logging.info(f"Prediction completed in {processing_time:.3f} seconds")
348
+
349
+ # Return in format expected by PHP code
350
+ return {
351
+ "is_rice_leaf": True,
352
+ "leaf_confidence": f"{rice_confidence * 100:.2f}%",
353
+ "predictions": predictions,
354
+ "disease": disease_name,
355
+ "confidence": f"{confidence:.2f}%",
356
+ "recommendation": recommendation_data.get("recommendation", ""),
357
+ "details": recommendation_data.get("details", ""),
358
+ "recommendation_source": recommendation_data.get("source", "static"),
359
+ "inference_time_seconds": round(processing_time, 3)
360
+ }
361
+
362
+ except Exception as e:
363
+ logging.error(f"Error during prediction: {e}")
364
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
365
+
366
+ if __name__ == "__main__":
367
+ import uvicorn
368
+ # Production settings: no hot reload, proper host binding
369
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)