File size: 1,712 Bytes
4ea8080
 
5dfae46
 
 
c7d1130
4ea8080
9053631
4ea8080
a5b9f6a
 
 
 
4ea8080
 
5dfae46
 
4ea8080
 
2e9bd2b
 
a7394d1
5dfae46
9053631
a5b9f6a
5dfae46
4ea8080
c7d1130
 
 
 
 
 
 
 
4ea8080
 
 
 
 
 
 
 
 
09b45e3
4ea8080
 
 
 
 
 
c7d1130
4ea8080
5dfae46
4ea8080
 
5dfae46
4ea8080
 
 
 
 
 
 
 
 
 
5dfae46
e247352
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
import numpy as np
import os
import huggingface_hub

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")


net = models.resnet18(pretrained=False)
net.fc = nn.Linear(net.fc.in_features, 6)
net.to(device)

HF_Token = os.environ['HF_Token']

model = huggingface_hub.cached_download(huggingface_hub.hf_hub_url(
    'danyalmalik/sceneryclassifier', '1655988285.9725637_Acc0.88_modelweights.pth'), use_auth_token=HF_Token)

net.load_state_dict(torch.load(model, map_location=device))
net.eval()

mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.25, 0.25, 0.25])

data_transforms = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])

labels = ['Buildings', 'Forest', 'Glacier', 'Mountain', 'Sea', 'Street']

title = "Scenery Classifier"


def examples():
    number = 8
    for i in range(number):
        imgs = os.listdir('examples')
        egs = [os.path.join('examples/', eg) for eg in imgs]

    return egs


def predict(img):
    try:
        img = data_transforms(img)
        img = img.to(device)
        img = img.unsqueeze(0)

        with torch.no_grad():
            output = F.softmax(net(img), dim=1)

            pred = [output[0][i].item() for i in range(len(labels))]

    except:
        pred = [0 for i in range(len(labels))]

    weightage = {labels[i]: pred[i] for i in range(len(labels))}
    return weightage


gr.Interface(fn=predict, inputs=gr.Image(type='pil'),
             outputs='label', title=title,  examples=examples()).launch()