Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import torch | |
| from PIL import Image | |
| from io import BytesIO | |
| import torchvision.transforms as transforms | |
| from branchedcnn import BranchedCNN | |
| from utils import predict_image, get_top3, generate_and_upload_saliency_images | |
| import cloudinary | |
| import cloudinary.uploader | |
| import os | |
| # ============ CONFIG ============ | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class_names = ['Corn__Common_Rust', | |
| 'Corn__Gray_Leaf_Spot', | |
| 'Corn__Healthy', | |
| 'Corn__Leaf_Blight', | |
| 'Potato_Early_Blight', | |
| 'Potato_Healthy', | |
| 'Potato_Late_Blight', | |
| 'Tomato__Early_Blight', | |
| 'Tomato__Healthy', | |
| 'Tomato__Late_Blight', | |
| 'Wheat__brown_rust', | |
| 'Wheat__healthy', | |
| 'Wheat__yellow_rust'] | |
| classes = [ | |
| {'crop': 'Corn', 'disease': 'Common_Rust'}, | |
| {'crop': 'Corn', 'disease': 'Gray_Leaf_Spot'}, | |
| {'crop': 'Corn', 'disease': 'Healthy'}, | |
| {'crop': 'Corn', 'disease': 'Leaf_Blight'}, | |
| {'crop': 'Potato', 'disease': 'Early_Blight'}, | |
| {'crop': 'Potato', 'disease': 'Healthy'}, | |
| {'crop': 'Potato', 'disease': 'Late_Blight'}, | |
| {'crop': 'Tomato', 'disease': 'Early_Blight'}, | |
| {'crop': 'Tomato', 'disease': 'Healthy'}, | |
| {'crop': 'Tomato', 'disease': 'Late_Blight'}, | |
| {'crop': 'Wheat', 'disease': 'brown_rust'}, | |
| {'crop': 'Wheat', 'disease': 'healthy'}, | |
| {'crop': 'Wheat', 'disease': 'yellow_rust'} | |
| ] | |
| model = BranchedCNN(num_classes=len(class_names)) | |
| model.load_state_dict(torch.load("TheHulkNet.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| # ============ CLOUDINARY CONFIG ============ | |
| cloudinary.config( | |
| cloud_name="dyqklqdp1", | |
| api_key="561197523251247", | |
| api_secret="_0CexqYMQ60r6aJNthJPrcFhcp4", | |
| secure=True | |
| ) | |
| # ============ ROUTES ============ | |
| def root(): | |
| return {"message": "BranchedCNN PyTorch API running!"} | |
| def predict(file: UploadFile = File(...)): | |
| try: | |
| image = Image.open(BytesIO(file.file.read())).convert("RGB") | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Get prediction probabilities | |
| probs = predict_image(model, image_tensor, class_names) | |
| top3 = get_top3(probs, class_names) | |
| # Generate & upload 3 saliency images | |
| saliency_urls = generate_and_upload_saliency_images(model, image_tensor, top3[0]['index']) | |
| print(top3) | |
| return JSONResponse({ | |
| "top_3_predictions": [ | |
| {'confidence': f"{top3[0]['confidence'] * 100:.2f}%", 'crop': classes[top3[0]['index']]['crop'], 'disease':classes[top3[0]['index']]['disease'] }, | |
| {'confidence': f"{top3[1]['confidence'] * 100:.2f}%", 'crop': classes[top3[1]['index']]['crop'], 'disease':classes[top3[1]['index']]['disease']}, | |
| {'confidence': f"{top3[2]['confidence'] * 100:.2f}%", 'crop': classes[top3[2]['index']]['crop'], 'disease':classes[top3[2]['index']]['disease']}], | |
| "ai_focus_map_url": saliency_urls['saliency'], | |
| "overlay_url": saliency_urls['overlay'], | |
| "original_image_url": saliency_urls['original'], | |
| "confidence": f"{top3[0]['confidence'] * 100:.2f}%", | |
| "predicted_class": classes[top3[0]['index']]['disease'], | |
| "predicted_crop": classes[top3[0]['index']]['crop'], | |
| }) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |