File size: 1,473 Bytes
a45fad6
7087444
ba0661b
 
 
 
eef7700
a45fad6
031b4b5
 
a45fad6
7087444
 
 
 
 
 
 
 
 
cadf3a7
9d44e76
7087444
cadf3a7
 
7087444
031b4b5
 
7087444
031b4b5
 
 
 
 
 
 
7087444
031b4b5
 
60b57d9
57ecb4e
031b4b5
60b57d9
57ecb4e
60b57d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
from fastai.vision.all import *
import pathlib

# Fix for windows path run time issue
temp = pathlib.PosixPath
pathlib.WindowsPath = pathlib.PosixPath

# Label function
def is_cat(x:string): return x[0].isupper()

def get_data():
    # Read the dataset from fastai
    path = untar_data(URLs.PETS)/'images'
    return ImageDataLoaders.from_name_func(
        path,get_image_files(path), valid_pct=0.2, seed=42,
        label_func=is_cat, item_tfms=Resize(224))
    
if __name__ == '__main__':
    # This is required for windows users
    # multiprocessing.set_start_method('spawn')
    # dls = get_data() 

    # Since the model is already trained, I have commented out the code to train it
    
    # Train the model with vision_learner
    # learn = vision_learner(dls, resnet34, metrics=error_rate)
    # learn.fine_tune(1)

    # #Export the model
    # learn.path = Path('.')
    # learn.export(
    #     'cats_classifier.pkl'
    # )

    model = load_learner('cats_classifier.pkl')

    def predict(image):
        img = PILImage.create(image)
        _,_,probs = model.predict(img)
        return {'Not a Cat':float("{:.2f}".format(probs[0].item())), 'Cat':float("{:.2f}".format(probs[1].item()))}

    # outputs label is not working, need to investigate further
    demo = gr.Interface(fn=predict, inputs=gr.Image(), outputs='label',examples=['examples/cat-1.jpg','examples/dog-1.jpg'],allow_flagging='never')
    demo.launch(show_error=True)