Inoue1 commited on
Commit
f3591b4
·
verified ·
1 Parent(s): 450d937

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +37 -14
model.py CHANGED
@@ -6,11 +6,19 @@ import tensorflow as tf
6
  from PIL import Image
7
  from torchvision import models, transforms
8
 
9
- BASE_DIR = os.path.dirname(__file__)
10
 
 
11
  MODELS_DIR = os.path.join(BASE_DIR, "models")
12
  LABELS_DIR = os.path.join(BASE_DIR, "labels")
13
 
 
 
 
 
 
 
 
14
  # ================= IMAGE PREPROCESS =================
15
 
16
  def preprocess_pytorch(img, size=224):
@@ -26,52 +34,64 @@ def preprocess_pytorch(img, size=224):
26
 
27
  def preprocess_keras(img, size=224):
28
  img = img.resize((size, size))
29
- arr = np.array(img) / 255.0
30
  return np.expand_dims(arr, axis=0)
31
 
32
- # ================= MODEL LOADERS =================
33
 
34
  PYTORCH_MODELS = {}
35
  KERAS_MODELS = {}
36
  LABELS = {}
37
 
 
 
38
  def load_models():
39
  for file in os.listdir(MODELS_DIR):
40
  name, ext = os.path.splitext(file)
 
 
41
  crop_name = name.replace("_model", "").lower()
42
 
43
  model_path = os.path.join(MODELS_DIR, file)
44
  label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
45
 
46
  if not os.path.exists(label_path):
47
- raise FileNotFoundError(f"Missing labels for {crop_name}")
48
 
49
- with open(label_path) as f:
50
  LABELS[crop_name] = json.load(f)
51
 
 
52
  if ext == ".pth":
53
  num_classes = len(LABELS[crop_name])
 
54
  model = models.resnet18(weights=None)
55
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
56
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
57
  model.eval()
 
58
  PYTORCH_MODELS[crop_name] = model
59
 
 
60
  elif ext in [".keras", ".h5"]:
61
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
62
  model_path,
63
- custom_objects={"swish": tf.keras.activations.swish},
64
- compile=False
 
 
 
65
  )
66
 
67
- # Load once
68
  load_models()
69
 
70
- # ================= PREDICT =================
71
 
72
  def predict(image, crop_name):
73
- crop_name = crop_name.lower()
74
 
 
75
  if crop_name in PYTORCH_MODELS:
76
  model = PYTORCH_MODELS[crop_name]
77
  labels = LABELS[crop_name]
@@ -81,16 +101,19 @@ def predict(image, crop_name):
81
  output = model(tensor)
82
  probs = torch.softmax(output[0], dim=0)
83
  idx = probs.argmax().item()
84
- return labels[idx], float(probs[idx])
85
 
 
 
 
86
  elif crop_name in KERAS_MODELS:
87
  model = KERAS_MODELS[crop_name]
88
  labels = LABELS[crop_name]
89
 
90
  arr = preprocess_keras(image)
91
- preds = model.predict(arr)[0]
92
- idx = np.argmax(preds)
 
93
  return labels[idx], float(preds[idx])
94
 
95
  else:
96
- raise ValueError(f"No model found for crop: {crop_name}")
 
6
  from PIL import Image
7
  from torchvision import models, transforms
8
 
9
+ # ================= PATHS =================
10
 
11
+ BASE_DIR = os.path.dirname(__file__)
12
  MODELS_DIR = os.path.join(BASE_DIR, "models")
13
  LABELS_DIR = os.path.join(BASE_DIR, "labels")
14
 
15
+ # ================= FIX: EfficientNet Custom Layer =================
16
+
17
+ @tf.keras.utils.register_keras_serializable()
18
+ class FixedDropout(tf.keras.layers.Dropout):
19
+ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
20
+ super().__init__(rate, noise_shape=noise_shape, seed=seed, **kwargs)
21
+
22
  # ================= IMAGE PREPROCESS =================
23
 
24
  def preprocess_pytorch(img, size=224):
 
34
 
35
  def preprocess_keras(img, size=224):
36
  img = img.resize((size, size))
37
+ arr = np.array(img).astype("float32") / 255.0
38
  return np.expand_dims(arr, axis=0)
39
 
40
+ # ================= MODEL REGISTRIES =================
41
 
42
  PYTORCH_MODELS = {}
43
  KERAS_MODELS = {}
44
  LABELS = {}
45
 
46
+ # ================= LOAD MODELS ONCE =================
47
+
48
  def load_models():
49
  for file in os.listdir(MODELS_DIR):
50
  name, ext = os.path.splitext(file)
51
+
52
+ # Normalize crop key (banana_model -> banana)
53
  crop_name = name.replace("_model", "").lower()
54
 
55
  model_path = os.path.join(MODELS_DIR, file)
56
  label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
57
 
58
  if not os.path.exists(label_path):
59
+ raise FileNotFoundError(f"Missing label file: {label_path}")
60
 
61
+ with open(label_path, "r") as f:
62
  LABELS[crop_name] = json.load(f)
63
 
64
+ # ---------- PyTorch (.pth) ----------
65
  if ext == ".pth":
66
  num_classes = len(LABELS[crop_name])
67
+
68
  model = models.resnet18(weights=None)
69
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
70
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
71
  model.eval()
72
+
73
  PYTORCH_MODELS[crop_name] = model
74
 
75
+ # ---------- Keras (.keras / .h5) ----------
76
  elif ext in [".keras", ".h5"]:
77
  KERAS_MODELS[crop_name] = tf.keras.models.load_model(
78
  model_path,
79
+ custom_objects={
80
+ "swish": tf.keras.activations.swish,
81
+ "FixedDropout": FixedDropout,
82
+ },
83
+ compile=False # IMPORTANT for HF
84
  )
85
 
86
+ # Load models at startup
87
  load_models()
88
 
89
+ # ================= PREDICTION =================
90
 
91
  def predict(image, crop_name):
92
+ crop_name = crop_name.strip().lower()
93
 
94
+ # -------- PyTorch --------
95
  if crop_name in PYTORCH_MODELS:
96
  model = PYTORCH_MODELS[crop_name]
97
  labels = LABELS[crop_name]
 
101
  output = model(tensor)
102
  probs = torch.softmax(output[0], dim=0)
103
  idx = probs.argmax().item()
 
104
 
105
+ return labels[idx], float(probs[idx])
106
+
107
+ # -------- Keras --------
108
  elif crop_name in KERAS_MODELS:
109
  model = KERAS_MODELS[crop_name]
110
  labels = LABELS[crop_name]
111
 
112
  arr = preprocess_keras(image)
113
+ preds = model.predict(arr, verbose=0)[0]
114
+ idx = int(np.argmax(preds))
115
+
116
  return labels[idx], float(preds[idx])
117
 
118
  else:
119
+ raise ValueError(f"No model found for crop: {crop_name}")