Spaces:
Sleeping
Sleeping
| import torch | |
| import json | |
| import os | |
| from models.model import PlantCNN | |
| from utils.config import load_config | |
| def load_model_and_config(): | |
| MODEL_PATH = "saved_models/plant_cnn.pt" | |
| CLASS_NAMES_PATH = "ui_text/class_names.json" | |
| DISEASE_INFO_PATH = "ui_text/disease_info.json" | |
| config = load_config() | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CHANNELS = config["channels"] | |
| DROPOUT = config["dropout"] | |
| NUM_CLASSES = config["num_classes"] | |
| with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f: | |
| class_names = json.load(f) | |
| with open(DISEASE_INFO_PATH, "r", encoding="utf-8") as f: | |
| disease_db = json.load(f) | |
| model = PlantCNN( | |
| num_classes=NUM_CLASSES, | |
| channels=CHANNELS, | |
| dropout=DROPOUT | |
| ).to(DEVICE) | |
| if os.path.exists(MODEL_PATH): | |
| print("Loading trained model weights...") | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| else: | |
| print(f"Model file not found at {MODEL_PATH}") | |
| exit() | |
| return { | |
| "model": model, | |
| "class_names": class_names, | |
| "disease_db": disease_db, | |
| "device": DEVICE, | |
| } | |
| def load_ui_text(): | |
| with open("ui_text/intro.md", "r", encoding="utf-8") as f: | |
| intro_md = f.read() | |
| with open("ui_text/about.md", "r", encoding="utf-8") as f: | |
| about_md = f.read() | |
| return intro_md, about_md | |