fatimaxa commited on
Commit
7a068ee
·
verified ·
1 Parent(s): 3d05927

Update utils/model_loader.py

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +11 -17
utils/model_loader.py CHANGED
@@ -2,19 +2,21 @@ import torch
2
  import json
3
  import os
4
  from models.model import PlantCNN
 
5
 
6
  def load_model_and_config():
7
- """Load the trained model and all configuration files"""
8
 
9
  # Paths
10
  MODEL_PATH = "saved_models/plant_cnn.pt"
11
  CLASS_NAMES_PATH = "ui_text/class_names.json"
12
-
13
- # Config
 
14
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- CHANNELS = [96, 192, 384, 768]
16
- DROPOUT = 0.4
17
- NUM_CLASSES = 39
 
18
 
19
  # Load class names
20
  with open(CLASS_NAMES_PATH, "r") as f:
@@ -28,9 +30,11 @@ def load_model_and_config():
28
  model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE)
29
  if os.path.exists(MODEL_PATH):
30
  print("Loading trained model weights...")
31
- model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
 
32
  model.eval()
33
  else:
 
34
  exit()
35
 
36
  return {
@@ -39,13 +43,3 @@ def load_model_and_config():
39
  'disease_db': disease_db,
40
  'device': DEVICE
41
  }
42
-
43
- def load_ui_text():
44
- """Load intro and about markdown files"""
45
- with open("ui_text/intro.md", "r", encoding="utf-8") as f:
46
- intro_md = f.read()
47
-
48
- with open("ui_text/about.md", "r", encoding="utf-8") as f:
49
- about_md = f.read()
50
-
51
- return intro_md, about_md
 
2
  import json
3
  import os
4
  from models.model import PlantCNN
5
+ from utils.config import load_config
6
 
7
  def load_model_and_config():
 
8
 
9
  # Paths
10
  MODEL_PATH = "saved_models/plant_cnn.pt"
11
  CLASS_NAMES_PATH = "ui_text/class_names.json"
12
+
13
+ # Load config
14
+ config = load_config()
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ CHANNELS = config["channels"]
18
+ DROPOUT = config["dropout"]
19
+ NUM_CLASSES = config["num_classes"]
20
 
21
  # Load class names
22
  with open(CLASS_NAMES_PATH, "r") as f:
 
30
  model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE)
31
  if os.path.exists(MODEL_PATH):
32
  print("Loading trained model weights...")
33
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
34
+ model.load_state_dict(state_dict)
35
  model.eval()
36
  else:
37
+ print("Model file not found at", MODEL_PATH)
38
  exit()
39
 
40
  return {
 
43
  'disease_db': disease_db,
44
  'device': DEVICE
45
  }