Spaces:
Runtime error
Runtime error
| from typing import Any | |
| import pytorch_lightning as pl | |
| from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights | |
| import torch | |
| from torch import nn | |
| from torchvision import transforms | |
| import yaml | |
| from yaml.loader import SafeLoader | |
| import gradio as gr | |
| import os | |
| class WeedModel(pl.LightningModule): | |
| def __init__(self, params): | |
| super().__init__() | |
| self.params = params | |
| model = self.params["model"] | |
| if model.lower() == "efficientnet": | |
| if self.params["pretrained"]: | |
| self.base_model = efficientnet_v2_s( | |
| weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1 | |
| ) | |
| else: | |
| self.base_model = efficientnet_v2_s(weights=None) | |
| num_ftrs = self.base_model.classifier[-1].in_features | |
| self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"]) | |
| else: | |
| print("not prepared model yet!!") | |
| def forward(self, x): | |
| embedding = self.base_model(x) | |
| return embedding | |
| def predict_step( | |
| self, batch: Any, batch_idx: int = 0, dataloader_idx: int = 0 | |
| ) -> Any: | |
| y_hat = self(batch) | |
| preds = torch.softmax(y_hat, dim=-1).tolist() | |
| # preds = torch.argmax(preds, dim=-1) | |
| return preds | |
| def predict(image): | |
| tensor_image = transform(image) | |
| outs = model.predict_step(tensor_image.unsqueeze(0)) | |
| labels = {class_names[k]: float(v) for k, v in enumerate(outs[0][:-1])} | |
| return labels | |
| title = " AISeed AI Application Demo " | |
| description = "# A Demo of Deep Learning for Weed Classification" | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| with open("class_names.txt", "r", encoding="utf-8") as f: | |
| class_names = f.read().splitlines() | |
| with gr.Blocks() as demo: | |
| demo.title = title | |
| gr.Markdown(description) | |
| with gr.Tabs(): | |
| with gr.TabItem("Images"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| im = gr.Image(type="pil", label="input image", sources=["upload", "webcam"]) | |
| with gr.Column(): | |
| label_conv = gr.Label(label="Predictions", num_top_classes=4) | |
| btn = gr.Button(value="predict") | |
| btn.click(predict, inputs=im, outputs=[label_conv]) | |
| gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv]) | |
| if __name__ == "__main__": | |
| with open("config.yaml") as f: | |
| PARAMS = yaml.load(f, Loader=SafeLoader) | |
| print(PARAMS) | |
| model = WeedModel.load_from_checkpoint( | |
| "model/epoch=08.ckpt", params=PARAMS, map_location=torch.device("cpu") | |
| ) | |
| model.eval() | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| demo.launch() | |