import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms from fastapi import FastAPI, File, UploadFile from PIL import Image import io import numpy as np # --------------------------- # App # --------------------------- app = FastAPI(title="Rice Leaf Disease Classification API") # --------------------------- # Device (CPU for HF Spaces) # --------------------------- device = torch.device("cpu") # --------------------------- # Load checkpoint # --------------------------- checkpoint = torch.load( "rice_leaf_model.pth", map_location=device ) class_names = checkpoint["class_names"] num_classes = checkpoint["num_classes"] # --------------------------- # Model # --------------------------- model = models.resnet18(pretrained=False) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, num_classes) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() print("✅ Model loaded successfully") # --------------------------- # Image Transform # --------------------------- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # --------------------------- # Helper: preprocess image # --------------------------- def preprocess_image(image: Image.Image): image = image.convert("RGB") image = transform(image) image = image.unsqueeze(0) return image.to(device) # --------------------------- # Routes # --------------------------- @app.get("/") def home(): return {"message": "Rice Leaf Disease API is running 🚀"} @app.post("/predict") async def predict(file: UploadFile = File(...)): image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) input_tensor = preprocess_image(image) with torch.no_grad(): outputs = model(input_tensor) probs = F.softmax(outputs, dim=1) confidence, predicted = torch.max(probs, 1) predicted_class = class_names[predicted.item()] return { "predicted_class": predicted_class, "confidence": round(confidence.item(), 4) }