Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib import image | |
| plt.style.use("fivethirtyeight") | |
| import PIL | |
| from PIL import Image | |
| from PIL import ImageFile | |
| from matplotlib import image | |
| import os, shutil, tqdm | |
| from tqdm.auto import tqdm, trange | |
| import pathlib | |
| from pathlib import Path | |
| import torch, torchvision, torchmetrics | |
| import torch.nn as nn | |
| from torchvision.transforms import v2 as v2 | |
| import lightning.pytorch as pl | |
| from lightning.pytorch import LightningModule, LightningDataModule | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| current_file = Path(__file__).resolve() | |
| checkpoint_path = current_file.parent.parent / "checkpoints" / "epoch=14-step=12120.ckpt" | |
| transform = v2.Compose( | |
| [ | |
| v2.Resize(size = (224, 224)), | |
| v2.ToImage(), | |
| v2.ToDtype(dtype = torch.float32, scale = True), | |
| v2.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ] | |
| ) | |
| idx_to_class = {0: 'Leaf mold', | |
| 1: 'Tomato mosaic virus', | |
| 2: 'Powdery mildew', | |
| 3: 'Spider mites', | |
| 4: 'Bacterial spot', | |
| 5: 'Early blight', | |
| 6: 'Healthy', | |
| 7: 'Late blight', | |
| 8: 'Tomato yellow leaf curl virus', | |
| 9: 'Septoria leaf spot', | |
| 10: 'Target spot'} | |
| def prepare_image(img_path): | |
| image_ = transform(img_path).unsqueeze(0) | |
| return image_ | |
| def make_preds_return_class_class_confidence_dict(img, model): | |
| model.eval() | |
| with torch.inference_mode(): | |
| logits = model(img) | |
| pred_probs = torch.softmax(logits, dim = 1) | |
| pred_probs_df = pd.DataFrame(data = torch.softmax(logits, dim = 1).numpy(), columns = idx_to_class.values()) | |
| class_ = idx_to_class[torch.argmax(pred_probs, axis = 1).item()] | |
| pred_probs_df = pred_probs_df.T | |
| pred_probs_df.columns = ["confidence"] | |
| pred_probs_df = pred_probs_df.sort_values("confidence", ascending = False).head(5) | |
| label_dict = dict() | |
| for disease, confidence in zip(pred_probs_df.index, pred_probs_df["confidence"].values): | |
| label_dict[disease] = confidence | |
| return class_, label_dict | |
| def load_model(): | |
| class myLightningModel(pl.LightningModule): | |
| def __init__(self, model, lr): | |
| super().__init__() | |
| self.model = model | |
| self.lr = lr | |
| self.loss_fn = nn.CrossEntropyLoss() | |
| self.metric_fn = torchmetrics.classification.MulticlassAccuracy(num_classes = 11) | |
| self.save_hyperparameters(ignore = ["model"]) | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| self.model.train() | |
| X, y = batch | |
| logits = self.model(X) | |
| loss = self.loss_fn(logits, y) | |
| acc = self.metric_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)), y) | |
| self.log("Train accuracy", acc, prog_bar = True, on_epoch = True, on_step = False) | |
| self.log("Train logloss", loss, prog_bar = True, on_epoch = True, on_step = False) | |
| return {"Train Accuracy": acc, "loss": loss} | |
| def validation_step(self, batch, batch_idx): | |
| self.model.eval() | |
| X, y = batch | |
| logits = self.model(X) | |
| val_loss = self.loss_fn(logits, y) | |
| val_acc = self.metric_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)), y) | |
| self.log("Val accuracy", val_acc, prog_bar = True, on_epoch = True, on_step = False) | |
| self.log("Val logloss", val_loss, prog_bar = True, on_epoch = True, on_step = False) | |
| return {"Val Accuracy": val_acc, "Val loss": val_loss} | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam(params = self.model.parameters(), lr = self.lr, weight_decay = 1e-4) | |
| return optimizer | |
| model = torchvision.models.efficientnet.efficientnet_b0(progress = True, weights = torchvision.models.efficientnet.EfficientNet_B0_Weights.DEFAULT) | |
| model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.2, inplace=True), | |
| nn.Linear(in_features=1280, out_features=11, bias=True) | |
| ) | |
| lightning_model = myLightningModel.load_from_checkpoint(checkpoint_path = checkpoint_path, map_location = device, model = model, lr = 1e-3) | |
| return lightning_model |