Inoue1 commited on
Commit
65af2fd
·
verified ·
1 Parent(s): 1aa5542

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +16 -7
model.py CHANGED
@@ -38,22 +38,31 @@ LABELS = {}
38
  def load_models():
39
  for file in os.listdir(MODELS_DIR):
40
  name, ext = os.path.splitext(file)
 
 
 
 
41
  model_path = os.path.join(MODELS_DIR, file)
 
 
 
 
 
 
42
 
43
- # Load labels
44
- with open(os.path.join(LABELS_DIR, f"{name}.json")) as f:
45
- LABELS[name] = json.load(f)
46
 
47
  if ext == ".pth":
48
- num_classes = len(LABELS[name])
49
  model = models.resnet18(weights=None)
50
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
51
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
52
  model.eval()
53
- PYTORCH_MODELS[name] = model
54
 
55
- elif ext == ".keras":
56
- KERAS_MODELS[name] = tf.keras.models.load_model(model_path)
57
 
58
  # Load once
59
  load_models()
 
38
  def load_models():
39
  for file in os.listdir(MODELS_DIR):
40
  name, ext = os.path.splitext(file)
41
+
42
+ # Normalize crop key
43
+ crop_name = name.replace("_model", "").lower()
44
+
45
  model_path = os.path.join(MODELS_DIR, file)
46
+ label_path = os.path.join(LABELS_DIR, f"{crop_name}_labels.json")
47
+
48
+ if not os.path.exists(label_path):
49
+ raise FileNotFoundError(
50
+ f"Label file missing for {crop_name}: {label_path}"
51
+ )
52
 
53
+ with open(label_path) as f:
54
+ LABELS[crop_name] = json.load(f)
 
55
 
56
  if ext == ".pth":
57
+ num_classes = len(LABELS[crop_name])
58
  model = models.resnet18(weights=None)
59
  model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
60
  model.load_state_dict(torch.load(model_path, map_location="cpu"))
61
  model.eval()
62
+ PYTORCH_MODELS[crop_name] = model
63
 
64
+ elif ext in [".keras", ".h5"]:
65
+ KERAS_MODELS[crop_name] = tf.keras.models.load_model(model_path)
66
 
67
  # Load once
68
  load_models()