File size: 3,561 Bytes
6db3733
 
 
ea2d324
 
6db3733
 
 
 
 
 
 
13ca5c2
6db3733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13ca5c2
6db3733
13ca5c2
6db3733
 
 
 
 
 
07ae0b6
ea2d324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6db3733
07ae0b6
c019acc
07ae0b6
 
6db3733
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# AUTOGENERATED! DO NOT EDIT! File to edit: car_or_not_nb.ipynb.

# %% auto 0
__all__ = ['imagenet_labels', 'model', 'transform', 'catogories', 'input_image', 'title', 'description', 'examples', 'intf',
           'get_imagenet_classes', 'create_model', 'pil_loader', 'car_or_not_inference']

# %% car_or_not_nb.ipynb 1
# imports
import os
import timm
import json
import torch
import gradio as gr
import pickle as pk
from PIL import Image
from collections import Counter, defaultdict
# from fastai.vision.all import *

# %% car_or_not_nb.ipynb 2
# Imagenet Class
def get_imagenet_classes():
    # read idx file
    imagenet_file = open("imagenet_class_index.txt", "r").read()
    # seperate elements and onvert string to list
    imagenet_labels_raw = imagenet_file.strip().split('\n')
    # keep first label
    imagenet_labels = [item.split(',')[0] for item in imagenet_labels_raw]
    return imagenet_labels

imagenet_labels = get_imagenet_classes()

# %% car_or_not_nb.ipynb 3
# Create Model
def create_model(model_name='vgg16.tv_in1k'):
    # import required model
    # model_name = 'vgg16.tv_in1k'
    # mnet = 'mobilenetv3_large_100'
    model = timm.create_model(model_name, pretrained=True).eval()
    # transform data as required by the model
    transform = timm.data.create_transform(
        **timm.data.resolve_data_config(model.pretrained_cfg)
    )
    return model, transform

model, transform = create_model()

# %% car_or_not_nb.ipynb 5
# open image as rgb 3 channel
def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

# %% car_or_not_nb.ipynb 7
# Main Inferene Code
catogories = ('Is a Car', 'Not a Car')
def car_or_not_inference(input_image):

    print ("Validating that this is a picture of a car...")

    # raise exception incase the car category pickle file is not found
    # assert os.path.isfile('car_predict_map.pk')
    with open('car_predict_map.pk', 'rb') as f:
        car_predict_map = pk.load(f)

    # retain the top 'n' most occuring items \\ n=36
    top_n_cat_list  = [k for k, v in car_predict_map.most_common()[:36]]

    if isinstance(input_image, str):
        image = pil_loader(input_image)
    else:
        image = Image.fromarray(input_image) # this opens images as greyscale sometimes so use func -> pil_loader
    # image = pil_loader(input_image)
    # image = PILImage.create(input_image)
    # transform image as required for prediction
    image_tensor = transform(image)
    # predict on image
    output = model(image_tensor.unsqueeze(0))
    # get probabilites
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    # select top 5 probs
    _, indices = torch.topk(probabilities, 5)

    for idx in indices:
        pred_label = imagenet_labels[idx]
        if pred_label in top_n_cat_list:
            return dict(zip(catogories, [1.0, 0.0])) #"Validation complete - proceed to damage evaluation"

    return dict(zip(catogories, [0.0, 1.0]))#"Are you sure this is a picture of your car? Please take another picture (try a different angle or lighting) and try again."

# input_image = 'rolls.jpg'
# car_or_not_inference(input_image)

# %% car_or_not_nb.ipynb 8
title = "Car Identifier"
description = "A car or not classifier trained on images scraped from the web."
examples = ['rolls.jpg', 'forest.jpg', 'dog.jpg']

# %% car_or_not_nb.ipynb 9
intf = gr.Interface(fn=car_or_not_inference,inputs=gr.Image(),outputs=gr.Label(num_top_classes=2),title=title,description=description,examples=examples)
intf.launch(share=True)