|
|
import os |
|
|
import re |
|
|
import cv2 |
|
|
import yaml |
|
|
import torch |
|
|
import hashlib |
|
|
import argparse |
|
|
|
|
|
import albumentations as A |
|
|
from albumentations.core.transforms_interface import ImageOnlyTransform |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from PIL import Image |
|
|
from threading import Thread |
|
|
from easydict import EasyDict |
|
|
|
|
|
VID_EXTS = ('mp4', 'avi', 'h264', 'mkv', 'mov', 'flv', 'wmv', 'webm', 'ts', 'm4v', 'vob', '3gp', '3g2', 'rm', 'rmvb', 'ogv', 'ogg', 'drc', 'gif', 'gifv', 'mng', 'avi', 'mov', 'qt', 'wmv', 'yuv', 'rm', 'rmvb', 'asf', 'amv', 'mp4', 'm4p', 'm4v', 'mpg', 'mp2', 'mpeg', 'mpe', 'mpv', 'mpg', 'mpeg', 'm2v', 'm4v', 'svi', '3gp', '3g2', 'mxf', 'roq', 'nsv', 'flv', 'f4v', 'f4p', 'f4a', 'f4b') |
|
|
IMG_EXTS = ('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'pgm', 'pbm', 'pnm', 'webp', 'sr', 'ras', 'tiff', 'tif', 'exr', 'hdr', 'pic', 'dib', 'jpe', 'jp2', 'j2k', 'jpf', 'jpx', 'jpm', 'mj2', 'jxr', 'hdp', 'wdp', 'cur', 'ico', 'ani', 'icns', 'bpg', 'jp2', 'j2k', 'jpf', 'jpx', 'jpm', 'mj2', 'jxr', 'hdp', 'wdp', 'cur', 'ico', 'ani', 'icns', 'bpg') |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--source', '-s', type=str, help="Path to the source. Single image, video, directory of images, directory of videos is supported.") |
|
|
parser.add_argument('--dest', '-d', type=str, default=None, help="Path to destination. Results will be stored in current directory if not specified.") |
|
|
parser.add_argument('--type', '-t', type=str, default='rgba', help="Specify output type. If not specified, output results will make the background transparent. Please refer to the documentation for other types.") |
|
|
parser.add_argument('--reverse', '-R', action='store_true', help="Output will be reverse and foreground will be removed instead of background if specified.") |
|
|
parser.add_argument('--format', '-f', type=str, default=None, help="Specify output format. If not specified, it will be saved with the format of input.") |
|
|
parser.add_argument('--resize', '-r', type=str, default='static', help="Specify resizing method. If not specified, static resize will be used. Choose from (static|dynamic).") |
|
|
parser.add_argument('--jit', '-j', action='store_true', help="Speed up inference speed by using torchscript, but decreases output quality.") |
|
|
parser.add_argument('--device', '-D', type=str, default=None, help="Designate device. If not specified, it will find available device.") |
|
|
parser.add_argument('--mode', '-m', type=str, default='base', help="choose between base and fast mode. Also, use base-nightly for nightly release checkpoint.") |
|
|
parser.add_argument('--ckpt', '-c', type=str, default=None, help="Designate checkpoint. If not specified, it will download or load pre-downloaded default checkpoint.") |
|
|
parser.add_argument('--threshold', '-th', type=str, default=None, help="Designate threshold. If specified, it will output hard prediction above threshold. If not specified, it will output soft prediction.") |
|
|
return parser.parse_args() |
|
|
|
|
|
def get_backend(): |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda:0" |
|
|
elif torch.backends.mps.is_available(): |
|
|
return "mps:0" |
|
|
else: |
|
|
return "cpu" |
|
|
|
|
|
def load_config(config_dir, easy=True): |
|
|
cfg = yaml.load(open(config_dir), yaml.FullLoader) |
|
|
if easy is True: |
|
|
cfg = EasyDict(cfg) |
|
|
return cfg |
|
|
|
|
|
def get_format(source): |
|
|
img_count = len([i for i in source if i.lower().endswith(IMG_EXTS)]) |
|
|
vid_count = len([i for i in source if i.lower().endswith(VID_EXTS)]) |
|
|
|
|
|
if img_count * vid_count != 0: |
|
|
return '' |
|
|
elif img_count != 0: |
|
|
return 'Image' |
|
|
elif vid_count != 0: |
|
|
return 'Video' |
|
|
else: |
|
|
return '' |
|
|
|
|
|
def sort(x): |
|
|
convert = lambda text: int(text) if text.isdigit() else text.lower() |
|
|
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] |
|
|
return sorted(x, key=alphanum_key) |
|
|
|
|
|
def download_and_unzip(filename, url, dest, unzip=True, **kwargs): |
|
|
if not os.path.isdir(dest): |
|
|
os.makedirs(dest, exist_ok=True) |
|
|
|
|
|
if os.path.isfile(os.path.join(dest, filename)) is False: |
|
|
os.system("wget -O {} {}".format(os.path.join(dest, filename), url)) |
|
|
elif 'md5' in kwargs.keys() and kwargs['md5'] != hashlib.md5(open(os.path.join(dest, filename), 'rb').read()).hexdigest(): |
|
|
os.system("wget -O {} {}".format(os.path.join(dest, filename), url)) |
|
|
|
|
|
if unzip: |
|
|
os.system("unzip -o {} -d {}".format(os.path.join(dest, filename), dest)) |
|
|
os.system("rm {}".format(os.path.join(dest, filename))) |
|
|
|
|
|
class dynamic_resize: |
|
|
def __init__(self, L=1280): |
|
|
self.L = L |
|
|
|
|
|
def __call__(self, img): |
|
|
size = list(img.size) |
|
|
if (size[0] >= size[1]) and size[1] > self.L: |
|
|
size[0] = size[0] / (size[1] / self.L) |
|
|
size[1] = self.L |
|
|
elif (size[1] > size[0]) and size[0] > self.L: |
|
|
size[1] = size[1] / (size[0] / self.L) |
|
|
size[0] = self.L |
|
|
size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32) |
|
|
|
|
|
return img.resize(size, Image.BILINEAR) |
|
|
|
|
|
class dynamic_resize_a(ImageOnlyTransform): |
|
|
def __init__(self, L=1280, always_apply=False, p=1.0): |
|
|
super(dynamic_resize_a, self).__init__(always_apply, p) |
|
|
self.L = L |
|
|
|
|
|
def apply(self, img, **params): |
|
|
size = list(img.shape[:2]) |
|
|
if (size[0] >= size[1]) and size[1] > self.L: |
|
|
size[0] = size[0] / (size[1] / self.L) |
|
|
size[1] = self.L |
|
|
elif (size[1] > size[0]) and size[0] > self.L: |
|
|
size[1] = size[1] / (size[0] / self.L) |
|
|
size[0] = self.L |
|
|
size = (int(round(size[0] / 32)) * 32, int(round(size[1] / 32)) * 32) |
|
|
|
|
|
return A.resize(img, height=size[0], width=size[1]) |
|
|
|
|
|
def get_transform_init_args_names(self): |
|
|
return ("L",) |
|
|
|
|
|
class static_resize: |
|
|
def __init__(self, size=[1024, 1024]): |
|
|
self.size = size |
|
|
|
|
|
def __call__(self, img): |
|
|
return img.resize(self.size, Image.BILINEAR) |
|
|
|
|
|
class normalize: |
|
|
def __init__(self, mean=None, std=None, div=255): |
|
|
self.mean = mean if mean is not None else 0.0 |
|
|
self.std = std if std is not None else 1.0 |
|
|
self.div = div |
|
|
|
|
|
def __call__(self, img): |
|
|
img /= self.div |
|
|
img -= self.mean |
|
|
img /= self.std |
|
|
|
|
|
return img |
|
|
|
|
|
class tonumpy: |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def __call__(self, img): |
|
|
img = np.array(img, dtype=np.float32) |
|
|
return img |
|
|
|
|
|
class totensor: |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def __call__(self, img): |
|
|
img = img.transpose((2, 0, 1)) |
|
|
img = torch.from_numpy(img).float() |
|
|
|
|
|
return img |
|
|
|
|
|
class ImageLoader: |
|
|
def __init__(self, root): |
|
|
if os.path.isdir(root): |
|
|
self.images = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] |
|
|
self.images = sort(self.images) |
|
|
elif os.path.isfile(root): |
|
|
self.images = [root] |
|
|
self.size = len(self.images) |
|
|
|
|
|
def __iter__(self): |
|
|
self.index = 0 |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
if self.index == self.size: |
|
|
raise StopIteration |
|
|
|
|
|
img = Image.open(self.images[self.index]).convert('RGB') |
|
|
name = os.path.split(self.images[self.index])[-1] |
|
|
|
|
|
|
|
|
self.index += 1 |
|
|
return img, name |
|
|
|
|
|
def __len__(self): |
|
|
return self.size |
|
|
|
|
|
class VideoLoader: |
|
|
def __init__(self, root): |
|
|
if os.path.isdir(root): |
|
|
self.videos = [os.path.join(root, f) for f in os.listdir(root) if f.lower().endswith(('.mp4', '.avi', 'mov'))] |
|
|
elif os.path.isfile(root): |
|
|
self.videos = [root] |
|
|
self.size = len(self.videos) |
|
|
|
|
|
def __iter__(self): |
|
|
self.index = 0 |
|
|
self.cap = None |
|
|
self.fps = None |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
if self.index == self.size: |
|
|
raise StopIteration |
|
|
|
|
|
if self.cap is None: |
|
|
self.cap = cv2.VideoCapture(self.videos[self.index]) |
|
|
self.fps = self.cap.get(cv2.CAP_PROP_FPS) |
|
|
ret, frame = self.cap.read() |
|
|
name = os.path.split(self.videos[self.index])[-1] |
|
|
|
|
|
if ret is False: |
|
|
self.cap.release() |
|
|
self.cap = None |
|
|
img = None |
|
|
self.index += 1 |
|
|
|
|
|
else: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
img = Image.fromarray(frame).convert('RGB') |
|
|
|
|
|
return img, name |
|
|
|
|
|
def __len__(self): |
|
|
return self.size |
|
|
|
|
|
class WebcamLoader: |
|
|
def __init__(self, ID): |
|
|
self.ID = int(ID) |
|
|
self.cap = cv2.VideoCapture(self.ID) |
|
|
self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) |
|
|
self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) |
|
|
self.imgs = [] |
|
|
self.imgs.append(self.cap.read()[1]) |
|
|
self.thread = Thread(target=self.update, daemon=True) |
|
|
self.thread.start() |
|
|
|
|
|
def update(self): |
|
|
while self.cap.isOpened(): |
|
|
ret, frame = self.cap.read() |
|
|
if ret is True: |
|
|
self.imgs.append(frame) |
|
|
else: |
|
|
break |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self): |
|
|
if len(self.imgs) > 0: |
|
|
frame = self.imgs[-1] |
|
|
else: |
|
|
frame = Image.fromarray(np.zeros((480, 640, 3)).astype(np.uint8)) |
|
|
|
|
|
if self.thread.is_alive() is False or cv2.waitKey(1) == ord('q'): |
|
|
cv2.destroyAllWindows() |
|
|
raise StopIteration |
|
|
|
|
|
else: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frame = Image.fromarray(frame).convert('RGB') |
|
|
|
|
|
del self.imgs[:-1] |
|
|
return frame, None |
|
|
|
|
|
def __len__(self): |
|
|
return 0 |
|
|
|