Image Classification
torch
DanielCruz09 commited on
Commit
c35dbf2
·
verified ·
1 Parent(s): cf8c846

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -0
main.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import torch
3
+ from models.resnet50 import ResNet50
4
+ from models.natural_disaster_dataset import NaturalDisasterDataset
5
+ from torch.utils.data import DataLoader
6
+ import json
7
+
8
+ model_path = hf_hub_download(
9
+ repo_id="DanielCruz09/disaster-image-classifier",
10
+ filename="models/model_weights.pth"
11
+ )
12
+
13
+ print("Model downloaded to: ", model_path)
14
+
15
+ with open("class_names.json", "r") as f:
16
+ class_names = json.load(f)
17
+
18
+ mapping = {name: idx for idx, name in enumerate(class_names)}
19
+
20
+ model = ResNet50(num_classes=len(class_names), mapping=mapping)
21
+ state_dict = torch.load(model_path, map_location="cpu")
22
+ model.model.load_state_dict(state_dict["model_state_dict"])
23
+
24
+ test_path = "data/processed/Test/"
25
+ test_dataset = NaturalDisasterDataset(root=test_path)
26
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
27
+
28
+ model.eval(test_loader=test_loader, write_path="results/resnet50_results.csv")