Spaces:
Runtime error
Runtime error
| import torch | |
| import json | |
| import os | |
| from models.model import PlantCNN | |
| def load_model_and_config(): | |
| """Load the trained model and all configuration files""" | |
| # Paths | |
| MODEL_PATH = "saved_models/plant_cnn.pt" | |
| CLASS_NAMES_PATH = "ui_text/class_names.json" | |
| # Config | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CHANNELS = [96, 192, 384, 768] | |
| DROPOUT = 0.4 | |
| NUM_CLASSES = 39 | |
| # Load class names | |
| with open(CLASS_NAMES_PATH, "r") as f: | |
| class_names = json.load(f) | |
| # Load disease info | |
| with open("ui_text/disease_info.json", "r", encoding="utf-8") as f: | |
| disease_db = json.load(f) | |
| # Load model | |
| model = PlantCNN(num_classes=NUM_CLASSES, channels=CHANNELS, dropout=DROPOUT).to(DEVICE) | |
| if os.path.exists(MODEL_PATH): | |
| print("Loading trained model weights...") | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model.eval() | |
| else: | |
| exit() | |
| return { | |
| 'model': model, | |
| 'class_names': class_names, | |
| 'disease_db': disease_db, | |
| 'device': DEVICE | |
| } | |
| def load_ui_text(): | |
| """Load intro and about markdown files""" | |
| with open("ui_text/intro.md", "r", encoding="utf-8") as f: | |
| intro_md = f.read() | |
| with open("ui_text/about.md", "r", encoding="utf-8") as f: | |
| about_md = f.read() | |
| return intro_md, about_md |