Car_or_Not / app.py
suku9's picture
Upload app.py
ea2d324 verified
# AUTOGENERATED! DO NOT EDIT! File to edit: car_or_not_nb.ipynb.
# %% auto 0
__all__ = ['imagenet_labels', 'model', 'transform', 'catogories', 'input_image', 'title', 'description', 'examples', 'intf',
'get_imagenet_classes', 'create_model', 'pil_loader', 'car_or_not_inference']
# %% car_or_not_nb.ipynb 1
# imports
import os
import timm
import json
import torch
import gradio as gr
import pickle as pk
from PIL import Image
from collections import Counter, defaultdict
# from fastai.vision.all import *
# %% car_or_not_nb.ipynb 2
# Imagenet Class
def get_imagenet_classes():
# read idx file
imagenet_file = open("imagenet_class_index.txt", "r").read()
# seperate elements and onvert string to list
imagenet_labels_raw = imagenet_file.strip().split('\n')
# keep first label
imagenet_labels = [item.split(',')[0] for item in imagenet_labels_raw]
return imagenet_labels
imagenet_labels = get_imagenet_classes()
# %% car_or_not_nb.ipynb 3
# Create Model
def create_model(model_name='vgg16.tv_in1k'):
# import required model
# model_name = 'vgg16.tv_in1k'
# mnet = 'mobilenetv3_large_100'
model = timm.create_model(model_name, pretrained=True).eval()
# transform data as required by the model
transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)
return model, transform
model, transform = create_model()
# %% car_or_not_nb.ipynb 5
# open image as rgb 3 channel
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# %% car_or_not_nb.ipynb 7
# Main Inferene Code
catogories = ('Is a Car', 'Not a Car')
def car_or_not_inference(input_image):
print ("Validating that this is a picture of a car...")
# raise exception incase the car category pickle file is not found
# assert os.path.isfile('car_predict_map.pk')
with open('car_predict_map.pk', 'rb') as f:
car_predict_map = pk.load(f)
# retain the top 'n' most occuring items \\ n=36
top_n_cat_list = [k for k, v in car_predict_map.most_common()[:36]]
if isinstance(input_image, str):
image = pil_loader(input_image)
else:
image = Image.fromarray(input_image) # this opens images as greyscale sometimes so use func -> pil_loader
# image = pil_loader(input_image)
# image = PILImage.create(input_image)
# transform image as required for prediction
image_tensor = transform(image)
# predict on image
output = model(image_tensor.unsqueeze(0))
# get probabilites
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# select top 5 probs
_, indices = torch.topk(probabilities, 5)
for idx in indices:
pred_label = imagenet_labels[idx]
if pred_label in top_n_cat_list:
return dict(zip(catogories, [1.0, 0.0])) #"Validation complete - proceed to damage evaluation"
return dict(zip(catogories, [0.0, 1.0]))#"Are you sure this is a picture of your car? Please take another picture (try a different angle or lighting) and try again."
# input_image = 'rolls.jpg'
# car_or_not_inference(input_image)
# %% car_or_not_nb.ipynb 8
title = "Car Identifier"
description = "A car or not classifier trained on images scraped from the web."
examples = ['rolls.jpg', 'forest.jpg', 'dog.jpg']
# %% car_or_not_nb.ipynb 9
intf = gr.Interface(fn=car_or_not_inference,inputs=gr.Image(),outputs=gr.Label(num_top_classes=2),title=title,description=description,examples=examples)
intf.launch(share=True)