TEAM_7_GAP / utils /model_loader.py
fatimaxa's picture
Update utils/model_loader.py
e767536 verified
import torch
import json
import os
from models.model import PlantCNN
from utils.config import load_config
def load_model_and_config():
MODEL_PATH = "saved_models/plant_cnn.pt"
CLASS_NAMES_PATH = "ui_text/class_names.json"
DISEASE_INFO_PATH = "ui_text/disease_info.json"
config = load_config()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHANNELS = config["channels"]
DROPOUT = config["dropout"]
NUM_CLASSES = config["num_classes"]
with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f:
class_names = json.load(f)
with open(DISEASE_INFO_PATH, "r", encoding="utf-8") as f:
disease_db = json.load(f)
model = PlantCNN(
num_classes=NUM_CLASSES,
channels=CHANNELS,
dropout=DROPOUT
).to(DEVICE)
if os.path.exists(MODEL_PATH):
print("Loading trained model weights...")
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
else:
print(f"Model file not found at {MODEL_PATH}")
exit()
return {
"model": model,
"class_names": class_names,
"disease_db": disease_db,
"device": DEVICE,
}
def load_ui_text():
with open("ui_text/intro.md", "r", encoding="utf-8") as f:
intro_md = f.read()
with open("ui_text/about.md", "r", encoding="utf-8") as f:
about_md = f.read()
return intro_md, about_md