|
|
import torch |
|
|
from torchvision import transforms |
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from model import EfficientNetB0Hybrid |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import logging |
|
|
import os |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Tea Disease Classification API", |
|
|
description="API for classifying tea leaf diseases using EfficientNetB0Hybrid" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["GET", "POST"], |
|
|
allow_headers=["*"] |
|
|
) |
|
|
|
|
|
|
|
|
class_names = [ |
|
|
'Algal Leaf', |
|
|
'Brown Blight', |
|
|
'Gray Blight', |
|
|
'Healthy Leaf', |
|
|
'Helopeltis', |
|
|
'Mirid_Looper Bug', |
|
|
'Red Spider', |
|
|
] |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
def load_model(): |
|
|
global model |
|
|
try: |
|
|
model_path = "tea_proposed.pth" |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"Model file '{model_path}' not found") |
|
|
|
|
|
model = EfficientNetB0Hybrid( |
|
|
num_classes=len(class_names), |
|
|
msfe_then_danet_indices=(6,), |
|
|
danet_only_indices=(4,), |
|
|
branch_out_ratio=0.33, |
|
|
drop_p=0.0, |
|
|
use_pretrained=False |
|
|
).to(device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
missing, unexpected = model.load_state_dict(checkpoint, strict=False) |
|
|
if missing or unexpected: |
|
|
logger.warning(f"Missing keys: {missing}, Unexpected keys: {unexpected}") |
|
|
|
|
|
model.eval() |
|
|
logger.info("✅ Model loaded successfully.") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error loading model: {str(e)}") |
|
|
return False |
|
|
|
|
|
model_loaded = load_model() |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], |
|
|
[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
def predict(image): |
|
|
if model is None: |
|
|
raise HTTPException(status_code=500, detail="Model not loaded") |
|
|
try: |
|
|
with torch.no_grad(): |
|
|
img_tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
outputs = model(img_tensor) |
|
|
probs = torch.softmax(outputs, dim=1) |
|
|
pred_class = torch.argmax(probs, dim=1).item() |
|
|
confidence = probs[0, pred_class].item() |
|
|
return class_names[pred_class], confidence |
|
|
except Exception as e: |
|
|
logger.error(f"Prediction error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict_image(file: UploadFile = File(...)): |
|
|
if not model_loaded: |
|
|
raise HTTPException(status_code=500, detail="Model not loaded") |
|
|
|
|
|
if not file.content_type.startswith("image/"): |
|
|
raise HTTPException(status_code=400, detail="File must be an image") |
|
|
|
|
|
contents = await file.read() |
|
|
try: |
|
|
image = Image.open(BytesIO(contents)).convert("RGB") |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image file") |
|
|
|
|
|
pred_class, confidence = predict(image) |
|
|
return { |
|
|
"filename": file.filename, |
|
|
"predicted_class": pred_class, |
|
|
"confidence_score": round(confidence, 4) |
|
|
} |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Welcome to the Tea Disease Classification API"} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return { |
|
|
"status": "healthy" if model_loaded else "unhealthy", |
|
|
"device": str(device) |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|