Prasanta4's picture
Update app.py
2b9e141 verified
import torch
from torchvision import transforms
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from model import EfficientNetB0Hybrid
from PIL import Image
from io import BytesIO
import logging
import os
# ---------------------- Logging ----------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------- App Setup ----------------------
app = FastAPI(
title="Tea Disease Classification API",
description="API for classifying tea leaf diseases using EfficientNetB0Hybrid"
)
# Allow CORS for development/testing
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change to your domain in production
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"]
)
# ---------------------- Class Names ----------------------
class_names = [
'Algal Leaf',
'Brown Blight',
'Gray Blight',
'Healthy Leaf',
'Helopeltis',
'Mirid_Looper Bug',
'Red Spider',
]
# ---------------------- Device ----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# ---------------------- Model Loading ----------------------
model = None
def load_model():
global model
try:
model_path = "tea_proposed.pth"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file '{model_path}' not found")
model = EfficientNetB0Hybrid(
num_classes=len(class_names),
msfe_then_danet_indices=(6,),
danet_only_indices=(4,),
branch_out_ratio=0.33, # Fix for checkpoint alignment
drop_p=0.0,
use_pretrained=False
).to(device)
# Load checkpoint with safe fallback
checkpoint = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(checkpoint, strict=False)
if missing or unexpected:
logger.warning(f"Missing keys: {missing}, Unexpected keys: {unexpected}")
model.eval()
logger.info("✅ Model loaded successfully.")
return True
except Exception as e:
logger.error(f"❌ Error loading model: {str(e)}")
return False
model_loaded = load_model()
# ---------------------- Preprocessing ----------------------
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# ---------------------- Prediction ----------------------
def predict(image):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
with torch.no_grad():
img_tensor = preprocess(image).unsqueeze(0).to(device)
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
confidence = probs[0, pred_class].item()
return class_names[pred_class], confidence
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
# ---------------------- API Routes ----------------------
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
if not model_loaded:
raise HTTPException(status_code=500, detail="Model not loaded")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
contents = await file.read()
try:
image = Image.open(BytesIO(contents)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
pred_class, confidence = predict(image)
return {
"filename": file.filename,
"predicted_class": pred_class,
"confidence_score": round(confidence, 4)
}
@app.get("/")
async def root():
return {"message": "Welcome to the Tea Disease Classification API"}
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model_loaded else "unhealthy",
"device": str(device)
}
# ---------------------- Entry Point ----------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)