Spaces:
Sleeping
Sleeping
Update utils/model_loader.py
Browse files- utils/model_loader.py +38 -23
utils/model_loader.py
CHANGED
|
@@ -2,44 +2,59 @@ import torch
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
from models.model import PlantCNN
|
| 5 |
-
from utils.config import load_config
|
|
|
|
| 6 |
|
| 7 |
def load_model_and_config():
|
| 8 |
-
|
| 9 |
-
|
| 10 |
MODEL_PATH = "saved_models/plant_cnn.pt"
|
| 11 |
CLASS_NAMES_PATH = "ui_text/class_names.json"
|
|
|
|
| 12 |
|
| 13 |
-
# Load config
|
| 14 |
config = load_config()
|
| 15 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
|
| 17 |
-
CHANNELS = config["channels"]
|
| 18 |
-
DROPOUT = config["dropout"]
|
| 19 |
-
NUM_CLASSES = config["num_classes"]
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
with open(CLASS_NAMES_PATH, "r") as f:
|
| 23 |
class_names = json.load(f)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
with open("ui_text/disease_info.json", "r", encoding="utf-8") as f:
|
| 27 |
disease_db = json.load(f)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
if os.path.exists(MODEL_PATH):
|
| 32 |
print("Loading trained model weights...")
|
| 33 |
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 34 |
-
model.load_state_dict(state_dict)
|
| 35 |
model.eval()
|
| 36 |
else:
|
| 37 |
-
print("Model file not found at
|
| 38 |
exit()
|
| 39 |
-
|
| 40 |
return {
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
from models.model import PlantCNN
|
| 5 |
+
from utils.config import load_config
|
| 6 |
+
|
| 7 |
|
| 8 |
def load_model_and_config():
|
| 9 |
+
"""Load the trained model and associated config + metadata"""
|
| 10 |
+
|
| 11 |
MODEL_PATH = "saved_models/plant_cnn.pt"
|
| 12 |
CLASS_NAMES_PATH = "ui_text/class_names.json"
|
| 13 |
+
DISEASE_INFO_PATH = "ui_text/disease_info.json"
|
| 14 |
|
|
|
|
| 15 |
config = load_config()
|
| 16 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
|
| 18 |
+
CHANNELS = config["channels"] # e.g., [64, 128, 256, 512]
|
| 19 |
+
DROPOUT = config["dropout"] # 0.4
|
| 20 |
+
NUM_CLASSES = config["num_classes"] # 39
|
| 21 |
+
|
| 22 |
+
with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f:
|
|
|
|
| 23 |
class_names = json.load(f)
|
| 24 |
+
|
| 25 |
+
with open(DISEASE_INFO_PATH, "r", encoding="utf-8") as f:
|
|
|
|
| 26 |
disease_db = json.load(f)
|
| 27 |
+
|
| 28 |
+
model = PlantCNN(
|
| 29 |
+
num_classes=NUM_CLASSES,
|
| 30 |
+
channels=CHANNELS,
|
| 31 |
+
dropout=DROPOUT
|
| 32 |
+
).to(DEVICE)
|
| 33 |
+
|
| 34 |
if os.path.exists(MODEL_PATH):
|
| 35 |
print("Loading trained model weights...")
|
| 36 |
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 37 |
+
model.load_state_dict(state_dict)
|
| 38 |
model.eval()
|
| 39 |
else:
|
| 40 |
+
print(f"❌ Model file not found at {MODEL_PATH}")
|
| 41 |
exit()
|
| 42 |
+
|
| 43 |
return {
|
| 44 |
+
"model": model,
|
| 45 |
+
"class_names": class_names,
|
| 46 |
+
"disease_db": disease_db,
|
| 47 |
+
"device": DEVICE,
|
| 48 |
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_ui_text():
|
| 52 |
+
"""Load intro.md and about.md for display in the UI"""
|
| 53 |
+
|
| 54 |
+
with open("ui_text/intro.md", "r", encoding="utf-8") as f:
|
| 55 |
+
intro_md = f.read()
|
| 56 |
+
|
| 57 |
+
with open("ui_text/about.md", "r", encoding="utf-8") as f:
|
| 58 |
+
about_md = f.read()
|
| 59 |
+
|
| 60 |
+
return intro_md, about_md
|