import torch import json import os from models.model import PlantCNN from utils.config import load_config def load_model_and_config(): MODEL_PATH = "saved_models/plant_cnn.pt" CLASS_NAMES_PATH = "ui_text/class_names.json" DISEASE_INFO_PATH = "ui_text/disease_info.json" config = load_config() DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CHANNELS = config["channels"] DROPOUT = config["dropout"] NUM_CLASSES = config["num_classes"] with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f: class_names = json.load(f) with open(DISEASE_INFO_PATH, "r", encoding="utf-8") as f: disease_db = json.load(f) model = PlantCNN( num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT ).to(DEVICE) if os.path.exists(MODEL_PATH): print("Loading trained model weights...") state_dict = torch.load(MODEL_PATH, map_location=DEVICE) model.load_state_dict(state_dict) model.eval() else: print(f"Model file not found at {MODEL_PATH}") exit() return { "model": model, "class_names": class_names, "disease_db": disease_db, "device": DEVICE, } def load_ui_text(): with open("ui_text/intro.md", "r", encoding="utf-8") as f: intro_md = f.read() with open("ui_text/about.md", "r", encoding="utf-8") as f: about_md = f.read() return intro_md, about_md