Spaces:
Sleeping
Sleeping
File size: 1,498 Bytes
867bae1 64515d2 867bae1 64515d2 867bae1 64515d2 7a068ee 867bae1 7a068ee 126c878 64515d2 867bae1 64515d2 867bae1 64515d2 867bae1 7a068ee 64515d2 867bae1 842df8b 867bae1 64515d2 867bae1 64515d2 867bae1 64515d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import torch
import json
import os
from models.model import PlantCNN
from utils.config import load_config
def load_model_and_config():
MODEL_PATH = "saved_models/plant_cnn.pt"
CLASS_NAMES_PATH = "ui_text/class_names.json"
DISEASE_INFO_PATH = "ui_text/disease_info.json"
config = load_config()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHANNELS = config["channels"]
DROPOUT = config["dropout"]
NUM_CLASSES = config["num_classes"]
with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f:
class_names = json.load(f)
with open(DISEASE_INFO_PATH, "r", encoding="utf-8") as f:
disease_db = json.load(f)
model = PlantCNN(
num_classes=NUM_CLASSES,
channels=CHANNELS,
dropout=DROPOUT
).to(DEVICE)
if os.path.exists(MODEL_PATH):
print("Loading trained model weights...")
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
else:
print(f"Model file not found at {MODEL_PATH}")
exit()
return {
"model": model,
"class_names": class_names,
"disease_db": disease_db,
"device": DEVICE,
}
def load_ui_text():
with open("ui_text/intro.md", "r", encoding="utf-8") as f:
intro_md = f.read()
with open("ui_text/about.md", "r", encoding="utf-8") as f:
about_md = f.read()
return intro_md, about_md
|