Spaces:
Sleeping
Sleeping
File size: 4,926 Bytes
ff88581 b0fe12d 853a62c b0fe12d 2b0d8f9 b0fe12d ff88581 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import io
import numpy as np
import cv2
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import load_model
import os
app = FastAPI(
title="Flagship Wheat Disease Detection API",
description="High-performance inference API for wheat disease classification.",
version="1.0.0"
)
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all for dev; restrict in prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global Model Variable
model = None
CLASSES = ["Crown and Root Rot", "Healthy Wheat", "Leaf Rust", "Wheat Loose Smut"]
MODEL_PATH = os.path.join(os.path.dirname(__file__), "../models/flagship_model.keras")
@app.on_event("startup")
async def load_ml_model():
global model
print(f"Loading model from {MODEL_PATH}...")
try:
# Check if model exists
if os.path.exists(MODEL_PATH):
model = load_model(MODEL_PATH)
print("Model loaded successfully!")
else:
print(f"WARNING: Model file not found at {MODEL_PATH}. API will return mock predictions.")
except Exception as e:
print(f"ERROR: Failed to load model: {e}")
@app.get("/")
def read_root():
debug_info = {
"message": "Wheat Analysis Flagship API is running.",
"cwd": os.getcwd(),
"file": __file__,
"model_path": MODEL_PATH,
"model_exists": os.path.exists(MODEL_PATH),
"dir_structure": []
}
try:
# Try to capture why model is None
error_msg = "No error captured"
if model is None:
try:
# Attempt to load again to capture error
temp_model = load_model(MODEL_PATH)
error_msg = "Model loaded successfully on retry! (Global var was None)"
except Exception as e:
error_msg = f"Load Error: {str(e)}"
debug_info["model_variable_status"] = str(model) if model else "None"
debug_info["load_error"] = error_msg
models_dir = os.path.join(os.path.dirname(__file__), "../models")
if os.path.exists(models_dir):
files = []
for f in os.listdir(models_dir):
fp = os.path.join(models_dir, f)
size = os.path.getsize(fp)
files.append(f"{f} ({size} bytes)")
debug_info["dir_structure"] = files
else:
debug_info["dir_structure"] = "Models dir not found"
except Exception as e:
debug_info["dir_structure"] = str(e)
return debug_info
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
# Read image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Preprocess
img_np = np.array(image)
# Resize to 260x260 (EfficientNetV2B2 native)
img_resized = cv2.resize(img_np, (260, 260))
img_batch = np.expand_dims(img_resized, axis=0)
# Inference
if model:
predictions = model.predict(img_batch)
confidence_scores = predictions[0]
predicted_class_index = np.argmax(confidence_scores)
raw_confidence = float(confidence_scores[predicted_class_index])
# PROBABILITY THRESHOLD: Filter out non-wheat images
# If the model isn't at least 70% sure, we flag it.
THRESHOLD = 0.70
scores_dict = {cls: float(score) for cls, score in zip(CLASSES, confidence_scores)}
if raw_confidence < THRESHOLD:
result = {
"prediction": "Unknown / Not Wheat",
"confidence": raw_confidence,
"scores": scores_dict,
"alert": "Low confidence. Please ensure the image is a clear leaf photo."
}
else:
result = {
"prediction": CLASSES[predicted_class_index],
"confidence": raw_confidence,
"scores": scores_dict
}
else:
# Mock Response for Dev/Testing if model isn't trained yet
import random
mock_class = random.choice(CLASSES)
result = {
"prediction": mock_class,
"confidence": 0.95,
"scores": {c: (0.95 if c == mock_class else 0.01) for c in CLASSES},
"warning": "Mock prediction (Model not loaded)"
}
return result
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|