Spaces:
Running
Running
| import os | |
| import numpy as np | |
| import tensorflow as tf | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from io import BytesIO | |
| from PIL import Image | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| import tensorflow_addons as tfa | |
| import uvicorn | |
| import requests | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Register the custom object | |
| custom_objects = { | |
| "Addons>CohenKappa": tfa.metrics.CohenKappa, | |
| } | |
| # Model and class information | |
| model_path = "model.h5" | |
| class_labels = { | |
| 0: "Apple___Apple_scab", | |
| 1: "Apple___Black_rot", | |
| 2: "Apple___Cedar_apple_rust", | |
| 3: "Apple___healthy", | |
| 4: "Background_without_leaves", | |
| 5: "Blueberry___healthy", | |
| 6: "Cherry___Powdery_mildew", | |
| 7: "Cherry___healthy", | |
| 8: "Corn___Cercospora_leaf_spot Gray_leaf_spot", | |
| 9: "Corn___Common_rust_", | |
| 10: "Corn___Northern_Leaf_Blight", | |
| 11: "Corn___healthy", | |
| 12: "Grape___Black_rot", | |
| 13: "Grape___Esca_(Black_Measles)", | |
| 14: "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", | |
| 15: "Grape___healthy", | |
| 16: "Orange___Haunglongbing_(Citrus_greening)", | |
| 17: "Peach___Bacterial_spot", | |
| 18: "Peach___healthy", | |
| 19: "Pepper,_bell___Bacterial_spot", | |
| 20: "Pepper,_bell___healthy", | |
| 21: "Potato___Early_blight", | |
| 22: "Potato___Late_blight", | |
| 23: "Potato___healthy", | |
| 24: "Raspberry___healthy", | |
| 25: "Soybean___healthy", | |
| 26: "Squash___Powdery_mildew", | |
| 27: "Strawberry___Leaf_scorch", | |
| 28: "Strawberry___healthy", | |
| 29: "Tomato___Bacterial_spot", | |
| 30: "Tomato___Early_blight", | |
| 31: "Tomato___Late_blight", | |
| 32: "Tomato___Leaf_Mold", | |
| 33: "Tomato___Septoria_leaf_spot", | |
| 34: "Tomato___Spider_mites Two-spotted_spider_mite", | |
| 35: "Tomato___Target_Spot", | |
| 36: "Tomato___Tomato_Yellow_Leaf_Curl_Virus", | |
| 37: "Tomato___Tomato_mosaic_virus", | |
| 38: "Tomato___healthy" | |
| } | |
| # Load the model if it exists | |
| if os.path.exists(model_path): | |
| model = tf.keras.models.load_model(model_path, custom_objects=custom_objects) | |
| print("Model loaded successfully.") | |
| else: | |
| print(f"Model file not found at {model_path}. Please upload the model.") | |
| # Function to preprocess input image | |
| def preprocess_image(image_data, img_size=224): | |
| img = Image.open(BytesIO(image_data)) | |
| img = img.resize((img_size, img_size)) | |
| img_array = img_to_array(img) | |
| img_array = img_array / 255.0 | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| # Predict function | |
| def predict_image(image_data): | |
| preprocessed_image = preprocess_image(image_data) | |
| predictions = model.predict(preprocessed_image) | |
| class_idx = int(np.argmax(predictions, axis=1)[0]) # Convert to int for JSON serialization | |
| confidence = float(predictions[0][class_idx]) # Convert to float for JSON serialization | |
| class_label = class_labels.get(class_idx, "Unknown") | |
| # Fetch additional data from external API | |
| try: | |
| response = requests.get(f"https://navpan2-sarva-ai-back.hf.space/kotlinback/{class_label}") | |
| external_data = response.json() if response.status_code == 200 else {"error": "Failed to fetch external data"} | |
| except Exception as e: | |
| external_data = {"error": str(e)} | |
| return external_data | |
| # Route for health check | |
| async def api_health_check(): | |
| return JSONResponse(content={"status": "Service is running"}) | |
| # Route for prediction using image via API | |
| async def api_predict_image(file: UploadFile = File(...)): | |
| try: | |
| image_data = await file.read() | |
| prediction = predict_image(image_data) | |
| return JSONResponse(content={"prediction": prediction}) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}) | |
| # Run the FastAPI app | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |