File size: 4,618 Bytes
43a1bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import image
plt.style.use("fivethirtyeight")
import PIL
from PIL import Image
from PIL import ImageFile
from matplotlib import image

import os, shutil, tqdm
from tqdm.auto import tqdm, trange
import pathlib
from pathlib import Path

import torch, torchvision, torchmetrics
import torch.nn as nn
from torchvision.transforms import v2 as v2
import lightning.pytorch as pl
from lightning.pytorch import LightningModule, LightningDataModule

ImageFile.LOAD_TRUNCATED_IMAGES = True
device = "cuda" if torch.cuda.is_available() else "cpu"


current_file = Path(__file__).resolve()
checkpoint_path = current_file.parent.parent / "checkpoints" / "epoch=14-step=12120.ckpt"

transform = v2.Compose(
        [
            v2.Resize(size = (224, 224)),
            v2.ToImage(),
            v2.ToDtype(dtype = torch.float32, scale = True),
            v2.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ]
    )
idx_to_class = {0: 'Leaf mold',
                1: 'Tomato mosaic virus',
                2: 'Powdery mildew',
                3: 'Spider mites',
                4: 'Bacterial spot',
                5: 'Early blight',
                6: 'Healthy',
                7: 'Late blight',
                8: 'Tomato yellow leaf curl virus',
                9: 'Septoria leaf spot',
                10: 'Target spot'}

def prepare_image(img_path):
    image_ = transform(img_path).unsqueeze(0)
    return image_


def make_preds_return_class_class_confidence_dict(img, model):
    model.eval()
    with torch.inference_mode():
        logits = model(img)

    pred_probs = torch.softmax(logits, dim = 1)
    pred_probs_df = pd.DataFrame(data = torch.softmax(logits, dim = 1).numpy(), columns = idx_to_class.values())
    class_ = idx_to_class[torch.argmax(pred_probs, axis = 1).item()]
    pred_probs_df = pred_probs_df.T
    pred_probs_df.columns = ["confidence"]
    pred_probs_df = pred_probs_df.sort_values("confidence", ascending = False).head(5)
    label_dict = dict()
    for disease, confidence in zip(pred_probs_df.index, pred_probs_df["confidence"].values):
        label_dict[disease] = confidence
    return class_, label_dict


def load_model():
    class myLightningModel(pl.LightningModule):
        def __init__(self, model, lr):
            super().__init__()
            self.model = model
            self.lr = lr
            self.loss_fn = nn.CrossEntropyLoss()
            self.metric_fn = torchmetrics.classification.MulticlassAccuracy(num_classes = 11)
            self.save_hyperparameters(ignore = ["model"])

        def forward(self, x):
            return self.model(x)

        def training_step(self, batch, batch_idx):
            self.model.train()
            X, y = batch
            logits = self.model(X)
            loss = self.loss_fn(logits, y)
            acc = self.metric_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)), y)
            self.log("Train accuracy", acc, prog_bar = True, on_epoch = True, on_step = False)
            self.log("Train logloss", loss, prog_bar = True, on_epoch = True, on_step = False)
            return {"Train Accuracy": acc, "loss": loss}

        def validation_step(self, batch, batch_idx):
            self.model.eval()
            X, y = batch
            logits = self.model(X)
            val_loss = self.loss_fn(logits, y)
            val_acc = self.metric_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)), y)
            self.log("Val accuracy", val_acc, prog_bar = True, on_epoch = True, on_step = False)
            self.log("Val logloss", val_loss, prog_bar = True, on_epoch = True, on_step = False)
            return {"Val Accuracy": val_acc, "Val loss": val_loss}

        def configure_optimizers(self):
            optimizer = torch.optim.Adam(params = self.model.parameters(), lr = self.lr, weight_decay = 1e-4)
            return optimizer
    
    model = torchvision.models.efficientnet.efficientnet_b0(progress = True, weights = torchvision.models.efficientnet.EfficientNet_B0_Weights.DEFAULT)
    model.classifier = nn.Sequential(
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1280, out_features=11, bias=True)
    )
    lightning_model = myLightningModel.load_from_checkpoint(checkpoint_path = checkpoint_path, map_location = device, model = model, lr = 1e-3)
    return lightning_model