File size: 965 Bytes
c35dbf2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | from huggingface_hub import hf_hub_download
import torch
from models.resnet50 import ResNet50
from models.natural_disaster_dataset import NaturalDisasterDataset
from torch.utils.data import DataLoader
import json
model_path = hf_hub_download(
repo_id="DanielCruz09/disaster-image-classifier",
filename="models/model_weights.pth"
)
print("Model downloaded to: ", model_path)
with open("class_names.json", "r") as f:
class_names = json.load(f)
mapping = {name: idx for idx, name in enumerate(class_names)}
model = ResNet50(num_classes=len(class_names), mapping=mapping)
state_dict = torch.load(model_path, map_location="cpu")
model.model.load_state_dict(state_dict["model_state_dict"])
test_path = "data/processed/Test/"
test_dataset = NaturalDisasterDataset(root=test_path)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
model.eval(test_loader=test_loader, write_path="results/resnet50_results.csv") |