File size: 5,403 Bytes
65936d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99f4b1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import torch
import lightning.pytorch as pl
import torchmetrics
import segmentation_models_pytorch as smp
import skimage
from skimage.transform import resize
import rasterio
from rasterio.plot import show
import PIL
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

def prepare_image(path:str):
    tif = rasterio.open(path).read()[:,:,:]
    resized = resize(tif, (9, 224, 224), 
                order=1,
                preserve_range=True,
                anti_aliasing=True)

    tif = resized.astype(tif.dtype)
    tif = np.nan_to_num(tif, nan = 0) 
    return torch.tensor(tif).type(torch.float32).unsqueeze(0)

def get_gt(path:str):
    im = ((np.dstack([rasterio.open(path).read(i) for i in (3,2,1)]) / np.max(rasterio.open(path).read())).clip(0,1)*255).astype(np.uint8)
    return im

# def make_preds_return_mask(img, model):
#     model.eval()
#     with torch.inference_mode():
#         logits = model(img)
#     pred = torch.argmax(torch.softmax(logits, dim = 1), axis = 1).to(device).numpy()
#     return Image.fromarray(np.squeeze(np.moveaxis(pred, (0,1,2), (2,0,1)),-1)*255.0).convert("L").resize((300,224))

def make_preds_return_mask(img, model):
    model.eval()
    with torch.inference_mode():
        logits = model(img)
    pred = torch.argmax(torch.softmax(logits, dim=1), axis=1).to(device).cpu().numpy()
    mask = np.squeeze(pred)  # shape (H, W)
    return mask  # return numpy array


def load_model():
    class floodLighningModel(pl.LightningModule):
        def __init__(self, model, lr):
            super().__init__()
            self.model = model
            self.lr = lr
            # self.loss_fn = nn.CrossEntropyLoss().to(device)
            self.loss_fn = smp.losses.DiceLoss(from_logits = True, mode = "multiclass")
            self.iou = torchmetrics.JaccardIndex(task = "binary")
            self.acc_fn = torchmetrics.classification.BinaryAccuracy()
            # self.acc_fn = BinaryAccuracy()
            self.f1_fn = torchmetrics.classification.MulticlassF1Score(num_classes = 2).to(device)
            # self.model.save_hyperparameter(ignore = ["model"])

        def forward(self, x):
            return self.model(x)

        def training_step(self, batch, batch_idx):
            self.model.train()
            image, gt = batch
            logits = self.model(image)
            loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
            iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
            acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
            f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
            # f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
            self.log("loss", loss, prog_bar = True, on_step = False, on_epoch = True)
            self.log("iou", iou, prog_bar = True, on_step = False, on_epoch = True)
            self.log("accuracy", acc, prog_bar = True, on_step = False, on_epoch = True)
            self.log("f1", f1, prog_bar = True, on_step = False, on_epoch = True)
            return {"loss": loss, "f1": f1, "iou": iou, "accuracy": acc}

        def validation_step(self, batch, batch_idx):
            self.model.eval()
            image, gt = batch
            logits = self.model(image)
            val_loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
            val_iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
            val_acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
            val_f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
            # val_f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
            self.log("validation loss", val_loss, prog_bar = True, on_step = False, on_epoch = True)
            self.log("validation iou", val_iou, prog_bar = True, on_step = False, on_epoch = True)
            self.log("validation accuracy", val_acc, prog_bar = True, on_step = False, on_epoch = True)
            self.log("validation f1", val_f1, prog_bar = True, on_step = False, on_epoch = True)
            return {"validation loss": val_loss, "validation f1": val_f1, "validation iou": val_iou, "validation accuracy": val_acc}

        def configure_optimizers(self):
            optimizer = torch.optim.AdamW(params = self.model.parameters(), lr = self.lr)
            return optimizer
        
    pl.seed_everything(2025)
    model_ = smp.Unet(
        encoder_name="efficientnet-b0",
        encoder_weights="imagenet",
        in_channels=9,
        classes=2,
    )
    # model_ = smp.DPT(
    #     encoder_name="tu-vit_base_patch16_224.augreg_in21k",
    #     encoder_weights="imagenet",
    #     in_channels=9,
    #     classes=2,
    # )
    lightning_model = floodLighningModel.load_from_checkpoint(model = model_, lr = 5e-6, map_location=torch.device(device=device), checkpoint_path="ckpts/epoch=39-step=2760.ckpt")
    return lightning_model