| import sys, os |
| import numpy as np |
| import scipy |
| import torch |
| import torch.nn as nn |
| from scipy import ndimage |
| from tqdm import tqdm, trange |
| from PIL import Image |
| import torch.hub |
| import torchvision |
| import torch.nn.functional as F |
|
|
| |
| |
| |
| CKPT_PATH = "TODO" |
|
|
| rescale = lambda x: (x + 1.) / 2. |
|
|
| def rescale_bgr(x): |
| x = (x+1)*127.5 |
| x = torch.flip(x, dims=[0]) |
| return x |
|
|
|
|
| class COCOStuffSegmenter(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.n_labels = 182 |
| model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels) |
| ckpt_path = CKPT_PATH |
| model.load_state_dict(torch.load(ckpt_path)) |
| self.model = model |
|
|
| normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) |
| self.image_transform = torchvision.transforms.Compose([ |
| torchvision.transforms.Lambda(lambda image: torch.stack( |
| [normalize(rescale_bgr(x)) for x in image])) |
| ]) |
|
|
| def forward(self, x, upsample=None): |
| x = self._pre_process(x) |
| x = self.model(x) |
| if upsample is not None: |
| x = torch.nn.functional.upsample_bilinear(x, size=upsample) |
| return x |
|
|
| def _pre_process(self, x): |
| x = self.image_transform(x) |
| return x |
|
|
| @property |
| def mean(self): |
| |
| return [104.008, 116.669, 122.675] |
|
|
| @property |
| def std(self): |
| return [1.0, 1.0, 1.0] |
|
|
| @property |
| def input_size(self): |
| return [3, 224, 224] |
|
|
|
|
| def run_model(img, model): |
| model = model.eval() |
| with torch.no_grad(): |
| segmentation = model(img, upsample=(img.shape[2], img.shape[3])) |
| segmentation = torch.argmax(segmentation, dim=1, keepdim=True) |
| return segmentation.detach().cpu() |
|
|
|
|
| def get_input(batch, k): |
| x = batch[k] |
| if len(x.shape) == 3: |
| x = x[..., None] |
| x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) |
| return x.float() |
|
|
|
|
| def save_segmentation(segmentation, path): |
| |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| assert len(segmentation.shape)==4 |
| assert segmentation.shape[0]==1 |
| for seg in segmentation: |
| seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8) |
| seg = Image.fromarray(seg) |
| seg.save(path) |
|
|
|
|
| def iterate_dataset(dataloader, destpath, model): |
| os.makedirs(destpath, exist_ok=True) |
| num_processed = 0 |
| for i, batch in tqdm(enumerate(dataloader), desc="Data"): |
| try: |
| img = get_input(batch, "image") |
| img = img.cuda() |
| seg = run_model(img, model) |
|
|
| path = batch["relative_file_path_"][0] |
| path = os.path.splitext(path)[0] |
|
|
| path = os.path.join(destpath, path + ".png") |
| save_segmentation(seg, path) |
| num_processed += 1 |
| except Exception as e: |
| print(e) |
| print("but anyhow..") |
|
|
| print("Processed {} files. Bye.".format(num_processed)) |
|
|
|
|
| from taming.data.sflckr import Examples |
| from torch.utils.data import DataLoader |
|
|
| if __name__ == "__main__": |
| dest = sys.argv[1] |
| batchsize = 1 |
| print("Running with batch-size {}, saving to {}...".format(batchsize, dest)) |
|
|
| model = COCOStuffSegmenter({}).cuda() |
| print("Instantiated model.") |
|
|
| dataset = Examples() |
| dloader = DataLoader(dataset, batch_size=batchsize) |
| iterate_dataset(dataloader=dloader, destpath=dest, model=model) |
| print("done.") |
|
|