cat_or_dog / app.py
arnavam's picture
Update app.py
4cd2b30 verified
from fastai.learner import load_model
from fastai.vision.all import *
import gradio as gr
import sys
print(f"Python: {sys.version}")
path = untar_data(URLs.PETS)/'images'
def label_func(fname):
return 'cat' if fname[0].isupper() else 'dog'
dls = ImageDataLoaders.from_name_func('.',
get_image_files(path), valid_pct=0.2, seed=42,
label_func=label_func,
item_tfms=Resize(192))
print(f"Vocab: {dls.vocab}")
# Recreate Learner
learn = vision_learner(dls, resnet18, metrics=error_rate)
load_model('resnet18-catdog.pth', learn.model, learn.opt, device=default_device(), weights_only=False)
def classify_image(img):
pred, idx, probs = learn.predict(img)
flipped_vocab = [learn.dls.vocab[1], learn.dls.vocab[0]]
return {flipped_vocab[i]: float(probs[i]) for i in range(len(probs))}
# Set up Gradio interface
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=2),
title='Cat_or_Dog Classifier',
description='Upload an image of a cat or dog.',
examples=["cat.png", "dog.png"]
)
if __name__ == "__main__":
demo.launch()