Inoue1 commited on
Commit
f879f74
·
verified ·
1 Parent(s): f2bbf2c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +46 -52
model.py CHANGED
@@ -24,16 +24,6 @@ class FixedDropout(tf.keras.layers.Dropout):
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
27
- # MobileNetV2
28
- from tensorflow.keras.applications import MobileNetV2
29
-
30
- # ================= INPUT SIZE PER MODEL =================
31
-
32
- KERAS_INPUT_SIZES = {
33
- "corn": 300,
34
- "bean": 224,
35
- }
36
-
37
  # ================= IMAGE PREPROCESS =================
38
 
39
  def preprocess_pytorch(img, size=224):
@@ -48,14 +38,12 @@ def preprocess_pytorch(img, size=224):
48
  ])
49
  return transform(img).unsqueeze(0)
50
 
51
- def preprocess_keras(img, crop_name):
52
- img = img.convert("RGB")
53
- size = KERAS_INPUT_SIZES.get(crop_name, 224)
54
- img = img.resize((size, size))
55
  arr = np.array(img).astype("float32") / 255.0
56
  return np.expand_dims(arr, axis=0)
57
 
58
- # ================= MODEL REGISTRIES =================
59
 
60
  PYTORCH_MODELS = {}
61
  KERAS_MODELS = {}
@@ -66,63 +54,69 @@ LABELS = {}
66
  def load_models():
67
  for file in os.listdir(MODELS_DIR):
68
  name, ext = os.path.splitext(file)
69
- crop_name = name.replace("_model", "").lower()
70
 
71
  model_path = os.path.join(MODELS_DIR, file)
72
- label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
73
-
74
- if not os.path.exists(label_path):
75
- raise FileNotFoundError(f"Missing label file: {label_path}")
76
 
77
- with open(label_path, "r") as f:
78
- LABELS[crop_name] = json.load(f)
79
 
80
  # ---------- PyTorch ----------
81
  if ext == ".pth":
82
- num_classes = len(LABELS[crop_name])
83
  model = models.resnet18(weights=None)
84
- model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
85
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
86
  model.eval()
87
- PYTORCH_MODELS[crop_name] = model
88
 
89
  # ---------- Keras ----------
90
  elif ext in [".keras", ".h5"]:
91
- KERAS_MODELS[crop_name] = tf.keras.models.load_model(
92
- model_path,
93
- custom_objects={
94
- "swish": tf.keras.activations.swish,
95
- "FixedDropout": FixedDropout,
96
- "EfficientNetB3": EfficientNetB3, # corn
97
- "MobileNetV2": MobileNetV2, # bean (REAL)
98
- },
99
- compile=False
100
- )
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  load_models()
103
 
104
- # ================= PREDICTION =================
105
 
106
- def predict(image, crop_name):
107
- crop_name = crop_name.strip().lower()
108
 
109
- if crop_name in PYTORCH_MODELS:
110
- model = PYTORCH_MODELS[crop_name]
111
- labels = LABELS[crop_name]
112
- tensor = preprocess_pytorch(image)
113
  with torch.no_grad():
114
- output = model(tensor)
115
- probs = torch.softmax(output[0], dim=0)
116
- idx = probs.argmax().item()
117
  return labels[idx], float(probs[idx])
118
 
119
- elif crop_name in KERAS_MODELS:
120
- model = KERAS_MODELS[crop_name]
121
- labels = LABELS[crop_name]
122
- arr = preprocess_keras(image, crop_name)
123
- preds = model.predict(arr, verbose=0)[0]
 
124
  idx = int(np.argmax(preds))
125
  return labels[idx], float(preds[idx])
126
 
127
- else:
128
- raise ValueError(f"No model found for crop: {crop_name}")
 
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
 
 
 
 
 
 
 
 
 
 
27
  # ================= IMAGE PREPROCESS =================
28
 
29
  def preprocess_pytorch(img, size=224):
 
38
  ])
39
  return transform(img).unsqueeze(0)
40
 
41
+ def preprocess_keras(img, size):
42
+ img = img.convert("RGB").resize((size, size))
 
 
43
  arr = np.array(img).astype("float32") / 255.0
44
  return np.expand_dims(arr, axis=0)
45
 
46
+ # ================= REGISTRIES =================
47
 
48
  PYTORCH_MODELS = {}
49
  KERAS_MODELS = {}
 
54
  def load_models():
55
  for file in os.listdir(MODELS_DIR):
56
  name, ext = os.path.splitext(file)
57
+ crop = name.replace("_model", "").lower()
58
 
59
  model_path = os.path.join(MODELS_DIR, file)
60
+ label_path = os.path.join(LABELS_DIR, f"{crop}_labels.json")
 
 
 
61
 
62
+ with open(label_path) as f:
63
+ LABELS[crop] = json.load(f)
64
 
65
  # ---------- PyTorch ----------
66
  if ext == ".pth":
 
67
  model = models.resnet18(weights=None)
68
+ model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS[crop]))
69
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
70
  model.eval()
71
+ PYTORCH_MODELS[crop] = model
72
 
73
  # ---------- Keras ----------
74
  elif ext in [".keras", ".h5"]:
75
+
76
+ # Bean model — load clean
77
+ if crop == "bean":
78
+ model = tf.keras.models.load_model(
79
+ model_path,
80
+ compile=False
81
+ )
82
+
83
+ # Corn model — needs EfficientNet
84
+ else:
85
+ model = tf.keras.models.load_model(
86
+ model_path,
87
+ custom_objects={
88
+ "FixedDropout": FixedDropout,
89
+ "EfficientNetB3": EfficientNetB3,
90
+ "swish": tf.keras.activations.swish,
91
+ },
92
+ compile=False
93
+ )
94
+
95
+ KERAS_MODELS[crop] = model
96
 
97
  load_models()
98
 
99
+ # ================= PREDICT =================
100
 
101
+ def predict(image, crop):
102
+ crop = crop.lower()
103
 
104
+ if crop in PYTORCH_MODELS:
105
+ model = PYTORCH_MODELS[crop]
106
+ labels = LABELS[crop]
107
+ x = preprocess_pytorch(image)
108
  with torch.no_grad():
109
+ probs = torch.softmax(model(x)[0], dim=0)
110
+ idx = probs.argmax().item()
 
111
  return labels[idx], float(probs[idx])
112
 
113
+ if crop in KERAS_MODELS:
114
+ model = KERAS_MODELS[crop]
115
+ labels = LABELS[crop]
116
+ size = 224 if crop == "bean" else 300
117
+ x = preprocess_keras(image, size)
118
+ preds = model.predict(x, verbose=0)[0]
119
  idx = int(np.argmax(preds))
120
  return labels[idx], float(preds[idx])
121
 
122
+ raise ValueError(f"No model found for {crop}")