| | """Compute segmentation maps for images in the input folder. |
| | """ |
| | import os |
| | import glob |
| | import cv2 |
| | import argparse |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | import util.io |
| |
|
| | from torchvision.transforms import Compose |
| | from dpt.models import DPTSegmentationModel |
| | from dpt.transforms import Resize, NormalizeImage, PrepareForNet |
| |
|
| |
|
| | def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True): |
| | """Run segmentation network |
| | |
| | Args: |
| | input_path (str): path to input folder |
| | output_path (str): path to output folder |
| | model_path (str): path to saved model |
| | """ |
| | print("initialize") |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print("device: %s" % device) |
| |
|
| | net_w = net_h = 480 |
| |
|
| | |
| | if model_type == "dpt_large": |
| | model = DPTSegmentationModel( |
| | 150, |
| | path=model_path, |
| | backbone="vitl16_384", |
| | ) |
| | elif model_type == "dpt_hybrid": |
| | model = DPTSegmentationModel( |
| | 150, |
| | path=model_path, |
| | backbone="vitb_rn50_384", |
| | ) |
| | else: |
| | assert ( |
| | False |
| | ), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]" |
| |
|
| | transform = Compose( |
| | [ |
| | Resize( |
| | net_w, |
| | net_h, |
| | resize_target=None, |
| | keep_aspect_ratio=True, |
| | ensure_multiple_of=32, |
| | resize_method="minimal", |
| | image_interpolation_method=cv2.INTER_CUBIC, |
| | ), |
| | NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| | PrepareForNet(), |
| | ] |
| | ) |
| |
|
| | model.eval() |
| |
|
| | if optimize == True and device == torch.device("cuda"): |
| | model = model.to(memory_format=torch.channels_last) |
| | model = model.half() |
| |
|
| | model.to(device) |
| |
|
| | |
| | img_names = glob.glob(os.path.join(input_path, "*")) |
| | num_images = len(img_names) |
| |
|
| | |
| | os.makedirs(output_path, exist_ok=True) |
| |
|
| | print("start processing") |
| |
|
| | for ind, img_name in enumerate(img_names): |
| |
|
| | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) |
| |
|
| | |
| | img = util.io.read_image(img_name) |
| | img_input = transform({"image": img})["image"] |
| |
|
| | |
| | with torch.no_grad(): |
| | sample = torch.from_numpy(img_input).to(device).unsqueeze(0) |
| | if optimize == True and device == torch.device("cuda"): |
| | sample = sample.to(memory_format=torch.channels_last) |
| | sample = sample.half() |
| |
|
| | out = model.forward(sample) |
| |
|
| | prediction = torch.nn.functional.interpolate( |
| | out, size=img.shape[:2], mode="bicubic", align_corners=False |
| | ) |
| | prediction = torch.argmax(prediction, dim=1) + 1 |
| | prediction = prediction.squeeze().cpu().numpy() |
| |
|
| | |
| | filename = os.path.join( |
| | output_path, os.path.splitext(os.path.basename(img_name))[0] |
| | ) |
| | util.io.write_segm_img(filename, img, prediction, alpha=0.5) |
| |
|
| | print("finished") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "-i", "--input_path", default="input", help="folder with input images" |
| | ) |
| |
|
| | parser.add_argument( |
| | "-o", "--output_path", default="output_semseg", help="folder for output images" |
| | ) |
| |
|
| | parser.add_argument( |
| | "-m", |
| | "--model_weights", |
| | default=None, |
| | help="path to the trained weights of model", |
| | ) |
| |
|
| | |
| | parser.add_argument("-t", "--model_type", default="dpt_hybrid", help="model type") |
| |
|
| | parser.add_argument("--optimize", dest="optimize", action="store_true") |
| | parser.add_argument("--no-optimize", dest="optimize", action="store_false") |
| | parser.set_defaults(optimize=True) |
| |
|
| | args = parser.parse_args() |
| |
|
| | default_models = { |
| | "dpt_large": "weights/dpt_large-ade20k-b12dca68.pt", |
| | "dpt_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt", |
| | } |
| |
|
| | if args.model_weights is None: |
| | args.model_weights = default_models[args.model_type] |
| |
|
| | |
| | torch.backends.cudnn.enabled = True |
| | torch.backends.cudnn.benchmark = True |
| |
|
| | |
| | run( |
| | args.input_path, |
| | args.output_path, |
| | args.model_weights, |
| | args.model_type, |
| | args.optimize, |
| | ) |
| |
|