Inoue1 commited on
Commit
4e13890
·
verified ·
1 Parent(s): e70906b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -21
model.py CHANGED
@@ -6,24 +6,25 @@ import tensorflow as tf
6
  from PIL import Image
7
  from torchvision import models, transforms
8
 
9
- # EfficientNet (needed ONLY for corn)
10
- from efficientnet.tfkeras import EfficientNetB3
11
-
12
  # ================= PATHS =================
13
 
14
  BASE_DIR = os.path.dirname(__file__)
15
  MODELS_DIR = os.path.join(BASE_DIR, "models")
16
  LABELS_DIR = os.path.join(BASE_DIR, "labels")
17
 
18
- # ================= FIX: EfficientNet Custom Layer =================
19
 
20
  @tf.keras.utils.register_keras_serializable()
21
  class FixedDropout(tf.keras.layers.Dropout):
22
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
23
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
24
 
 
 
 
 
 
25
  # ================= INPUT SIZE PER MODEL =================
26
- # Only corn differs — others remain 224
27
 
28
  KERAS_INPUT_SIZES = {
29
  "corn": 300,
@@ -55,13 +56,11 @@ PYTORCH_MODELS = {}
55
  KERAS_MODELS = {}
56
  LABELS = {}
57
 
58
- # ================= LOAD MODELS ONCE =================
59
 
60
  def load_models():
61
  for file in os.listdir(MODELS_DIR):
62
  name, ext = os.path.splitext(file)
63
-
64
- # Normalize crop key (banana_model -> banana)
65
  crop_name = name.replace("_model", "").lower()
66
 
67
  model_path = os.path.join(MODELS_DIR, file)
@@ -73,31 +72,27 @@ def load_models():
73
  with open(label_path, "r") as f:
74
  LABELS[crop_name] = json.load(f)
75
 
76
- # ---------- PyTorch (.pth) ----------
77
  if ext == ".pth":
78
  num_classes = len(LABELS[crop_name])
79
-
80
  model = models.resnet18(weights=None)
81
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
82
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
83
  model.eval()
84
-
85
  PYTORCH_MODELS[crop_name] = model
86
 
87
- # ---------- Keras (.keras / .h5) ----------
88
  elif ext in [".keras", ".h5"]:
89
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
90
  model_path,
91
  custom_objects={
92
  "swish": tf.keras.activations.swish,
93
  "FixedDropout": FixedDropout,
94
- # Needed ONLY for corn, harmless for others
95
  "EfficientNetB3": EfficientNetB3,
96
  },
97
- compile=False # IMPORTANT for HF
98
  )
99
 
100
- # Load models at startup
101
  load_models()
102
 
103
  # ================= PREDICTION =================
@@ -105,28 +100,22 @@ load_models()
105
  def predict(image, crop_name):
106
  crop_name = crop_name.strip().lower()
107
 
108
- # -------- PyTorch --------
109
  if crop_name in PYTORCH_MODELS:
110
  model = PYTORCH_MODELS[crop_name]
111
  labels = LABELS[crop_name]
112
-
113
  tensor = preprocess_pytorch(image)
114
  with torch.no_grad():
115
  output = model(tensor)
116
  probs = torch.softmax(output[0], dim=0)
117
  idx = probs.argmax().item()
118
-
119
  return labels[idx], float(probs[idx])
120
 
121
- # -------- Keras --------
122
  elif crop_name in KERAS_MODELS:
123
  model = KERAS_MODELS[crop_name]
124
  labels = LABELS[crop_name]
125
-
126
  arr = preprocess_keras(image, crop_name)
127
  preds = model.predict(arr, verbose=0)[0]
128
  idx = int(np.argmax(preds))
129
-
130
  return labels[idx], float(preds[idx])
131
 
132
  else:
 
6
  from PIL import Image
7
  from torchvision import models, transforms
8
 
 
 
 
9
  # ================= PATHS =================
10
 
11
  BASE_DIR = os.path.dirname(__file__)
12
  MODELS_DIR = os.path.join(BASE_DIR, "models")
13
  LABELS_DIR = os.path.join(BASE_DIR, "labels")
14
 
15
+ # ================= TF CUSTOM OBJECTS =================
16
 
17
  @tf.keras.utils.register_keras_serializable()
18
  class FixedDropout(tf.keras.layers.Dropout):
19
  def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
  super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
 
22
+ # Dummy class to satisfy EfficientNet deserialization
23
+ @tf.keras.utils.register_keras_serializable()
24
+ class EfficientNetB3(tf.keras.Model):
25
+ pass
26
+
27
  # ================= INPUT SIZE PER MODEL =================
 
28
 
29
  KERAS_INPUT_SIZES = {
30
  "corn": 300,
 
56
  KERAS_MODELS = {}
57
  LABELS = {}
58
 
59
+ # ================= LOAD MODELS =================
60
 
61
  def load_models():
62
  for file in os.listdir(MODELS_DIR):
63
  name, ext = os.path.splitext(file)
 
 
64
  crop_name = name.replace("_model", "").lower()
65
 
66
  model_path = os.path.join(MODELS_DIR, file)
 
72
  with open(label_path, "r") as f:
73
  LABELS[crop_name] = json.load(f)
74
 
75
+ # ---------- PyTorch ----------
76
  if ext == ".pth":
77
  num_classes = len(LABELS[crop_name])
 
78
  model = models.resnet18(weights=None)
79
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
80
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
81
  model.eval()
 
82
  PYTORCH_MODELS[crop_name] = model
83
 
84
+ # ---------- Keras ----------
85
  elif ext in [".keras", ".h5"]:
86
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
87
  model_path,
88
  custom_objects={
89
  "swish": tf.keras.activations.swish,
90
  "FixedDropout": FixedDropout,
 
91
  "EfficientNetB3": EfficientNetB3,
92
  },
93
+ compile=False
94
  )
95
 
 
96
  load_models()
97
 
98
  # ================= PREDICTION =================
 
100
  def predict(image, crop_name):
101
  crop_name = crop_name.strip().lower()
102
 
 
103
  if crop_name in PYTORCH_MODELS:
104
  model = PYTORCH_MODELS[crop_name]
105
  labels = LABELS[crop_name]
 
106
  tensor = preprocess_pytorch(image)
107
  with torch.no_grad():
108
  output = model(tensor)
109
  probs = torch.softmax(output[0], dim=0)
110
  idx = probs.argmax().item()
 
111
  return labels[idx], float(probs[idx])
112
 
 
113
  elif crop_name in KERAS_MODELS:
114
  model = KERAS_MODELS[crop_name]
115
  labels = LABELS[crop_name]
 
116
  arr = preprocess_keras(image, crop_name)
117
  preds = model.predict(arr, verbose=0)[0]
118
  idx = int(np.argmax(preds))
 
119
  return labels[idx], float(preds[idx])
120
 
121
  else: