Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import lightning.pytorch as pl | |
| import torchmetrics | |
| import segmentation_models_pytorch as smp | |
| import skimage | |
| from skimage.transform import resize | |
| import rasterio | |
| from rasterio.plot import show | |
| import PIL | |
| from PIL import Image | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def prepare_image(path:str): | |
| tif = rasterio.open(path).read()[:,:,:] | |
| resized = resize(tif, (9, 224, 224), | |
| order=1, | |
| preserve_range=True, | |
| anti_aliasing=True) | |
| tif = resized.astype(tif.dtype) | |
| tif = np.nan_to_num(tif, nan = 0) | |
| return torch.tensor(tif).type(torch.float32).unsqueeze(0) | |
| def get_gt(path:str): | |
| im = ((np.dstack([rasterio.open(path).read(i) for i in (3,2,1)]) / np.max(rasterio.open(path).read())).clip(0,1)*255).astype(np.uint8) | |
| return im | |
| # def make_preds_return_mask(img, model): | |
| # model.eval() | |
| # with torch.inference_mode(): | |
| # logits = model(img) | |
| # pred = torch.argmax(torch.softmax(logits, dim = 1), axis = 1).to(device).numpy() | |
| # return Image.fromarray(np.squeeze(np.moveaxis(pred, (0,1,2), (2,0,1)),-1)*255.0).convert("L").resize((300,224)) | |
| def make_preds_return_mask(img, model): | |
| model.eval() | |
| with torch.inference_mode(): | |
| logits = model(img) | |
| pred = torch.argmax(torch.softmax(logits, dim=1), axis=1).to(device).cpu().numpy() | |
| mask = np.squeeze(pred) # shape (H, W) | |
| return mask # return numpy array | |
| def load_model(): | |
| class floodLighningModel(pl.LightningModule): | |
| def __init__(self, model, lr): | |
| super().__init__() | |
| self.model = model | |
| self.lr = lr | |
| # self.loss_fn = nn.CrossEntropyLoss().to(device) | |
| self.loss_fn = smp.losses.DiceLoss(from_logits = True, mode = "multiclass") | |
| self.iou = torchmetrics.JaccardIndex(task = "binary") | |
| self.acc_fn = torchmetrics.classification.BinaryAccuracy() | |
| # self.acc_fn = BinaryAccuracy() | |
| self.f1_fn = torchmetrics.classification.MulticlassF1Score(num_classes = 2).to(device) | |
| # self.model.save_hyperparameter(ignore = ["model"]) | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| self.model.train() | |
| image, gt = batch | |
| logits = self.model(image) | |
| loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device)) | |
| iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) | |
| acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) | |
| f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) | |
| # f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt) | |
| self.log("loss", loss, prog_bar = True, on_step = False, on_epoch = True) | |
| self.log("iou", iou, prog_bar = True, on_step = False, on_epoch = True) | |
| self.log("accuracy", acc, prog_bar = True, on_step = False, on_epoch = True) | |
| self.log("f1", f1, prog_bar = True, on_step = False, on_epoch = True) | |
| return {"loss": loss, "f1": f1, "iou": iou, "accuracy": acc} | |
| def validation_step(self, batch, batch_idx): | |
| self.model.eval() | |
| image, gt = batch | |
| logits = self.model(image) | |
| val_loss = self.loss_fn(logits.to(device), gt.squeeze(1).type(torch.LongTensor).to(device)) | |
| val_iou = self.iou(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) | |
| val_acc = self.acc_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) | |
| val_f1 = self.f1_fn(torch.flatten(torch.argmax(torch.softmax(logits, dim = 1), axis = 1)).to(device), torch.flatten(gt).to(device)) | |
| # val_f1 = self.f1_fn(torch.round(torch.sigmoid(logits)), gt) | |
| self.log("validation loss", val_loss, prog_bar = True, on_step = False, on_epoch = True) | |
| self.log("validation iou", val_iou, prog_bar = True, on_step = False, on_epoch = True) | |
| self.log("validation accuracy", val_acc, prog_bar = True, on_step = False, on_epoch = True) | |
| self.log("validation f1", val_f1, prog_bar = True, on_step = False, on_epoch = True) | |
| return {"validation loss": val_loss, "validation f1": val_f1, "validation iou": val_iou, "validation accuracy": val_acc} | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW(params = self.model.parameters(), lr = self.lr) | |
| return optimizer | |
| pl.seed_everything(2025) | |
| model_ = smp.Unet( | |
| encoder_name="efficientnet-b0", | |
| encoder_weights="imagenet", | |
| in_channels=9, | |
| classes=2, | |
| ) | |
| # model_ = smp.DPT( | |
| # encoder_name="tu-vit_base_patch16_224.augreg_in21k", | |
| # encoder_weights="imagenet", | |
| # in_channels=9, | |
| # classes=2, | |
| # ) | |
| lightning_model = floodLighningModel.load_from_checkpoint(model = model_, lr = 5e-6, map_location=torch.device(device=device), checkpoint_path="ckpts/epoch=39-step=2760.ckpt") | |
| return lightning_model |