| import albumentations | |
| import cv2 | |
| import torch | |
| import timm | |
| import gradio as gr | |
| import numpy as np | |
| device = torch.device('cpu') | |
| labels = {0: 'bacterial_leaf_blight', | |
| 1: 'bacterial_leaf_streak', | |
| 2: 'bacterial_panicle_blight', | |
| 3: 'blast', | |
| 4: 'brown_spot', | |
| 5: 'dead_heart', | |
| 6: 'downy_mildew', | |
| 7: 'hispa', | |
| 8: 'normal', | |
| 9: 'tungro'} | |
| def inference_fn(model, image=None): | |
| model.eval() | |
| image = image.to(device) | |
| print(image.shape) | |
| with torch.no_grad(): | |
| output = model(image.unsqueeze(0)) | |
| out = output.sigmoid().detach().cpu().numpy().flatten() | |
| return out | |
| def predict(image = None) : | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| mean = (0.485, 0.456, 0.406) | |
| std = (0.229, 0.224, 0.225) | |
| augmentations = albumentations.Compose( | |
| [ | |
| albumentations.Resize(256, 256), | |
| albumentations.HorizontalFlip(p=0.5), | |
| albumentations.VerticalFlip(p=0.5), | |
| albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True), | |
| ] | |
| ) | |
| augmented = augmentations(image=image) | |
| image = augmented["image"] | |
| image = np.transpose(image, (2, 0, 1)) | |
| image = torch.tensor(image, dtype=torch.float32) | |
| model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10) | |
| model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device))) | |
| model.to(device) | |
| predicted = inference_fn(model, image) | |
| del model | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return {labels[i]: float(predicted[i]) for i in range(10)} | |
| gr.Interface(fn=predict, | |
| inputs=gr.inputs.Image(shape=(256, 256)), | |
| outputs=gr.outputs.Label(num_top_classes=10), | |
| examples=["200005.jpg", "200006.jpg"]).launch() |