rld / app.py
zynt31's picture
Updated app.py with recommendation
a760cc2 verified
import io
import sys
import time
import logging
import torch
import torch.nn.functional as F
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager # Add this import
from models.rice_model import RiceDiseaseCNN
from models.rice_leaf_validator import RiceLeafValidator, is_rice_leaf
from utils.image_processing import process_image
import requests
import json
import os
from typing import Dict, Any
# Hugging Face transformers pipeline for text recommendations
try:
from transformers import pipeline
recommender = pipeline("text-generation", model="google/flan-t5-small")
except Exception as e:
recommender = None
logging.warning(f"Transformers pipeline not loaded: {e}")
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
)
# Initialize model variables
model = None
leaf_validator_model = None
model_path = "rice_disease_model_final.pth"
leaf_validator_model_path = "rice_leaf_validator_model.pth"
model_loaded = False
validator_loaded = False
model_loading = False
last_error = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define lifespan context manager for model loading
@asynccontextmanager
async def lifespan(app):
# Load models on startup
load_model()
yield
# Cleanup operations can be added here for shutdown
# Create FastAPI app with lifespan
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# Class names matching your training data
CLASS_NAMES = ['Bacterialblight', 'Blast', 'Brownspot', 'Tungro']
# Disease display names for more specific prompts
disease_display_names = {
'Bacterialblight': 'Bacterial Leaf Blight (BLB)',
'Blast': 'Rice Blast Disease',
'Brownspot': 'Brown Spot Disease',
'Tungro': 'Rice Tungro Disease'
}
# Disease recommendations (you can expand these with more detailed information)
RECOMMENDATIONS = {
'Bacterialblight': {
'recommendation': 'Use resistant varieties like PSB Rc18 (ALA), NSIC Rc354, or NSIC Rc302. Apply balanced nutrients, especially nitrogen (avoid excess). Maintain good drainage and field sanitation: weed removal, plow under stubbles and straw. Destroy ratoons and volunteer seedlings. Allow fallow fields to dry to suppress pathogens. Consult your local DA technician for specific advice.',
'details': 'Bacterial Leaf Blight (BLB) is caused by Xanthomonas oryzae pv. oryzae. Symptoms include water-soaked stripes that expand to large grayish-white lesions with wavy light brown margins. In susceptible varieties, yield loss can reach 70%. At booting stage, causes poor quality grains and high broken grain percentage. Spread by ooze droplets, strong winds, heavy rains, contaminated stubbles, and infected weeds in tropical and temperate lowlands (25-34°C, >70% relative humidity).'
},
'Blast': {
'recommendation': 'Use resistant varieties such as PSB Rc18 or IRRI-derived lines. Avoid high nitrogen application and split nitrogen fertilizer application. Maintain proper irrigation (avoid water stress) and field sanitation (remove stubbles, debris). Adjust planting calendar to avoid peak infection period. Apply calcium silicate to strengthen cell walls. Apply fungicides: triazoles and strobilurins at first signs, especially during tillering/booting. Burn or plow under infected straw. Consult local DA.',
'details': 'Rice Blast is caused by Magnaporthe oryzae fungus. Symptoms include leaf blast (small, spindle-shaped spots with brown border and gray center), node blast (nodes become black and break easily), and panicle blast (base becomes black, grains remain unfilled). Occurs from seedling to reproductive stage under cool temperature, high relative humidity, continuous rain, and large day-night temperature differences. Can reduce yield significantly by reducing leaf area for grain filling.'
},
'Brownspot': {
'recommendation': 'Plant resistant varieties and ensure balanced fertilization, especially sufficient potassium (use complete or potash fertilizers). Conduct soil test and correct deficiencies. Improve soil health and maintain proper field drainage. Use clean, healthy seeds and practice field sanitation (remove stubbles, weeds). Split nitrogen fertilizer application. Apply calcium silicate & potassium fertilizers. If severe, apply fungicides (triazoles, strobilurins) following BPI registration and safety rules. Consult local DA.',
'details': 'Brown Spot is caused by Cochliobolus miyabeanus. Symptoms include small, circular to oval spots with gray centers on leaves, black spots on glumes with dark brown velvety fungal spores, and discolored, shriveled grains. Historically caused the 1943 Great Bengal Famine. Yield loss ranges (5-45%) and causes seedling blight mortality of 10-58%. Fungus survives 2-4 years in infected tissues under high humidity (86-100%), 16-36°C, in nutrient-deficient or toxic soils.'
},
'Tungro': {
'recommendation': 'Plant tungro- or GLH-resistant varieties like IR64 if available. Practice synchronous planting (sabay-sabay) with neighbors and maintain a fallow period (at least 1 month) between crops to break vector cycles. Remove (rogue) and destroy infected plants immediately. Plow infected stubbles after harvest to destroy inoculum and GLH breeding sites. Maintain balanced nutrient management. Note: Chemical control of GLH is not effective as they move quickly and spread tungro even with short feeding. Consult local DA.',
'details': 'Rice Tungro Disease is caused by a combination of two viruses (Rice Tungro Bacilliform Virus + Rice Tungro Spherical Virus) transmitted by green leafhoppers (GLH). Symptoms include young leaves becoming mottled, older leaves turning yellow to orange, stunted growth, reduced tillers, delayed flowering, small panicles, and high sterility. Highly destructive in South & Southeast Asia with yield loss up to (100 Percent) in susceptible varieties infected early. Most damaging during vegetative/tillering stage.'
}
}
# New Hugging Face recommendation endpoint
@app.post("/recommend/")
async def recommend(disease: str):
"""Generate farming recommendations for a rice disease using Hugging Face LLM"""
if recommender is None:
return {"error": "Text generation pipeline not available."}
prompt = f"Give 2-3 short farming recommendations for managing {disease} in rice plants."
response = recommender(prompt, max_length=100, num_return_sequences=1)
return {"recommendation": response[0]['generated_text']}
def load_model():
"""Load the PyTorch models"""
global model, leaf_validator_model, model_loaded, validator_loaded, model_loading, last_error
try:
logging.info(f"Loading disease model from: {model_path}")
model_loading = True
# Initialize the disease model architecture
model = RiceDiseaseCNN(num_classes=4).to(device)
# Load the saved model weights
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval() # Set to evaluation mode
model_loaded = True
# Load rice leaf validator model
logging.info(f"Loading leaf validator model from: {leaf_validator_model_path}")
leaf_validator_model = RiceLeafValidator().to(device)
leaf_validator_model.load_state_dict(torch.load(leaf_validator_model_path, map_location=device))
leaf_validator_model.eval()
validator_loaded = True
model_loading = False
logging.info("Models loaded successfully!")
return True
except Exception as e:
last_error = str(e)
model_loading = False
logging.error(f"Failed to load models: {e}")
return False
@app.get("/")
async def root():
"""API root endpoint"""
return {
"message": "Rice Disease Detection API is running!",
"model_loaded": model_loaded,
"validator_loaded": validator_loaded,
"model_loading": model_loading,
"error": last_error
}
@app.get("/health")
async def health():
"""Health check endpoint needed by PHP status checker"""
return {
"status": "running",
"model_status": "loaded" if model_loaded else "loading" if model_loading else "not_loaded",
"validator_status": "loaded" if validator_loaded else "not_loaded",
"model_loaded": model_loaded,
"model_loading": model_loading,
"last_error": last_error,
"device": str(device)
}
@app.get("/status")
async def status():
"""Check model loading status"""
return {
"model_loaded": model_loaded,
"validator_loaded": validator_loaded,
"model_loading": model_loading,
"error": last_error,
"device": str(device)
}
@app.get("/test-ollama")
async def test_ollama():
"""Test if Ollama is working"""
try:
response = requests.post('http://localhost:11434/api/generate',
json={
"model": "mistral:7b-instruct",
"prompt": "Give a short greeting to verify the API is working",
"stream": False
},
timeout=10
)
if response.status_code == 200:
result = response.json()
return {
"success": True,
"message": "Ollama is working correctly",
"response": result.get("response", "")
}
else:
return {
"success": False,
"message": f"Ollama API returned status code {response.status_code}",
"error": response.text
}
except Exception as e:
return {
"success": False,
"message": "Failed to connect to Ollama",
"error": str(e)
}
@app.post("/validate-leaf/")
async def validate_rice_leaf_image(file: UploadFile = File(...)):
"""Check if an image contains a rice leaf"""
if not validator_loaded:
if model_loading:
raise HTTPException(status_code=503, detail="Models are still loading. Please try again later.")
else:
success = load_model()
if not success:
raise HTTPException(status_code=500, detail=f"Failed to load models: {last_error}")
try:
# Save uploaded file temporarily in /tmp (Hugging Face Spaces requirement)
temp_file_path = os.path.join("/tmp", f"temp_{file.filename}")
with open(temp_file_path, "wb") as buffer:
buffer.write(await file.read())
# Check if image contains rice leaf
is_rice, confidence = is_rice_leaf(temp_file_path, leaf_validator_model, device)
# Remove temporary file
os.remove(temp_file_path)
return {
"is_rice_leaf": is_rice,
"confidence": f"{confidence * 100:.2f}%"
}
except Exception as e:
logging.error(f"Error during leaf validation: {e}")
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}")
@app.post("/predict/")
async def predict_disease(file: UploadFile = File(...), use_ai_recommendation: bool = True):
"""Predict rice disease from uploaded image"""
# Check if models are loaded
if not model_loaded or not validator_loaded:
if model_loading:
raise HTTPException(status_code=503, detail="Models are still loading. Please try again later.")
else:
# Try loading the models if they're not already loading
success = load_model()
if not success:
raise HTTPException(status_code=500, detail=f"Failed to load models: {last_error}")
try:
# Read image file
start_time = time.time()
# Save uploaded file temporarily for validation in /tmp (Hugging Face Spaces requirement)
import os
temp_file_path = os.path.join("/tmp", f"temp_{file.filename}")
contents = await file.read()
with open(temp_file_path, "wb") as buffer:
buffer.write(contents)
logging.info(f"Image size: {len(contents)} bytes")
# First validate if it's a rice leaf
is_rice, rice_confidence = is_rice_leaf(temp_file_path, leaf_validator_model, device)
# Remove temporary file
import os
os.remove(temp_file_path)
if not is_rice:
logging.info(f"Image not recognized as rice leaf: {rice_confidence * 100:.2f}% confidence")
return {
"is_rice_leaf": False,
"confidence": f"{rice_confidence * 100:.2f}%",
"message": "The uploaded image does not appear to be a rice leaf. Please upload a clear image of a rice plant leaf."
}
# Process image for disease detection
input_tensor = process_image(contents).to(device)
# Make prediction
with torch.no_grad():
outputs = model(input_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top predictions with probabilities
top_probs, top_classes = torch.topk(probabilities, len(CLASS_NAMES))
predictions = []
for i in range(len(top_classes)):
class_idx = top_classes[i].item()
predictions.append({
"label": CLASS_NAMES[class_idx],
"confidence": f"{top_probs[i].item()*100:.2f}%"
})
# Get primary prediction (highest confidence)
predicted_class = top_classes[0].item()
disease_name = CLASS_NAMES[predicted_class]
confidence = float(top_probs[0].item()) * 100
# Get recommendation - either AI or static
if use_ai_recommendation:
try:
recommendation_data = get_ollama_recommendation(disease_name, confidence)
except Exception as e:
logging.error(f"Failed to get AI recommendation: {e}")
recommendation_data = RECOMMENDATIONS.get(disease_name, {}) # Fallback
else:
recommendation_data = RECOMMENDATIONS.get(disease_name, {})
# Log prediction time
processing_time = time.time() - start_time
logging.info(f"Prediction completed in {processing_time:.3f} seconds")
# Return in format expected by PHP code
return {
"is_rice_leaf": True,
"leaf_confidence": f"{rice_confidence * 100:.2f}%",
"predictions": predictions,
"disease": disease_name,
"confidence": f"{confidence:.2f}%",
"recommendation": recommendation_data.get("recommendation", ""),
"details": recommendation_data.get("details", ""),
"recommendation_source": recommendation_data.get("source", "static"),
"inference_time_seconds": round(processing_time, 3)
}
except Exception as e:
logging.error(f"Error during prediction: {e}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
# Production settings: no hot reload, proper host binding
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)