Dharine's picture
Update app.py
b0cb778 verified
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from torchvision import transforms
import torchvision.transforms
from PIL import Image
from io import BytesIO
import base64
from typing import List, Tuple
import os
import gdown
from timm.models.resnet import ResNet # <- safe load
app = Flask(__name__)
CORS(app)
# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Class names
class_names = ['Apple_Apple_scab', 'Apple_Black_rot', 'Apple_Cedar_apple_rust', 'Apple_healthy',
'Blueberry_healthy', 'Cherry(including_sour)Powdery_mildew', 'Cherry(including_sour)healthy',
'Corn(maize)Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)Common_rust',
'Corn(maize)Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape_Black_rot',
'Grape_Esca(Black_Measles)', 'Grape_Leaf_blight(Isariopsis_Leaf_Spot)', 'Grapehealthy',
'Orange_Haunglongbing(Citrus_greening)', 'PeachBacterial_spot', 'Peach_healthy',
'Pepper,_bell_Bacterial_spot', 'Pepper,_bell_healthy', 'Potato_Early_blight',
'Potato_Late_blight', 'Potato_healthy', 'Raspberry_healthy', 'Soybean_healthy',
'Squash_Powdery_mildew', 'Strawberry_Leaf_scorch', 'Strawberry_healthy',
'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold',
'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites Two-spotted_spider_mite',
'Tomato_Target_Spot', 'Tomato_Tomato_Yellow_Leaf_Curl_Virus',
'Tomato_Tomato_mosaic_virus', 'Tomato_healthy']
# Download model if not present
model_path = "full_model.pth"
model_drive_url = "https://drive.google.com/uc?id=1DXpL1anOs6943Ifj1Uno7_4nd99RjGU3"
if not os.path.exists(model_path):
print("Downloading model from Google Drive...")
gdown.download(model_drive_url, model_path, quiet=False)
# Load model safely
with torch.serialization.safe_globals([ResNet]):
model = torch.load(model_path, weights_only=False)
model.to(device)
model.eval()
# Transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Prediction function
def pred(model: torch.nn.Module, image_path: bytes, class_names: List[str],
image_size: Tuple[int, int] = (299, 299), transform: torchvision.transforms = None,
device: torch.device = device):
img = Image.open(BytesIO(image_path))
if transform is not None:
image_transform = transform
else:
image_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
model.to(device)
model.eval()
with torch.inference_mode():
transformed_image = image_transform(img).unsqueeze(dim=0)
target_image_pred = model(transformed_image.to(device))
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
return class_names[target_image_pred_label]
# Routes
@app.route('/')
def home():
return "Welcome to the crop prediction deep learning API"
@app.route('/favicon.ico')
def favicon():
return '', 204
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
base64img = data.get('image')
if not base64img:
return jsonify({'error': 'No image data found'}), 400
try:
image_data = base64.b64decode(base64img)
op = pred(model=model, image_path=image_data, class_names=class_names, transform=transform,
image_size=(224, 224))
return jsonify({'status': 'ok', 'predicted_class': op}), 200
except Exception as e:
print(f"Error processing image: {e}")
return jsonify({'error': 'Failed to process image'}), 500
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860)) # HF Docker Spaces require 7860
app.run(host='0.0.0.0', port=port)