fatimaxa commited on
Commit
64515d2
·
verified ·
1 Parent(s): 7a068ee

Update utils/model_loader.py

Browse files
Files changed (1) hide show
  1. utils/model_loader.py +38 -23
utils/model_loader.py CHANGED
@@ -2,44 +2,59 @@ import torch
2
  import json
3
  import os
4
  from models.model import PlantCNN
5
- from utils.config import load_config
 
6
 
7
  def load_model_and_config():
8
-
9
- # Paths
10
  MODEL_PATH = "saved_models/plant_cnn.pt"
11
  CLASS_NAMES_PATH = "ui_text/class_names.json"
 
12
 
13
- # Load config
14
  config = load_config()
15
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- CHANNELS = config["channels"]
18
- DROPOUT = config["dropout"]
19
- NUM_CLASSES = config["num_classes"]
20
-
21
- # Load class names
22
- with open(CLASS_NAMES_PATH, "r") as f:
23
  class_names = json.load(f)
24
-
25
- # Load disease info
26
- with open("ui_text/disease_info.json", "r", encoding="utf-8") as f:
27
  disease_db = json.load(f)
28
-
29
- # Load model
30
- model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE)
 
 
 
 
31
  if os.path.exists(MODEL_PATH):
32
  print("Loading trained model weights...")
33
  state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
34
- model.load_state_dict(state_dict)
35
  model.eval()
36
  else:
37
- print("Model file not found at", MODEL_PATH)
38
  exit()
39
-
40
  return {
41
- 'model': model,
42
- 'class_names': class_names,
43
- 'disease_db': disease_db,
44
- 'device': DEVICE
45
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import json
3
  import os
4
  from models.model import PlantCNN
5
+ from utils.config import load_config
6
+
7
 
8
  def load_model_and_config():
9
+ """Load the trained model and associated config + metadata"""
10
+
11
  MODEL_PATH = "saved_models/plant_cnn.pt"
12
  CLASS_NAMES_PATH = "ui_text/class_names.json"
13
+ DISEASE_INFO_PATH = "ui_text/disease_info.json"
14
 
 
15
  config = load_config()
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ CHANNELS = config["channels"] # e.g., [64, 128, 256, 512]
19
+ DROPOUT = config["dropout"] # 0.4
20
+ NUM_CLASSES = config["num_classes"] # 39
21
+
22
+ with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f:
 
23
  class_names = json.load(f)
24
+
25
+ with open(DISEASE_INFO_PATH, "r", encoding="utf-8") as f:
 
26
  disease_db = json.load(f)
27
+
28
+ model = PlantCNN(
29
+ num_classes=NUM_CLASSES,
30
+ channels=CHANNELS,
31
+ dropout=DROPOUT
32
+ ).to(DEVICE)
33
+
34
  if os.path.exists(MODEL_PATH):
35
  print("Loading trained model weights...")
36
  state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
37
+ model.load_state_dict(state_dict)
38
  model.eval()
39
  else:
40
+ print(f"Model file not found at {MODEL_PATH}")
41
  exit()
42
+
43
  return {
44
+ "model": model,
45
+ "class_names": class_names,
46
+ "disease_db": disease_db,
47
+ "device": DEVICE,
48
  }
49
+
50
+
51
+ def load_ui_text():
52
+ """Load intro.md and about.md for display in the UI"""
53
+
54
+ with open("ui_text/intro.md", "r", encoding="utf-8") as f:
55
+ intro_md = f.read()
56
+
57
+ with open("ui_text/about.md", "r", encoding="utf-8") as f:
58
+ about_md = f.read()
59
+
60
+ return intro_md, about_md