| import gradio as gr |
| import torch |
| from huggingface_hub import from_pretrained_fastai |
| from pathlib import Path |
|
|
| examples = ["examples/example_0.png", |
| "examples/example_1.png", |
| "examples/example_2.png", |
| "examples/example_3.png", |
| "examples/example_4.png"] |
| |
| repo_id = "hugginglearners/rice_image_classification" |
| path = Path("./") |
|
|
| def get_y(r): |
| return r["label"] |
|
|
| def get_x(r): |
| return path/r["fname"] |
|
|
| learner = from_pretrained_fastai(repo_id) |
| labels = learner.dls.vocab |
|
|
| def inference(image): |
| label_predict, _, probs = learner.predict(image) |
| labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)} |
| return labels_probs |
|
|
| gr.Interface( |
| fn=inference, |
| title="Rice Disease Classification", |
| description="Predict which type of rice disease is affecting the leaf: Tungro, Rice Blast, Bacterial Blight, or Healthy Rice Leaf.", |
| inputs=gr.Image(), |
| examples=examples, |
| outputs=gr.Label(num_top_classes=4, label='Prediction'), |
| cache_examples=False, |
| article="Authors: Your Team Name", |
| ).launch(debug=True, enable_queue=True) |
|
|