Update model.py
Browse files
model.py
CHANGED
|
@@ -19,15 +19,13 @@ class FixedDropout(tf.keras.layers.Dropout):
|
|
| 19 |
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
|
| 20 |
super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
|
| 21 |
|
| 22 |
-
#
|
| 23 |
@tf.keras.utils.register_keras_serializable()
|
| 24 |
class EfficientNetB3(tf.keras.Model):
|
| 25 |
pass
|
| 26 |
|
| 27 |
# MobileNetV2
|
| 28 |
-
|
| 29 |
-
class MobileNetV2(tf.keras.Model):
|
| 30 |
-
pass
|
| 31 |
|
| 32 |
# ================= INPUT SIZE PER MODEL =================
|
| 33 |
|
|
@@ -79,7 +77,7 @@ def load_models():
|
|
| 79 |
with open(label_path, "r") as f:
|
| 80 |
LABELS[crop_name] = json.load(f)
|
| 81 |
|
| 82 |
-
# ---------- PyTorch
|
| 83 |
if ext == ".pth":
|
| 84 |
num_classes = len(LABELS[crop_name])
|
| 85 |
model = models.resnet18(weights=None)
|
|
@@ -88,20 +86,19 @@ def load_models():
|
|
| 88 |
model.eval()
|
| 89 |
PYTORCH_MODELS[crop_name] = model
|
| 90 |
|
| 91 |
-
# ---------- Keras
|
| 92 |
elif ext in [".keras", ".h5"]:
|
| 93 |
KERAS_MODELS[crop_name] = tf.keras.models.load_model(
|
| 94 |
model_path,
|
| 95 |
custom_objects={
|
| 96 |
"swish": tf.keras.activations.swish,
|
| 97 |
"FixedDropout": FixedDropout,
|
| 98 |
-
"EfficientNetB3": EfficientNetB3,
|
| 99 |
-
"MobileNetV2": MobileNetV2,
|
| 100 |
},
|
| 101 |
compile=False
|
| 102 |
)
|
| 103 |
|
| 104 |
-
# Load once at startup
|
| 105 |
load_models()
|
| 106 |
|
| 107 |
# ================= PREDICTION =================
|
|
@@ -109,28 +106,22 @@ load_models()
|
|
| 109 |
def predict(image, crop_name):
|
| 110 |
crop_name = crop_name.strip().lower()
|
| 111 |
|
| 112 |
-
# -------- PyTorch --------
|
| 113 |
if crop_name in PYTORCH_MODELS:
|
| 114 |
model = PYTORCH_MODELS[crop_name]
|
| 115 |
labels = LABELS[crop_name]
|
| 116 |
-
|
| 117 |
tensor = preprocess_pytorch(image)
|
| 118 |
with torch.no_grad():
|
| 119 |
output = model(tensor)
|
| 120 |
probs = torch.softmax(output[0], dim=0)
|
| 121 |
idx = probs.argmax().item()
|
| 122 |
-
|
| 123 |
return labels[idx], float(probs[idx])
|
| 124 |
|
| 125 |
-
# -------- Keras --------
|
| 126 |
elif crop_name in KERAS_MODELS:
|
| 127 |
model = KERAS_MODELS[crop_name]
|
| 128 |
labels = LABELS[crop_name]
|
| 129 |
-
|
| 130 |
arr = preprocess_keras(image, crop_name)
|
| 131 |
preds = model.predict(arr, verbose=0)[0]
|
| 132 |
idx = int(np.argmax(preds))
|
| 133 |
-
|
| 134 |
return labels[idx], float(preds[idx])
|
| 135 |
|
| 136 |
else:
|
|
|
|
| 19 |
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
|
| 20 |
super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
|
| 21 |
|
| 22 |
+
# EfficientNet
|
| 23 |
@tf.keras.utils.register_keras_serializable()
|
| 24 |
class EfficientNetB3(tf.keras.Model):
|
| 25 |
pass
|
| 26 |
|
| 27 |
# MobileNetV2
|
| 28 |
+
from tensorflow.keras.applications import MobileNetV2
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# ================= INPUT SIZE PER MODEL =================
|
| 31 |
|
|
|
|
| 77 |
with open(label_path, "r") as f:
|
| 78 |
LABELS[crop_name] = json.load(f)
|
| 79 |
|
| 80 |
+
# ---------- PyTorch ----------
|
| 81 |
if ext == ".pth":
|
| 82 |
num_classes = len(LABELS[crop_name])
|
| 83 |
model = models.resnet18(weights=None)
|
|
|
|
| 86 |
model.eval()
|
| 87 |
PYTORCH_MODELS[crop_name] = model
|
| 88 |
|
| 89 |
+
# ---------- Keras ----------
|
| 90 |
elif ext in [".keras", ".h5"]:
|
| 91 |
KERAS_MODELS[crop_name] = tf.keras.models.load_model(
|
| 92 |
model_path,
|
| 93 |
custom_objects={
|
| 94 |
"swish": tf.keras.activations.swish,
|
| 95 |
"FixedDropout": FixedDropout,
|
| 96 |
+
"EfficientNetB3": EfficientNetB3, # corn
|
| 97 |
+
"MobileNetV2": MobileNetV2, # bean (REAL)
|
| 98 |
},
|
| 99 |
compile=False
|
| 100 |
)
|
| 101 |
|
|
|
|
| 102 |
load_models()
|
| 103 |
|
| 104 |
# ================= PREDICTION =================
|
|
|
|
| 106 |
def predict(image, crop_name):
|
| 107 |
crop_name = crop_name.strip().lower()
|
| 108 |
|
|
|
|
| 109 |
if crop_name in PYTORCH_MODELS:
|
| 110 |
model = PYTORCH_MODELS[crop_name]
|
| 111 |
labels = LABELS[crop_name]
|
|
|
|
| 112 |
tensor = preprocess_pytorch(image)
|
| 113 |
with torch.no_grad():
|
| 114 |
output = model(tensor)
|
| 115 |
probs = torch.softmax(output[0], dim=0)
|
| 116 |
idx = probs.argmax().item()
|
|
|
|
| 117 |
return labels[idx], float(probs[idx])
|
| 118 |
|
|
|
|
| 119 |
elif crop_name in KERAS_MODELS:
|
| 120 |
model = KERAS_MODELS[crop_name]
|
| 121 |
labels = LABELS[crop_name]
|
|
|
|
| 122 |
arr = preprocess_keras(image, crop_name)
|
| 123 |
preds = model.predict(arr, verbose=0)[0]
|
| 124 |
idx = int(np.argmax(preds))
|
|
|
|
| 125 |
return labels[idx], float(preds[idx])
|
| 126 |
|
| 127 |
else:
|