rice-scanner / app.py
NickNam2710's picture
add server
e4e0607
import io
import uvicorn
import numpy as np
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
import tensorflow as tf
from task_1_model import load_task_1_model
from task_2_model import load_task_2_model
from task_3_model import load_task_3_model
# Load models
disease_model = load_task_1_model("task_1_model_inception.keras")
variety_model = load_task_2_model("task_2_model_inception.keras")
age_model = load_task_3_model("task_3_ensemble_model_og_data.keras")
# Define class names for each task in the same order used during training
DISEASE_CLASSES = ['Bacterial leaf blight', 'Bacterial leaf streak', 'Bacterial panicle blight', 'Blast', 'Brown spot', 'Dead heart', 'Downy mildew', 'Hispa', 'Normal', 'Tungro']
VARIETY_CLASSES = ['ADT45', 'AndraPonni', 'AtchayaPonni', 'IR20', 'KarnatakaPonni', 'Onthanel', 'Ponni', 'RR', 'Surya', 'Zonal']
app = FastAPI()
def preprocess_image(contents: bytes, target_size=(256, 256)):
image = Image.open(io.BytesIO(contents)).convert("RGB")
img_resized = image.resize(target_size)
return np.expand_dims(img_resized, axis=0)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# Read and preprocess
contents = await file.read()
img_batch = preprocess_image(contents)
# Run inference on each model
variety_preds = variety_model.predict(img_batch)[0]
age_preds = age_model.predict(img_batch)[0]
disease_preds = disease_model.predict(img_batch)[0]
# Get top predictions
variety_idx = int(np.argmax(variety_preds))
disease_idx = int(np.argmax(disease_preds))
age = int(np.round(age_preds))
result = {
"variety": {
"class": VARIETY_CLASSES[variety_idx],
"confidence": float(variety_preds[variety_idx])
},
"disease": {
"class": DISEASE_CLASSES[disease_idx],
"confidence": float(disease_preds[disease_idx])
},
"age": age,
}
return JSONResponse(content=result)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)