import torch from torchvision import transforms from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from model import EfficientNetB0Hybrid from PIL import Image from io import BytesIO import logging import os # ---------------------- Logging ---------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ---------------------- App Setup ---------------------- app = FastAPI( title="Tea Disease Classification API", description="API for classifying tea leaf diseases using EfficientNetB0Hybrid" ) # Allow CORS for development/testing app.add_middleware( CORSMiddleware, allow_origins=["*"], # Change to your domain in production allow_credentials=True, allow_methods=["GET", "POST"], allow_headers=["*"] ) # ---------------------- Class Names ---------------------- class_names = [ 'Algal Leaf', 'Brown Blight', 'Gray Blight', 'Healthy Leaf', 'Helopeltis', 'Mirid_Looper Bug', 'Red Spider', ] # ---------------------- Device ---------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # ---------------------- Model Loading ---------------------- model = None def load_model(): global model try: model_path = "tea_proposed.pth" if not os.path.exists(model_path): raise FileNotFoundError(f"Model file '{model_path}' not found") model = EfficientNetB0Hybrid( num_classes=len(class_names), msfe_then_danet_indices=(6,), danet_only_indices=(4,), branch_out_ratio=0.33, # Fix for checkpoint alignment drop_p=0.0, use_pretrained=False ).to(device) # Load checkpoint with safe fallback checkpoint = torch.load(model_path, map_location=device) missing, unexpected = model.load_state_dict(checkpoint, strict=False) if missing or unexpected: logger.warning(f"Missing keys: {missing}, Unexpected keys: {unexpected}") model.eval() logger.info("✅ Model loaded successfully.") return True except Exception as e: logger.error(f"❌ Error loading model: {str(e)}") return False model_loaded = load_model() # ---------------------- Preprocessing ---------------------- preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ---------------------- Prediction ---------------------- def predict(image): if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: with torch.no_grad(): img_tensor = preprocess(image).unsqueeze(0).to(device) outputs = model(img_tensor) probs = torch.softmax(outputs, dim=1) pred_class = torch.argmax(probs, dim=1).item() confidence = probs[0, pred_class].item() return class_names[pred_class], confidence except Exception as e: logger.error(f"Prediction error: {str(e)}") raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") # ---------------------- API Routes ---------------------- @app.post("/predict") async def predict_image(file: UploadFile = File(...)): if not model_loaded: raise HTTPException(status_code=500, detail="Model not loaded") if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") contents = await file.read() try: image = Image.open(BytesIO(contents)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image file") pred_class, confidence = predict(image) return { "filename": file.filename, "predicted_class": pred_class, "confidence_score": round(confidence, 4) } @app.get("/") async def root(): return {"message": "Welcome to the Tea Disease Classification API"} @app.get("/health") async def health_check(): return { "status": "healthy" if model_loaded else "unhealthy", "device": str(device) } # ---------------------- Entry Point ---------------------- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)