AyoAgbaje commited on
Commit
65936d6
·
verified ·
1 Parent(s): 99f4b1d

Update utils/utils.py

Browse files
Files changed (1) hide show
  1. utils/utils.py +110 -110
utils/utils.py CHANGED
@@ -1,111 +1,111 @@
1
- import numpy as np
2
- import torch
3
- import lightning.pytorch as pl
4
- import torchmetrics
5
- import segmentation_models_pytorch as smp
6
- import skimage
7
- from skimage.transform import resize
8
- import rasterio
9
- from rasterio.plot import show
10
- import PIL
11
- from PIL import Image
12
-
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
-
15
- def prepare_image(path:str):
16
- tif = rasterio.open(path).read()[:,:,:]
17
- resized = resize(tif, (9, 224, 224),
18
- order=1,
19
- preserve_range=True,
20
- anti_aliasing=True)
21
-
22
- tif = resized.astype(tif.dtype)
23
- tif = np.nan_to_num(tif, nan = 0)
24
- return torch.tensor(tif).type(torch.float32).unsqueeze(0)
25
-
26
- def get_gt(path:str):
27
- im = np.moveaxis(rasterio.open(path).read()[[2,1,0],:,:], (0,1,2), (2,0,1))
28
- return im
29
-
30
- # def make_preds_return_mask(img, model):
31
- # model.eval()
32
- # with torch.inference_mode():
33
- # logits = model(img)
34
- # pred = torch.argmax(torch.softmax(logits, dim = 1), axis = 1).to(device).numpy()
35
- # return Image.fromarray(np.squeeze(np.moveaxis(pred, (0,1,2), (2,0,1)),-1)*255.0).convert("L").resize((300,224))
36
-
37
- def make_preds_return_mask(img, model):
38
- model.eval()
39
- with torch.inference_mode():
40
- logits = model(img)
41
- pred = torch.argmax(torch.softmax(logits, dim=1), axis=1).to(device).cpu().numpy()
42
- mask = np.squeeze(pred) # shape (H, W)
43
- return mask # return numpy array
44
-
45
-
46
- def load_model():
47
- class floodLighningModel(pl.LightningModule):
48
- def __init__(self, model, lr):
49
- super().__init__()
50
- self.model = model
51
- self.lr = lr
52
- # self.loss_fn = nn.CrossEntropyLoss().to(device)
53
- self.loss_fn = smp.losses.DiceLoss(from_logits = True, mode = "multiclass")
54
- self.iou = torchmetrics.JaccardIndex(task = "binary")
55
- self.acc_fn = torchmetrics.classification.BinaryAccuracy()
56
- # self.acc_fn = BinaryAccuracy()
57
- self.f1_fn = torchmetrics.classification.MulticlassF1Score(num_classes = 2).to(device)
58
- # self.model.save_hyperparameter(ignore = ["model"])
59
-
60
- def forward(self, x):
61
- return self.model(x)
62
-
63
- def training_step(self, batch, batch_idx):
64
- self.model.train()
65
- image, gt = batch
66
- logits = self.model(image)
67
- loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
68
- iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
69
- acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
70
- f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
71
- # f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
72
- self.log("loss", loss, prog_bar = True, on_step = False, on_epoch = True)
73
- self.log("iou", iou, prog_bar = True, on_step = False, on_epoch = True)
74
- self.log("accuracy", acc, prog_bar = True, on_step = False, on_epoch = True)
75
- self.log("f1", f1, prog_bar = True, on_step = False, on_epoch = True)
76
- return {"loss": loss, "f1": f1, "iou": iou, "accuracy": acc}
77
-
78
- def validation_step(self, batch, batch_idx):
79
- self.model.eval()
80
- image, gt = batch
81
- logits = self.model(image)
82
- val_loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
83
- val_iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
84
- val_acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
85
- val_f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
86
- # val_f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
87
- self.log("validation loss", val_loss, prog_bar = True, on_step = False, on_epoch = True)
88
- self.log("validation iou", val_iou, prog_bar = True, on_step = False, on_epoch = True)
89
- self.log("validation accuracy", val_acc, prog_bar = True, on_step = False, on_epoch = True)
90
- self.log("validation f1", val_f1, prog_bar = True, on_step = False, on_epoch = True)
91
- return {"validation loss": val_loss, "validation f1": val_f1, "validation iou": val_iou, "validation accuracy": val_acc}
92
-
93
- def configure_optimizers(self):
94
- optimizer = torch.optim.AdamW(params = self.model.parameters(), lr = self.lr)
95
- return optimizer
96
-
97
- pl.seed_everything(2025)
98
- model_ = smp.Unet(
99
- encoder_name="efficientnet-b0",
100
- encoder_weights="imagenet",
101
- in_channels=9,
102
- classes=2,
103
- )
104
- # model_ = smp.DPT(
105
- # encoder_name="tu-vit_base_patch16_224.augreg_in21k",
106
- # encoder_weights="imagenet",
107
- # in_channels=9,
108
- # classes=2,
109
- # )
110
- lightning_model = floodLighningModel.load_from_checkpoint(model = model_, lr = 5e-6, map_location=torch.device(device=device), checkpoint_path="ckpts/epoch=39-step=2760.ckpt")
111
  return lightning_model
 
