Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from model import BaselineCNN | |
| from data_pipeline import val_test_transform, IMAGE_SIZE | |
| import torch | |
| import pandas as pd | |
| from datasets import load_dataset | |
| dataset = load_dataset("DScomp380/plant_village", split="train") | |
| CLASS_NAMES = dataset.features["label"].names | |
| #load the model | |
| CLASSES = 39 | |
| model = BaselineCNN(num_classes=CLASSES) | |
| model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| def predict(input_image): | |
| # 1. Transform the image (resize, normalize, etc.) | |
| processed_image = val_test_transform(input_image) | |
| # 2. Add a batch dimension because the model expects [batch, channels, height, width] | |
| processed_image = processed_image.unsqueeze(0) | |
| # 3. Run the image through the model | |
| with torch.no_grad(): | |
| model_output = model(processed_image) | |
| # 4. Convert raw model scores into probabilities | |
| probabilities = torch.nn.functional.softmax(model_output, dim=1)[0] | |
| # 5. Choose how many results you want to show | |
| number_of_predictions_to_show = 5 | |
| # 6. Get the top-k highest probability classes | |
| top_probabilities, top_class_indices = torch.topk(probabilities, number_of_predictions_to_show) | |
| # build dataframe | |
| df = pd.DataFrame({ | |
| "Class": [CLASS_NAMES[i.item()] for i in top_class_indices], | |
| "Probability": [f"{p.item() * 100:.2f}%" for p in top_probabilities] # multiply by 100 and round | |
| }) | |
| return df | |
| with gr.Blocks(title="Plant Disease Classifier") as app: | |
| gr.Markdown("# Plant Disease Classification") | |
| gr.Markdown("Upload an image of a plant leaf to classify its disease.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Leaf Image") | |
| label_output = gr.DataFrame(headers=["Class", "Probability"], type="pandas") | |
| #gr.Examples( | |
| # examples =[], inputs=image_input) | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click(fn=predict, inputs=image_input, outputs=label_output) | |
| #fn=predict, | |
| # inputs=gr.Image(type="pil"), | |
| # outputs=gr.Label(num_top_classes=3)) | |
| if __name__ == "__main__": | |
| app.launch() |