import numpy as np import torch import lightning.pytorch as pl import torchmetrics import segmentation_models_pytorch as smp import skimage from skimage.transform import resize import rasterio from rasterio.plot import show import PIL from PIL import Image device = "cuda" if torch.cuda.is_available() else "cpu" def prepare_image(path:str): tif = rasterio.open(path).read()[:,:,:] resized = resize(tif, (9, 224, 224), order=1, preserve_range=True, anti_aliasing=True) tif = resized.astype(tif.dtype) tif = np.nan_to_num(tif, nan = 0) return torch.tensor(tif).type(torch.float32).unsqueeze(0) def get_gt(path:str): im = ((np.dstack([rasterio.open(path).read(i) for i in (3,2,1)]) / np.max(rasterio.open(path).read())).clip(0,1)*255).astype(np.uint8) return im # def make_preds_return_mask(img, model): # model.eval() # with torch.inference_mode(): # logits = model(img) # pred = torch.argmax(torch.softmax(logits, dim = 1), axis = 1).to(device).numpy() # return Image.fromarray(np.squeeze(np.moveaxis(pred, (0,1,2), (2,0,1)),-1)*255.0).convert("L").resize((300,224)) def make_preds_return_mask(img, model): model.eval() with torch.inference_mode(): logits = model(img) pred = torch.argmax(torch.softmax(logits, dim=1), axis=1).to(device).cpu().numpy() mask = np.squeeze(pred) # shape (H, W) return mask # return numpy array def load_model(): class floodLighningModel(pl.LightningModule): def __init__(self, model, lr): super().__init__() self.model = model self.lr = lr # self.loss_fn = nn.CrossEntropyLoss().to(device) self.loss_fn = smp.losses.DiceLoss(from_logits = True, mode = "multiclass") self.iou = torchmetrics.JaccardIndex(task = "binary") self.acc_fn = torchmetrics.classification.BinaryAccuracy() # self.acc_fn = BinaryAccuracy() self.f1_fn = torchmetrics.classification.MulticlassF1Score(num_classes = 2).to(device) # self.model.save_hyperparameter(ignore = ["model"]) def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): self.model.train() image, gt = batch logits = self.model(image) loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device)) iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) # f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt) self.log("loss", loss, prog_bar = True, on_step = False, on_epoch = True) self.log("iou", iou, prog_bar = True, on_step = False, on_epoch = True) self.log("accuracy", acc, prog_bar = True, on_step = False, on_epoch = True) self.log("f1", f1, prog_bar = True, on_step = False, on_epoch = True) return {"loss": loss, "f1": f1, "iou": iou, "accuracy": acc} def validation_step(self, batch, batch_idx): self.model.eval() image, gt = batch logits = self.model(image) val_loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device)) val_iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) val_acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) val_f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) # val_f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt) self.log("validation loss", val_loss, prog_bar = True, on_step = False, on_epoch = True) self.log("validation iou", val_iou, prog_bar = True, on_step = False, on_epoch = True) self.log("validation accuracy", val_acc, prog_bar = True, on_step = False, on_epoch = True) self.log("validation f1", val_f1, prog_bar = True, on_step = False, on_epoch = True) return {"validation loss": val_loss, "validation f1": val_f1, "validation iou": val_iou, "validation accuracy": val_acc} def configure_optimizers(self): optimizer = torch.optim.AdamW(params = self.model.parameters(), lr = self.lr) return optimizer pl.seed_everything(2025) model_ = smp.Unet( encoder_name="efficientnet-b0", encoder_weights="imagenet", in_channels=9, classes=2, ) # model_ = smp.DPT( # encoder_name="tu-vit_base_patch16_224.augreg_in21k", # encoder_weights="imagenet", # in_channels=9, # classes=2, # ) lightning_model = floodLighningModel.load_from_checkpoint(model = model_, lr = 5e-6, map_location=torch.device(device=device), checkpoint_path="ckpts/epoch=39-step=2760.ckpt") return lightning_model