Image Classification
torch
File size: 965 Bytes
c35dbf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from huggingface_hub import hf_hub_download
import torch
from models.resnet50 import ResNet50
from models.natural_disaster_dataset import NaturalDisasterDataset
from torch.utils.data import DataLoader
import json

model_path = hf_hub_download(
    repo_id="DanielCruz09/disaster-image-classifier",
    filename="models/model_weights.pth"
)

print("Model downloaded to: ", model_path)

with open("class_names.json", "r") as f:
    class_names = json.load(f)

mapping = {name: idx for idx, name in enumerate(class_names)}

model = ResNet50(num_classes=len(class_names), mapping=mapping)
state_dict = torch.load(model_path, map_location="cpu")
model.model.load_state_dict(state_dict["model_state_dict"])

test_path = "data/processed/Test/"
test_dataset = NaturalDisasterDataset(root=test_path)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

model.eval(test_loader=test_loader, write_path="results/resnet50_results.csv")