import cv2 def to_rgb(image): if len(image.shape) == 3 and image.shape[-1] == 3: return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)