Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,8 @@ import torchvision.transforms as transforms
|
|
| 9 |
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
|
| 10 |
# model.train()
|
| 11 |
|
|
|
|
|
|
|
| 12 |
import os
|
| 13 |
|
| 14 |
def print_bn():
|
|
@@ -45,7 +47,7 @@ def greet_backdoor(image):
|
|
| 45 |
image = transform_nor(image).unsqueeze(0)
|
| 46 |
print(image.shape)
|
| 47 |
output = model(image).squeeze()
|
| 48 |
-
return 'classified as
|
| 49 |
|
| 50 |
|
| 51 |
def greet(image):
|
|
|
|
| 9 |
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
|
| 10 |
# model.train()
|
| 11 |
|
| 12 |
+
id_label = {0:'airplane', 1:'automobile', 2:'bird', 3:'cat', 4:'deer', 5:'dog', 6:'frog', 7:'horse', 8:'ship', 9:'trunk'}
|
| 13 |
+
|
| 14 |
import os
|
| 15 |
|
| 16 |
def print_bn():
|
|
|
|
| 47 |
image = transform_nor(image).unsqueeze(0)
|
| 48 |
print(image.shape)
|
| 49 |
output = model(image).squeeze()
|
| 50 |
+
return 'classified as: ' + id_label[int(torch.argmax(output))]
|
| 51 |
|
| 52 |
|
| 53 |
def greet(image):
|