Inoue1 commited on
Commit
450d937
·
verified ·
1 Parent(s): 65af2fd

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -6
model.py CHANGED
@@ -38,17 +38,13 @@ LABELS = {}
38
  def load_models():
39
  for file in os.listdir(MODELS_DIR):
40
  name, ext = os.path.splitext(file)
41
-
42
- # Normalize crop key
43
  crop_name = name.replace("_model", "").lower()
44
 
45
  model_path = os.path.join(MODELS_DIR, file)
46
  label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
47
 
48
  if not os.path.exists(label_path):
49
- raise FileNotFoundError(
50
- f"Label file missing for {crop_name}: {label_path}"
51
- )
52
 
53
  with open(label_path) as f:
54
  LABELS[crop_name] = json.load(f)
@@ -62,7 +58,11 @@ def load_models():
62
  PYTORCH_MODELS[crop_name] = model
63
 
64
  elif ext in [".keras", ".h5"]:
65
- KERAS_MODELS[crop_name] = tf.keras.models.load_model(model_path)
 
 
 
 
66
 
67
  # Load once
68
  load_models()
 
38
  def load_models():
39
  for file in os.listdir(MODELS_DIR):
40
  name, ext = os.path.splitext(file)
 
 
41
  crop_name = name.replace("_model", "").lower()
42
 
43
  model_path = os.path.join(MODELS_DIR, file)
44
  label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
45
 
46
  if not os.path.exists(label_path):
47
+ raise FileNotFoundError(f"Missing labels for {crop_name}")
 
 
48
 
49
  with open(label_path) as f:
50
  LABELS[crop_name] = json.load(f)
 
58
  PYTORCH_MODELS[crop_name] = model
59
 
60
  elif ext in [".keras", ".h5"]:
61
+ KERAS_MODELS[crop_name] = tf.keras.models.load_model(
62
+ model_path,
63
+ custom_objects={"swish": tf.keras.activations.swish},
64
+ compile=False
65
+ )
66
 
67
  # Load once
68
  load_models()