"""Compute segmentation maps for images in the input folder. """ import os import glob import cv2 import argparse import torch import util.io from torchvision.transforms import Compose from dpt.models import DPTSegmentationModel from dpt.transforms import Resize, NormalizeImage, PrepareForNet def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True): """Run segmentation network Args: input_path (str): path to input folder output_path (str): path to output folder model_path (str): path to saved model """ print("initialize") # select device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device: %s" % device) net_w = net_h = 480 # load network if model_type == "dpt_large": model = DPTSegmentationModel( 150, path=model_path, backbone="vitl16_384", ) elif model_type == "dpt_hybrid": model = DPTSegmentationModel( 150, path=model_path, backbone="vitb_rn50_384", ) else: assert ( False ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]" transform = Compose( [ Resize( net_w, net_h, resize_target=None, keep_aspect_ratio=True, ensure_multiple_of=32, resize_method="minimal", image_interpolation_method=cv2.INTER_CUBIC, ), NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), PrepareForNet(), ] ) model.eval() if optimize == True and device == torch.device("cuda"): model = model.to(memory_format=torch.channels_last) model = model.half() model.to(device) # get input img_names = glob.glob(os.path.join(input_path, "*")) num_images = len(img_names) # create output folder os.makedirs(output_path, exist_ok=True) print("start processing") for ind, img_name in enumerate(img_names): print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) # input img = util.io.read_image(img_name) img_input = transform({"image": img})["image"] # compute with torch.no_grad(): sample = torch.from_numpy(img_input).to(device).unsqueeze(0) if optimize == True and device == torch.device("cuda"): sample = sample.to(memory_format=torch.channels_last) sample = sample.half() out = model.forward(sample) prediction = torch.nn.functional.interpolate( out, size=img.shape[:2], mode="bicubic", align_corners=False ) prediction = torch.argmax(prediction, dim=1) + 1 prediction = prediction.squeeze().cpu().numpy() # output filename = os.path.join( output_path, os.path.splitext(os.path.basename(img_name))[0] ) util.io.write_segm_img(filename, img, prediction, alpha=0.5) print("finished") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input_path", default="input", help="folder with input images" ) parser.add_argument( "-o", "--output_path", default="output_semseg", help="folder for output images" ) parser.add_argument( "-m", "--model_weights", default=None, help="path to the trained weights of model", ) # 'vit_large', 'vit_hybrid' parser.add_argument("-t", "--model_type", default="dpt_hybrid", help="model type") parser.add_argument("--optimize", dest="optimize", action="store_true") parser.add_argument("--no-optimize", dest="optimize", action="store_false") parser.set_defaults(optimize=True) args = parser.parse_args() default_models = { "dpt_large": "weights/dpt_large-ade20k-b12dca68.pt", "dpt_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt", } if args.model_weights is None: args.model_weights = default_models[args.model_type] # set torch options torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True # compute segmentation maps run( args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, )