File size: 1,498 Bytes
867bae1
 
 
 
64515d2
 
867bae1
 
64515d2
867bae1
 
64515d2
7a068ee
 
867bae1
7a068ee
126c878
 
 
64515d2
 
867bae1
64515d2
 
867bae1
64515d2
 
 
 
 
 
 
867bae1
 
7a068ee
64515d2
867bae1
 
842df8b
867bae1
64515d2
867bae1
64515d2
 
 
 
867bae1
64515d2
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import json
import os
from models.model import PlantCNN
from utils.config import load_config


def load_model_and_config():

    MODEL_PATH = "saved_models/plant_cnn.pt"
    CLASS_NAMES_PATH = "ui_text/class_names.json"
    DISEASE_INFO_PATH = "ui_text/disease_info.json"

    config = load_config()
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    CHANNELS = config["channels"] 
    DROPOUT = config["dropout"]            
    NUM_CLASSES = config["num_classes"]  

    with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f:
        class_names = json.load(f)

    with open(DISEASE_INFO_PATH, "r", encoding="utf-8") as f:
        disease_db = json.load(f)

    model = PlantCNN(
        num_classes=NUM_CLASSES,
        channels=CHANNELS,
        dropout=DROPOUT
    ).to(DEVICE)

    if os.path.exists(MODEL_PATH):
        print("Loading trained model weights...")
        state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
        model.load_state_dict(state_dict)
        model.eval()
    else:
        print(f"Model file not found at {MODEL_PATH}")
        exit()

    return {
        "model": model,
        "class_names": class_names,
        "disease_db": disease_db,
        "device": DEVICE,
    }


def load_ui_text():

    with open("ui_text/intro.md", "r", encoding="utf-8") as f:
        intro_md = f.read()

    with open("ui_text/about.md", "r", encoding="utf-8") as f:
        about_md = f.read()

    return intro_md, about_md