AyoAgbaje's picture
Update utils/utils.py
65936d6 verified
import numpy as np
import torch
import lightning.pytorch as pl
import torchmetrics
import segmentation_models_pytorch as smp
import skimage
from skimage.transform import resize
import rasterio
from rasterio.plot import show
import PIL
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
def prepare_image(path:str):
tif = rasterio.open(path).read()[:,:,:]
resized = resize(tif, (9, 224, 224),
order=1,
preserve_range=True,
anti_aliasing=True)
tif = resized.astype(tif.dtype)
tif = np.nan_to_num(tif, nan = 0)
return torch.tensor(tif).type(torch.float32).unsqueeze(0)
def get_gt(path:str):
im = ((np.dstack([rasterio.open(path).read(i) for i in (3,2,1)]) / np.max(rasterio.open(path).read())).clip(0,1)*255).astype(np.uint8)
return im
# def make_preds_return_mask(img, model):
# model.eval()
# with torch.inference_mode():
# logits = model(img)
# pred = torch.argmax(torch.softmax(logits, dim = 1), axis = 1).to(device).numpy()
# return Image.fromarray(np.squeeze(np.moveaxis(pred, (0,1,2), (2,0,1)),-1)*255.0).convert("L").resize((300,224))
def make_preds_return_mask(img, model):
model.eval()
with torch.inference_mode():
logits = model(img)
pred = torch.argmax(torch.softmax(logits, dim=1), axis=1).to(device).cpu().numpy()
mask = np.squeeze(pred) # shape (H, W)
return mask # return numpy array
def load_model():
class floodLighningModel(pl.LightningModule):
def __init__(self, model, lr):
super().__init__()
self.model = model
self.lr = lr
# self.loss_fn = nn.CrossEntropyLoss().to(device)
self.loss_fn = smp.losses.DiceLoss(from_logits = True, mode = "multiclass")
self.iou = torchmetrics.JaccardIndex(task = "binary")
self.acc_fn = torchmetrics.classification.BinaryAccuracy()
# self.acc_fn = BinaryAccuracy()
self.f1_fn = torchmetrics.classification.MulticlassF1Score(num_classes = 2).to(device)
# self.model.save_hyperparameter(ignore = ["model"])
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
self.model.train()
image, gt = batch
logits = self.model(image)
loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
# f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
self.log("loss", loss, prog_bar = True, on_step = False, on_epoch = True)
self.log("iou", iou, prog_bar = True, on_step = False, on_epoch = True)
self.log("accuracy", acc, prog_bar = True, on_step = False, on_epoch = True)
self.log("f1", f1, prog_bar = True, on_step = False, on_epoch = True)
return {"loss": loss, "f1": f1, "iou": iou, "accuracy": acc}
def validation_step(self, batch, batch_idx):
self.model.eval()
image, gt = batch
logits = self.model(image)
val_loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device))
val_iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
val_acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
val_f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device))
# val_f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt)
self.log("validation loss", val_loss, prog_bar = True, on_step = False, on_epoch = True)
self.log("validation iou", val_iou, prog_bar = True, on_step = False, on_epoch = True)
self.log("validation accuracy", val_acc, prog_bar = True, on_step = False, on_epoch = True)
self.log("validation f1", val_f1, prog_bar = True, on_step = False, on_epoch = True)
return {"validation loss": val_loss, "validation f1": val_f1, "validation iou": val_iou, "validation accuracy": val_acc}
def configure_optimizers(self):
optimizer = torch.optim.AdamW(params = self.model.parameters(), lr = self.lr)
return optimizer
pl.seed_everything(2025)
model_ = smp.Unet(
encoder_name="efficientnet-b0",
encoder_weights="imagenet",
in_channels=9,
classes=2,
)
# model_ = smp.DPT(
# encoder_name="tu-vit_base_patch16_224.augreg_in21k",
# encoder_weights="imagenet",
# in_channels=9,
# classes=2,
# )
lightning_model = floodLighningModel.load_from_checkpoint(model = model_, lr = 5e-6, map_location=torch.device(device=device), checkpoint_path="ckpts/epoch=39-step=2760.ckpt")
return lightning_model