| import gradio as gr | |
| import torch | |
| from huggingface_hub import from_pretrained_fastai | |
| from pathlib import Path | |
| examples = ["./examples/image_1.png", | |
| "./examples/image_2.png", | |
| "./examples/image_3.png", | |
| "./examples/image_4.png", | |
| "./examples/image_5.png"] | |
| repo_id = "hugginglearners/rice_image_classification" | |
| path = Path("./") | |
| def get_y(r): | |
| return r["label"] | |
| def get_x(r): | |
| return path/r["fname"] | |
| learner = from_pretrained_fastai(repo_id) | |
| def inference(image): | |
| label_predict,_,probs = learner.predict(image) | |
| return f"This rice image is {label_predict} with {100*probs[torch.argmax(probs)].item():.2f}% probability" | |
| gr.Interface( | |
| fn=inference, | |
| title="Rice image classification", | |
| description = "Predict which type of rice belong to Arborio, Basmati, Ipsala, Jasmine, Karacadag", | |
| inputs="image", | |
| examples=examples, | |
| outputs=gr.Textbox(label='Prediction'), | |
| cache_examples=False, | |
| article = "Author: <a href=\"https://www.linkedin.com/in/vumichien/\">Vu Minh Chien</a>", | |
| ).launch(debug=True, enable_queue=True) |