Inoue1 commited on
Commit
f2bbf2c
·
verified ·
1 Parent(s): bd1516d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -15
model.py CHANGED
@@ -19,15 +19,13 @@ class FixedDropout(tf.keras.layers.Dropout):
19
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
22
- # EfficientNetB3
23
  @tf.keras.utils.register_keras_serializable()
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
27
  # MobileNetV2
28
- @tf.keras.utils.register_keras_serializable()
29
- class MobileNetV2(tf.keras.Model):
30
- pass
31
 
32
  # ================= INPUT SIZE PER MODEL =================
33
 
@@ -79,7 +77,7 @@ def load_models():
79
  with open(label_path, "r") as f:
80
  LABELS[crop_name] = json.load(f)
81
 
82
- # ---------- PyTorch (.pth) ----------
83
  if ext == ".pth":
84
  num_classes = len(LABELS[crop_name])
85
  model = models.resnet18(weights=None)
@@ -88,20 +86,19 @@ def load_models():
88
  model.eval()
89
  PYTORCH_MODELS[crop_name] = model
90
 
91
- # ---------- Keras (.keras / .h5) ----------
92
  elif ext in [".keras", ".h5"]:
93
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
94
  model_path,
95
  custom_objects={
96
  "swish": tf.keras.activations.swish,
97
  "FixedDropout": FixedDropout,
98
- "EfficientNetB3": EfficientNetB3,
99
- "MobileNetV2": MobileNetV2,
100
  },
101
  compile=False
102
  )
103
 
104
- # Load once at startup
105
  load_models()
106
 
107
  # ================= PREDICTION =================
@@ -109,28 +106,22 @@ load_models()
109
  def predict(image, crop_name):
110
  crop_name = crop_name.strip().lower()
111
 
112
- # -------- PyTorch --------
113
  if crop_name in PYTORCH_MODELS:
114
  model = PYTORCH_MODELS[crop_name]
115
  labels = LABELS[crop_name]
116
-
117
  tensor = preprocess_pytorch(image)
118
  with torch.no_grad():
119
  output = model(tensor)
120
  probs = torch.softmax(output[0], dim=0)
121
  idx = probs.argmax().item()
122
-
123
  return labels[idx], float(probs[idx])
124
 
125
- # -------- Keras --------
126
  elif crop_name in KERAS_MODELS:
127
  model = KERAS_MODELS[crop_name]
128
  labels = LABELS[crop_name]
129
-
130
  arr = preprocess_keras(image, crop_name)
131
  preds = model.predict(arr, verbose=0)[0]
132
  idx = int(np.argmax(preds))
133
-
134
  return labels[idx], float(preds[idx])
135
 
136
  else:
 
19
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
22
+ # EfficientNet
23
  @tf.keras.utils.register_keras_serializable()
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
27
  # MobileNetV2
28
+ from tensorflow.keras.applications import MobileNetV2
 
 
29
 
30
  # ================= INPUT SIZE PER MODEL =================
31
 
 
77
  with open(label_path, "r") as f:
78
  LABELS[crop_name] = json.load(f)
79
 
80
+ # ---------- PyTorch ----------
81
  if ext == ".pth":
82
  num_classes = len(LABELS[crop_name])
83
  model = models.resnet18(weights=None)
 
86
  model.eval()
87
  PYTORCH_MODELS[crop_name] = model
88
 
89
+ # ---------- Keras ----------
90
  elif ext in [".keras", ".h5"]:
91
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
92
  model_path,
93
  custom_objects={
94
  "swish": tf.keras.activations.swish,
95
  "FixedDropout": FixedDropout,
96
+ "EfficientNetB3": EfficientNetB3, # corn
97
+ "MobileNetV2": MobileNetV2, # bean (REAL)
98
  },
99
  compile=False
100
  )
101
 
 
102
  load_models()
103
 
104
  # ================= PREDICTION =================
 
106
  def predict(image, crop_name):
107
  crop_name = crop_name.strip().lower()
108
 
 
109
  if crop_name in PYTORCH_MODELS:
110
  model = PYTORCH_MODELS[crop_name]
111
  labels = LABELS[crop_name]
 
112
  tensor = preprocess_pytorch(image)
113
  with torch.no_grad():
114
  output = model(tensor)
115
  probs = torch.softmax(output[0], dim=0)
116
  idx = probs.argmax().item()
 
117
  return labels[idx], float(probs[idx])
118
 
 
119
  elif crop_name in KERAS_MODELS:
120
  model = KERAS_MODELS[crop_name]
121
  labels = LABELS[crop_name]
 
122
  arr = preprocess_keras(image, crop_name)
123
  preds = model.predict(arr, verbose=0)[0]
124
  idx = int(np.argmax(preds))
 
125
  return labels[idx], float(preds[idx])
126
 
127
  else: