Inoue1 commited on
Commit
010286a
·
verified ·
1 Parent(s): 0cf105d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +54 -65
model.py CHANGED
@@ -16,29 +16,25 @@ LABELS_DIR = os.path.join(BASE_DIR, "labels")
16
 
17
  @tf.keras.utils.register_keras_serializable()
18
  class FixedDropout(tf.keras.layers.Dropout):
19
- def __init__(self, rate, **kwargs):
20
- super().__init__(rate, **kwargs)
21
 
 
22
  @tf.keras.utils.register_keras_serializable()
23
  class EfficientNetB3(tf.keras.Model):
24
  pass
25
 
26
- # ================= HARD INPUT SIZES =================
27
 
28
- INPUT_SIZE = {
29
- "bean": 224,
30
  "corn": 300,
31
- "banana": 224,
32
- "chilli": 224,
33
- "rice": 224,
34
  }
35
 
36
  # ================= IMAGE PREPROCESS =================
37
 
38
- def preprocess_pytorch(img):
39
- img = img.convert("RGB")
40
  transform = transforms.Compose([
41
- transforms.Resize((224, 224)),
42
  transforms.ToTensor(),
43
  transforms.Normalize(
44
  mean=[0.485, 0.456, 0.406],
@@ -47,13 +43,14 @@ def preprocess_pytorch(img):
47
  ])
48
  return transform(img).unsqueeze(0)
49
 
50
- def preprocess_keras(img, crop):
51
- size = INPUT_SIZE[crop]
52
- img = img.convert("RGB").resize((size, size))
 
53
  arr = np.array(img).astype("float32") / 255.0
54
  return np.expand_dims(arr, axis=0)
55
 
56
- # ================= REGISTRIES =================
57
 
58
  PYTORCH_MODELS = {}
59
  KERAS_MODELS = {}
@@ -64,71 +61,63 @@ LABELS = {}
64
  def load_models():
65
  for file in os.listdir(MODELS_DIR):
66
  name, ext = os.path.splitext(file)
67
- crop = name.replace("_model", "").lower()
68
-
69
- with open(os.path.join(LABELS_DIR, f"{crop}_labels.json")) as f:
70
- LABELS[crop] = json.load(f)
71
 
72
  model_path = os.path.join(MODELS_DIR, file)
 
 
 
 
 
 
 
73
 
74
  # ---------- PyTorch ----------
75
  if ext == ".pth":
 
76
  model = models.resnet18(weights=None)
77
- model.fc = torch.nn.Linear(model.fc.in_features, len(LABELS[crop]))
78
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
79
  model.eval()
80
- PYTORCH_MODELS[crop] = model
81
 
82
  # ---------- Keras ----------
83
  elif ext in [".keras", ".h5"]:
84
-
85
- if crop == "bean":
86
- # NO custom_objects
87
- model = tf.keras.models.load_model(
88
- model_path,
89
- compile=False
90
- )
91
-
92
- elif crop == "corn":
93
- model = tf.keras.models.load_model(
94
- model_path,
95
- custom_objects={
96
- "FixedDropout": FixedDropout,
97
- "EfficientNetB3": EfficientNetB3,
98
- "swish": tf.keras.activations.swish,
99
- },
100
- compile=False
101
- )
102
- else:
103
- model = tf.keras.models.load_model(model_path, compile=False)
104
-
105
- KERAS_MODELS[crop] = model
106
-
107
  load_models()
108
 
109
- # ================= PREDICT =================
110
 
111
- def predict(image, crop):
112
- crop = crop.lower()
113
 
114
- if crop in PYTORCH_MODELS:
115
- model = PYTORCH_MODELS[crop]
116
- x = preprocess_pytorch(image)
 
117
  with torch.no_grad():
118
- probs = torch.softmax(model(x)[0], dim=0)
119
- idx = probs.argmax().item()
120
- return LABELS[crop][idx], float(probs[idx])
121
-
122
- if crop in KERAS_MODELS:
123
- model = KERAS_MODELS[crop]
124
- x = preprocess_keras(image, crop)
125
-
126
- # SAFETY ASSERT
127
- expected = INPUT_SIZE[crop]
128
- assert x.shape[1] == expected, f"{crop} received wrong input size!"
129
-
130
- preds = model.predict(x, verbose=0)[0]
131
  idx = int(np.argmax(preds))
132
- return LABELS[crop][idx], float(preds[idx])
133
 
134
- raise ValueError(f"No model found for {crop}")
 
 
16
 
17
  @tf.keras.utils.register_keras_serializable()
18
  class FixedDropout(tf.keras.layers.Dropout):
19
+ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
+ super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
22
+ # Dummy class to satisfy EfficientNet deserialization
23
  @tf.keras.utils.register_keras_serializable()
24
  class EfficientNetB3(tf.keras.Model):
25
  pass
26
 
27
+ # ================= INPUT SIZE PER MODEL =================
28
 
29
+ KERAS_INPUT_SIZES = {
 
30
  "corn": 300,
 
 
 
31
  }
32
 
33
  # ================= IMAGE PREPROCESS =================
34
 
35
+ def preprocess_pytorch(img, size=224):
 
36
  transform = transforms.Compose([
37
+ transforms.Resize((size, size)),
38
  transforms.ToTensor(),
39
  transforms.Normalize(
40
  mean=[0.485, 0.456, 0.406],
 
43
  ])
44
  return transform(img).unsqueeze(0)
45
 
46
+ def preprocess_keras(img, crop_name):
47
+ img = img.convert("RGB")
48
+ size = KERAS_INPUT_SIZES.get(crop_name, 224)
49
+ img = img.resize((size, size))
50
  arr = np.array(img).astype("float32") / 255.0
51
  return np.expand_dims(arr, axis=0)
52
 
53
+ # ================= MODEL REGISTRIES =================
54
 
55
  PYTORCH_MODELS = {}
56
  KERAS_MODELS = {}
 
61
  def load_models():
62
  for file in os.listdir(MODELS_DIR):
63
  name, ext = os.path.splitext(file)
64
+ crop_name = name.replace("_model", "").lower()
 
 
 
65
 
66
  model_path = os.path.join(MODELS_DIR, file)
67
+ label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
68
+
69
+ if not os.path.exists(label_path):
70
+ raise FileNotFoundError(f"Missing label file: {label_path}")
71
+
72
+ with open(label_path, "r") as f:
73
+ LABELS[crop_name] = json.load(f)
74
 
75
  # ---------- PyTorch ----------
76
  if ext == ".pth":
77
+ num_classes = len(LABELS[crop_name])
78
  model = models.resnet18(weights=None)
79
+ model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
80
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
81
  model.eval()
82
+ PYTORCH_MODELS[crop_name] = model
83
 
84
  # ---------- Keras ----------
85
  elif ext in [".keras", ".h5"]:
86
+ KERAS_MODELS[crop_name] = tf.keras.models.load_model(
87
+ model_path,
88
+ custom_objects={
89
+ "swish": tf.keras.activations.swish,
90
+ "FixedDropout": FixedDropout,
91
+ "EfficientNetB3": EfficientNetB3,
92
+ },
93
+ compile=False
94
+ )
95
+
96
+ # Load models at startup
 
 
 
 
 
 
 
 
 
 
 
 
97
  load_models()
98
 
99
+ # ================= PREDICTION =================
100
 
101
+ def predict(image, crop_name):
102
+ crop_name = crop_name.strip().lower()
103
 
104
+ if crop_name in PYTORCH_MODELS:
105
+ model = PYTORCH_MODELS[crop_name]
106
+ labels = LABELS[crop_name]
107
+ tensor = preprocess_pytorch(image)
108
  with torch.no_grad():
109
+ output = model(tensor)
110
+ probs = torch.softmax(output[0], dim=0)
111
+ idx = probs.argmax().item()
112
+ return labels[idx], float(probs[idx])
113
+
114
+ elif crop_name in KERAS_MODELS:
115
+ model = KERAS_MODELS[crop_name]
116
+ labels = LABELS[crop_name]
117
+ arr = preprocess_keras(image, crop_name)
118
+ preds = model.predict(arr, verbose=0)[0]
 
 
 
119
  idx = int(np.argmax(preds))
120
+ return labels[idx], float(preds[idx])
121
 
122
+ else:
123
+ raise ValueError(f"No model found for crop: {crop_name}")