fyp / app.py
Muhammad Saleem
Update app.py
b50af6d verified
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import torch
from PIL import Image
from io import BytesIO
import torchvision.transforms as transforms
from branchedcnn import BranchedCNN
from utils import predict_image, get_top3, generate_and_upload_saliency_images
import cloudinary
import cloudinary.uploader
import os
# ============ CONFIG ============
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class_names = ['Corn__Common_Rust',
'Corn__Gray_Leaf_Spot',
'Corn__Healthy',
'Corn__Leaf_Blight',
'Potato_Early_Blight',
'Potato_Healthy',
'Potato_Late_Blight',
'Tomato__Early_Blight',
'Tomato__Healthy',
'Tomato__Late_Blight',
'Wheat__brown_rust',
'Wheat__healthy',
'Wheat__yellow_rust']
classes = [
{'crop': 'Corn', 'disease': 'Common_Rust'},
{'crop': 'Corn', 'disease': 'Gray_Leaf_Spot'},
{'crop': 'Corn', 'disease': 'Healthy'},
{'crop': 'Corn', 'disease': 'Leaf_Blight'},
{'crop': 'Potato', 'disease': 'Early_Blight'},
{'crop': 'Potato', 'disease': 'Healthy'},
{'crop': 'Potato', 'disease': 'Late_Blight'},
{'crop': 'Tomato', 'disease': 'Early_Blight'},
{'crop': 'Tomato', 'disease': 'Healthy'},
{'crop': 'Tomato', 'disease': 'Late_Blight'},
{'crop': 'Wheat', 'disease': 'brown_rust'},
{'crop': 'Wheat', 'disease': 'healthy'},
{'crop': 'Wheat', 'disease': 'yellow_rust'}
]
model = BranchedCNN(num_classes=len(class_names))
model.load_state_dict(torch.load("TheHulkNet.pth", map_location=torch.device('cpu')))
model.eval()
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# ============ CLOUDINARY CONFIG ============
cloudinary.config(
cloud_name="dyqklqdp1",
api_key="561197523251247",
api_secret="_0CexqYMQ60r6aJNthJPrcFhcp4",
secure=True
)
# ============ ROUTES ============
@app.get("/")
def root():
return {"message": "BranchedCNN PyTorch API running!"}
@app.post("/predict")
def predict(file: UploadFile = File(...)):
try:
image = Image.open(BytesIO(file.file.read())).convert("RGB")
image_tensor = transform(image).unsqueeze(0)
# Get prediction probabilities
probs = predict_image(model, image_tensor, class_names)
top3 = get_top3(probs, class_names)
# Generate & upload 3 saliency images
saliency_urls = generate_and_upload_saliency_images(model, image_tensor, top3[0]['index'])
print(top3)
return JSONResponse({
"top_3_predictions": [
{'confidence': f"{top3[0]['confidence'] * 100:.2f}%", 'crop': classes[top3[0]['index']]['crop'], 'disease':classes[top3[0]['index']]['disease'] },
{'confidence': f"{top3[1]['confidence'] * 100:.2f}%", 'crop': classes[top3[1]['index']]['crop'], 'disease':classes[top3[1]['index']]['disease']},
{'confidence': f"{top3[2]['confidence'] * 100:.2f}%", 'crop': classes[top3[2]['index']]['crop'], 'disease':classes[top3[2]['index']]['disease']}],
"ai_focus_map_url": saliency_urls['saliency'],
"overlay_url": saliency_urls['overlay'],
"original_image_url": saliency_urls['original'],
"confidence": f"{top3[0]['confidence'] * 100:.2f}%",
"predicted_class": classes[top3[0]['index']]['disease'],
"predicted_crop": classes[top3[0]['index']]['crop'],
})
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})