File size: 4,469 Bytes
01b6023
 
 
 
 
 
 
 
 
 
2b9e141
01b6023
 
 
2b9e141
 
 
 
 
01b6023
2b9e141
01b6023
 
2b9e141
01b6023
 
 
 
 
2b9e141
01b6023
 
 
 
 
 
 
 
 
 
2b9e141
01b6023
2b9e141
01b6023
2b9e141
01b6023
 
 
 
 
 
 
 
 
2b9e141
 
 
 
 
 
 
 
 
 
01b6023
2b9e141
 
 
 
01b6023
2b9e141
01b6023
 
2b9e141
01b6023
 
 
 
2b9e141
01b6023
 
 
 
 
 
 
2b9e141
01b6023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b9e141
01b6023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b9e141
 
 
 
 
01b6023
 
 
 
 
 
 
 
 
 
 
 
2b9e141
01b6023
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from torchvision import transforms
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from model import EfficientNetB0Hybrid
from PIL import Image
from io import BytesIO
import logging
import os

# ---------------------- Logging ----------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ---------------------- App Setup ----------------------
app = FastAPI(
    title="Tea Disease Classification API",
    description="API for classifying tea leaf diseases using EfficientNetB0Hybrid"
)

# Allow CORS for development/testing
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change to your domain in production
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["*"]
)

# ---------------------- Class Names ----------------------
class_names = [
    'Algal Leaf',
    'Brown Blight',
    'Gray Blight',
    'Healthy Leaf',
    'Helopeltis',
    'Mirid_Looper Bug',
    'Red Spider',
]

# ---------------------- Device ----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# ---------------------- Model Loading ----------------------
model = None

def load_model():
    global model
    try:
        model_path = "tea_proposed.pth"
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file '{model_path}' not found")

        model = EfficientNetB0Hybrid(
            num_classes=len(class_names),
            msfe_then_danet_indices=(6,),
            danet_only_indices=(4,),
            branch_out_ratio=0.33,  # Fix for checkpoint alignment
            drop_p=0.0,
            use_pretrained=False
        ).to(device)

        # Load checkpoint with safe fallback
        checkpoint = torch.load(model_path, map_location=device)
        missing, unexpected = model.load_state_dict(checkpoint, strict=False)
        if missing or unexpected:
            logger.warning(f"Missing keys: {missing}, Unexpected keys: {unexpected}")

        model.eval()
        logger.info("✅ Model loaded successfully.")
        return True
    except Exception as e:
        logger.error(f"❌ Error loading model: {str(e)}")
        return False

model_loaded = load_model()

# ---------------------- Preprocessing ----------------------
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ---------------------- Prediction ----------------------
def predict(image):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    try:
        with torch.no_grad():
            img_tensor = preprocess(image).unsqueeze(0).to(device)
            outputs = model(img_tensor)
            probs = torch.softmax(outputs, dim=1)
            pred_class = torch.argmax(probs, dim=1).item()
            confidence = probs[0, pred_class].item()
            return class_names[pred_class], confidence
    except Exception as e:
        logger.error(f"Prediction error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")

# ---------------------- API Routes ----------------------
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
    if not model_loaded:
        raise HTTPException(status_code=500, detail="Model not loaded")

    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File must be an image")

    contents = await file.read()
    try:
        image = Image.open(BytesIO(contents)).convert("RGB")
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid image file")

    pred_class, confidence = predict(image)
    return {
        "filename": file.filename,
        "predicted_class": pred_class,
        "confidence_score": round(confidence, 4)
    }

@app.get("/")
async def root():
    return {"message": "Welcome to the Tea Disease Classification API"}

@app.get("/health")
async def health_check():
    return {
        "status": "healthy" if model_loaded else "unhealthy",
        "device": str(device)
    }

# ---------------------- Entry Point ----------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)