| from fastai.vision.all import * | |
| import gradio as gr | |
| def is_cat(x): | |
| return x[0].isupper() | |
| learn = load_learner('model.pkl') | |
| categories = ('Cat', 'Dog') | |
| prompts = [ | |
| "# Definitely a {}!", | |
| "# Well, that must be a {}!", | |
| "# Oh, that's a {}!", | |
| "# That's a {}!", | |
| "# Looks like a {} to me!", | |
| ] | |
| failure_prompts = [ | |
| "# I'm not sure what that is.", | |
| "# I don't know what that thing is.", | |
| "# I've never seen that before.", | |
| "# Looks familiar, but unsure.", | |
| "# Something, something?", | |
| "# Beats me.", | |
| ] | |
| def classify_image(img): | |
| pred,idx,probs = learn.predict(img) | |
| return dict(zip(categories, map(float,probs))) | |
| def calculate(confidence_threshold, img): | |
| classifications = classify_image(img) | |
| classification = random.choice(failure_prompts) | |
| for key, value in classifications.items(): | |
| if value > confidence_threshold: | |
| classification = random.choice(prompts).format(key) | |
| break | |
| return [classification, classifications] | |
| with gr.Blocks() as ui: | |
| heading = gr.Markdown(" # Dog or Cat?", render=False) | |
| results = gr.Label(value="Waiting to receive image.", label="Details", show_label=False, render=False) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| gr.Markdown("Upload an image of a cat or a dog.") | |
| with gr.Group(): | |
| image = gr.Image(show_label=False, height=300) | |
| confidence = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Confidence Threshold") | |
| btn = gr.Button(value="Classify") | |
| btn.click(calculate, inputs=[confidence, image], outputs=[heading, results]) | |
| with gr.Column(): | |
| gr.Markdown("Then wait for the magic to happen") | |
| with gr.Group(): | |
| results.render() | |
| heading.render() | |
| gr.Markdown(" # Examples") | |
| with gr.Group(): | |
| gr.Examples(inputs=image, examples=['images/cat1.jpeg', 'images/cat2.jpeg', 'images/cat3.jpeg', 'images/dog1.jpeg', 'images/dog2.jpeg', 'images/dog3.jpeg']) | |
| if __name__ == "__main__": | |
| ui.launch() |