Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models, transforms | |
| from fastapi import FastAPI, File, UploadFile | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| # --------------------------- | |
| # App | |
| # --------------------------- | |
| app = FastAPI(title="Rice Leaf Disease Classification API") | |
| # --------------------------- | |
| # Device (CPU for HF Spaces) | |
| # --------------------------- | |
| device = torch.device("cpu") | |
| # --------------------------- | |
| # Load checkpoint | |
| # --------------------------- | |
| checkpoint = torch.load( | |
| "rice_leaf_model.pth", | |
| map_location=device | |
| ) | |
| class_names = checkpoint["class_names"] | |
| num_classes = checkpoint["num_classes"] | |
| # --------------------------- | |
| # Model | |
| # --------------------------- | |
| model = models.resnet18(pretrained=False) | |
| num_ftrs = model.fc.in_features | |
| model.fc = nn.Linear(num_ftrs, num_classes) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| print("✅ Model loaded successfully") | |
| # --------------------------- | |
| # Image Transform | |
| # --------------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # --------------------------- | |
| # Helper: preprocess image | |
| # --------------------------- | |
| def preprocess_image(image: Image.Image): | |
| image = image.convert("RGB") | |
| image = transform(image) | |
| image = image.unsqueeze(0) | |
| return image.to(device) | |
| # --------------------------- | |
| # Routes | |
| # --------------------------- | |
| def home(): | |
| return {"message": "Rice Leaf Disease API is running 🚀"} | |
| async def predict(file: UploadFile = File(...)): | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| input_tensor = preprocess_image(image) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probs = F.softmax(outputs, dim=1) | |
| confidence, predicted = torch.max(probs, 1) | |
| predicted_class = class_names[predicted.item()] | |
| return { | |
| "predicted_class": predicted_class, | |
| "confidence": round(confidence.item(), 4) | |
| } | |