1
+ import numpy as np
2
+ import torch
3
+ import lightning.pytorch as pl
4
+ import torchmetrics
5
+ import segmentation_models_pytorch as smp
6
+ import skimage
7
+ from skimage.transform import resize
8
+ import rasterio
9
+ from rasterio.plot import show
10
+ import PIL
11
+ from PIL import Image
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ def prepare_image(path:str):
16
+ tif = rasterio.open(path).read()[:,:,:]
17
+ resized = resize(tif, (9, 224, 224),
18
+ order=1,
19
+ preserve_range=True,
20
+ anti_aliasing=True)
21
+
22
+ tif = resized.astype(tif.dtype)
23
+ tif = np.nan_to_num(tif, nan = 0)
24
+ return torch.tensor(tif).type(torch.float32).unsqueeze(0)
25
+
26
+ def get_gt(path:str):
27
+ im = ((np.dstack([rasterio.open(path).read(i) for i in (3,2,1)]) / np.max(rasterio.open(path).read())).clip(0,1)*255).astype(np.uint8)
28
+ return im
29
+
30
+ # def make_preds_return_mask(img, model):
31
+ # model.eval()
32
+ # with torch.inference_mode():
33
+ # logits = model(img)
34
+ # pred = torch.argmax(torch.softmax(logits, dim = 1), axis = 1).to(device).numpy()
35
+ # return Image.fromarray(np.squeeze(np.moveaxis(pred, (0,1,2), (2,0,1)),-1)*255.0).convert("L").resize((300,224))
36
+
37
+ def make_preds_return_mask(img, model):
38
+ model.eval()
39
+ with torch.inference_mode():
40
+ logits = model(img)
41
+ pred = torch.argmax(torch.softmax(logits, dim=1), axis=1).to(device).cpu().numpy()
42
+ mask = np.squeeze(pred) # shape (H, W)
43
+ return mask # return numpy array
44
+
45
+
46
+ def load_model():
47
+ class floodLighningModel(pl.LightningModule):
48
+ def __init__(self, model, lr):
49
+ super().__init__()
50
+ self.model = model
51
+ self.lr = lr
52
+ # self.loss_fn = nn.CrossEntropyLoss().to(device)
53
+ self.loss_fn = smp.losses.DiceLoss(from_logits = True, mode = "multiclass")
54
+ self.iou = torchmetrics.JaccardIndex(task = "binary")
55
+ self.acc_fn = torchmetrics.classification.BinaryAccuracy()
56
+ # self.acc_fn = BinaryAccuracy()
57
+ self.f1_fn = torchmetrics.classification.MulticlassF1Score(num_classes = 2).to(device)
58
+ # self.model.save_hyperparameter(ignore = ["model"])
59
+
60
+ def forward(self, x):
61
+ return self.model(x)
62
+
63
+ def training_step(self, batch, batch_idx):
64
+ self.model.train()
65
+ image, gt = batch
66
+ logits = self.model(image)
67
+ loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
68
+ iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
69
+ acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
70
+ f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
71
+ # f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
72
+ self.log("loss", loss, prog_bar = True, on_step = False, on_epoch = True)
73
+ self.log("iou", iou, prog_bar = True, on_step = False, on_epoch = True)
74
+ self.log("accuracy", acc, prog_bar = True, on_step = False, on_epoch = True)
75
+ self.log("f1", f1, prog_bar = True, on_step = False, on_epoch = True)
76
+ return {"loss": loss, "f1": f1, "iou": iou, "accuracy": acc}
77
+
78
+ def validation_step(self, batch, batch_idx):
79
+ self.model.eval()
80
+ image, gt = batch
81
+ logits = self.model(image)
82
+ val_loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
83
+ val_iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
84
+ val_acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
85
+ val_f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
86
+ # val_f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
87
+ self.log("validation loss", val_loss, prog_bar = True, on_step = False, on_epoch = True)
88
+ self.log("validation iou", val_iou, prog_bar = True, on_step = False, on_epoch = True)
89
+ self.log("validation accuracy", val_acc, prog_bar = True, on_step = False, on_epoch = True)
90
+ self.log("validation f1", val_f1, prog_bar = True, on_step = False, on_epoch = True)
91
+ return {"validation loss": val_loss, "validation f1": val_f1, "validation iou": val_iou, "validation accuracy": val_acc}
92
+
93
+ def configure_optimizers(self):
94
+ optimizer = torch.optim.AdamW(params = self.model.parameters(), lr = self.lr)
95
+ return optimizer
96
+
97
+ pl.seed_everything(2025)
98
+ model_ = smp.Unet(
99
+ encoder_name="efficientnet-b0",
100
+ encoder_weights="imagenet",
101
+ in_channels=9,
102
+ classes=2,
103
+ )
104
+ # model_ = smp.DPT(
105
+ # encoder_name="tu-vit_base_patch16_224.augreg_in21k",
106
+ # encoder_weights="imagenet",
107
+ # in_channels=9,
108
+ # classes=2,
109
+ # )
110
+ lightning_model = floodLighningModel.load_from_checkpoint(model = model_, lr = 5e-6, map_location=torch.device(device=device), checkpoint_path="ckpts/epoch=39-step=2760.ckpt")
111
  return lightning_model