Inoue1 commited on
Commit
bd1516d
·
verified ·
1 Parent(s): 4da8345

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +17 -3
model.py CHANGED
@@ -19,11 +19,16 @@ class FixedDropout(tf.keras.layers.Dropout):
19
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
22
- # Dummy class to satisfy EfficientNet deserialization
23
  @tf.keras.utils.register_keras_serializable()
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
 
 
 
 
 
27
  # ================= INPUT SIZE PER MODEL =================
28
 
29
  KERAS_INPUT_SIZES = {
@@ -34,6 +39,7 @@ KERAS_INPUT_SIZES = {
34
  # ================= IMAGE PREPROCESS =================
35
 
36
  def preprocess_pytorch(img, size=224):
 
37
  transform = transforms.Compose([
38
  transforms.Resize((size, size)),
39
  transforms.ToTensor(),
@@ -73,7 +79,7 @@ def load_models():
73
  with open(label_path, "r") as f:
74
  LABELS[crop_name] = json.load(f)
75
 
76
- # ---------- PyTorch ----------
77
  if ext == ".pth":
78
  num_classes = len(LABELS[crop_name])
79
  model = models.resnet18(weights=None)
@@ -82,7 +88,7 @@ def load_models():
82
  model.eval()
83
  PYTORCH_MODELS[crop_name] = model
84
 
85
- # ---------- Keras ----------
86
  elif ext in [".keras", ".h5"]:
87
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
88
  model_path,
@@ -90,10 +96,12 @@ def load_models():
90
  "swish": tf.keras.activations.swish,
91
  "FixedDropout": FixedDropout,
92
  "EfficientNetB3": EfficientNetB3,
 
93
  },
94
  compile=False
95
  )
96
 
 
97
  load_models()
98
 
99
  # ================= PREDICTION =================
@@ -101,22 +109,28 @@ load_models()
101
  def predict(image, crop_name):
102
  crop_name = crop_name.strip().lower()
103
 
 
104
  if crop_name in PYTORCH_MODELS:
105
  model = PYTORCH_MODELS[crop_name]
106
  labels = LABELS[crop_name]
 
107
  tensor = preprocess_pytorch(image)
108
  with torch.no_grad():
109
  output = model(tensor)
110
  probs = torch.softmax(output[0], dim=0)
111
  idx = probs.argmax().item()
 
112
  return labels[idx], float(probs[idx])
113
 
 
114
  elif crop_name in KERAS_MODELS:
115
  model = KERAS_MODELS[crop_name]
116
  labels = LABELS[crop_name]
 
117
  arr = preprocess_keras(image, crop_name)
118
  preds = model.predict(arr, verbose=0)[0]
119
  idx = int(np.argmax(preds))
 
120
  return labels[idx], float(preds[idx])
121
 
122
  else:
 
19
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
22
+ # EfficientNetB3
23
  @tf.keras.utils.register_keras_serializable()
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
27
+ # MobileNetV2
28
+ @tf.keras.utils.register_keras_serializable()
29
+ class MobileNetV2(tf.keras.Model):
30
+ pass
31
+
32
  # ================= INPUT SIZE PER MODEL =================
33
 
34
  KERAS_INPUT_SIZES = {
 
39
  # ================= IMAGE PREPROCESS =================
40
 
41
  def preprocess_pytorch(img, size=224):
42
+ img = img.convert("RGB")
43
  transform = transforms.Compose([
44
  transforms.Resize((size, size)),
45
  transforms.ToTensor(),
 
79
  with open(label_path, "r") as f:
80
  LABELS[crop_name] = json.load(f)
81
 
82
+ # ---------- PyTorch (.pth) ----------
83
  if ext == ".pth":
84
  num_classes = len(LABELS[crop_name])
85
  model = models.resnet18(weights=None)
 
88
  model.eval()
89
  PYTORCH_MODELS[crop_name] = model
90
 
91
+ # ---------- Keras (.keras / .h5) ----------
92
  elif ext in [".keras", ".h5"]:
93
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
94
  model_path,
 
96
  "swish": tf.keras.activations.swish,
97
  "FixedDropout": FixedDropout,
98
  "EfficientNetB3": EfficientNetB3,
99
+ "MobileNetV2": MobileNetV2,
100
  },
101
  compile=False
102
  )
103
 
104
+ # Load once at startup
105
  load_models()
106
 
107
  # ================= PREDICTION =================
 
109
  def predict(image, crop_name):
110
  crop_name = crop_name.strip().lower()
111
 
112
+ # -------- PyTorch --------
113
  if crop_name in PYTORCH_MODELS:
114
  model = PYTORCH_MODELS[crop_name]
115
  labels = LABELS[crop_name]
116
+
117
  tensor = preprocess_pytorch(image)
118
  with torch.no_grad():
119
  output = model(tensor)
120
  probs = torch.softmax(output[0], dim=0)
121
  idx = probs.argmax().item()
122
+
123
  return labels[idx], float(probs[idx])
124
 
125
+ # -------- Keras --------
126
  elif crop_name in KERAS_MODELS:
127
  model = KERAS_MODELS[crop_name]
128
  labels = LABELS[crop_name]
129
+
130
  arr = preprocess_keras(image, crop_name)
131
  preds = model.predict(arr, verbose=0)[0]
132
  idx = int(np.argmax(preds))
133
+
134
  return labels[idx], float(preds[idx])
135
 
136
  else: