Spaces:
Sleeping
Sleeping
| import io | |
| import uvicorn | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import tensorflow as tf | |
| from task_1_model import load_task_1_model | |
| from task_2_model import load_task_2_model | |
| from task_3_model import load_task_3_model | |
| # Load models | |
| disease_model = load_task_1_model("task_1_model_inception.keras") | |
| variety_model = load_task_2_model("task_2_model_inception.keras") | |
| age_model = load_task_3_model("task_3_ensemble_model_og_data.keras") | |
| # Define class names for each task in the same order used during training | |
| DISEASE_CLASSES = ['Bacterial leaf blight', 'Bacterial leaf streak', 'Bacterial panicle blight', 'Blast', 'Brown spot', 'Dead heart', 'Downy mildew', 'Hispa', 'Normal', 'Tungro'] | |
| VARIETY_CLASSES = ['ADT45', 'AndraPonni', 'AtchayaPonni', 'IR20', 'KarnatakaPonni', 'Onthanel', 'Ponni', 'RR', 'Surya', 'Zonal'] | |
| app = FastAPI() | |
| def preprocess_image(contents: bytes, target_size=(256, 256)): | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| img_resized = image.resize(target_size) | |
| return np.expand_dims(img_resized, axis=0) | |
| async def predict(file: UploadFile = File(...)): | |
| # Read and preprocess | |
| contents = await file.read() | |
| img_batch = preprocess_image(contents) | |
| # Run inference on each model | |
| variety_preds = variety_model.predict(img_batch)[0] | |
| age_preds = age_model.predict(img_batch)[0] | |
| disease_preds = disease_model.predict(img_batch)[0] | |
| # Get top predictions | |
| variety_idx = int(np.argmax(variety_preds)) | |
| disease_idx = int(np.argmax(disease_preds)) | |
| age = int(np.round(age_preds)) | |
| result = { | |
| "variety": { | |
| "class": VARIETY_CLASSES[variety_idx], | |
| "confidence": float(variety_preds[variety_idx]) | |
| }, | |
| "disease": { | |
| "class": DISEASE_CLASSES[disease_idx], | |
| "confidence": float(disease_preds[disease_idx]) | |
| }, | |
| "age": age, | |
| } | |
| return JSONResponse(content=result) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |