File size: 4,926 Bytes
ff88581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0fe12d
 
 
 
 
 
 
 
 
853a62c
 
 
 
 
 
 
 
 
 
 
 
 
b0fe12d
 
2b0d8f9
 
 
 
 
 
b0fe12d
 
 
 
 
 
ff88581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import io
import numpy as np
import cv2
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import load_model
import os

app = FastAPI(
    title="Flagship Wheat Disease Detection API",
    description="High-performance inference API for wheat disease classification.",
    version="1.0.0"
)

# CORS Configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # Allow all for dev; restrict in prod
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global Model Variable
model = None
CLASSES = ["Crown and Root Rot", "Healthy Wheat", "Leaf Rust", "Wheat Loose Smut"]
MODEL_PATH = os.path.join(os.path.dirname(__file__), "../models/flagship_model.keras")

@app.on_event("startup")
async def load_ml_model():
    global model
    print(f"Loading model from {MODEL_PATH}...")
    try:
        # Check if model exists
        if os.path.exists(MODEL_PATH):
            model = load_model(MODEL_PATH)
            print("Model loaded successfully!")
        else:
            print(f"WARNING: Model file not found at {MODEL_PATH}. API will return mock predictions.")
    except Exception as e:
        print(f"ERROR: Failed to load model: {e}")

@app.get("/")
def read_root():
    debug_info = {
        "message": "Wheat Analysis Flagship API is running.",
        "cwd": os.getcwd(),
        "file": __file__,
        "model_path": MODEL_PATH,
        "model_exists": os.path.exists(MODEL_PATH),
        "dir_structure": []
    }
    try:
        # Try to capture why model is None
        error_msg = "No error captured"
        if model is None:
            try:
                # Attempt to load again to capture error
                temp_model = load_model(MODEL_PATH)
                error_msg = "Model loaded successfully on retry! (Global var was None)"
            except Exception as e:
                error_msg = f"Load Error: {str(e)}"
        
        debug_info["model_variable_status"] = str(model) if model else "None"
        debug_info["load_error"] = error_msg
        
        models_dir = os.path.join(os.path.dirname(__file__), "../models")
        if os.path.exists(models_dir):
            files = []
            for f in os.listdir(models_dir):
                fp = os.path.join(models_dir, f)
                size = os.path.getsize(fp)
                files.append(f"{f} ({size} bytes)")
            debug_info["dir_structure"] = files
        else:
            debug_info["dir_structure"] = "Models dir not found"
    except Exception as e:
        debug_info["dir_structure"] = str(e)
        
    return debug_info

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File must be an image")

    # Read image
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    
    # Preprocess
    img_np = np.array(image)
    # Resize to 260x260 (EfficientNetV2B2 native)
    img_resized = cv2.resize(img_np, (260, 260))
    img_batch = np.expand_dims(img_resized, axis=0)
    
    # Inference
    if model:
        predictions = model.predict(img_batch)
        confidence_scores = predictions[0]
        predicted_class_index = np.argmax(confidence_scores)
        raw_confidence = float(confidence_scores[predicted_class_index])
        
        # PROBABILITY THRESHOLD: Filter out non-wheat images
        # If the model isn't at least 70% sure, we flag it.
        THRESHOLD = 0.70
        
        scores_dict = {cls: float(score) for cls, score in zip(CLASSES, confidence_scores)}
        
        if raw_confidence < THRESHOLD:
            result = {
                "prediction": "Unknown / Not Wheat",
                "confidence": raw_confidence,
                "scores": scores_dict,
                "alert": "Low confidence. Please ensure the image is a clear leaf photo."
            }
        else:
            result = {
                "prediction": CLASSES[predicted_class_index],
                "confidence": raw_confidence,
                "scores": scores_dict
            }
    else:
        # Mock Response for Dev/Testing if model isn't trained yet
        import random
        mock_class = random.choice(CLASSES)
        result = {
            "prediction": mock_class,
            "confidence": 0.95,
            "scores": {c: (0.95 if c == mock_class else 0.01) for c in CLASSES},
            "warning": "Mock prediction (Model not loaded)"
        }
        
    return result

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)