from flask import Flask, request, jsonify from flask_cors import CORS import torch from torchvision import transforms import torchvision.transforms from PIL import Image from io import BytesIO import base64 from typing import List, Tuple import os import gdown from timm.models.resnet import ResNet # <- safe load app = Flask(__name__) CORS(app) # Device device = 'cuda' if torch.cuda.is_available() else 'cpu' # Class names class_names = ['Apple_Apple_scab', 'Apple_Black_rot', 'Apple_Cedar_apple_rust', 'Apple_healthy', 'Blueberry_healthy', 'Cherry(including_sour)Powdery_mildew', 'Cherry(including_sour)healthy', 'Corn(maize)Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)Common_rust', 'Corn(maize)Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape_Black_rot', 'Grape_Esca(Black_Measles)', 'Grape_Leaf_blight(Isariopsis_Leaf_Spot)', 'Grapehealthy', 'Orange_Haunglongbing(Citrus_greening)', 'PeachBacterial_spot', 'Peach_healthy', 'Pepper,_bell_Bacterial_spot', 'Pepper,_bell_healthy', 'Potato_Early_blight', 'Potato_Late_blight', 'Potato_healthy', 'Raspberry_healthy', 'Soybean_healthy', 'Squash_Powdery_mildew', 'Strawberry_Leaf_scorch', 'Strawberry_healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites Two-spotted_spider_mite', 'Tomato_Target_Spot', 'Tomato_Tomato_Yellow_Leaf_Curl_Virus', 'Tomato_Tomato_mosaic_virus', 'Tomato_healthy'] # Download model if not present model_path = "full_model.pth" model_drive_url = "https://drive.google.com/uc?id=1DXpL1anOs6943Ifj1Uno7_4nd99RjGU3" if not os.path.exists(model_path): print("Downloading model from Google Drive...") gdown.download(model_drive_url, model_path, quiet=False) # Load model safely with torch.serialization.safe_globals([ResNet]): model = torch.load(model_path, weights_only=False) model.to(device) model.eval() # Transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Prediction function def pred(model: torch.nn.Module, image_path: bytes, class_names: List[str], image_size: Tuple[int, int] = (299, 299), transform: torchvision.transforms = None, device: torch.device = device): img = Image.open(BytesIO(image_path)) if transform is not None: image_transform = transform else: image_transform = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) model.to(device) model.eval() with torch.inference_mode(): transformed_image = image_transform(img).unsqueeze(dim=0) target_image_pred = model(transformed_image.to(device)) target_image_pred_probs = torch.softmax(target_image_pred, dim=1) target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1) return class_names[target_image_pred_label] # Routes @app.route('/') def home(): return "Welcome to the crop prediction deep learning API" @app.route('/favicon.ico') def favicon(): return '', 204 @app.route('/predict', methods=['POST']) def predict(): data = request.get_json() base64img = data.get('image') if not base64img: return jsonify({'error': 'No image data found'}), 400 try: image_data = base64.b64decode(base64img) op = pred(model=model, image_path=image_data, class_names=class_names, transform=transform, image_size=(224, 224)) return jsonify({'status': 'ok', 'predicted_class': op}), 200 except Exception as e: print(f"Error processing image: {e}") return jsonify({'error': 'Failed to process image'}), 500 if __name__ == '__main__': port = int(os.environ.get("PORT", 7860)) # HF Docker Spaces require 7860 app.run(host='0.0.0.0', port=port)