sedtha's picture
Update app.py
6a7584c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import io
import numpy as np
# ---------------------------
# App
# ---------------------------
app = FastAPI(title="Rice Leaf Disease Classification API")
# ---------------------------
# Device (CPU for HF Spaces)
# ---------------------------
device = torch.device("cpu")
# ---------------------------
# Load checkpoint
# ---------------------------
checkpoint = torch.load(
"rice_leaf_model.pth",
map_location=device
)
class_names = checkpoint["class_names"]
num_classes = checkpoint["num_classes"]
# ---------------------------
# Model
# ---------------------------
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
print("✅ Model loaded successfully")
# ---------------------------
# Image Transform
# ---------------------------
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# ---------------------------
# Helper: preprocess image
# ---------------------------
def preprocess_image(image: Image.Image):
image = image.convert("RGB")
image = transform(image)
image = image.unsqueeze(0)
return image.to(device)
# ---------------------------
# Routes
# ---------------------------
@app.get("/")
def home():
return {"message": "Rice Leaf Disease API is running 🚀"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
input_tensor = preprocess_image(image)
with torch.no_grad():
outputs = model(input_tensor)
probs = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probs, 1)
predicted_class = class_names[predicted.item()]
return {
"predicted_class": predicted_class,
"confidence": round(confidence.item(), 4)
}