File size: 4,234 Bytes
b0cb778
 
5c89ff8
 
b0cb778
5c89ff8
 
 
b0cb778
5c89ff8
b0cb778
 
5c89ff8
b0cb778
 
 
 
5c89ff8
 
b0cb778
 
 
 
 
 
 
 
 
 
 
 
 
 
5c89ff8
b0cb778
5c89ff8
 
 
 
b0cb778
5c89ff8
 
b0cb778
 
 
 
5c89ff8
 
b0cb778
5c89ff8
 
 
 
 
 
 
b0cb778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c89ff8
b0cb778
 
 
 
 
 
 
 
5c89ff8
b0cb778
 
5c89ff8
b0cb778
5c89ff8
b0cb778
 
 
 
 
 
 
 
5c89ff8
b0cb778
 
5c89ff8
b0cb778
 
 
 
 
5c89ff8
b0cb778
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from torchvision import transforms
import torchvision.transforms
from PIL import Image
from io import BytesIO
import base64
from typing import List, Tuple
import os
import gdown
from timm.models.resnet import ResNet  # <- safe load

app = Flask(__name__)
CORS(app)

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Class names
class_names = ['Apple_Apple_scab', 'Apple_Black_rot', 'Apple_Cedar_apple_rust', 'Apple_healthy',
               'Blueberry_healthy', 'Cherry(including_sour)Powdery_mildew', 'Cherry(including_sour)healthy',
               'Corn(maize)Cercospora_leaf_spot Gray_leaf_spot', 'Corn(maize)Common_rust',
               'Corn(maize)Northern_Leaf_Blight', 'Corn(maize)healthy', 'Grape_Black_rot',
               'Grape_Esca(Black_Measles)', 'Grape_Leaf_blight(Isariopsis_Leaf_Spot)', 'Grapehealthy',
               'Orange_Haunglongbing(Citrus_greening)', 'PeachBacterial_spot', 'Peach_healthy',
               'Pepper,_bell_Bacterial_spot', 'Pepper,_bell_healthy', 'Potato_Early_blight',
               'Potato_Late_blight', 'Potato_healthy', 'Raspberry_healthy', 'Soybean_healthy',
               'Squash_Powdery_mildew', 'Strawberry_Leaf_scorch', 'Strawberry_healthy',
               'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold',
               'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites Two-spotted_spider_mite',
               'Tomato_Target_Spot', 'Tomato_Tomato_Yellow_Leaf_Curl_Virus',
               'Tomato_Tomato_mosaic_virus', 'Tomato_healthy']

# Download model if not present
model_path = "full_model.pth"
model_drive_url = "https://drive.google.com/uc?id=1DXpL1anOs6943Ifj1Uno7_4nd99RjGU3"

if not os.path.exists(model_path):
    print("Downloading model from Google Drive...")
    gdown.download(model_drive_url, model_path, quiet=False)

# Load model safely
with torch.serialization.safe_globals([ResNet]):
    model = torch.load(model_path, weights_only=False)
model.to(device)
model.eval()

# Transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Prediction function
def pred(model: torch.nn.Module, image_path: bytes, class_names: List[str],
         image_size: Tuple[int, int] = (299, 299), transform: torchvision.transforms = None,
         device: torch.device = device):
    img = Image.open(BytesIO(image_path))

    if transform is not None:
        image_transform = transform
    else:
        image_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    model.to(device)
    model.eval()
    with torch.inference_mode():
        transformed_image = image_transform(img).unsqueeze(dim=0)
        target_image_pred = model(transformed_image.to(device))
        target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
        target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
        return class_names[target_image_pred_label]

# Routes
@app.route('/')
def home():
    return "Welcome to the crop prediction deep learning API"

@app.route('/favicon.ico')
def favicon():
    return '', 204

@app.route('/predict', methods=['POST'])
def predict():
    data = request.get_json()
    base64img = data.get('image')

    if not base64img:
        return jsonify({'error': 'No image data found'}), 400

    try:
        image_data = base64.b64decode(base64img)
        op = pred(model=model, image_path=image_data, class_names=class_names, transform=transform,
                  image_size=(224, 224))
        return jsonify({'status': 'ok', 'predicted_class': op}), 200
    except Exception as e:
        print(f"Error processing image: {e}")
        return jsonify({'error': 'Failed to process image'}), 500

if __name__ == '__main__':
    port = int(os.environ.get("PORT", 7860))  # HF Docker Spaces require 7860
    app.run(host='0.0.0.0', port=port)