PR-IQA / submodules /loftup /eval_seg.py
2cu1001's picture
Upload 349 files
52d0a0e verified
from os.path import join
import gc
import hydra
import matplotlib.pyplot as plt
import numpy as np
import cv2
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from omegaconf import DictConfig
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from torchmetrics.classification import Accuracy, JaccardIndex
from upsamplers import load_lift_checkpoints, norm, unnorm, get_upsampler, load_upsampler_weights
from datasets import get_dataset, EmbeddingFile
from featurizers import get_featurizer
from utils import ADE20K_150_CATEGORIES, pca, ToTensorWithoutScaling
class SemSegEvaluator(pl.LightningModule):
def __init__(self, n_dim, upsampler_type=None, upsampler_path=None, cfg=None, model=None):
super().__init__()
self.n_dim = n_dim
self.lr = cfg.lr
self.dataset = cfg.dataset
self.upsampler_type = upsampler_type
self.model = model
self.color_map = {category["id"]: np.array(category["color"]) / 255.0 for category in ADE20K_150_CATEGORIES}
self.name = f"{self.dataset}_{cfg.load_size}_{cfg.upsampler_type}_{cfg.model_type}"
self.guidance_res = cfg.guidance_res
self.train_small_res = cfg.train_small_res
self.upsampler_path = upsampler_path
self.visualize = cfg.visualize
if cfg.dataset == 'cocostuff':
n_classes = 27
elif cfg.dataset == "cityscapes":
n_classes = 19
self.n_classes = n_classes
self.classifier_path = upsampler_path[:-5] + f'{cfg.dataset}_{cfg.load_size}_classifier.ckpt'
self.use_pretrained_classifier = cfg.use_pretrained_classifier
self.classifier = torch.nn.Conv2d(n_dim, n_classes, 1)
if self.use_pretrained_classifier: # Load a pre-trained classifier and freeze it
ckpt_weight = torch.load(cfg.pretrained_classifier_path)
self.classifier.load_state_dict(ckpt_weight)
for param in self.classifier.parameters():
param.requires_grad = False
print(f"Loaded classifier from {cfg.pretrained_classifier_path} and froze it")
if cfg.model_type == "dinov2":
final_size = 16
kernel_size = 14
else:
final_size = 14
kernel_size = 16
if upsampler_type == "lift":
upsampler = load_lift_checkpoints("/path/to/your/lift.ckpt")
for param in upsampler.parameters():
param.requires_grad = False
elif upsampler_type == "featup":
upsampler = torch.hub.load("mhamilton723/FeatUp", cfg.model_type, use_norm=False).upsampler.to("cuda")
for param in upsampler.parameters():
param.requires_grad = False
elif upsampler_type != "no": ## i.e., loftup, bilinear, etc.
upsampler = get_upsampler(upsampler_type, n_dim, lr_size=final_size, cfg=cfg)
if upsampler_type != "bilinear": # contains trainable weights
upsampler = load_upsampler_weights(upsampler, upsampler_path, n_dim)
elif upsampler_type == "no": ## i.e., no upsampler
upsampler = None
else:
raise ValueError(f"Upsampler {upsampler_type} not implemented")
self.upsampler = upsampler
self.linear_acc_metric = Accuracy(num_classes=n_classes, task="multiclass")
self.linear_acc_buff = self.register_buffer("linear_acc", torch.tensor(0.0))
self.linear_iou_metric = JaccardIndex(num_classes=n_classes, task="multiclass")
self.linear_iou_buff = self.register_buffer("linear_iou", torch.tensor(0.0))
self.ce = torch.nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
if self.model is not None:
imgs = batch["img"].cuda()
label = batch["label"]
feats = self.model(imgs)
else:
imgs, feats, label = batch
label = label.to(torch.int32)
label[label == 255] = -1
label = label.float()
if self.upsampler is not None:
if self.guidance_res != 224:
guidance_imgs = F.interpolate(imgs, size=(self.guidance_res, self.guidance_res), mode='bilinear')
else:
guidance_imgs = imgs
feats = self.upsampler(feats, guidance_imgs)
b, c, h, w = feats.shape
if self.train_small_res: ## Consistent with FeatUp's setup to make labels' size the same as feats; use nearest interpolation
labels = F.interpolate(
label.unsqueeze(1),
size=(h, w), mode='nearest').to(torch.int64)
else: ## The common practice is actually to make feats' size the same as labels, so labels are the same size as images
labels = F.interpolate(
label.unsqueeze(1),
size=(imgs.shape[2], imgs.shape[3]), mode='nearest').to(torch.int64)
h, w = labels.shape[2], labels.shape[3]
linear_preds = self.classifier(feats)
## Interpolate the linear_preds
if not self.train_small_res:
linear_preds = F.interpolate(
linear_preds.float(),
size=(imgs.shape[2], imgs.shape[3]), mode='nearest')
flat_labels = labels.permute(0, 2, 3, 1).reshape(b * h * w)
flat_linear_preds = linear_preds.permute(0, 2, 3, 1).reshape(b * h * w, -1)
selected = flat_labels > -1
linear_loss = self.ce(
flat_linear_preds[selected],
flat_labels[selected])
loss = linear_loss
self.log("linear_loss", linear_loss)
self.log("loss", loss)
if self.global_step % 100 == 0:
gc.collect()
torch.cuda.empty_cache()
return loss
def validation_step(self, batch, batch_idx):
with torch.no_grad():
if self.model is not None:
imgs = batch["img"].cuda()
label = batch["label"]
feats = self.model(imgs)
else:
imgs, feats, label = batch
label = label.to(torch.int32)
label[label == 255] = -1
lr_feats = feats.clone()
if self.upsampler is not None:
if self.guidance_res != 224:
guidance_imgs = F.interpolate(imgs, size=(self.guidance_res, self.guidance_res), mode='bilinear')
else:
guidance_imgs = imgs
feats = self.upsampler(feats, guidance_imgs)
linear_preds = self.classifier(feats).argmax(1, keepdim=True)
b, h, w = label.shape
flat_labels = label.flatten()
selected = flat_labels > -1
flat_labels = flat_labels[selected]
flat_linear_preds = F.interpolate(linear_preds.float(), (h, w), mode='nearest').to(torch.int64).flatten()[selected]
self.linear_acc_metric.update(flat_linear_preds, flat_labels)
self.linear_iou_metric.update(flat_linear_preds, flat_labels)
if self.visualize:
# if batch_idx < 40:
unnorm_imgs = unnorm(imgs)
# Use PCA on lr_feats to project feats
[red_lr_feats], fit_pca = pca([lr_feats[0].unsqueeze(0)])
[red_hr_feats], _ = pca([feats[0].unsqueeze(0)], fit_pca=fit_pca)
self.visualize_results(unnorm_imgs[0].cpu().permute(1, 2, 0).numpy(), red_hr_feats[0].cpu().numpy(), linear_preds[0].cpu().numpy().squeeze(), label[0].cpu().numpy(), batch_idx, name=self.name)
return None
def visualize_results(self, original_image, feats, seg_pred, gt_sem_seg, img_idx, name=""):
seg_pred_rgb = np.zeros((*seg_pred.shape, 3), dtype=np.float32)
gt_sem_seg_rgb = np.zeros((*gt_sem_seg.shape, 3), dtype=np.float32)
for class_id, color in self.color_map.items():
seg_pred_rgb[seg_pred == class_id] = color
gt_sem_seg_rgb[gt_sem_seg == class_id] = color
pca_feats = feats.transpose(1, 2, 0)
# Overlay predictions and ground truth on the original image
alpha = 0.8 # Transparency for overlay
# Ensure the shapes of the overlays are the same
seg_pred_rgb = cv2.resize(seg_pred_rgb, (original_image.shape[1], original_image.shape[0]))
gt_sem_seg_rgb = cv2.resize(gt_sem_seg_rgb, (original_image.shape[1], original_image.shape[0]))
seg_pred_overlay = (1 - alpha) * original_image + alpha * seg_pred_rgb
gt_overlay = (1 - alpha) * original_image + alpha * gt_sem_seg_rgb
# Plot results
plt.figure(figsize=(20, 5))
plt.subplot(1, 4, 1)
plt.imshow(original_image)
plt.title("Original Image")
plt.axis("off")
plt.subplot(1, 4, 2)
plt.imshow(pca_feats)
plt.title("Feature Visualization (PCA)")
plt.axis("off")
plt.subplot(1, 4, 3)
plt.imshow(seg_pred_overlay)
plt.title("Segmentation Prediction")
plt.axis("off")
plt.subplot(1, 4, 4)
plt.imshow(gt_overlay)
plt.title("Ground Truth Overlay")
plt.axis("off")
# plt.tight_layout()
plt.savefig(f"visualization/{name}_seg_results_{img_idx}.png", bbox_inches="tight")
plt.close()
def on_validation_epoch_end(self):
self.linear_acc = self.linear_acc_metric.compute()
self.linear_iou = self.linear_iou_metric.compute()
print('Validation:')
print('Linear Accuracy:', self.linear_acc)
print('Linear mIoU:', self.linear_iou)
# Adding to logger for tensorboard
writer = self.logger.experiment
writer.add_scalar('Linear Accuracy', self.linear_acc, self.global_step)
writer.add_scalar('Linear mIoU', self.linear_iou, self.global_step)
self.linear_acc_metric.reset()
self.linear_iou_metric.reset()
if self.classifier_path is not None and (not self.use_pretrained_classifier):
torch.save(self.classifier.state_dict(), self.classifier_path)
print(f"Saved classifier to {self.classifier_path}")
def configure_optimizers(self):
return torch.optim.NAdam(self.classifier.parameters(), lr=self.lr)
@hydra.main(config_path="configs", config_name="eval_seg.yaml")
def my_app(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
print(cfg.output_root)
seed_everything(seed=0, workers=True)
log_dir = join(cfg.output_root, f"logs/eval/semseg")
name = (f"{cfg.model_type}_{cfg.upsampler_type}_"
f"{cfg.dataset}_{cfg.load_size}"
)
upsampler_path = cfg.upsampler_path
emb_root = join(cfg.pytorch_data_dir, cfg.dataset, "embedding", cfg.model_type)
if cfg.image_train:
## Standard data loading. Will load images and transform them on the fly.
## Can be slow if evaluating multiple times. But might be the only option if the dataset is too large to fit in memory.
load_size = cfg.load_size
transform = T.Compose([
T.Resize(load_size, InterpolationMode.BILINEAR),
T.CenterCrop(load_size),
T.ToTensor(),
norm])
target_transforms = T.Compose([
T.Resize(load_size, InterpolationMode.NEAREST),
T.CenterCrop(load_size),
ToTensorWithoutScaling()
])
train_dataset = get_dataset(
cfg.pytorch_data_dir,
cfg.dataset,
"train",
transform=transform,
target_transform=target_transforms,
include_labels=True)
train_loader = DataLoader(
train_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_dataset = get_dataset(
cfg.pytorch_data_dir,
cfg.dataset,
"val",
transform=transform,
target_transform=target_transforms,
include_labels=True)
val_loader = DataLoader(
val_dataset, cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
else:
## Extracting embeddings beforehand. Will save a lot of time during evaluation.
train_dataset = EmbeddingFile(join(emb_root, "train", f"embeddings_{cfg.model_type}.npz"), loading_imgs=cfg.loading_imgs)
train_loader = DataLoader(train_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_dataset = EmbeddingFile(join(emb_root, "val", f"embeddings_{cfg.model_type}.npz"), loading_imgs=cfg.loading_imgs)
val_loader = DataLoader(val_dataset, cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
if cfg.image_train:
model, patch_size, dim = get_featurizer(cfg.model_type, cfg.activation_type, num_classes=1000)
model.eval()
for param in model.parameters():
param.requires_grad = False
else:
_, _, dim = get_featurizer(cfg.model_type, cfg.activation_type, num_classes=1000)
evaluator = SemSegEvaluator(dim, upsampler_type=cfg.upsampler_type, upsampler_path=upsampler_path, cfg=cfg, model=model)
tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False)
trainer = Trainer(
accelerator='gpu',
strategy='ddp',
devices=cfg.num_gpus,
max_epochs=cfg.epochs,
logger=tb_logger,
log_every_n_steps=100,
reload_dataloaders_every_n_epochs=1,
check_val_every_n_epoch=1,
)
if cfg.use_pretrained_classifier: # Directly run validation
trainer.validate(evaluator, val_loader)
else:
trainer.fit(evaluator, train_loader, val_loader)
result = {
"Linear Accuracy": float(evaluator.linear_acc),
"Linear mIoU": float(evaluator.linear_iou),
"Model": cfg.model_type
}
print(result)
if __name__ == "__main__":
my_app()