Inoue1 commited on
Commit
94a728b
·
verified ·
1 Parent(s): f879f74

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +31 -19
model.py CHANGED
@@ -16,20 +16,29 @@ LABELS_DIR = os.path.join(BASE_DIR, "labels")
16
 
17
  @tf.keras.utils.register_keras_serializable()
18
  class FixedDropout(tf.keras.layers.Dropout):
19
- def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
- super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
22
- # EfficientNet
23
  @tf.keras.utils.register_keras_serializable()
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
 
 
 
 
 
 
 
 
 
 
27
  # ================= IMAGE PREPROCESS =================
28
 
29
- def preprocess_pytorch(img, size=224):
30
  img = img.convert("RGB")
31
  transform = transforms.Compose([
32
- transforms.Resize((size, size)),
33
  transforms.ToTensor(),
34
  transforms.Normalize(
35
  mean=[0.485, 0.456, 0.406],
@@ -38,7 +47,8 @@ def preprocess_pytorch(img, size=224):
38
  ])
39
  return transform(img).unsqueeze(0)
40
 
41
- def preprocess_keras(img, size):
 
42
  img = img.convert("RGB").resize((size, size))
43
  arr = np.array(img).astype("float32") / 255.0
44
  return np.expand_dims(arr, axis=0)
@@ -56,12 +66,11 @@ def load_models():
56
  name, ext = os.path.splitext(file)
57
  crop = name.replace("_model", "").lower()
58
 
59
- model_path = os.path.join(MODELS_DIR, file)
60
- label_path = os.path.join(LABELS_DIR, f"{crop}_labels.json")
61
-
62
- with open(label_path) as f:
63
  LABELS[crop] = json.load(f)
64
 
 
 
65
  # ---------- PyTorch ----------
66
  if ext == ".pth":
67
  model = models.resnet18(weights=None)
@@ -73,15 +82,14 @@ def load_models():
73
  # ---------- Keras ----------
74
  elif ext in [".keras", ".h5"]:
75
 
76
- # Bean model — load clean
77
  if crop == "bean":
 
78
  model = tf.keras.models.load_model(
79
  model_path,
80
  compile=False
81
  )
82
 
83
- # Corn model — needs EfficientNet
84
- else:
85
  model = tf.keras.models.load_model(
86
  model_path,
87
  custom_objects={
@@ -91,6 +99,8 @@ def load_models():
91
  },
92
  compile=False
93
  )
 
 
94
 
95
  KERAS_MODELS[crop] = model
96
 
@@ -103,20 +113,22 @@ def predict(image, crop):
103
 
104
  if crop in PYTORCH_MODELS:
105
  model = PYTORCH_MODELS[crop]
106
- labels = LABELS[crop]
107
  x = preprocess_pytorch(image)
108
  with torch.no_grad():
109
  probs = torch.softmax(model(x)[0], dim=0)
110
  idx = probs.argmax().item()
111
- return labels[idx], float(probs[idx])
112
 
113
  if crop in KERAS_MODELS:
114
  model = KERAS_MODELS[crop]
115
- labels = LABELS[crop]
116
- size = 224 if crop == "bean" else 300
117
- x = preprocess_keras(image, size)
 
 
 
118
  preds = model.predict(x, verbose=0)[0]
119
  idx = int(np.argmax(preds))
120
- return labels[idx], float(preds[idx])
121
 
122
  raise ValueError(f"No model found for {crop}")
 
16
 
17
  @tf.keras.utils.register_keras_serializable()
18
  class FixedDropout(tf.keras.layers.Dropout):
19
+ def __init__(self, rate, **kwargs):
20
+ super().__init__(rate, **kwargs)
21
 
 
22
  @tf.keras.utils.register_keras_serializable()
23
  class EfficientNetB3(tf.keras.Model):
24
  pass
25
 
26
+ # ================= HARD INPUT SIZES =================
27
+
28
+ INPUT_SIZE = {
29
+ "bean": 224,
30
+ "corn": 300,
31
+ "banana": 224,
32
+ "chilli": 224,
33
+ "rice": 224,
34
+ }
35
+
36
  # ================= IMAGE PREPROCESS =================
37
 
38
+ def preprocess_pytorch(img):
39
  img = img.convert("RGB")
40
  transform = transforms.Compose([
41
+ transforms.Resize((224, 224)),
42
  transforms.ToTensor(),
43
  transforms.Normalize(
44
  mean=[0.485, 0.456, 0.406],
 
47
  ])
48
  return transform(img).unsqueeze(0)
49
 
50
+ def preprocess_keras(img, crop):
51
+ size = INPUT_SIZE[crop]
52
  img = img.convert("RGB").resize((size, size))
53
  arr = np.array(img).astype("float32") / 255.0
54
  return np.expand_dims(arr, axis=0)
 
66
  name, ext = os.path.splitext(file)
67
  crop = name.replace("_model", "").lower()
68
 
69
+ with open(os.path.join(LABELS_DIR, f"{crop}_labels.json")) as f:
 
 
 
70
  LABELS[crop] = json.load(f)
71
 
72
+ model_path = os.path.join(MODELS_DIR, file)
73
+
74
  # ---------- PyTorch ----------
75
  if ext == ".pth":
76
  model = models.resnet18(weights=None)
 
82
  # ---------- Keras ----------
83
  elif ext in [".keras", ".h5"]:
84
 
 
85
  if crop == "bean":
86
+ # NO custom_objects
87
  model = tf.keras.models.load_model(
88
  model_path,
89
  compile=False
90
  )
91
 
92
+ elif crop == "corn":
 
93
  model = tf.keras.models.load_model(
94
  model_path,
95
  custom_objects={
 
99
  },
100
  compile=False
101
  )
102
+ else:
103
+ model = tf.keras.models.load_model(model_path, compile=False)
104
 
105
  KERAS_MODELS[crop] = model
106
 
 
113
 
114
  if crop in PYTORCH_MODELS:
115
  model = PYTORCH_MODELS[crop]
 
116
  x = preprocess_pytorch(image)
117
  with torch.no_grad():
118
  probs = torch.softmax(model(x)[0], dim=0)
119
  idx = probs.argmax().item()
120
+ return LABELS[crop][idx], float(probs[idx])
121
 
122
  if crop in KERAS_MODELS:
123
  model = KERAS_MODELS[crop]
124
+ x = preprocess_keras(image, crop)
125
+
126
+ # SAFETY ASSERT
127
+ expected = INPUT_SIZE[crop]
128
+ assert x.shape[1] == expected, f"{crop} received wrong input size!"
129
+
130
  preds = model.predict(x, verbose=0)[0]
131
  idx = int(np.argmax(preds))
132
+ return LABELS[crop][idx], float(preds[idx])
133
 
134
  raise ValueError(f"No model found for {crop}")