Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,6 +34,19 @@ def get_model(model_name, classes, device):
|
|
| 34 |
model.load_state_dict(torch.load('BaseLine-Model.pt', map_location=torch.device(device)))
|
| 35 |
|
| 36 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def make_predictions(input_img, model_name):
|
| 39 |
classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']
|
|
|
|
| 34 |
model.load_state_dict(torch.load('BaseLine-Model.pt', map_location=torch.device(device)))
|
| 35 |
|
| 36 |
return model
|
| 37 |
+
|
| 38 |
+
def get_transform(input_img, device):
|
| 39 |
+
normalize = transforms.Normalize(
|
| 40 |
+
[0.485, 0.456, 0.406],
|
| 41 |
+
[0.229, 0.224, 0.225]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
test_transform = transforms.Compose([
|
| 45 |
+
transforms.ToTensor(),
|
| 46 |
+
normalize,
|
| 47 |
+
])
|
| 48 |
+
input_img = test_transform(input_img).unsqueeze(0).to(device)
|
| 49 |
+
return input_img
|
| 50 |
|
| 51 |
def make_predictions(input_img, model_name):
|
| 52 |
classes = ['buildings','forest', 'glacier', 'mountain', 'sea', 'street']
|