Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import tensorflow as tf | |
| import numpy as np | |
| import os | |
| import requests | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing import image | |
| from tensorflow.keras.layers import Layer, Conv2D, Softmax, Concatenate | |
| import shutil | |
| import uvicorn | |
| app = FastAPI() | |
| # Directory where models are stored | |
| MODEL_DIRECTORY = "dsanet_models" | |
| # Temporary directory for uploaded files | |
| TMP_DIR = os.getenv("TMP_DIR", "/app/temp") | |
| os.makedirs(TMP_DIR, exist_ok=True) # Ensure the temp directory exists | |
| # Plant disease class names | |
| plant_disease_dict = { | |
| "Rice": ['Blight', 'Brown_Spots'], | |
| "Tomato": ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', | |
| 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', | |
| 'Tomato___Spider_mites Two-spotted_spider_mite', | |
| 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', | |
| 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'], | |
| "Strawberry": ['Strawberry___Leaf_scorch', 'Strawberry___healthy'], | |
| "Potato": ['Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy'], | |
| "Pepperbell": ['Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy'], | |
| "Peach": ['Peach___Bacterial_spot', 'Peach___healthy'], | |
| "Grape": ['Grape___Black_rot', 'Grape___Esca_(Black_Measles)', | |
| 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy'], | |
| "Apple": ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy'], | |
| "Cherry": ['Cherry___Powdery_mildew', 'Cherry___healthy'], | |
| "Corn": ['Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', | |
| 'Corn___Northern_Leaf_Blight', 'Corn___healthy'],"Blueberry":["okk"] | |
| } | |
| # Custom Self-Attention Layer | |
| class SelfAttention(Layer): | |
| def __init__(self, reduction_ratio=2, **kwargs): | |
| super(SelfAttention, self).__init__(**kwargs) | |
| self.reduction_ratio = reduction_ratio | |
| def build(self, input_shape): | |
| n_channels = input_shape[-1] // self.reduction_ratio | |
| self.query_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
| self.key_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
| self.value_conv = Conv2D(n_channels, kernel_size=1, use_bias=False) | |
| super(SelfAttention, self).build(input_shape) | |
| def call(self, inputs): | |
| query = self.query_conv(inputs) | |
| key = self.key_conv(inputs) | |
| value = self.value_conv(inputs) | |
| # Calculate attention scores | |
| attention_scores = tf.matmul(query, key, transpose_b=True) | |
| attention_scores = Softmax(axis=1)(attention_scores) | |
| # Apply attention to values | |
| attended_value = tf.matmul(attention_scores, value) | |
| concatenated_output = Concatenate(axis=-1)([inputs, attended_value]) | |
| return concatenated_output | |
| def get_config(self): | |
| config = super(SelfAttention, self).get_config() | |
| config.update({"reduction_ratio": self.reduction_ratio}) | |
| return config | |
| # **Load all models into memory at startup** | |
| loaded_models = {} | |
| def load_all_models(): | |
| """ | |
| Load all models from the `dsanet_models` directory at startup. | |
| """ | |
| global loaded_models | |
| for plant_name in plant_disease_dict.keys(): | |
| model_path = os.path.join(MODEL_DIRECTORY, f"model_{plant_name}.keras") | |
| if os.path.isfile(model_path): | |
| try: | |
| if plant_name == "Rice": | |
| loaded_models[plant_name] = load_model(model_path) # Load normally | |
| else: | |
| loaded_models[plant_name] = load_model(model_path, custom_objects={"SelfAttention": SelfAttention}) | |
| print(f"✅ Model for {plant_name} loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Error loading model '{plant_name}': {e}") | |
| else: | |
| print(f"⚠ Warning: Model file '{model_path}' not found!") | |
| # Load models at startup | |
| load_all_models() | |
| async def api_health_check(): | |
| return JSONResponse(content={"status": "Service is running"}) | |
| async def predict_plant_disease(plant_name: str, file: UploadFile = File(...),lang: str = "en"): | |
| """ | |
| API endpoint to predict plant disease from an uploaded image. | |
| Args: | |
| plant_name (str): The plant type (must match a key in `plant_disease_dict`). | |
| file (UploadFile): The image file uploaded by the user. | |
| Returns: | |
| JSON response with the predicted class and additional details from an external API. | |
| """ | |
| # Ensure the plant name is valid | |
| if len(plant_disease_dict.get(plant_name, [])) == 1: | |
| single_disease = plant_disease_dict[plant_name][0] # Get the only class available | |
| # 🔥 Fetch external data directly | |
| try: | |
| response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{single_disease}") | |
| external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} | |
| except Exception as e: | |
| external_data = {"error": str(e)} | |
| return JSONResponse(content={ | |
| "plantName": external_data.get("plantName", plant_name), | |
| "botanicalName": external_data.get("botanicalName", "Unknown"), | |
| "diseaseDesc": { | |
| "diseaseName": external_data.get("diseaseDesc", {}).get("diseaseName", single_disease), | |
| "symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"), | |
| "diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available") | |
| }, | |
| "diseaseRemedyList": [ | |
| { | |
| "title": remedy.get("title", "Unknown"), | |
| "diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"), | |
| "diseaseRemedy": remedy.get("diseaseRemedy", "Not Available") | |
| } for remedy in external_data.get("diseaseRemedyList", []) | |
| ] | |
| }) | |
| if plant_name not in loaded_models: | |
| raise HTTPException(status_code=400, detail=f"Invalid plant name or model not loaded: {plant_name}") | |
| # Save uploaded file temporarily | |
| temp_path = os.path.join(TMP_DIR, file.filename) | |
| with open(temp_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| try: | |
| # Retrieve the preloaded model | |
| model = loaded_models[plant_name] | |
| # Load and preprocess the image | |
| img = image.load_img(temp_path, target_size=(224, 224)) | |
| img_array = image.img_to_array(img) | |
| img_array = np.expand_dims(img_array, axis=0) # Expand dimensions for model input | |
| img_array = img_array / 255.0 # Normalize | |
| # Make prediction | |
| prediction = model.predict(img_array) | |
| class_label = plant_disease_dict[plant_name][np.argmax(prediction)] | |
| # Fetch additional data from external API | |
| try: | |
| response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}?lang={lang}") | |
| external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} | |
| except Exception as e: | |
| external_data = {"error": str(e)} | |
| return JSONResponse(content={ | |
| "plantName": external_data.get("plantName", plant_name), | |
| "botanicalName": external_data.get("botanicalName", "Unknown"), | |
| "diseaseDesc": {"diseaseName":external_data.get("diseaseDesc", {}).get("diseaseName", class_label), | |
| "symptoms": external_data.get("diseaseDesc", {}).get("symptoms", "Not Available"), | |
| "diseaseCauses": external_data.get("diseaseDesc", {}).get("diseaseCauses", "Not Available")}, | |
| "diseaseRemedyList": [ | |
| { | |
| "title": remedy.get("title", "Unknown"), | |
| "diseaseRemedyShortDesc": remedy.get("diseaseRemedyShortDesc", "Not Available"), | |
| "diseaseRemedy": remedy.get("diseaseRemedy", "Not Available") | |
| } for remedy in external_data.get("diseaseRemedyList", []) | |
| ] | |
| }) | |
| # return JSONResponse(content={ | |
| # "plant": plant_name, | |
| # "predicted_disease": class_label, | |
| # "external_data": external_data | |
| # }) | |
| finally: | |
| # Clean up temporary file | |
| os.remove(temp_path) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |