File size: 3,831 Bytes
4418807
 
 
 
 
 
 
 
f3591b4
4418807
f3591b4
4418807
 
 
4e13890
f3591b4
 
 
010286a
 
f3591b4
010286a
4e13890
 
 
 
010286a
94a728b
010286a
94a728b
 
 
4418807
 
010286a
4418807
010286a
4418807
 
 
 
 
 
 
 
010286a
 
 
 
f3591b4
4418807
 
010286a
4418807
 
 
 
 
4e13890
f3591b4
4418807
 
 
010286a
4418807
94a728b
010286a
 
 
 
 
 
 
94a728b
f2bbf2c
4418807
010286a
4418807
010286a
4418807
 
010286a
4418807
f2bbf2c
65af2fd
010286a
 
 
 
 
 
 
 
 
 
 
4418807
 
010286a
4418807
010286a
 
4418807
010286a
 
 
 
4418807
010286a
 
 
 
 
 
 
 
 
 
f3591b4
010286a
4418807
010286a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import json
import torch
import numpy as np
import tensorflow as tf
from PIL import Image
from torchvision import models, transforms

# ================= PATHS =================

BASE_DIR = os.path.dirname(__file__)
MODELS_DIR = os.path.join(BASE_DIR, "models")
LABELS_DIR = os.path.join(BASE_DIR, "labels")

# ================= TF CUSTOM OBJECTS =================

@tf.keras.utils.register_keras_serializable()
class FixedDropout(tf.keras.layers.Dropout):
    def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
        super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)

# Dummy class to satisfy EfficientNet deserialization
@tf.keras.utils.register_keras_serializable()
class EfficientNetB3(tf.keras.Model):
    pass

# ================= INPUT SIZE PER MODEL =================

KERAS_INPUT_SIZES = {
    "corn": 300,
}

# ================= IMAGE PREPROCESS =================

def preprocess_pytorch(img, size=224):
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    return transform(img).unsqueeze(0)

def preprocess_keras(img, crop_name):
    img = img.convert("RGB")
    size = KERAS_INPUT_SIZES.get(crop_name, 224)
    img = img.resize((size, size))
    arr = np.array(img).astype("float32") / 255.0
    return np.expand_dims(arr, axis=0)

# ================= MODEL REGISTRIES =================

PYTORCH_MODELS = {}
KERAS_MODELS = {}
LABELS = {}

# ================= LOAD MODELS =================

def load_models():
    for file in os.listdir(MODELS_DIR):
        name, ext = os.path.splitext(file)
        crop_name = name.replace("_model", "").lower()

        model_path = os.path.join(MODELS_DIR, file)
        label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")

        if not os.path.exists(label_path):
            raise FileNotFoundError(f"Missing label file: {label_path}")

        with open(label_path, "r") as f:
            LABELS[crop_name] = json.load(f)

        # ---------- PyTorch ----------
        if ext == ".pth":
            num_classes = len(LABELS[crop_name])
            model = models.resnet18(weights=None)
            model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
            model.load_state_dict(torch.load(model_path, map_location="cpu"))
            model.eval()
            PYTORCH_MODELS[crop_name] = model

        # ---------- Keras ----------
        elif ext in [".keras", ".h5"]:
            KERAS_MODELS[crop_name] = tf.keras.models.load_model(
                model_path,
                custom_objects={
                    "swish": tf.keras.activations.swish,
                    "FixedDropout": FixedDropout,
                    "EfficientNetB3": EfficientNetB3,
                },
                compile=False
            )

# Load models at startup
load_models()

# ================= PREDICTION =================

def predict(image, crop_name):
    crop_name = crop_name.strip().lower()

    if crop_name in PYTORCH_MODELS:
        model = PYTORCH_MODELS[crop_name]
        labels = LABELS[crop_name]
        tensor = preprocess_pytorch(image)
        with torch.no_grad():
            output = model(tensor)
            probs = torch.softmax(output[0], dim=0)
            idx = probs.argmax().item()
        return labels[idx], float(probs[idx])

    elif crop_name in KERAS_MODELS:
        model = KERAS_MODELS[crop_name]
        labels = LABELS[crop_name]
        arr = preprocess_keras(image, crop_name)
        preds = model.predict(arr, verbose=0)[0]
        idx = int(np.argmax(preds))
        return labels[idx], float(preds[idx])

    else:
        raise ValueError(f"No model found for crop: {crop_name}")