Inoue1 commited on
Commit
e70906b
·
verified ·
1 Parent(s): 33979f5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -4
model.py CHANGED
@@ -6,6 +6,9 @@ import tensorflow as tf
6
  from PIL import Image
7
  from torchvision import models, transforms
8
 
 
 
 
9
  # ================= PATHS =================
10
 
11
  BASE_DIR = os.path.dirname(__file__)
@@ -19,6 +22,14 @@ class FixedDropout(tf.keras.layers.Dropout):
19
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
 
 
 
 
 
 
 
 
22
  # ================= IMAGE PREPROCESS =================
23
 
24
  def preprocess_pytorch(img, size=224):
@@ -32,7 +43,8 @@ def preprocess_pytorch(img, size=224):
32
  ])
33
  return transform(img).unsqueeze(0)
34
 
35
- def preprocess_keras(img, size=224):
 
36
  img = img.resize((size, size))
37
  arr = np.array(img).astype("float32") / 255.0
38
  return np.expand_dims(arr, axis=0)
@@ -56,7 +68,7 @@ def load_models():
56
  label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
57
 
58
  if not os.path.exists(label_path):
59
- raise FileNotFoundError(f"Missing label file: {label_path}")
60
 
61
  with open(label_path, "r") as f:
62
  LABELS[crop_name] = json.load(f)
@@ -79,6 +91,8 @@ def load_models():
79
  custom_objects={
80
  "swish": tf.keras.activations.swish,
81
  "FixedDropout": FixedDropout,
 
 
82
  },
83
  compile=False # IMPORTANT for HF
84
  )
@@ -109,11 +123,11 @@ def predict(image, crop_name):
109
  model = KERAS_MODELS[crop_name]
110
  labels = LABELS[crop_name]
111
 
112
- arr = preprocess_keras(image)
113
  preds = model.predict(arr, verbose=0)[0]
114
  idx = int(np.argmax(preds))
115
 
116
  return labels[idx], float(preds[idx])
117
 
118
  else:
119
- raise ValueError(f"No model found for crop: {crop_name}")
 
6
  from PIL import Image
7
  from torchvision import models, transforms
8
 
9
+ # EfficientNet (needed ONLY for corn)
10
+ from efficientnet.tfkeras import EfficientNetB3
11
+
12
  # ================= PATHS =================
13
 
14
  BASE_DIR = os.path.dirname(__file__)
 
22
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
23
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
24
 
25
+ # ================= INPUT SIZE PER MODEL =================
26
+ # Only corn differs — others remain 224
27
+
28
+ KERAS_INPUT_SIZES = {
29
+ "corn": 300,
30
+ "bean": 224,
31
+ }
32
+
33
  # ================= IMAGE PREPROCESS =================
34
 
35
  def preprocess_pytorch(img, size=224):
 
43
  ])
44
  return transform(img).unsqueeze(0)
45
 
46
+ def preprocess_keras(img, crop_name):
47
+ size = KERAS_INPUT_SIZES.get(crop_name, 224)
48
  img = img.resize((size, size))
49
  arr = np.array(img).astype("float32") / 255.0
50
  return np.expand_dims(arr, axis=0)
 
68
  label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
69
 
70
  if not os.path.exists(label_path):
71
+ raise FileNotFoundError(f"Missing label file: {label_path}")
72
 
73
  with open(label_path, "r") as f:
74
  LABELS[crop_name] = json.load(f)
 
91
  custom_objects={
92
  "swish": tf.keras.activations.swish,
93
  "FixedDropout": FixedDropout,
94
+ # Needed ONLY for corn, harmless for others
95
+ "EfficientNetB3": EfficientNetB3,
96
  },
97
  compile=False # IMPORTANT for HF
98
  )
 
123
  model = KERAS_MODELS[crop_name]
124
  labels = LABELS[crop_name]
125
 
126
+ arr = preprocess_keras(image, crop_name)
127
  preds = model.predict(arr, verbose=0)[0]
128
  idx = int(np.argmax(preds))
129
 
130
  return labels[idx], float(preds[idx])
131
 
132
  else:
133
+ raise ValueError(f"No model found for crop: {crop_name}")