Spaces:
Sleeping
Sleeping
Update utils/model_loader.py
Browse files- utils/model_loader.py +11 -17
utils/model_loader.py
CHANGED
|
@@ -2,19 +2,21 @@ import torch
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
from models.model import PlantCNN
|
|
|
|
| 5 |
|
| 6 |
def load_model_and_config():
|
| 7 |
-
"""Load the trained model and all configuration files"""
|
| 8 |
|
| 9 |
# Paths
|
| 10 |
MODEL_PATH = "saved_models/plant_cnn.pt"
|
| 11 |
CLASS_NAMES_PATH = "ui_text/class_names.json"
|
| 12 |
-
|
| 13 |
-
#
|
|
|
|
| 14 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
# Load class names
|
| 20 |
with open(CLASS_NAMES_PATH, "r") as f:
|
|
@@ -28,9 +30,11 @@ def load_model_and_config():
|
|
| 28 |
model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE)
|
| 29 |
if os.path.exists(MODEL_PATH):
|
| 30 |
print("Loading trained model weights...")
|
| 31 |
-
|
|
|
|
| 32 |
model.eval()
|
| 33 |
else:
|
|
|
|
| 34 |
exit()
|
| 35 |
|
| 36 |
return {
|
|
@@ -39,13 +43,3 @@ def load_model_and_config():
|
|
| 39 |
'disease_db': disease_db,
|
| 40 |
'device': DEVICE
|
| 41 |
}
|
| 42 |
-
|
| 43 |
-
def load_ui_text():
|
| 44 |
-
"""Load intro and about markdown files"""
|
| 45 |
-
with open("ui_text/intro.md", "r", encoding="utf-8") as f:
|
| 46 |
-
intro_md = f.read()
|
| 47 |
-
|
| 48 |
-
with open("ui_text/about.md", "r", encoding="utf-8") as f:
|
| 49 |
-
about_md = f.read()
|
| 50 |
-
|
| 51 |
-
return intro_md, about_md
|
|
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
from models.model import PlantCNN
|
| 5 |
+
from utils.config import load_config
|
| 6 |
|
| 7 |
def load_model_and_config():
|
|
|
|
| 8 |
|
| 9 |
# Paths
|
| 10 |
MODEL_PATH = "saved_models/plant_cnn.pt"
|
| 11 |
CLASS_NAMES_PATH = "ui_text/class_names.json"
|
| 12 |
+
|
| 13 |
+
# Load config
|
| 14 |
+
config = load_config()
|
| 15 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
|
| 17 |
+
CHANNELS = config["channels"]
|
| 18 |
+
DROPOUT = config["dropout"]
|
| 19 |
+
NUM_CLASSES = config["num_classes"]
|
| 20 |
|
| 21 |
# Load class names
|
| 22 |
with open(CLASS_NAMES_PATH, "r") as f:
|
|
|
|
| 30 |
model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE)
|
| 31 |
if os.path.exists(MODEL_PATH):
|
| 32 |
print("Loading trained model weights...")
|
| 33 |
+
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 34 |
+
model.load_state_dict(state_dict)
|
| 35 |
model.eval()
|
| 36 |
else:
|
| 37 |
+
print("Model file not found at", MODEL_PATH)
|
| 38 |
exit()
|
| 39 |
|
| 40 |
return {
|
|
|
|
| 43 |
'disease_db': disease_db,
|
| 44 |
'device': DEVICE
|
| 45 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|