Spaces:
Build error
Build error
| # AUTOGENERATED! DO NOT EDIT! File to edit: car_or_not_nb.ipynb. | |
| # %% auto 0 | |
| __all__ = ['imagenet_labels', 'model', 'transform', 'catogories', 'input_image', 'title', 'description', 'examples', 'intf', | |
| 'get_imagenet_classes', 'create_model', 'pil_loader', 'car_or_not_inference'] | |
| # %% car_or_not_nb.ipynb 1 | |
| # imports | |
| import os | |
| import timm | |
| import json | |
| import torch | |
| import gradio as gr | |
| import pickle as pk | |
| from PIL import Image | |
| from collections import Counter, defaultdict | |
| # from fastai.vision.all import * | |
| # %% car_or_not_nb.ipynb 2 | |
| # Imagenet Class | |
| def get_imagenet_classes(): | |
| # read idx file | |
| imagenet_file = open("imagenet_class_index.txt", "r").read() | |
| # seperate elements and onvert string to list | |
| imagenet_labels_raw = imagenet_file.strip().split('\n') | |
| # keep first label | |
| imagenet_labels = [item.split(',')[0] for item in imagenet_labels_raw] | |
| return imagenet_labels | |
| imagenet_labels = get_imagenet_classes() | |
| # %% car_or_not_nb.ipynb 3 | |
| # Create Model | |
| def create_model(model_name='vgg16.tv_in1k'): | |
| # import required model | |
| # model_name = 'vgg16.tv_in1k' | |
| # mnet = 'mobilenetv3_large_100' | |
| model = timm.create_model(model_name, pretrained=True).eval() | |
| # transform data as required by the model | |
| transform = timm.data.create_transform( | |
| **timm.data.resolve_data_config(model.pretrained_cfg) | |
| ) | |
| return model, transform | |
| model, transform = create_model() | |
| # %% car_or_not_nb.ipynb 5 | |
| # open image as rgb 3 channel | |
| def pil_loader(path): | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| # %% car_or_not_nb.ipynb 7 | |
| # Main Inferene Code | |
| catogories = ('Is a Car', 'Not a Car') | |
| def car_or_not_inference(input_image): | |
| print ("Validating that this is a picture of a car...") | |
| # raise exception incase the car category pickle file is not found | |
| # assert os.path.isfile('car_predict_map.pk') | |
| with open('car_predict_map.pk', 'rb') as f: | |
| car_predict_map = pk.load(f) | |
| # retain the top 'n' most occuring items \\ n=36 | |
| top_n_cat_list = [k for k, v in car_predict_map.most_common()[:36]] | |
| if isinstance(input_image, str): | |
| image = pil_loader(input_image) | |
| else: | |
| image = Image.fromarray(input_image) # this opens images as greyscale sometimes so use func -> pil_loader | |
| # image = pil_loader(input_image) | |
| # image = PILImage.create(input_image) | |
| # transform image as required for prediction | |
| image_tensor = transform(image) | |
| # predict on image | |
| output = model(image_tensor.unsqueeze(0)) | |
| # get probabilites | |
| probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
| # select top 5 probs | |
| _, indices = torch.topk(probabilities, 5) | |
| for idx in indices: | |
| pred_label = imagenet_labels[idx] | |
| if pred_label in top_n_cat_list: | |
| return dict(zip(catogories, [1.0, 0.0])) #"Validation complete - proceed to damage evaluation" | |
| return dict(zip(catogories, [0.0, 1.0]))#"Are you sure this is a picture of your car? Please take another picture (try a different angle or lighting) and try again." | |
| # input_image = 'rolls.jpg' | |
| # car_or_not_inference(input_image) | |
| # %% car_or_not_nb.ipynb 8 | |
| title = "Car Identifier" | |
| description = "A car or not classifier trained on images scraped from the web." | |
| examples = ['rolls.jpg', 'forest.jpg', 'dog.jpg'] | |
| # %% car_or_not_nb.ipynb 9 | |
| intf = gr.Interface(fn=car_or_not_inference,inputs=gr.Image(),outputs=gr.Label(num_top_classes=2),title=title,description=description,examples=examples) | |
| intf.launch(share=True) | |