diff --git a/third_party/flux/__init__.py b/third_party/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43c365a49d6980e88acba10ef3069f110a59644a --- /dev/null +++ b/third_party/flux/__init__.py @@ -0,0 +1,11 @@ +try: + from ._version import version as __version__ # type: ignore + from ._version import version_tuple +except ImportError: + __version__ = "unknown (no version information available)" + version_tuple = (0, 0, "unknown", "noinfo") + +from pathlib import Path + +PACKAGE = __package__.replace("_", "-") +PACKAGE_ROOT = Path(__file__).parent diff --git a/third_party/flux/__main__.py b/third_party/flux/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cf0fd2444d4cda4053fa74dad3371556b886e5 --- /dev/null +++ b/third_party/flux/__main__.py @@ -0,0 +1,4 @@ +from .cli import app + +if __name__ == "__main__": + app() diff --git a/third_party/flux/__pycache__/__init__.cpython-310.pyc b/third_party/flux/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ecb94b23095305e766128748b86324b1599f05b Binary files /dev/null and b/third_party/flux/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/__pycache__/controlnet.cpython-310.pyc b/third_party/flux/__pycache__/controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b3640bfc597ace6fd155324c703c4f03d46e7c8 Binary files /dev/null and b/third_party/flux/__pycache__/controlnet.cpython-310.pyc differ diff --git a/third_party/flux/__pycache__/math.cpython-310.pyc b/third_party/flux/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6941439dd7280956f7cd141077fd478f8c2d154a Binary files /dev/null and b/third_party/flux/__pycache__/math.cpython-310.pyc differ diff --git a/third_party/flux/__pycache__/model.cpython-310.pyc b/third_party/flux/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cfb3d9d33eddb74fa0c1ae6b9c55b760e8c7eff Binary files /dev/null and b/third_party/flux/__pycache__/model.cpython-310.pyc differ diff --git a/third_party/flux/__pycache__/sampling.cpython-310.pyc b/third_party/flux/__pycache__/sampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebb317ca49f7702b4a163b044ee7d50d1d61fa0 Binary files /dev/null and b/third_party/flux/__pycache__/sampling.cpython-310.pyc differ diff --git a/third_party/flux/__pycache__/util.cpython-310.pyc b/third_party/flux/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f4f9ddd0e2cd2a95c29809e64ffbbb639249e79 Binary files /dev/null and b/third_party/flux/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/flux/__pycache__/xflux_pipeline.cpython-310.pyc b/third_party/flux/__pycache__/xflux_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb755fe1d5906f0a6d56928682ebf7f6812535ad Binary files /dev/null and b/third_party/flux/__pycache__/xflux_pipeline.cpython-310.pyc differ diff --git a/third_party/flux/annotator/__pycache__/util.cpython-310.pyc b/third_party/flux/annotator/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..661767f8b715bb270e44635c6dea85635495514f Binary files /dev/null and b/third_party/flux/annotator/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/flux/annotator/canny/__init__.py b/third_party/flux/annotator/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b --- /dev/null +++ b/third_party/flux/annotator/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/third_party/flux/annotator/canny/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/canny/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c25b2c33ab826d34ce8376ddae350ae2d287735 Binary files /dev/null and b/third_party/flux/annotator/canny/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/ckpts/ckpts.txt b/third_party/flux/annotator/ckpts/ckpts.txt new file mode 100644 index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b --- /dev/null +++ b/third_party/flux/annotator/ckpts/ckpts.txt @@ -0,0 +1 @@ +Weights here. \ No newline at end of file diff --git a/third_party/flux/annotator/dwpose/__init__.py b/third_party/flux/annotator/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6e172d05c9de3f1cdd61e330ad8d6dde46dfdd --- /dev/null +++ b/third_party/flux/annotator/dwpose/__init__.py @@ -0,0 +1,68 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .wholebody import Wholebody + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + canvas = util.draw_bodypose(canvas, candidate, subset) + + canvas = util.draw_handpose(canvas, hands) + + canvas = util.draw_facepose(canvas, faces) + + return canvas + + +class DWposeDetector: + def __init__(self, device): + + self.pose_estimation = Wholebody(device) + + def __call__(self, oriImg): + oriImg = oriImg.copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.pose_estimation(oriImg) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + return draw_pose(pose, H, W) diff --git a/third_party/flux/annotator/dwpose/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/dwpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18f80f6ef3e364005f7ca59bc6ac665d8cae4375 Binary files /dev/null and b/third_party/flux/annotator/dwpose/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc b/third_party/flux/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f00398dc949dff0623a655a919802ddb166e839e Binary files /dev/null and b/third_party/flux/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc differ diff --git a/third_party/flux/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc b/third_party/flux/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2586fc953359bfa3a5d5bbb3eada52520ded1eb4 Binary files /dev/null and b/third_party/flux/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc differ diff --git a/third_party/flux/annotator/dwpose/__pycache__/util.cpython-310.pyc b/third_party/flux/annotator/dwpose/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..988374f95a5985940a5c7bd9c6a334cf571c53bd Binary files /dev/null and b/third_party/flux/annotator/dwpose/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/flux/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc b/third_party/flux/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0520ae9cfc6f915748963d929ff260902465ad9b Binary files /dev/null and b/third_party/flux/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc differ diff --git a/third_party/flux/annotator/dwpose/onnxdet.py b/third_party/flux/annotator/dwpose/onnxdet.py new file mode 100644 index 0000000000000000000000000000000000000000..e0411c96a5eef41e981bde5481ef7786b242f1fa --- /dev/null +++ b/third_party/flux/annotator/dwpose/onnxdet.py @@ -0,0 +1,125 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + output = session.run(None, ort_inputs) + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/third_party/flux/annotator/dwpose/onnxpose.py b/third_party/flux/annotator/dwpose/onnxpose.py new file mode 100644 index 0000000000000000000000000000000000000000..79cd4a06241123af81ea22446a4ca8816716443f --- /dev/null +++ b/third_party/flux/annotator/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/third_party/flux/annotator/dwpose/util.py b/third_party/flux/annotator/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..73d7d0153b38d143eb8090e07a9784a274b619ed --- /dev/null +++ b/third_party/flux/annotator/dwpose/util.py @@ -0,0 +1,297 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/third_party/flux/annotator/dwpose/wholebody.py b/third_party/flux/annotator/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..d73f19d61c238c47cf7de98d01385b2150a5361f --- /dev/null +++ b/third_party/flux/annotator/dwpose/wholebody.py @@ -0,0 +1,48 @@ +import cv2 +import numpy as np + +import onnxruntime as ort +from huggingface_hub import hf_hub_download +from .onnxdet import inference_detector +from .onnxpose import inference_pose + + +class Wholebody: + def __init__(self, device="cuda:0"): + providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] + onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx") + onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx") + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + diff --git a/third_party/flux/annotator/hed/__init__.py b/third_party/flux/annotator/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70d11c9e62133149d38091a597a1b6691ff8f1b6 --- /dev/null +++ b/third_party/flux/annotator/hed/__init__.py @@ -0,0 +1,95 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import cv2 +import torch +import numpy as np + +from huggingface_hub import hf_hub_download +from einops import rearrange +from ...annotator.util import annotator_ckpts_path + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + + +class HEDdetector: + def __init__(self): + modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") + if not os.path.exists(modelpath): + modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth") + self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() + self.netNetwork.load_state_dict(torch.load(modelpath)) + + def __call__(self, input_image): + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image.copy()).float().cuda() + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + return edge + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z diff --git a/third_party/flux/annotator/hed/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/hed/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e93afb2a9f585e3500ae94a479e619a5c466873b Binary files /dev/null and b/third_party/flux/annotator/hed/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/LICENSE b/third_party/flux/annotator/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/third_party/flux/annotator/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/flux/annotator/midas/__init__.py b/third_party/flux/annotator/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36789767f35bcc169c2cbf096e2747539df4f14d --- /dev/null +++ b/third_party/flux/annotator/midas/__init__.py @@ -0,0 +1,42 @@ +# Midas Depth Estimation +# From https://github.com/isl-org/MiDaS +# MIT LICENSE + +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self): + self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image, normal_image diff --git a/third_party/flux/annotator/midas/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/midas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f88a2aab13ec985215029ad8e4d8b2680e8bfea0 Binary files /dev/null and b/third_party/flux/annotator/midas/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/__pycache__/api.cpython-310.pyc b/third_party/flux/annotator/midas/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45d0e8c9ffb371aa738251abf9f501722d9edc07 Binary files /dev/null and b/third_party/flux/annotator/midas/__pycache__/api.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/api.py b/third_party/flux/annotator/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..6226a39d80de978162a7238cec1c4d4a64bacbe9 --- /dev/null +++ b/third_party/flux/annotator/midas/api.py @@ -0,0 +1,168 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from huggingface_hub import hf_hub_download + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from ...annotator.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt") + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/third_party/flux/annotator/midas/midas/__init__.py b/third_party/flux/annotator/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/flux/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42b12bfc2daa55fcc6d541002e6704c2a6c896d7 Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb73a9deb6d4e599a91fffb401ff06d24752306a Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30b672c7d4a0a6f91141be430e93f767b6019c86 Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9515765ccfa7c363d8db36de4528fde761c6645f Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..961a9bc4736eedad970cd9e4022e9ccea82809c9 Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f60678bcb38692347fd8a09d3504f25c560f965a Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e18485af45255a58ec2762347581730ede27e6 Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/__pycache__/vit.cpython-310.pyc b/third_party/flux/annotator/midas/midas/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b54edc61f4fde1132a7c697b4195c0fa7741928a Binary files /dev/null and b/third_party/flux/annotator/midas/midas/__pycache__/vit.cpython-310.pyc differ diff --git a/third_party/flux/annotator/midas/midas/base_model.py b/third_party/flux/annotator/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/third_party/flux/annotator/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/third_party/flux/annotator/midas/midas/blocks.py b/third_party/flux/annotator/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/third_party/flux/annotator/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/third_party/flux/annotator/midas/midas/dpt_depth.py b/third_party/flux/annotator/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/third_party/flux/annotator/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/third_party/flux/annotator/midas/midas/midas_net.py b/third_party/flux/annotator/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/third_party/flux/annotator/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/third_party/flux/annotator/midas/midas/midas_net_custom.py b/third_party/flux/annotator/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/third_party/flux/annotator/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/third_party/flux/annotator/midas/midas/transforms.py b/third_party/flux/annotator/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/third_party/flux/annotator/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/flux/annotator/midas/midas/vit.py b/third_party/flux/annotator/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/third_party/flux/annotator/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/third_party/flux/annotator/midas/utils.py b/third_party/flux/annotator/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/third_party/flux/annotator/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/third_party/flux/annotator/mlsd/LICENSE b/third_party/flux/annotator/mlsd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363 --- /dev/null +++ b/third_party/flux/annotator/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/third_party/flux/annotator/mlsd/__init__.py b/third_party/flux/annotator/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5028aef051a1e67caae1f0d23a0b0dbca883a7f8 --- /dev/null +++ b/third_party/flux/annotator/mlsd/__init__.py @@ -0,0 +1,40 @@ +# MLSD Line Detection +# From https://github.com/navervision/mlsd +# Apache-2.0 license + +import cv2 +import numpy as np +import torch +import os + +from einops import rearrange +from huggingface_hub import hf_hub_download +from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + +from ...annotator.util import annotator_ckpts_path + + +class MLSDdetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth") + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + self.model = model.cuda().eval() + + def __call__(self, input_image, thr_v, thr_d): + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + return img_output[:, :, 0] diff --git a/third_party/flux/annotator/mlsd/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/mlsd/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43154fc3f8dc6fd62ccb55167e4e1366c93c4ec8 Binary files /dev/null and b/third_party/flux/annotator/mlsd/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/mlsd/__pycache__/utils.cpython-310.pyc b/third_party/flux/annotator/mlsd/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11313581d9a7f41e1d04eb56c2391af353537bfd Binary files /dev/null and b/third_party/flux/annotator/mlsd/__pycache__/utils.cpython-310.pyc differ diff --git a/third_party/flux/annotator/mlsd/models/__pycache__/mbv2_mlsd_large.cpython-310.pyc b/third_party/flux/annotator/mlsd/models/__pycache__/mbv2_mlsd_large.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1442f5b10548f750e33346a4f0533ddc81426681 Binary files /dev/null and b/third_party/flux/annotator/mlsd/models/__pycache__/mbv2_mlsd_large.cpython-310.pyc differ diff --git a/third_party/flux/annotator/mlsd/models/__pycache__/mbv2_mlsd_tiny.cpython-310.pyc b/third_party/flux/annotator/mlsd/models/__pycache__/mbv2_mlsd_tiny.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0d853c08518a17cbf9debf3a6d80124c2f6eb0e Binary files /dev/null and b/third_party/flux/annotator/mlsd/models/__pycache__/mbv2_mlsd_tiny.cpython-310.pyc differ diff --git a/third_party/flux/annotator/mlsd/models/mbv2_mlsd_large.py b/third_party/flux/annotator/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603 --- /dev/null +++ b/third_party/flux/annotator/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x \ No newline at end of file diff --git a/third_party/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py b/third_party/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83 --- /dev/null +++ b/third_party/flux/annotator/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x \ No newline at end of file diff --git a/third_party/flux/annotator/mlsd/utils.py b/third_party/flux/annotator/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..848a9fd7f9ff9d909f18c5d3ff55786c5a4b547a --- /dev/null +++ b/third_party/flux/annotator/mlsd/utils.py @@ -0,0 +1,580 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().to("cuda:4") + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/third_party/flux/annotator/tile/__init__.py b/third_party/flux/annotator/tile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c96899289d6e04796140cd7eff9c08e5f693af02 --- /dev/null +++ b/third_party/flux/annotator/tile/__init__.py @@ -0,0 +1,26 @@ +import random +import cv2 +from .guided_filter import FastGuidedFilter + + +class TileDetector: + # https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0 + def __init__(self): + pass + + def __call__(self, image): + blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0] + radius = random.sample([i for i in range(1, 40, 2)], k=1)[0] + eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0] + scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0] + + ksize = int(blur_strength) + if ksize % 2 == 0: + ksize += 1 + + if random.random() > 0.5: + image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2) + if random.random() > 0.5: + filter = FastGuidedFilter(image, radius, eps, scale_factor) + image = filter.filter(image) + return image diff --git a/third_party/flux/annotator/tile/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/tile/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39b58b72cc79bc08e31c0438d05ad70bbe61dca4 Binary files /dev/null and b/third_party/flux/annotator/tile/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/tile/__pycache__/guided_filter.cpython-310.pyc b/third_party/flux/annotator/tile/__pycache__/guided_filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..416f88ac5dcee11ab5092a46e9c17d1bccd61c5d Binary files /dev/null and b/third_party/flux/annotator/tile/__pycache__/guided_filter.cpython-310.pyc differ diff --git a/third_party/flux/annotator/tile/guided_filter.py b/third_party/flux/annotator/tile/guided_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7172a5e144672eea26551ef75f70b90a2f96d6 --- /dev/null +++ b/third_party/flux/annotator/tile/guided_filter.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +## @package guided_filter.core.filters +# +# Implementation of guided filter. +# * GuidedFilter: Original guided filter. +# * FastGuidedFilter: Fast version of the guided filter. +# @author tody +# @date 2015/08/26 + +import numpy as np +import cv2 + +## Convert image into float32 type. +def to32F(img): + if img.dtype == np.float32: + return img + return (1.0 / 255.0) * np.float32(img) + +## Convert image into uint8 type. +def to8U(img): + if img.dtype == np.uint8: + return img + return np.clip(np.uint8(255.0 * img), 0, 255) + +## Return if the input image is gray or not. +def _isGray(I): + return len(I.shape) == 2 + + +## Return down sampled image. +# @param scale (w/s, h/s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _downSample(I, scale=4, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST) + + +## Return up sampled image. +# @param scale (w*s, h*s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _upSample(I, scale=2, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) + +## Fast guide filter. +class FastGuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + # @param scale Down sampled scale. + def __init__(self, I, radius=5, epsilon=0.4, scale=4): + I_32F = to32F(I) + self._I = I_32F + h, w = I.shape[:2] + + I_sub = _downSample(I_32F, scale) + + self._I_sub = I_sub + radius = int(radius / scale) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + shape_original = p.shape[:2] + + p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2]) + + if _isGray(p_sub): + return self._filterGray(p_sub, shape_original) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original) + return to8U(q) + + def _filterGray(self, p_sub, shape_original): + ab_sub = self._guided_filter._computeCoefficients(p_sub) + ab = [_upSample(abi, shape=shape_original) for abi in ab_sub] + return self._guided_filter._computeOutput(ab, self._I) + + +## Guide filter. +class GuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + I_32F = to32F(I) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return to8U(self._guided_filter.filter(p)) + + +## Common parts of guided filter. +# +# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor. +# Based on guided_filter._computeCoefficients, guided_filter._computeOutput, +# GuidedFilterCommon.filter computes filtered image for color and gray. +class GuidedFilterCommon: + def __init__(self, guided_filter): + self._guided_filter = guided_filter + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + if _isGray(p_32F): + return self._filterGray(p_32F) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_32F[:, :, ci]) + return q + + def _filterGray(self, p): + ab = self._guided_filter._computeCoefficients(p) + return self._guided_filter._computeOutput(ab, self._guided_filter._I) + + +## Guided filter for gray guidance image. +class GuidedFilterGray: + # @param I Input gray guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + self._I_mean = cv2.blur(I, (r, r)) + I_mean_sq = cv2.blur(I ** 2, (r, r)) + self._I_var = I_mean_sq - self._I_mean ** 2 + + def _computeCoefficients(self, p): + r = self._radius + p_mean = cv2.blur(p, (r, r)) + p_cov = p_mean - self._I_mean * p_mean + a = p_cov / (self._I_var + self._epsilon) + b = p_mean - a * self._I_mean + a_mean = cv2.blur(a, (r, r)) + b_mean = cv2.blur(b, (r, r)) + return a_mean, b_mean + + def _computeOutput(self, ab, I): + a_mean, b_mean = ab + return a_mean * I + b_mean + + +## Guided filter for color guidance image. +class GuidedFilterColor: + # @param I Input color guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.2): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + eps = self._epsilon + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + self._Ir_mean = cv2.blur(Ir, (r, r)) + self._Ig_mean = cv2.blur(Ig, (r, r)) + self._Ib_mean = cv2.blur(Ib, (r, r)) + + Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps + Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean + Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean + Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps + Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean + Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps + + Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var + Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var + Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var + Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var + Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var + Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var + + I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var + Irr_inv /= I_cov + Irg_inv /= I_cov + Irb_inv /= I_cov + Igg_inv /= I_cov + Igb_inv /= I_cov + Ibb_inv /= I_cov + + self._Irr_inv = Irr_inv + self._Irg_inv = Irg_inv + self._Irb_inv = Irb_inv + self._Igg_inv = Igg_inv + self._Igb_inv = Igb_inv + self._Ibb_inv = Ibb_inv + + def _computeCoefficients(self, p): + r = self._radius + I = self._I + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + p_mean = cv2.blur(p, (r, r)) + + Ipr_mean = cv2.blur(Ir * p, (r, r)) + Ipg_mean = cv2.blur(Ig * p, (r, r)) + Ipb_mean = cv2.blur(Ib * p, (r, r)) + + Ipr_cov = Ipr_mean - self._Ir_mean * p_mean + Ipg_cov = Ipg_mean - self._Ig_mean * p_mean + Ipb_cov = Ipb_mean - self._Ib_mean * p_mean + + ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov + ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov + ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov + b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean + + ar_mean = cv2.blur(ar, (r, r)) + ag_mean = cv2.blur(ag, (r, r)) + ab_mean = cv2.blur(ab, (r, r)) + b_mean = cv2.blur(b, (r, r)) + + return ar_mean, ag_mean, ab_mean, b_mean + + def _computeOutput(self, ab, I): + ar_mean, ag_mean, ab_mean, b_mean = ab + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + q = (ar_mean * Ir + + ag_mean * Ig + + ab_mean * Ib + + b_mean) + + return q diff --git a/third_party/flux/annotator/util.py b/third_party/flux/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05 --- /dev/null +++ b/third_party/flux/annotator/util.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/third_party/flux/annotator/zoe/LICENSE b/third_party/flux/annotator/zoe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab --- /dev/null +++ b/third_party/flux/annotator/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/__init__.py b/third_party/flux/annotator/zoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7628090932e35bdd71d041069ae62f6a731f60d4 --- /dev/null +++ b/third_party/flux/annotator/zoe/__init__.py @@ -0,0 +1,48 @@ +# ZoeDepth +# https://github.com/isl-org/ZoeDepth + +import os +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.utils.config import get_config +from ...annotator.util import annotator_ckpts_path +from huggingface_hub import hf_hub_download + + +class ZoeDetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "ZoeD_M12_N.pt") + conf = get_config("zoedepth", "infer") + model = ZoeDepth.build_from_config(conf) + model.load_state_dict(torch.load(model_path)['model'], strict=False) + model = model.cuda() + model.device = 'cuda' + model.eval() + self.model = model + + def __call__(self, input_image): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + return depth_image diff --git a/third_party/flux/annotator/zoe/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/zoe/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da199faa0fae8b85c816ad4aa3fe1ee0e7ecbe81 Binary files /dev/null and b/third_party/flux/annotator/zoe/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/data/__init__.py b/third_party/flux/annotator/zoe/zoedepth/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/flux/annotator/zoe/zoedepth/data/data_mono.py b/third_party/flux/annotator/zoe/zoedepth/data/data_mono.py new file mode 100644 index 0000000000000000000000000000000000000000..80a8486f239a35331df553f490e213f9bf71e735 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/data_mono.py @@ -0,0 +1,573 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee + +import itertools +import os +import random + +import numpy as np +import cv2 +import torch +import torch.nn as nn +import torch.utils.data.distributed +from zoedepth.utils.easydict import EasyDict as edict +from PIL import Image, ImageOps +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + +from zoedepth.utils.config import change_dataset + +from .ddad import get_ddad_loader +from .diml_indoor_test import get_diml_indoor_loader +from .diml_outdoor_test import get_diml_outdoor_loader +from .diode import get_diode_loader +from .hypersim import get_hypersim_loader +from .ibims import get_ibims_loader +from .sun_rgbd_loader import get_sunrgbd_loader +from .vkitti import get_vkitti_loader +from .vkitti2 import get_vkitti2_loader + +from .preprocess import CropParams, get_white_border, get_black_border + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def preprocessing_transforms(mode, **kwargs): + return transforms.Compose([ + ToTensor(mode=mode, **kwargs) + ]) + + +class DepthDataLoader(object): + def __init__(self, config, mode, device='cpu', transform=None, **kwargs): + """ + Data loader for depth datasets + + Args: + config (dict): Config dictionary. Refer to utils/config.py + mode (str): "train" or "online_eval" + device (str, optional): Device to load the data on. Defaults to 'cpu'. + transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None. + """ + + self.config = config + + if config.dataset == 'ibims': + self.data = get_ibims_loader(config, batch_size=1, num_workers=1) + return + + if config.dataset == 'sunrgbd': + self.data = get_sunrgbd_loader( + data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_indoor': + self.data = get_diml_indoor_loader( + data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_outdoor': + self.data = get_diml_outdoor_loader( + data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1) + return + + if "diode" in config.dataset: + self.data = get_diode_loader( + config[config.dataset+"_root"], batch_size=1, num_workers=1) + return + + if config.dataset == 'hypersim_test': + self.data = get_hypersim_loader( + config.hypersim_test_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti': + self.data = get_vkitti_loader( + config.vkitti_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti2': + self.data = get_vkitti2_loader( + config.vkitti2_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'ddad': + self.data = get_ddad_loader(config.ddad_root, resize_shape=( + 352, 1216), batch_size=1, num_workers=1) + return + + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + + if transform is None: + transform = preprocessing_transforms(mode, size=img_size) + + if mode == 'train': + + Dataset = DataLoadPreprocess + self.training_samples = Dataset( + config, mode, transform=transform, device=device) + + if config.distributed: + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_samples) + else: + self.train_sampler = None + + self.data = DataLoader(self.training_samples, + batch_size=config.batch_size, + shuffle=(self.train_sampler is None), + num_workers=config.workers, + pin_memory=True, + persistent_workers=True, + # prefetch_factor=2, + sampler=self.train_sampler) + + elif mode == 'online_eval': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + if config.distributed: # redundant. here only for readability and to be more explicit + # Give whole test set to all processes (and report evaluation only on one) regardless + self.eval_sampler = None + else: + self.eval_sampler = None + self.data = DataLoader(self.testing_samples, 1, + shuffle=kwargs.get("shuffle_test", False), + num_workers=1, + pin_memory=False, + sampler=self.eval_sampler) + + elif mode == 'test': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + self.data = DataLoader(self.testing_samples, + 1, shuffle=False, num_workers=1) + + else: + print( + 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) + + +def repetitive_roundrobin(*iterables): + """ + cycles through iterables but sample wise + first yield first sample from first iterable then first sample from second iterable and so on + then second sample from first iterable then second sample from second iterable and so on + + If one iterable is shorter than the others, it is repeated until all iterables are exhausted + repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E + """ + # Repetitive roundrobin + iterables_ = [iter(it) for it in iterables] + exhausted = [False] * len(iterables) + while not all(exhausted): + for i, it in enumerate(iterables_): + try: + yield next(it) + except StopIteration: + exhausted[i] = True + iterables_[i] = itertools.cycle(iterables[i]) + # First elements may get repeated if one iterable is shorter than the others + yield next(iterables_[i]) + + +class RepetitiveRoundRobinDataLoader(object): + def __init__(self, *dataloaders): + self.dataloaders = dataloaders + + def __iter__(self): + return repetitive_roundrobin(*self.dataloaders) + + def __len__(self): + # First samples get repeated, thats why the plus one + return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1) + + +class MixedNYUKITTI(object): + def __init__(self, config, mode, device='cpu', **kwargs): + config = edict(config) + config.workers = config.workers // 2 + self.config = config + nyu_conf = change_dataset(edict(config), 'nyu') + kitti_conf = change_dataset(edict(config), 'kitti') + + # make nyu default for testing + self.config = config = nyu_conf + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + if mode == 'train': + nyu_loader = DepthDataLoader( + nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + kitti_loader = DepthDataLoader( + kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + # It has been changed to repetitive roundrobin + self.data = RepetitiveRoundRobinDataLoader( + nyu_loader, kitti_loader) + else: + self.data = DepthDataLoader(nyu_conf, mode, device=device).data + + +def remove_leading_slash(s): + if s[0] == '/' or s[0] == '\\': + return s[1:] + return s + + +class CachedReader: + def __init__(self, shared_dict=None): + if shared_dict: + self._cache = shared_dict + else: + self._cache = {} + + def open(self, fpath): + im = self._cache.get(fpath, None) + if im is None: + im = self._cache[fpath] = Image.open(fpath) + return im + + +class ImReader: + def __init__(self): + pass + + # @cache + def open(self, fpath): + return Image.open(fpath) + + +class DataLoadPreprocess(Dataset): + def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs): + self.config = config + if mode == 'online_eval': + with open(config.filenames_file_eval, 'r') as f: + self.filenames = f.readlines() + else: + with open(config.filenames_file, 'r') as f: + self.filenames = f.readlines() + + self.mode = mode + self.transform = transform + self.to_tensor = ToTensor(mode) + self.is_for_online_eval = is_for_online_eval + if config.use_shared_dict: + self.reader = CachedReader(config.shared_dict) + else: + self.reader = ImReader() + + def postprocess(self, sample): + return sample + + def __getitem__(self, idx): + sample_path = self.filenames[idx] + focal = float(sample_path.split()[2]) + sample = {} + + if self.mode == 'train': + if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[3])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[4])) + else: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[0])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[1])) + + image = self.reader.open(image_path) + depth_gt = self.reader.open(depth_path) + w, h = image.size + + if self.config.do_kb_crop: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth_gt = depth_gt.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + + # Avoid blank boundaries due to pixel registration? + # Train images have white border. Test images have black border. + if self.config.dataset == 'nyu' and self.config.avoid_boundary: + # print("Avoiding Blank Boundaries!") + # We just crop and pad again with reflect padding to original size + # original_size = image.size + crop_params = get_white_border(np.array(image, dtype=np.uint8)) + image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + + # Use reflect padding to fill the blank + image = np.array(image) + image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect') + image = Image.fromarray(image) + + depth_gt = np.array(depth_gt) + depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0) + depth_gt = Image.fromarray(depth_gt) + + + if self.config.do_random_rotate and (self.config.aug): + random_angle = (random.random() - 0.5) * 2 * self.config.degree + image = self.rotate_image(image, random_angle) + depth_gt = self.rotate_image( + depth_gt, random_angle, flag=Image.NEAREST) + + image = np.asarray(image, dtype=np.float32) / 255.0 + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + if self.config.aug and (self.config.random_crop): + image, depth_gt = self.random_crop( + image, depth_gt, self.config.input_height, self.config.input_width) + + if self.config.aug and self.config.random_translate: + # print("Random Translation!") + image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation) + + image, depth_gt = self.train_preprocess(image, depth_gt) + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample = {'image': image, 'depth': depth_gt, 'focal': focal, + 'mask': mask, **sample} + + else: + if self.mode == 'online_eval': + data_path = self.config.data_path_eval + else: + data_path = self.config.data_path + + image_path = os.path.join( + data_path, remove_leading_slash(sample_path.split()[0])) + image = np.asarray(self.reader.open(image_path), + dtype=np.float32) / 255.0 + + if self.mode == 'online_eval': + gt_path = self.config.gt_path_eval + depth_path = os.path.join( + gt_path, remove_leading_slash(sample_path.split()[1])) + has_valid_depth = False + try: + depth_gt = self.reader.open(depth_path) + has_valid_depth = True + except IOError: + depth_gt = False + # print('Missing gt for {}'.format(image_path)) + + if has_valid_depth: + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + mask = np.logical_and( + depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...] + else: + mask = False + + if self.config.do_kb_crop: + height = image.shape[0] + width = image.shape[1] + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + image = image[top_margin:top_margin + 352, + left_margin:left_margin + 1216, :] + if self.mode == 'online_eval' and has_valid_depth: + depth_gt = depth_gt[top_margin:top_margin + + 352, left_margin:left_margin + 1216, :] + + if self.mode == 'online_eval': + sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1], + 'mask': mask} + else: + sample = {'image': image, 'focal': focal} + + if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']): + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample['mask'] = mask + + if self.transform: + sample = self.transform(sample) + + sample = self.postprocess(sample) + sample['dataset'] = self.config.dataset + sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]} + + return sample + + def rotate_image(self, image, angle, flag=Image.BILINEAR): + result = image.rotate(angle, resample=flag) + return result + + def random_crop(self, img, depth, height, width): + assert img.shape[0] >= height + assert img.shape[1] >= width + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + x = random.randint(0, img.shape[1] - width) + y = random.randint(0, img.shape[0] - height) + img = img[y:y + height, x:x + width, :] + depth = depth[y:y + height, x:x + width, :] + + return img, depth + + def random_translate(self, img, depth, max_t=20): + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + p = self.config.translate_prob + do_translate = random.random() + if do_translate > p: + return img, depth + x = random.randint(-max_t, max_t) + y = random.randint(-max_t, max_t) + M = np.float32([[1, 0, x], [0, 1, y]]) + # print(img.shape, depth.shape) + img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) + depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0])) + depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it + # print("after", img.shape, depth.shape) + return img, depth + + def train_preprocess(self, image, depth_gt): + if self.config.aug: + # Random flipping + do_flip = random.random() + if do_flip > 0.5: + image = (image[:, ::-1, :]).copy() + depth_gt = (depth_gt[:, ::-1, :]).copy() + + # Random gamma, brightness, color augmentation + do_augment = random.random() + if do_augment > 0.5: + image = self.augment_image(image) + + return image, depth_gt + + def augment_image(self, image): + # gamma augmentation + gamma = random.uniform(0.9, 1.1) + image_aug = image ** gamma + + # brightness augmentation + if self.config.dataset == 'nyu': + brightness = random.uniform(0.75, 1.25) + else: + brightness = random.uniform(0.9, 1.1) + image_aug = image_aug * brightness + + # color augmentation + colors = np.random.uniform(0.9, 1.1, size=3) + white = np.ones((image.shape[0], image.shape[1])) + color_image = np.stack([white * colors[i] for i in range(3)], axis=2) + image_aug *= color_image + image_aug = np.clip(image_aug, 0, 1) + + return image_aug + + def __len__(self): + return len(self.filenames) + + +class ToTensor(object): + def __init__(self, mode, do_normalize=False, size=None): + self.mode = mode + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() + self.size = size + if size is not None: + self.resize = transforms.Resize(size=size) + else: + self.resize = nn.Identity() + + def __call__(self, sample): + image, focal = sample['image'], sample['focal'] + image = self.to_tensor(image) + image = self.normalize(image) + image = self.resize(image) + + if self.mode == 'test': + return {'image': image, 'focal': focal} + + depth = sample['depth'] + if self.mode == 'train': + depth = self.to_tensor(depth) + return {**sample, 'image': image, 'depth': depth, 'focal': focal} + else: + has_valid_depth = sample['has_valid_depth'] + image = self.resize(image) + return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample['image_path'], 'depth_path': sample['depth_path']} + + def to_tensor(self, pic): + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError( + 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img diff --git a/third_party/flux/annotator/zoe/zoedepth/data/ddad.py b/third_party/flux/annotator/zoe/zoedepth/data/ddad.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd0492bdec767685d3a21992b4a26e62d002d97 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/ddad.py @@ -0,0 +1,117 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self, resize_shape): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(resize_shape) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "ddad"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DDAD(Dataset): + def __init__(self, data_dir_root, resize_shape): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join(data_dir_root, '*.png')) + self.depth_files = [r.replace("_rgb.png", "_depth.npy") + for r in self.image_files] + self.transform = ToTensor(resize_shape) + + def __getitem__(self, idx): + + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs): + dataset = DDAD(data_dir_root, resize_shape) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/third_party/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py b/third_party/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f720ad9aefaee78ef4ec363dfef0f82ace850a6d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/diml_indoor_test.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diml_indoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Indoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "LR", '*', 'color', '*.png')) + self.depth_files = [r.replace("color", "depth_filled").replace( + "_c.png", "_depth_filled.png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Indoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR") +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR") diff --git a/third_party/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py b/third_party/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8670b48f5febafb819dac22848ad79ccb5dd5ae4 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/diml_outdoor_test.py @@ -0,0 +1,114 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Outdoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "*", 'outleft', '*.png')) + self.depth_files = [r.replace("outleft", "depthmap") + for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth, dataset="diml_outdoor") + + # return sample + return self.transform(sample) + + def __len__(self): + return len(self.image_files) + + +def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Outdoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR") +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR") diff --git a/third_party/flux/annotator/zoe/zoedepth/data/diode.py b/third_party/flux/annotator/zoe/zoedepth/data/diode.py new file mode 100644 index 0000000000000000000000000000000000000000..1510c87116b8f70ce2e1428873a8e4da042bee23 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/diode.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(480) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diode"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIODE(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /scene_#/scan_#/*.png + self.image_files = glob.glob( + os.path.join(data_dir_root, '*', '*', '*.png')) + self.depth_files = [r.replace(".png", "_depth.npy") + for r in self.image_files] + self.depth_mask_files = [ + r.replace(".png", "_depth_mask.npy") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + depth_mask_path = self.depth_mask_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # in meters + valid = np.load(depth_mask_path) # binary + + # depth[depth > 8] = -1 + # depth = depth[..., None] + + sample = dict(image=image, depth=depth, valid=valid) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diode_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIODE(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diode_loader(data_dir_root="datasets/diode/val/outdoor") diff --git a/third_party/flux/annotator/zoe/zoedepth/data/hypersim.py b/third_party/flux/annotator/zoe/zoedepth/data/hypersim.py new file mode 100644 index 0000000000000000000000000000000000000000..4334198971830200f72ea2910d03f4c7d6a43334 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/hypersim.py @@ -0,0 +1,138 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import glob +import os + +import h5py +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +def hypersim_distance_to_depth(npyDistance): + intWidth, intHeight, fltFocal = 1024, 768, 886.81 + + npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape( + 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None] + npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5, + intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None] + npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32) + npyImageplane = np.concatenate( + [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2) + + npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal + return npyDepth + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "hypersim"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class HyperSim(Dataset): + def __init__(self, data_dir_root): + # image paths are of the form //images/scene_cam_#_final_preview/*.tonemap.jpg + # depth paths are of the form //images/scene_cam_#_final_preview/*.depth_meters.hdf5 + self.image_files = glob.glob(os.path.join( + data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg')) + self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace( + ".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + + # depth from hdf5 + depth_fd = h5py.File(depth_path, "r") + # in meters (Euclidean distance) + distance_meters = np.array(depth_fd['dataset']) + depth = hypersim_distance_to_depth( + distance_meters) # in meters (planar depth) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs): + dataset = HyperSim(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/third_party/flux/annotator/zoe/zoedepth/data/ibims.py b/third_party/flux/annotator/zoe/zoedepth/data/ibims.py new file mode 100644 index 0000000000000000000000000000000000000000..b66abfabcf4cfc617d4a60ec818780c3548d9920 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/ibims.py @@ -0,0 +1,81 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms as T + + +class iBims(Dataset): + def __init__(self, config): + root_folder = config.ibims_root + with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f: + imglist = f.read().split() + + samples = [] + for basename in imglist: + img_path = os.path.join(root_folder, 'rgb', basename + ".png") + depth_path = os.path.join(root_folder, 'depth', basename + ".png") + valid_mask_path = os.path.join( + root_folder, 'mask_invalid', basename+".png") + transp_mask_path = os.path.join( + root_folder, 'mask_transp', basename+".png") + + samples.append( + (img_path, depth_path, valid_mask_path, transp_mask_path)) + + self.samples = samples + # self.normalize = T.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __getitem__(self, idx): + img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx] + + img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype=np.uint16).astype('float')*50.0/65535 + + mask_valid = np.asarray(Image.open(valid_mask_path)) + mask_transp = np.asarray(Image.open(transp_mask_path)) + + # depth = depth * mask_valid * mask_transp + depth = np.where(mask_valid * mask_transp, depth, -1) + + img = torch.from_numpy(img).permute(2, 0, 1) + img = self.normalize(img) + depth = torch.from_numpy(depth).unsqueeze(0) + return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims') + + def __len__(self): + return len(self.samples) + + +def get_ibims_loader(config, batch_size=1, **kwargs): + dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs) + return dataloader diff --git a/third_party/flux/annotator/zoe/zoedepth/data/preprocess.py b/third_party/flux/annotator/zoe/zoedepth/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e08cc309dc823ae6efd7cda8db9eb37130dc5499 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/preprocess.py @@ -0,0 +1,154 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +from dataclasses import dataclass +from typing import Tuple, List + +# dataclass to store the crop parameters +@dataclass +class CropParams: + top: int + bottom: int + left: int + right: int + + + +def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams: + gray_image = np.mean(rgb_image, axis=channel_axis) + h, w = gray_image.shape + + + def num_value_pixels(arr): + return np.sum(np.abs(arr - value) < level_diff_threshold) + + def is_above_tolerance(arr, total_pixels): + return (num_value_pixels(arr) / total_pixels) > tolerance + + # Crop top border until number of value pixels become below tolerance + top = min_border + while is_above_tolerance(gray_image[top, :], w) and top < h-1: + top += 1 + if top > cut_off: + break + + # Crop bottom border until number of value pixels become below tolerance + bottom = h - min_border + while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0: + bottom -= 1 + if h - bottom > cut_off: + break + + # Crop left border until number of value pixels become below tolerance + left = min_border + while is_above_tolerance(gray_image[:, left], h) and left < w-1: + left += 1 + if left > cut_off: + break + + # Crop right border until number of value pixels become below tolerance + right = w - min_border + while is_above_tolerance(gray_image[:, right], h) and right > 0: + right -= 1 + if w - right > cut_off: + break + + + return CropParams(top, bottom, left, right) + + +def get_white_border(rgb_image, value=255, **kwargs) -> CropParams: + """Crops the white border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + Returns: + Crop parameters. + """ + if value == 255: + # assert range of values in rgb image is [0, 255] + assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]." + assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]." + elif value == 1: + # assert range of values in rgb image is [0, 1] + assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]." + + return get_border_params(rgb_image, value=value, **kwargs) + +def get_black_border(rgb_image, **kwargs) -> CropParams: + """Crops the black border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + + Returns: + Crop parameters. + """ + + return get_border_params(rgb_image, value=0, **kwargs) + +def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray: + """Crops the image according to the crop parameters. + + Args: + image: RGB or depth image, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped image. + """ + return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right] + +def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]: + """Crops the images according to the crop parameters. + + Args: + images: RGB or depth images, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped images. + """ + return tuple(crop_image(image, crop_params) for image in images) + +def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]: + """Crops the white and black border of the RGB and depth images. + + Args: + rgb: RGB image, shape (H, W, 3). This image is used to determine the border. + other_images: The other images to crop according to the border of the RGB image. + Returns: + Cropped RGB and other images. + """ + # crop black border + crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params) + + # crop white border + crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(*cropped_images, crop_params=crop_params) + + return cropped_images + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py b/third_party/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2bdb9aefe68ca4439f41eff3bba722c49fb976 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/sun_rgbd_loader.py @@ -0,0 +1,106 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "sunrgbd"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class SunRGBD(Dataset): + def __init__(self, data_dir_root): + # test_file_dirs = loadmat(train_test_file)['alltest'].squeeze() + # all_test = [t[0].replace("/n/fs/sun3d/data/", "") for t in test_file_dirs] + # self.all_test = [os.path.join(data_dir_root, t) for t in all_test] + import glob + self.image_files = glob.glob( + os.path.join(data_dir_root, 'rgb', 'rgb', '*')) + self.depth_files = [ + r.replace("rgb/rgb", "gt/gt").replace("jpg", "png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), dtype='uint16') / 1000.0 + depth[depth > 8] = -1 + depth = depth[..., None] + return self.transform(dict(image=image, depth=depth)) + + def __len__(self): + return len(self.image_files) + + +def get_sunrgbd_loader(data_dir_root, batch_size=1, **kwargs): + dataset = SunRGBD(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/third_party/flux/annotator/zoe/zoedepth/data/transforms.py b/third_party/flux/annotator/zoe/zoedepth/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..374416dff24fb4fd55598f3946d6d6b091ddefc9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/transforms.py @@ -0,0 +1,481 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import math +import random + +import cv2 +import numpy as np + + +class RandomFliplr(object): + """Horizontal flip of the sample with given probability. + """ + + def __init__(self, probability=0.5): + """Init. + + Args: + probability (float, optional): Flip probability. Defaults to 0.5. + """ + self.__probability = probability + + def __call__(self, sample): + prob = random.random() + + if prob < self.__probability: + for k, v in sample.items(): + if len(v.shape) >= 2: + sample[k] = np.fliplr(v).copy() + + return sample + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class RandomCrop(object): + """Get a random crop of the sample with the given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_if_needed=False, + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): output width + height (int): output height + resize_if_needed (bool, optional): If True, sample might be upsampled to ensure + that a crop of size (width, height) is possbile. Defaults to False. + """ + self.__size = (height, width) + self.__resize_if_needed = resize_if_needed + self.__image_interpolation_method = image_interpolation_method + + def __call__(self, sample): + + shape = sample["disparity"].shape + + if self.__size[0] > shape[0] or self.__size[1] > shape[1]: + if self.__resize_if_needed: + shape = apply_min_size( + sample, self.__size, self.__image_interpolation_method + ) + else: + raise Exception( + "Output size {} bigger than input size {}.".format( + self.__size, shape + ) + ) + + offset = ( + np.random.randint(shape[0] - self.__size[0] + 1), + np.random.randint(shape[1] - self.__size[1] + 1), + ) + + for k, v in sample.items(): + if k == "code" or k == "basis": + continue + + if len(sample[k].shape) >= 2: + sample[k] = v[ + offset[0]: offset[0] + self.__size[0], + offset[1]: offset[1] + self.__size[1], + ] + + return sample + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + letter_box=False, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + self.__letter_box = letter_box + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def make_letter_box(self, sample): + top = bottom = (self.__height - sample.shape[0]) // 2 + left = right = (self.__width - sample.shape[1]) // 2 + sample = cv2.copyMakeBorder( + sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0) + return sample + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__letter_box: + sample["image"] = self.make_letter_box(sample["image"]) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["disparity"] = self.make_letter_box( + sample["disparity"]) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, + height), interpolation=cv2.INTER_NEAREST + ) + + if self.__letter_box: + sample["depth"] = self.make_letter_box(sample["depth"]) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["mask"] = self.make_letter_box(sample["mask"]) + + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class ResizeFixed(object): + def __init__(self, size): + self.__size = size + + def __call__(self, sample): + sample["image"] = cv2.resize( + sample["image"], self.__size[::-1], interpolation=cv2.INTER_LINEAR + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], self.__size[::- + 1], interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + self.__size[::-1], + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class Rescale(object): + """Rescale target values to the interval [0, max_val]. + If input is constant, values are set to max_val / 2. + """ + + def __init__(self, max_val=1.0, use_mask=True): + """Init. + + Args: + max_val (float, optional): Max output value. Defaults to 1.0. + use_mask (bool, optional): Only operate on valid pixels (mask == True). Defaults to True. + """ + self.__max_val = max_val + self.__use_mask = use_mask + + def __call__(self, sample): + disp = sample["disparity"] + + if self.__use_mask: + mask = sample["mask"] + else: + mask = np.ones_like(disp, dtype=np.bool) + + if np.sum(mask) == 0: + return sample + + min_val = np.min(disp[mask]) + max_val = np.max(disp[mask]) + + if max_val > min_val: + sample["disparity"][mask] = ( + (disp[mask] - min_val) / (max_val - min_val) * self.__max_val + ) + else: + sample["disparity"][mask] = np.ones_like( + disp[mask]) * self.__max_val / 2.0 + + return sample + + +# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class DepthToDisparity(object): + """Convert depth to disparity. Removes depth from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "depth" in sample + + sample["mask"][sample["depth"] < self.__eps] = False + + sample["disparity"] = np.zeros_like(sample["depth"]) + sample["disparity"][sample["depth"] >= self.__eps] = ( + 1.0 / sample["depth"][sample["depth"] >= self.__eps] + ) + + del sample["depth"] + + return sample + + +class DisparityToDepth(object): + """Convert disparity to depth. Removes disparity from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "disparity" in sample + + disp = np.abs(sample["disparity"]) + sample["mask"][disp < self.__eps] = False + + # print(sample["disparity"]) + # print(sample["mask"].sum()) + # exit() + + sample["depth"] = np.zeros_like(disp) + sample["depth"][disp >= self.__eps] = ( + 1.0 / disp[disp >= self.__eps] + ) + + del sample["disparity"] + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/flux/annotator/zoe/zoedepth/data/vkitti.py b/third_party/flux/annotator/zoe/zoedepth/data/vkitti.py new file mode 100644 index 0000000000000000000000000000000000000000..72a2e5a8346f6e630ede0e28d6959725af8d7e72 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/vkitti.py @@ -0,0 +1,151 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import os + +from PIL import Image +import numpy as np +import cv2 + + +class ToTensor(object): + def __init__(self): + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True): + import glob + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "test_color", '*.png')) + self.depth_files = [r.replace("test_color", "test_depth") + for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) + print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + + if self.do_kb_crop and False: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti_test") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/third_party/flux/annotator/zoe/zoedepth/data/vkitti2.py b/third_party/flux/annotator/zoe/zoedepth/data/vkitti2.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcfb0414b7f3f21859f30ae34bd71689516a3e7 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/data/vkitti2.py @@ -0,0 +1,187 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI2(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True, split="test"): + import glob + + # image paths are of the form /rgb///frames//Camera<0,1>/rgb_{}.jpg + self.image_files = glob.glob(os.path.join( + data_dir_root, "rgb", "**", "frames", "rgb", "Camera_0", '*.jpg'), recursive=True) + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + # If train test split is not created, then create one. + # Split is such that 8% of the frames from each scene are used for testing. + if not os.path.exists(os.path.join(data_dir_root, "train.txt")): + import random + scenes = set([os.path.basename(os.path.dirname( + os.path.dirname(os.path.dirname(f)))) for f in self.image_files]) + train_files = [] + test_files = [] + for scene in scenes: + scene_files = [f for f in self.image_files if os.path.basename( + os.path.dirname(os.path.dirname(os.path.dirname(f)))) == scene] + random.shuffle(scene_files) + train_files.extend(scene_files[:int(len(scene_files) * 0.92)]) + test_files.extend(scene_files[int(len(scene_files) * 0.92):]) + with open(os.path.join(data_dir_root, "train.txt"), "w") as f: + f.write("\n".join(train_files)) + with open(os.path.join(data_dir_root, "test.txt"), "w") as f: + f.write("\n".join(test_files)) + + if split == "train": + with open(os.path.join(data_dir_root, "train.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + elif split == "test": + with open(os.path.join(data_dir_root, "test.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + # depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m + depth = Image.fromarray(depth) + # print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + if self.do_kb_crop: + if idx == 0: + print("Using KB input crop") + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = np.asarray(depth, dtype=np.float32) / 1. + depth[depth > 80] = -1 + + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti2_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI2(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti2_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti2") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/third_party/flux/annotator/zoe/zoedepth/models/__init__.py b/third_party/flux/annotator/zoe/zoedepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8dd295c5dd382ad6f1bcdb8d107542445b735f3 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/depth_model.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/depth_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b7fee4dca09e328b61bb9fa7340976ab8a5fc69 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/depth_model.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/model_io.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/model_io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e895dc522c8c06810e1570cc2b6fd1ad57765ed Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/__pycache__/model_io.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/__init__.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/base_models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79fc9497b958c7fe764fd880ee01bd2b6c10b5e0 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/base_models/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/__pycache__/midas.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/base_models/__pycache__/midas.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e191db7efb3a64abafe0622278a0bf87747aa6 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/base_models/__pycache__/midas.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas.py new file mode 100644 index 0000000000000000000000000000000000000000..ee660bc93d44c28efe8d8c674e715ea2ecb4c183 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas.py @@ -0,0 +1,379 @@ +# MIT License +import os + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + print("Params passed to Resize transform:") + print("\twidth: ", width) + print("\theight: ", height) + print("\tresize_target: ", resize_target) + print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + print("\tensure_multiple_of: ", ensure_multiple_of) + print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (height, width), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + self.normalization = Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class MidasCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + with torch.no_grad(): + if denorm: + x = denormalize(x) + x = self.prep(x) + # print("Shape after prep: ", x.shape) + + with torch.set_grad_enabled(self.trainable): + + # print("Input size to Midascore", x.shape) + rel_depth = self.core(x) + # print("Output from midas shape", rel_depth.shape) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.scratch.output_conv.children())[ + 3].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self, model_type): + self.output_channels = MIDAS_SETTINGS[model_type] + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if midas_model_type not in MIDAS_SETTINGS: + raise ValueError( + f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") + if "img_size" in kwargs: + kwargs = MidasCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + print("img_size", img_size) + midas_path = os.path.join(os.path.dirname(__file__), 'midas_repo') + midas = torch.hub.load(midas_path, midas_model_type, + pretrained=use_pretrained_midas, force_reload=force_reload, source='local') + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + midas_core.set_output_channels(midas_model_type) + return midas_core + + @staticmethod + def build_from_config(config): + return MidasCore.build(**config) + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a13c80028de3d297de4a3f09cee1b20759acc006 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore @@ -0,0 +1,110 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +*.png +*.pfm +*.jpg +*.jpeg +*.pt \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..466bc94ba3128ea9cbe4bde82bd2fd1fc9daa8af --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile @@ -0,0 +1,29 @@ +# enables cuda support in docker +FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 + +# install python 3.6, pip and requirements for opencv-python +# (see https://github.com/NVIDIA/nvidia-docker/issues/864) +RUN apt-get update && apt-get -y install \ + python3 \ + python3-pip \ + libsm6 \ + libxext6 \ + libxrender-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# install python dependencies +RUN pip3 install --upgrade pip +RUN pip3 install torch~=1.8 torchvision opencv-python-headless~=3.4 timm + +# copy inference code +WORKDIR /opt/MiDaS +COPY ./midas ./midas +COPY ./*.py ./ + +# download model weights so the docker image can be used offline +RUN cd weights && {curl -OL https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt; cd -; } +RUN python3 run.py --model_type dpt_hybrid; exit 0 + +# entrypoint (dont forget to mount input and output directories) +CMD python3 run.py --model_type dpt_hybrid diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9568ea71c755b6938ee5482ba9f09be722e75943 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md @@ -0,0 +1,259 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + + +and our [preprint](https://arxiv.org/abs/2103.13413): + +> Vision Transformers for Dense Prediction +> René Ranftl, Alexey Bochkovskiy, Vladlen Koltun + + +MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with +multi-objective optimization. +The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2). +The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters. + +![](figures/Improvement_vs_FPS.png) + +### Setup + +1) Pick one or more models and download the corresponding weights to the `weights` folder: + +MiDaS 3.1 +- For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) +- For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt) +- For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt) +- For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin) + +MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) + +MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) + +1) Set up dependencies: + + ```shell + conda env create -f environment.yaml + conda activate midas-py310 + ``` + +#### optional + +For the Next-ViT model, execute + +```shell +git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit +``` + +For the OpenVINO model, install + +```shell +pip install openvino +``` + +### Usage + +1) Place one or more input images in the folder `input`. + +2) Run the model with + + ```shell + python run.py --model_type --input_path input --output_path output + ``` + where `````` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type), + [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type), + [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type), + [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type), + [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). + +3) The resulting depth maps are written to the `output` folder. + +#### optional + +1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This + size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single + inference height but a range of different heights. Feel free to explore different heights by appending the extra + command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may + decrease the model accuracy. +2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is + supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, + disregarding the aspect ratio while preserving the height, use the command line argument `--square`. + +#### via Camera + + If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths + away and choose a model type as shown above: + + ```shell + python run.py --model_type --side + ``` + + The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown + side-by-side for comparison. + +#### via Docker + +1) Make sure you have installed Docker and the + [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)). + +2) Build the Docker image: + + ```shell + docker build -t midas . + ``` + +3) Run inference: + + ```shell + docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas + ``` + + This command passes through all of your NVIDIA GPUs to the container, mounts the + `input` and `output` directories and then runs the inference. + +#### via PyTorch Hub + +The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/) + +#### via TensorFlow or ONNX + +See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory. + +Currently only supports MiDaS v2.1. + + +#### via Mobile (iOS / Android) + +See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory. + +#### via ROS1 (Robot Operating System) + +See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory. + +Currently only supports MiDaS v2.1. DPT-based models to be added. + + +### Accuracy + +We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets +(see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**. +$\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to +MiDaS 3.0 DPTL-384. The models are grouped by the height used for inference, whereas the square training resolution is given by +the numbers in the model names. The table also shows the **number of parameters** (in millions) and the +**frames per second** for inference at the training resolution (for GPU RTX 3090): + +| MiDaS Model | DIW
WHDR | Eth3d
AbsRel | Sintel
AbsRel | TUM
δ1 | KITTI
δ1 | NYUv2
δ1 | $\color{green}{\textsf{Imp.}}$
% | Par.
M | FPS
  | +|-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:| +| **Inference height 512** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** | +| | | | | | | | | | | +| **Inference height 384** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 | +| [v3.1 Swin2L-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 | +| [v3.1 Swin2B-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 | +| [v3.1 SwinL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 | +| [v3.1 BEiTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 | +| [v3.1 Next-ViTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 | +| [v3.1 BEiTB-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 | +| [v3.0 DPTL-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** | +| [v3.0 DPTH-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 | +| [v2.1 Large384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 | +| | | | | | | | | | | +| **Inference height 256** | | | | | | | | | | +| [v3.1 Swin2T-256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 | +| [v2.1 Small256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** | +| | | | | | | | | | | +| **Inference height 224** | | | | | | | | | | +| [v3.1 LeViT224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** | + +* No zero-shot error, because models are also trained on KITTI and NYU Depth V2\ +$\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model +does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other +validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the +improvement, because these quantities are averages over the pixels of an image and do not take into account the +advantage of more details due to a higher resolution.\ +Best values per column and same validation height in bold + +#### Improvement + +The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0 +DPTL-384 and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then +the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%. + +Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the +improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large384 +and v2.0 Large384 respectively instead of v3.0 DPTL-384. + +### Depth map comparison + +Zoom in for better visibility +![](figures/Comparison.png) + +### Speed on Camera Feed + +Test configuration +- Windows 10 +- 11th Gen Intel Core i7-1185G7 3.00GHz +- 16GB RAM +- Camera resolution 640x480 +- openvino_midas_v21_small_256 + +Speed: 22 FPS + +### Changelog + +* [Dec 2022] Released MiDaS v3.1: + - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf)) + - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split + - Best model, BEiTLarge 512, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0 + - Integrated live depth estimation from camera feed +* [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large). +* [Apr 2021] Released MiDaS v3.0: + - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1 + - Additional models can be found [here](https://github.com/isl-org/DPT) +* [Nov 2020] Released MiDaS v2.1: + - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2) + - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms. + - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android) + - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots +* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/). +* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust +* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1)) + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@ARTICLE {Ranftl2022, + author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", + title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", + journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", + year = "2022", + volume = "44", + number = "3" +} +``` + +If you use a DPT-based model, please also cite: + +``` +@article{Ranftl2021, + author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, + title = {Vision Transformers for Dense Prediction}, + journal = {ICCV}, + year = {2021}, +} +``` + +### Acknowledgements + +Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT). +We'd like to thank the authors for making these libraries available. + +### License + +MIT License diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9abe5693b9e0de56b7d20728f4d0e6333c5822d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml @@ -0,0 +1,16 @@ +name: midas-py310 +channels: + - pytorch + - defaults +dependencies: + - nvidia::cudatoolkit=11.7 + - python=3.10.8 + - pytorch::pytorch=1.13.0 + - torchvision=0.14.0 + - pip=22.3.1 + - numpy=1.23.4 + - pip: + - opencv-python==4.6.0.66 + - imutils==0.5.4 + - timm==0.6.12 + - einops==0.6.0 \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d638be5151c4e305daff0c47d1ea3fc8066377d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py @@ -0,0 +1,435 @@ +dependencies = ["torch"] + +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small + +def DPT_BEiT_L_512(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_512 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_512", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitb16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2l24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2b24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_T_256(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_T_256 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2t16_256", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Swin_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Swin_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swinl12_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Next_ViT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="next_vit_large_6m", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_LeViT_224(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_LeViT_224 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Large(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Large model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Hybrid(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Hybrid model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitb_rn50_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet() + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS_small(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + + +def transforms(): + import cv2 + from torchvision.transforms import Compose + from midas.transforms import Resize, NormalizeImage, PrepareForNet + from midas import transforms + + transforms.default_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.small_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.dpt_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.beit512_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 512, + 512, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin384_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin256_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.levit_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 224, + 224, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + return transforms diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..e616dfd4026f448f9e22d35c6ad8b0028732acb9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py @@ -0,0 +1,196 @@ +import timm +import torch +import types + +import numpy as np +import torch.nn.functional as F + +from .utils import forward_adapted_unflatten, make_backbone_default +from timm.models.beit import gen_relative_position_index +from torch.utils.checkpoint import checkpoint +from typing import Optional + + +def forward_beit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_features") + + +def patch_embed_forward(self, x): + """ + Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. + """ + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +def _get_rel_pos_bias(self, window_size): + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear") + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) + + key = str(window_size[1]) + "," + str(window_size[0]) + if key not in self.relative_position_indices.keys(): + self.relative_position_indices[key] = gen_relative_position_index(window_size) + + relative_position_bias = new_relative_position_bias_table[ + self.relative_position_indices[key].view(-1)].view( + window_size[0] * window_size[1] + 1, + window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + +def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. + """ + B, N, C = x.shape + + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + window_size = tuple(np.array(resolution) // 16) + attn = attn + self._get_rel_pos_bias(window_size) + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. + """ + if self.gamma_1 is None: + x = x + self.drop_path1(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), resolution, + shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +def beit_forward_features(self, x): + """ + Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. + """ + resolution = x.shape[2:] + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) + x = self.norm(x) + return x + + +def _make_beit_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[0, 4, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) + backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) + + for block in backbone.model.blocks: + attn = block.attn + attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) + attn.forward = types.MethodType(attention_forward, attn) + attn.relative_position_indices = {} + + block.forward = types.MethodType(block_forward, block) + + return backbone + + +def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + + features = [256, 512, 1024, 1024] + + return _make_beit_backbone( + model, + features=features, + size=[512, 512], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + ) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d023a98702a0451806d26f33f8bccf931814f10 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py @@ -0,0 +1,106 @@ +import timm +import torch +import torch.nn as nn +import numpy as np + +from .utils import activations, get_activation, Transpose + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=[3, 11, 21], + patch_grid=[14, 14] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks == None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8afdd8b743b5ab023a359dc3b721e601b1a40d11 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py @@ -0,0 +1,39 @@ +import timm + +import torch.nn as nn + +from pathlib import Path +from .utils import activations, forward_default, get_activation + +from ..external.next_vit.classification.nextvit import * + + +def forward_next_vit(pretrained, x): + return forward_default(pretrained, x, "forward") + + +def _make_next_vit_backbone( + model, + hooks=[2, 6, 36, 39], +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + return pretrained + + +def _make_pretrained_next_vit_large_6m(hooks=None): + model = timm.create_model("nextvit_large") + + hooks = [2, 6, 36, 39] if hooks == None else hooks + return _make_next_vit_backbone( + model, + hooks=hooks, + ) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c71367e3e78b087f80b2ab3e2f495a9c372f1a --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py @@ -0,0 +1,13 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swinl12_384(pretrained, hooks=None): + model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4c8f1d6fc1807a207dc6b9a261c6f7b14a87a3 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py @@ -0,0 +1,34 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swin2l24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2b24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2t16_256(pretrained, hooks=None): + model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) + + hooks = [1, 1, 5, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks, + patch_grid=[64, 64] + ) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py new file mode 100644 index 0000000000000000000000000000000000000000..94d63d408f18511179d90b3ac6f697385d1e556d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py @@ -0,0 +1,52 @@ +import torch + +import torch.nn as nn +import numpy as np + +from .utils import activations, forward_default, get_activation, Transpose + + +def forward_swin(pretrained, x): + return forward_default(pretrained, x) + + +def _make_swin_backbone( + model, + hooks=[1, 1, 17, 1], + patch_grid=[96, 96] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if hasattr(model, "patch_grid"): + used_patch_grid = model.patch_grid + else: + used_patch_grid = patch_grid + + patch_grid_size = np.array(used_patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) + ) + pretrained.act_postprocess4 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) + ) + + return pretrained diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0558899dddcfccec5f01a764d4f21738eb612149 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py @@ -0,0 +1,249 @@ +import torch + +import torch.nn as nn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def forward_default(pretrained, x, function_name="forward_features"): + exec(f"pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + if hasattr(pretrained, "act_postprocess1"): + layer_1 = pretrained.act_postprocess1(layer_1) + if hasattr(pretrained, "act_postprocess2"): + layer_2 = pretrained.act_postprocess2(layer_2) + if hasattr(pretrained, "act_postprocess3"): + layer_3 = pretrained.act_postprocess3(layer_3) + if hasattr(pretrained, "act_postprocess4"): + layer_4 = pretrained.act_postprocess4(layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): + b, c, h, w = x.shape + + exec(f"glob = pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def make_backbone_default( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + return pretrained diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..413f9693bd4548342280e329c9128c1a52cea920 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + +from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, + make_backbone_default, Transpose) + + +def forward_vit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_flex") + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + if self.no_embed_class: + x = x + pos_embed + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if not self.no_embed_class: + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + patch_size=[16, 16], + number_stages=2, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + used_number_stages = 0 if use_vit_only else number_stages + for s in range(used_number_stages): + pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( + get_activation(str(s + 1)) + ) + for s in range(used_number_stages, 4): + pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + for s in range(used_number_stages): + value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + exec(f"pretrained.act_postprocess{s + 1}=value") + for s in range(used_number_stages, 4): + if s < number_stages: + final_layer = nn.ConvTranspose2d( + in_channels=features[s], + out_channels=features[s], + kernel_size=4 // (2 ** s), + stride=4 // (2 ** s), + padding=0, + bias=True, + dilation=1, + groups=1, + ) + elif s > number_stages: + final_layer = nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ) + else: + final_layer = None + + layers = [ + readout_oper[s], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[s], + kernel_size=1, + stride=1, + padding=0, + ), + ] + if final_layer is not None: + layers.append(final_layer) + + value = nn.Sequential(*layers) + exec(f"pretrained.act_postprocess{s + 1}=value") + + pretrained.model.start_index = start_index + pretrained.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..6d87a00680bb6ed9a6d7c3043ea30a1e90361794 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py @@ -0,0 +1,439 @@ +import torch +import torch.nn as nn + +from .backbones.beit import ( + _make_pretrained_beitl16_512, + _make_pretrained_beitl16_384, + _make_pretrained_beitb16_384, + forward_beit, +) +from .backbones.swin_common import ( + forward_swin, +) +from .backbones.swin2 import ( + _make_pretrained_swin2l24_384, + _make_pretrained_swin2b24_384, + _make_pretrained_swin2t16_256, +) +from .backbones.swin import ( + _make_pretrained_swinl12_384, +) +from .backbones.levit import ( + _make_pretrained_levit_384, + forward_levit, +) +from .backbones.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]): + if backbone == "beitl16_512": + pretrained = _make_pretrained_beitl16_512( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_512-L (backbone) + elif backbone == "beitl16_384": + pretrained = _make_pretrained_beitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_384-L (backbone) + elif backbone == "beitb16_384": + pretrained = _make_pretrained_beitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # BEiT_384-B (backbone) + elif backbone == "swin2l24_384": + pretrained = _make_pretrained_swin2l24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin2-L/12to24 (backbone) + elif backbone == "swin2b24_384": + pretrained = _make_pretrained_swin2b24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [128, 256, 512, 1024], features, groups=groups, expand=expand + ) # Swin2-B/12to24 (backbone) + elif backbone == "swin2t16_256": + pretrained = _make_pretrained_swin2t16_256( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # Swin2-T/16 (backbone) + elif backbone == "swinl12_384": + pretrained = _make_pretrained_swinl12_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin-L/12 (backbone) + elif backbone == "next_vit_large_6m": + from .backbones.next_vit import _make_pretrained_next_vit_large_6m + pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) + scratch = _make_scratch( + in_features, features, groups=groups, expand=expand + ) # Next-ViT-L on ImageNet-1K-6M (backbone) + elif backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + elif backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..3129d09cb43a7c79b23916236991fabbedb78f55 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_beit, + forward_swin, + forward_levit, + forward_vit, +) +from .backbones.levit import stem_b4_transpose +from timm.models.layers import get_act_layer + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + size_refinenet3 = None + self.scratch.stem_transpose = None + + if "beit" in backbone: + self.forward_transformer = forward_beit + elif "swin" in backbone: + self.forward_transformer = forward_swin + elif "next_vit" in backbone: + from .backbones.next_vit import forward_next_vit + self.forward_transformer = forward_next_vit + elif "levit" in backbone: + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + else: + self.forward_transformer = forward_vit + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cd1f2d43054bfd3d650587c7b2ed35f1347c9e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py @@ -0,0 +1,242 @@ +import cv2 +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +from torchvision.transforms import Compose + +default_models = { + "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", + "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", + "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", + "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", + "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", + "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", + "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", + "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", + "dpt_levit_224": "weights/dpt_levit_224.pt", + "dpt_large_384": "weights/dpt_large_384.pt", + "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", + "midas_v21_384": "weights/midas_v21_384.pt", + "midas_v21_small_256": "weights/midas_v21_small_256.pt", + "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", +} + + +def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): + """Load the specified network. + + Args: + device (device): the torch device used + model_path (str): path to saved model + model_type (str): the type of the model to be loaded + optimize (bool): optimize the model to half-integer on CUDA? + height (int): inference encoder image height + square (bool): resize to a square resolution? + + Returns: + The loaded network, the transform which prepares images as input to the network and the dimensions of the + network input + """ + if "openvino" in model_type: + from openvino.runtime import Core + + keep_aspect_ratio = not square + + if model_type == "dpt_beit_large_512": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_512", + non_negative=True, + ) + net_w, net_h = 512, 512 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_base_384": + model = DPTDepthModel( + path=model_path, + backbone="beitb16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2l24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_base_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2b24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_tiny_256": + model = DPTDepthModel( + path=model_path, + backbone="swin2t16_256", + non_negative=True, + ) + net_w, net_h = 256, 256 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swinl12_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_next_vit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="next_vit_large_6m", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers + # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of + # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py + # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) + elif model_type == "dpt_levit_224": + model = DPTDepthModel( + path=model_path, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + net_w, net_h = 224, 224 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_large_384": + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid_384": + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21_384": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small_256": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "openvino_midas_v21_small_256": + ie = Core() + uncompiled_model = ie.read_model(model=model_path) + model = ie.compile_model(uncompiled_model, "CPU") + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + if not "openvino" in model_type: + print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) + else: + print("Model loaded, optimized with OpenVINO") + + if "openvino" in model_type: + keep_aspect_ratio = False + + if height is not None: + net_w, net_h = height, height + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + if not "openvino" in model_type: + model.eval() + + if optimize and (device == torch.device("cuda")): + if not "openvino" in model_type: + model = model.to(memory_format=torch.channels_last) + model = model.half() + else: + print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") + exit() + + if not "openvino" in model_type: + model.to(device) + + return model, transform, net_w, net_h diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md new file mode 100644 index 0000000000000000000000000000000000000000..45c18f7f0bfe40c0db373e8a94716867705f5827 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md @@ -0,0 +1,70 @@ +## Mobile version of MiDaS for iOS / Android - Monocular Depth Estimation + +### Accuracy + +* Old small model - ResNet50 default-decoder 384x384 +* New small model - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| Old small model 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| New small model 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on iOS / Android + +**Frames Per Second** (the higher - the better): + +| Model | iPhone CPU | iPhone GPU | iPhone NPU | OnePlus8 CPU | OnePlus8 GPU | OnePlus8 NNAPI | +|---|---|---|---|---|---|---| +| Old small model 384x384 | 0.6 | N/A | N/A | 0.45 | 0.50 | 0.50 | +| New small model 256x256 | 8 | 22 | **30** | 6 | **22** | 4 | +| SpeedUp, X times | **12.8x** | - | - | **13.2x** | **44x** | **8x** | + +N/A - run-time error (no data available) + + +#### Models: + +* Old small model - ResNet50 default-decoder 1x384x384x3, batch=1 FP32 (converters: Pytorch -> ONNX - [onnx_tf](https://github.com/onnx/onnx-tensorflow) -> (saved model) PB -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor) + +* New small model - EfficientNet-Lite3 small-decoder 1x256x256x3, batch=1 FP32 (custom converter: Pytorch -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor, HRWSI, IRS, TartanAir, BlendedMVS, ApolloScape) + +#### Frameworks for training and conversions: +``` +pip install torch==1.6.0 torchvision==0.7.0 +pip install tf-nightly-gpu==2.5.0.dev20201031 tensorflow-addons==0.11.2 numpy==1.18.0 +git clone --depth 1 --branch v1.6.0 https://github.com/onnx/onnx-tensorflow +``` + +#### SoC - OS - Library: + +* iPhone 11 (A13 Bionic) - iOS 13.7 - TensorFlowLiteSwift 0.0.1-nightly +* OnePlus 8 (Snapdragon 865) - Andoird 10 - org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly + + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2fbe357549c64ae2966d5c3013a9179427b7b396 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore @@ -0,0 +1,13 @@ +*.iml +.gradle +/local.properties +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +.DS_Store +/build +/captures +.externalNativeBuild + +/.gradle/ +/.idea/ \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md new file mode 100644 index 0000000000000000000000000000000000000000..72014bdfa2cd701a6453debbc8e53fcc15c0a5dc --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md @@ -0,0 +1,414 @@ +# TensorFlow Lite Android image classification example + +This document walks through the code of a simple Android mobile application that +demonstrates +[image classification](https://www.tensorflow.org/lite/models/image_classification/overview) +using the device camera. + +## Explore the code + +We're now going to walk through the most important parts of the sample code. + +### Get camera input + +This mobile application gets the camera input using the functions defined in the +file +[`CameraActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java). +This file depends on +[`AndroidManifest.xml`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/AndroidManifest.xml) +to set the camera orientation. + +`CameraActivity` also contains code to capture user preferences from the UI and +make them available to other classes via convenience methods. + +```java +model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); +device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); +numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); +``` + +### Classifier + +This Image Classification Android reference app demonstrates two implementation +solutions, +[`lib_task_api`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api) +that leverages the out-of-box API from the +[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier), +and +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support) +that creates the custom inference pipleline using the +[TensorFlow Lite Support Library](https://www.tensorflow.org/lite/inference_with_metadata/lite_support). + +Both solutions implement the file `Classifier.java` (see +[the one in lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java) +and +[the one in lib_support](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java)) +that contains most of the complex logic for processing the camera input and +running inference. + +Two subclasses of the `Classifier` exist, as in `ClassifierFloatMobileNet.java` +and `ClassifierQuantizedMobileNet.java`, which contain settings for both +floating point and +[quantized](https://www.tensorflow.org/lite/performance/post_training_quantization) +models. + +The `Classifier` class implements a static method, `create`, which is used to +instantiate the appropriate subclass based on the supplied model type (quantized +vs floating point). + +#### Using the TensorFlow Lite Task Library + +Inference can be done using just a few lines of code with the +[`ImageClassifier`](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier) +in the TensorFlow Lite Task Library. + +##### Load model and create ImageClassifier + +`ImageClassifier` expects a model populated with the +[model metadata](https://www.tensorflow.org/lite/convert/metadata) and the label +file. See the +[model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements) +for more details. + +`ImageClassifierOptions` allows manipulation on various inference options, such +as setting the maximum number of top scored results to return using +`setMaxResults(MAX_RESULTS)`, and setting the score threshold using +`setScoreThreshold(scoreThreshold)`. + +```java +// Create the ImageClassifier instance. +ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); +imageClassifier = ImageClassifier.createFromFileAndOptions(activity, + getModelPath(), options); +``` + +`ImageClassifier` currently does not support configuring delegates and +multithread, but those are on our roadmap. Please stay tuned! + +##### Run inference + +`ImageClassifier` contains builtin logic to preprocess the input image, such as +rotating and resizing an image. Processing options can be configured through +`ImageProcessingOptions`. In the following example, input images are rotated to +the up-right angle and cropped to the center as the model expects a square input +(`224x224`). See the +[Java doc of `ImageClassifier`](https://github.com/tensorflow/tflite-support/blob/195b574f0aa9856c618b3f1ad87bd185cddeb657/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java#L22) +for more details about how the underlying image processing is performed. + +```java +TensorImage inputImage = TensorImage.fromBitmap(bitmap); +int width = bitmap.getWidth(); +int height = bitmap.getHeight(); +int cropSize = min(width, height); +ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + +List results = imageClassifier.classify(inputImage, + imageOptions); +``` + +The output of `ImageClassifier` is a list of `Classifications` instance, where +each `Classifications` element is a single head classification result. All the +demo models are single head models, therefore, `results` only contains one +`Classifications` object. Use `Classifications.getCategories()` to get a list of +top-k categories as specified with `MAX_RESULTS`. Each `Category` object +contains the srting label and the score of that category. + +To match the implementation of +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support), +`results` is converted into `List` in the method, +`getRecognitions`. + +#### Using the TensorFlow Lite Support Library + +##### Load model and create interpreter + +To perform inference, we need to load a model file and instantiate an +`Interpreter`. This happens in the constructor of the `Classifier` class, along +with loading the list of class labels. Information about the device type and +number of threads is used to configure the `Interpreter` via the +`Interpreter.Options` instance passed into its constructor. Note that if a GPU, +DSP (Digital Signal Processor) or NPU (Neural Processing Unit) is available, a +[`Delegate`](https://www.tensorflow.org/lite/performance/delegates) can be used +to take full advantage of these hardware. + +Please note that there are performance edge cases and developers are adviced to +test with a representative set of devices prior to production. + +```java +protected Classifier(Activity activity, Device device, int numThreads) throws + IOException { + tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + labels = FileUtil.loadLabels(activity, getLabelPath()); +... +``` + +For Android devices, we recommend pre-loading and memory mapping the model file +to offer faster load times and reduce the dirty pages in memory. The method +`FileUtil.loadMappedFile` does this, returning a `MappedByteBuffer` containing +the model. + +The `MappedByteBuffer` is passed into the `Interpreter` constructor, along with +an `Interpreter.Options` object. This object can be used to configure the +interpreter, for example by setting the number of threads (`.setNumThreads(1)`) +or enabling [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks) +(`.addDelegate(nnApiDelegate)`). + +##### Pre-process bitmap image + +Next in the `Classifier` constructor, we take the input camera bitmap image, +convert it to a `TensorImage` format for efficient processing and pre-process +it. The steps are shown in the private 'loadImage' method: + +```java +/** Loads input image, and applys preprocessing. */ +private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + image.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight()); + int numRoration = sensorOrientation / 90; + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.BILINEAR)) + .add(new Rot90Op(numRoration)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); +} +``` + +The pre-processing is largely the same for quantized and float models with one +exception: Normalization. + +In `ClassifierFloatMobileNet`, the normalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 127.5f; +private static final float IMAGE_STD = 127.5f; +``` + +In `ClassifierQuantizedMobileNet`, normalization is not required. Thus the +nomalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 0.0f; +private static final float IMAGE_STD = 1.0f; +``` + +##### Allocate output object + +Initiate the output `TensorBuffer` for the output of the model. + +```java +/** Output probability TensorBuffer. */ +private final TensorBuffer outputProbabilityBuffer; + +//... +// Get the array size for the output buffer from the TensorFlow Lite model file +int probabilityTensorIndex = 0; +int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, 1001} +DataType probabilityDataType = + tflite.getOutputTensor(probabilityTensorIndex).dataType(); + +// Creates the output tensor and its processor. +outputProbabilityBuffer = + TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + +// Creates the post processor for the output probability. +probabilityProcessor = + new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); +``` + +For quantized models, we need to de-quantize the prediction with the NormalizeOp +(as they are all essentially linear transformation). For float model, +de-quantize is not required. But to uniform the API, de-quantize is added to +float model too. Mean and std are set to 0.0f and 1.0f, respectively. To be more +specific, + +In `ClassifierQuantizedMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 255.0f; +``` + +In `ClassifierFloatMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 1.0f; +``` + +##### Run inference + +Inference is performed using the following in `Classifier` class: + +```java +tflite.run(inputImageBuffer.getBuffer(), + outputProbabilityBuffer.getBuffer().rewind()); +``` + +##### Recognize image + +Rather than call `run` directly, the method `recognizeImage` is used. It accepts +a bitmap and sensor orientation, runs inference, and returns a sorted `List` of +`Recognition` instances, each corresponding to a label. The method will return a +number of results bounded by `MAX_RESULTS`, which is 3 by default. + +`Recognition` is a simple class that contains information about a specific +recognition result, including its `title` and `confidence`. Using the +post-processing normalization method specified, the confidence is converted to +between 0 and 1 of a given class being represented by the image. + +```java +/** Gets the label to probability map. */ +Map labeledProbability = + new TensorLabel(labels, + probabilityProcessor.process(outputProbabilityBuffer)) + .getMapWithFloatValue(); +``` + +A `PriorityQueue` is used for sorting. + +```java +/** Gets the top-k results. */ +private static List getTopKProbability( + Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of + // the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), + entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; +} +``` + +### Display results + +The classifier is invoked and inference results are displayed by the +`processImage()` function in +[`ClassifierActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java). + +`ClassifierActivity` is a subclass of `CameraActivity` that contains method +implementations that render the camera image, run classification, and display +the results. The method `processImage()` runs classification on a background +thread as fast as possible, rendering information on the UI thread to avoid +blocking inference and creating latency. + +```java +@Override +protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, + previewHeight); + final int imageSizeX = classifier.getImageSizeX(); + final int imageSizeY = classifier.getImageSizeY(); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + final List results = + classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + showResultsInBottomSheet(results); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(imageSizeX + "x" + imageSizeY); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); +} +``` + +Another important role of `ClassifierActivity` is to determine user preferences +(by interrogating `CameraActivity`), and instantiate the appropriately +configured `Classifier` subclass. This happens when the video feed begins (via +`onPreviewSizeChosen()`) and when options are changed in the UI (via +`onInferenceConfigurationChanged()`). + +```java +private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU && model == Model.QUANTIZED) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, "GPU does not yet supported quantized models.", + Toast.LENGTH_LONG) + .show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, + device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException e) { + LOGGER.e(e, "Failed to create classifier."); + } +} +``` diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md new file mode 100644 index 0000000000000000000000000000000000000000..faf415eb27ccc1a62357718d1e0a9b8c746de4e8 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md @@ -0,0 +1,21 @@ +# MiDaS on Android smartphone by using TensorFlow-lite (TFLite) + + +* Either use Android Studio for compilation. + +* Or use ready to install apk-file: + * Or use URL: https://i.diawi.com/CVb8a9 + * Or use QR-code: + +Scan QR-code or open URL -> Press `Install application` -> Press `Download` and wait for download -> Open -> Install -> Open -> Press: Allow MiDaS to take photo and video from the camera While using the APP + +![CVb8a9](https://user-images.githubusercontent.com/4096485/97727213-38552500-1ae1-11eb-8b76-4ea11216f76d.png) + +---- + +To use another model, you should convert it to `model_opt.tflite` and place it to the directory: `models\src\main\assets` + + +---- + +Original repository: https://github.com/isl-org/MiDaS diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1ae74c6780c277d75fedfb7511ff51f69941b48b --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore @@ -0,0 +1,3 @@ +/build + +/build/ \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..94e9886a55c7d54f71b424bb246c849dd6bd795d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle @@ -0,0 +1,56 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 28 + defaultConfig { + applicationId "org.tensorflow.lite.examples.classification" + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + lintOptions { + abortOnError false + } + flavorDimensions "tfliteInference" + productFlavors { + // The TFLite inference is built using the TFLite Support library. + support { + dimension "tfliteInference" + } + // The TFLite inference is built using the TFLite Task library. + taskApi { + dimension "tfliteInference" + } + } + +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + supportImplementation project(":lib_support") + taskApiImplementation project(":lib_task_api") + implementation 'androidx.appcompat:appcompat:1.0.0' + implementation 'androidx.coordinatorlayout:coordinatorlayout:1.0.0' + implementation 'com.google.android.material:material:1.0.0' + + androidTestImplementation 'androidx.test.ext:junit:1.1.1' + androidTestImplementation 'com.google.truth:truth:1.0.1' + androidTestImplementation 'androidx.test:runner:1.2.0' + androidTestImplementation 'androidx.test:rules:1.1.0' +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdfad31f9b3e694817025d8b8f2ca0b40aa436bb --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt @@ -0,0 +1,3 @@ +red_fox 0.79403335 +kit_fox 0.16753247 +grey_fox 0.03619214 diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt new file mode 100644 index 0000000000000000000000000000000000000000..3668ce54df0d1e57e31c58281d6085b83928f991 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt @@ -0,0 +1,3 @@ +red_fox 0.85 +kit_fox 0.13 +grey_fox 0.02 diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..3653d8799092492ebbb16c7c956eb50e3d404aa4 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java new file mode 100644 index 0000000000000000000000000000000000000000..0194132890aae659c2a70d33106306ed665b22e8 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.util.Log; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.rule.ActivityTestRule; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Scanner; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +/** Golden test for Image Classification Reference app. */ +@RunWith(AndroidJUnit4.class) +public class ClassifierTest { + + @Rule + public ActivityTestRule rule = + new ActivityTestRule<>(ClassifierActivity.class); + + private static final String[] INPUTS = {"fox.jpg"}; + private static final String[] GOLDEN_OUTPUTS_SUPPORT = {"fox-mobilenet_v1_1.0_224_support.txt"}; + private static final String[] GOLDEN_OUTPUTS_TASK = {"fox-mobilenet_v1_1.0_224_task_api.txt"}; + + @Test + public void classificationResultsShouldNotChange() throws IOException { + ClassifierActivity activity = rule.getActivity(); + Classifier classifier = Classifier.create(activity, Model.FLOAT_MOBILENET, Device.CPU, 1); + for (int i = 0; i < INPUTS.length; i++) { + String imageFileName = INPUTS[i]; + String goldenOutputFileName; + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // This is a temporary workaround to set different golden rest results as the preprocessing + // of lib_support and lib_task_api are different. Will merge them once the above TODO is + // resolved. + if (Classifier.TAG.equals("ClassifierWithSupport")) { + goldenOutputFileName = GOLDEN_OUTPUTS_SUPPORT[i]; + } else { + goldenOutputFileName = GOLDEN_OUTPUTS_TASK[i]; + } + Bitmap input = loadImage(imageFileName); + List goldenOutput = loadRecognitions(goldenOutputFileName); + + List result = classifier.recognizeImage(input, 0); + Iterator goldenOutputIterator = goldenOutput.iterator(); + + for (Recognition actual : result) { + Assert.assertTrue(goldenOutputIterator.hasNext()); + Recognition expected = goldenOutputIterator.next(); + assertThat(actual.getTitle()).isEqualTo(expected.getTitle()); + assertThat(actual.getConfidence()).isWithin(0.01f).of(expected.getConfidence()); + } + } + } + + private static Bitmap loadImage(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load image from assets"); + } + return BitmapFactory.decodeStream(inputStream); + } + + private static List loadRecognitions(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load probability results from assets"); + } + Scanner scanner = new Scanner(inputStream); + List result = new ArrayList<>(); + while (scanner.hasNext()) { + String category = scanner.next(); + category = category.replace('_', ' '); + if (!scanner.hasNextFloat()) { + break; + } + float probability = scanner.nextFloat(); + Recognition recognition = new Recognition(null, category, probability, null); + result.add(recognition); + } + return result; + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..7a414d5176a117262dce56c2220e6b71791287de --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..d1eb26c862c04bf573ecc4eb127e7460f0b100fc --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java @@ -0,0 +1,717 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.Manifest; +import android.app.Fragment; +import android.content.Context; +import android.content.pm.PackageManager; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.RectF; +import android.hardware.Camera; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Build; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Trace; +import androidx.annotation.NonNull; +import androidx.annotation.UiThread; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewTreeObserver; +import android.view.WindowManager; +import android.widget.AdapterView; +import android.widget.ImageView; +import android.widget.LinearLayout; +import android.widget.Spinner; +import android.widget.TextView; +import android.widget.Toast; +import com.google.android.material.bottomsheet.BottomSheetBehavior; +import java.nio.ByteBuffer; +import java.util.List; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public abstract class CameraActivity extends AppCompatActivity + implements OnImageAvailableListener, + Camera.PreviewCallback, + View.OnClickListener, + AdapterView.OnItemSelectedListener { + private static final Logger LOGGER = new Logger(); + + private static final int PERMISSIONS_REQUEST = 1; + + private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA; + protected int previewWidth = 0; + protected int previewHeight = 0; + private Handler handler; + private HandlerThread handlerThread; + private boolean useCamera2API; + private boolean isProcessingFrame = false; + private byte[][] yuvBytes = new byte[3][]; + private int[] rgbBytes = null; + private int yRowStride; + private Runnable postInferenceCallback; + private Runnable imageConverter; + private LinearLayout bottomSheetLayout; + private LinearLayout gestureLayout; + private BottomSheetBehavior sheetBehavior; + protected TextView recognitionTextView, + recognition1TextView, + recognition2TextView, + recognitionValueTextView, + recognition1ValueTextView, + recognition2ValueTextView; + protected TextView frameValueTextView, + cropValueTextView, + cameraResolutionTextView, + rotationTextView, + inferenceTimeTextView; + protected ImageView bottomSheetArrowImageView; + private ImageView plusImageView, minusImageView; + private Spinner modelSpinner; + private Spinner deviceSpinner; + private TextView threadsTextView; + + //private Model model = Model.QUANTIZED_EFFICIENTNET; + //private Device device = Device.CPU; + private Model model = Model.FLOAT_EFFICIENTNET; + private Device device = Device.GPU; + private int numThreads = -1; + + @Override + protected void onCreate(final Bundle savedInstanceState) { + LOGGER.d("onCreate " + this); + super.onCreate(null); + getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); + + setContentView(R.layout.tfe_ic_activity_camera); + + if (hasPermission()) { + setFragment(); + } else { + requestPermission(); + } + + threadsTextView = findViewById(R.id.threads); + plusImageView = findViewById(R.id.plus); + minusImageView = findViewById(R.id.minus); + modelSpinner = findViewById(R.id.model_spinner); + deviceSpinner = findViewById(R.id.device_spinner); + bottomSheetLayout = findViewById(R.id.bottom_sheet_layout); + gestureLayout = findViewById(R.id.gesture_layout); + sheetBehavior = BottomSheetBehavior.from(bottomSheetLayout); + bottomSheetArrowImageView = findViewById(R.id.bottom_sheet_arrow); + + ViewTreeObserver vto = gestureLayout.getViewTreeObserver(); + vto.addOnGlobalLayoutListener( + new ViewTreeObserver.OnGlobalLayoutListener() { + @Override + public void onGlobalLayout() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.JELLY_BEAN) { + gestureLayout.getViewTreeObserver().removeGlobalOnLayoutListener(this); + } else { + gestureLayout.getViewTreeObserver().removeOnGlobalLayoutListener(this); + } + // int width = bottomSheetLayout.getMeasuredWidth(); + int height = gestureLayout.getMeasuredHeight(); + + sheetBehavior.setPeekHeight(height); + } + }); + sheetBehavior.setHideable(false); + + sheetBehavior.setBottomSheetCallback( + new BottomSheetBehavior.BottomSheetCallback() { + @Override + public void onStateChanged(@NonNull View bottomSheet, int newState) { + switch (newState) { + case BottomSheetBehavior.STATE_HIDDEN: + break; + case BottomSheetBehavior.STATE_EXPANDED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_down); + } + break; + case BottomSheetBehavior.STATE_COLLAPSED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + } + break; + case BottomSheetBehavior.STATE_DRAGGING: + break; + case BottomSheetBehavior.STATE_SETTLING: + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + break; + } + } + + @Override + public void onSlide(@NonNull View bottomSheet, float slideOffset) {} + }); + + recognitionTextView = findViewById(R.id.detected_item); + recognitionValueTextView = findViewById(R.id.detected_item_value); + recognition1TextView = findViewById(R.id.detected_item1); + recognition1ValueTextView = findViewById(R.id.detected_item1_value); + recognition2TextView = findViewById(R.id.detected_item2); + recognition2ValueTextView = findViewById(R.id.detected_item2_value); + + frameValueTextView = findViewById(R.id.frame_info); + cropValueTextView = findViewById(R.id.crop_info); + cameraResolutionTextView = findViewById(R.id.view_info); + rotationTextView = findViewById(R.id.rotation_info); + inferenceTimeTextView = findViewById(R.id.inference_info); + + modelSpinner.setOnItemSelectedListener(this); + deviceSpinner.setOnItemSelectedListener(this); + + plusImageView.setOnClickListener(this); + minusImageView.setOnClickListener(this); + + model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); + device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); + numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); + } + + protected int[] getRgbBytes() { + imageConverter.run(); + return rgbBytes; + } + + protected int getLuminanceStride() { + return yRowStride; + } + + protected byte[] getLuminance() { + return yuvBytes[0]; + } + + /** Callback for android.hardware.Camera API */ + @Override + public void onPreviewFrame(final byte[] bytes, final Camera camera) { + if (isProcessingFrame) { + LOGGER.w("Dropping frame!"); + return; + } + + try { + // Initialize the storage bitmaps once when the resolution is known. + if (rgbBytes == null) { + Camera.Size previewSize = camera.getParameters().getPreviewSize(); + previewHeight = previewSize.height; + previewWidth = previewSize.width; + rgbBytes = new int[previewWidth * previewHeight]; + onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90); + } + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + return; + } + + isProcessingFrame = true; + yuvBytes[0] = bytes; + yRowStride = previewWidth; + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + camera.addCallbackBuffer(bytes); + isProcessingFrame = false; + } + }; + processImage(); + } + + /** Callback for Camera2 API */ + @Override + public void onImageAvailable(final ImageReader reader) { + // We need wait until we have some size from onPreviewSizeChosen + if (previewWidth == 0 || previewHeight == 0) { + return; + } + if (rgbBytes == null) { + rgbBytes = new int[previewWidth * previewHeight]; + } + try { + final Image image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + if (isProcessingFrame) { + image.close(); + return; + } + isProcessingFrame = true; + Trace.beginSection("imageAvailable"); + final Plane[] planes = image.getPlanes(); + fillBytes(planes, yuvBytes); + yRowStride = planes[0].getRowStride(); + final int uvRowStride = planes[1].getRowStride(); + final int uvPixelStride = planes[1].getPixelStride(); + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420ToARGB8888( + yuvBytes[0], + yuvBytes[1], + yuvBytes[2], + previewWidth, + previewHeight, + yRowStride, + uvRowStride, + uvPixelStride, + rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + image.close(); + isProcessingFrame = false; + } + }; + + processImage(); + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + Trace.endSection(); + return; + } + Trace.endSection(); + } + + @Override + public synchronized void onStart() { + LOGGER.d("onStart " + this); + super.onStart(); + } + + @Override + public synchronized void onResume() { + LOGGER.d("onResume " + this); + super.onResume(); + + handlerThread = new HandlerThread("inference"); + handlerThread.start(); + handler = new Handler(handlerThread.getLooper()); + } + + @Override + public synchronized void onPause() { + LOGGER.d("onPause " + this); + + handlerThread.quitSafely(); + try { + handlerThread.join(); + handlerThread = null; + handler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + + super.onPause(); + } + + @Override + public synchronized void onStop() { + LOGGER.d("onStop " + this); + super.onStop(); + } + + @Override + public synchronized void onDestroy() { + LOGGER.d("onDestroy " + this); + super.onDestroy(); + } + + protected synchronized void runInBackground(final Runnable r) { + if (handler != null) { + handler.post(r); + } + } + + @Override + public void onRequestPermissionsResult( + final int requestCode, final String[] permissions, final int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == PERMISSIONS_REQUEST) { + if (allPermissionsGranted(grantResults)) { + setFragment(); + } else { + requestPermission(); + } + } + } + + private static boolean allPermissionsGranted(final int[] grantResults) { + for (int result : grantResults) { + if (result != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + private boolean hasPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED; + } else { + return true; + } + } + + private void requestPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA)) { + Toast.makeText( + CameraActivity.this, + "Camera permission is required for this demo", + Toast.LENGTH_LONG) + .show(); + } + requestPermissions(new String[] {PERMISSION_CAMERA}, PERMISSIONS_REQUEST); + } + } + + // Returns true if the device supports the required hardware level, or better. + private boolean isHardwareLevelSupported( + CameraCharacteristics characteristics, int requiredLevel) { + int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL); + if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) { + return requiredLevel == deviceLevel; + } + // deviceLevel is not LEGACY, can use numerical sort + return requiredLevel <= deviceLevel; + } + + private String chooseCamera() { + final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE); + try { + for (final String cameraId : manager.getCameraIdList()) { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + if (map == null) { + continue; + } + + // Fallback to camera1 API for internal cameras that don't have full support. + // This should help with legacy situations where using the camera2 API causes + // distorted or otherwise broken previews. + useCamera2API = + (facing == CameraCharacteristics.LENS_FACING_EXTERNAL) + || isHardwareLevelSupported( + characteristics, CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL); + LOGGER.i("Camera API lv2?: %s", useCamera2API); + return cameraId; + } + } catch (CameraAccessException e) { + LOGGER.e(e, "Not allowed to access camera"); + } + + return null; + } + + protected void setFragment() { + String cameraId = chooseCamera(); + + Fragment fragment; + if (useCamera2API) { + CameraConnectionFragment camera2Fragment = + CameraConnectionFragment.newInstance( + new CameraConnectionFragment.ConnectionCallback() { + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + previewHeight = size.getHeight(); + previewWidth = size.getWidth(); + CameraActivity.this.onPreviewSizeChosen(size, rotation); + } + }, + this, + getLayoutId(), + getDesiredPreviewFrameSize()); + + camera2Fragment.setCamera(cameraId); + fragment = camera2Fragment; + } else { + fragment = + new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize()); + } + + getFragmentManager().beginTransaction().replace(R.id.container, fragment).commit(); + } + + protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) { + // Because of the variable row stride it's not possible to know in + // advance the actual necessary dimensions of the yuv planes. + for (int i = 0; i < planes.length; ++i) { + final ByteBuffer buffer = planes[i].getBuffer(); + if (yuvBytes[i] == null) { + LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity()); + yuvBytes[i] = new byte[buffer.capacity()]; + } + buffer.get(yuvBytes[i]); + } + } + + protected void readyForNextImage() { + if (postInferenceCallback != null) { + postInferenceCallback.run(); + } + } + + protected int getScreenOrientation() { + switch (getWindowManager().getDefaultDisplay().getRotation()) { + case Surface.ROTATION_270: + return 270; + case Surface.ROTATION_180: + return 180; + case Surface.ROTATION_90: + return 90; + default: + return 0; + } + } + + @UiThread + protected void showResultsInTexture(float[] img_array, int imageSizeX, int imageSizeY) { + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + + } + + protected void showResultsInBottomSheet(List results) { + if (results != null && results.size() >= 3) { + Recognition recognition = results.get(0); + if (recognition != null) { + if (recognition.getTitle() != null) recognitionTextView.setText(recognition.getTitle()); + if (recognition.getConfidence() != null) + recognitionValueTextView.setText( + String.format("%.2f", (100 * recognition.getConfidence())) + "%"); + } + + Recognition recognition1 = results.get(1); + if (recognition1 != null) { + if (recognition1.getTitle() != null) recognition1TextView.setText(recognition1.getTitle()); + if (recognition1.getConfidence() != null) + recognition1ValueTextView.setText( + String.format("%.2f", (100 * recognition1.getConfidence())) + "%"); + } + + Recognition recognition2 = results.get(2); + if (recognition2 != null) { + if (recognition2.getTitle() != null) recognition2TextView.setText(recognition2.getTitle()); + if (recognition2.getConfidence() != null) + recognition2ValueTextView.setText( + String.format("%.2f", (100 * recognition2.getConfidence())) + "%"); + } + } + } + + protected void showFrameInfo(String frameInfo) { + frameValueTextView.setText(frameInfo); + } + + protected void showCropInfo(String cropInfo) { + cropValueTextView.setText(cropInfo); + } + + protected void showCameraResolution(String cameraInfo) { + cameraResolutionTextView.setText(cameraInfo); + } + + protected void showRotationInfo(String rotation) { + rotationTextView.setText(rotation); + } + + protected void showInference(String inferenceTime) { + inferenceTimeTextView.setText(inferenceTime); + } + + protected Model getModel() { + return model; + } + + private void setModel(Model model) { + if (this.model != model) { + LOGGER.d("Updating model: " + model); + this.model = model; + onInferenceConfigurationChanged(); + } + } + + protected Device getDevice() { + return device; + } + + private void setDevice(Device device) { + if (this.device != device) { + LOGGER.d("Updating device: " + device); + this.device = device; + final boolean threadsEnabled = device == Device.CPU; + plusImageView.setEnabled(threadsEnabled); + minusImageView.setEnabled(threadsEnabled); + threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A"); + onInferenceConfigurationChanged(); + } + } + + protected int getNumThreads() { + return numThreads; + } + + private void setNumThreads(int numThreads) { + if (this.numThreads != numThreads) { + LOGGER.d("Updating numThreads: " + numThreads); + this.numThreads = numThreads; + onInferenceConfigurationChanged(); + } + } + + protected abstract void processImage(); + + protected abstract void onPreviewSizeChosen(final Size size, final int rotation); + + protected abstract int getLayoutId(); + + protected abstract Size getDesiredPreviewFrameSize(); + + protected abstract void onInferenceConfigurationChanged(); + + @Override + public void onClick(View v) { + if (v.getId() == R.id.plus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads >= 9) return; + setNumThreads(++numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } else if (v.getId() == R.id.minus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads == 1) { + return; + } + setNumThreads(--numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } + } + + @Override + public void onItemSelected(AdapterView parent, View view, int pos, long id) { + if (parent == modelSpinner) { + setModel(Model.valueOf(parent.getItemAtPosition(pos).toString().toUpperCase())); + } else if (parent == deviceSpinner) { + setDevice(Device.valueOf(parent.getItemAtPosition(pos).toString())); + } + } + + @Override + public void onNothingSelected(AdapterView parent) { + // Do nothing. + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..13e5c0dc341a86b1ddd66c4b562e0bf767641b42 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java @@ -0,0 +1,575 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.annotation.SuppressLint; +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.res.Configuration; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.text.TextUtils; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.Toast; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.Logger; + +/** + * Camera Connection Fragment that captures images from camera. + * + *

Instantiated by newInstance.

+ */ +@SuppressWarnings("FragmentNotInstantiable") +public class CameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + + /** + * The camera preview size will be chosen to be the smallest frame by pixel size capable of + * containing a DESIRED_SIZE x DESIRED_SIZE square. + */ + private static final int MINIMUM_PREVIEW_SIZE = 320; + + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + private static final String FRAGMENT_DIALOG = "dialog"; + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private final Semaphore cameraOpenCloseLock = new Semaphore(1); + /** A {@link OnImageAvailableListener} to receive frames as they are available. */ + private final OnImageAvailableListener imageListener; + /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */ + private final Size inputSize; + /** The layout identifier to inflate for this Fragment. */ + private final int layout; + + private final ConnectionCallback cameraConnectionCallback; + private final CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + @Override + public void onCaptureProgressed( + final CameraCaptureSession session, + final CaptureRequest request, + final CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + final CameraCaptureSession session, + final CaptureRequest request, + final TotalCaptureResult result) {} + }; + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + /** The rotation in degrees of the camera sensor from the display. */ + private Integer sensorOrientation; + /** The {@link Size} of camera preview. */ + private Size previewSize; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An {@link ImageReader} that handles preview frame capture. */ + private ImageReader previewReader; + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + @Override + public void onOpened(final CameraDevice cd) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = cd; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(final CameraDevice cd) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + } + + @Override + public void onError(final CameraDevice cd, final int error) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + final Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + @SuppressLint("ValidFragment") + private CameraConnectionFragment( + final ConnectionCallback connectionCallback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + this.cameraConnectionCallback = connectionCallback; + this.imageListener = imageListener; + this.layout = layout; + this.inputSize = inputSize; + } + + /** + * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose + * width and height are at least as large as the minimum of both, or an exact match if possible. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param width The minimum desired width + * @param height The minimum desired height + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) { + final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE); + final Size desiredSize = new Size(width, height); + + // Collect the supported resolutions that are at least as big as the preview Surface + boolean exactSizeFound = false; + final List bigEnough = new ArrayList(); + final List tooSmall = new ArrayList(); + for (final Size option : choices) { + if (option.equals(desiredSize)) { + // Set the size but don't return yet so that remaining sizes will still be logged. + exactSizeFound = true; + } + + if (option.getHeight() >= minSize && option.getWidth() >= minSize) { + bigEnough.add(option); + } else { + tooSmall.add(option); + } + } + + LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize); + LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]"); + LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]"); + + if (exactSizeFound) { + LOGGER.i("Exact size match found."); + return desiredSize; + } + + // Pick the smallest of those, assuming we found any + if (bigEnough.size() > 0) { + final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea()); + LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight()); + return chosenSize; + } else { + LOGGER.e("Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static CameraConnectionFragment newInstance( + final ConnectionCallback callback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + return new CameraConnectionFragment(callback, imageListener, layout, inputSize); + } + + /** + * Shows a {@link Toast} on the UI thread. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + Toast.makeText(activity, text, Toast.LENGTH_SHORT).show(); + } + }); + } + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + public void setCamera(String cameraId) { + this.cameraId = cameraId; + } + + /** Sets up member variables related to camera. */ + private void setUpCameraOutputs() { + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + + // Danger, W.R.! Attempting to use too large a preview size could exceed the camera + // bus' bandwidth limitation, resulting in gorgeous previews but the storage of + // garbage capture data. + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + inputSize.getWidth(), + inputSize.getHeight()); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + final int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.tfe_ic_camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + throw new IllegalStateException(getString(R.string.tfe_ic_camera_error)); + } + + cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation); + } + + /** Opens the camera specified by {@link CameraConnectionFragment#cameraId}. */ + private void openCamera(final int width, final int height) { + setUpCameraOutputs(); + configureTransform(width, height); + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != previewReader) { + previewReader.close(); + previewReader = null; + } + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("ImageListener"); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + final SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + final Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight()); + + // Create the reader for the preview frames. + previewReader = + ImageReader.newInstance( + previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2); + + previewReader.setOnImageAvailableListener(imageListener, backgroundHandler); + previewRequestBuilder.addTarget(previewReader.getSurface()); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface, previewReader.getSurface()), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(final CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + // Flash is automatically enabled when necessary. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + @Override + public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** + * Configures the necessary {@link Matrix} transformation to `mTextureView`. This method should be + * called after the camera preview size is determined in setUpCameraOutputs and also the size of + * `mTextureView` is fixed. + * + * @param viewWidth The width of `mTextureView` + * @param viewHeight The height of `mTextureView` + */ + private void configureTransform(final int viewWidth, final int viewHeight) { + final Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + final Matrix matrix = new Matrix(); + final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + final float centerX = viewRect.centerX(); + final float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + final float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** + * Callback for Activities to use to initialize their data once the selected preview size is + * known. + */ + public interface ConnectionCallback { + void onPreviewSizeChosen(Size size, int cameraRotation); + } + + /** Compares two {@code Size}s based on their areas. */ + static class CompareSizesByArea implements Comparator { + @Override + public int compare(final Size lhs, final Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(final String message) { + final ErrorDialog dialog = new ErrorDialog(); + final Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(final Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(final DialogInterface dialogInterface, final int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..24b5d72fdb42d47e5d2c87e3f944b71105748c1b --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java @@ -0,0 +1,238 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Typeface; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.SystemClock; +import android.util.Size; +import android.util.TypedValue; +import android.view.TextureView; +import android.view.ViewStub; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.List; +import java.util.ArrayList; + +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.BorderedText; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; + +import android.widget.ImageView; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.Rect; +import android.graphics.RectF; +import android.graphics.PixelFormat; +import java.nio.ByteBuffer; + +public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); + private static final float TEXT_SIZE_DIP = 10; + private Bitmap rgbFrameBitmap = null; + private long lastProcessingTimeMs; + private Integer sensorOrientation; + private Classifier classifier; + private BorderedText borderedText; + /** Input image size of the model along x axis. */ + private int imageSizeX; + /** Input image size of the model along y axis. */ + private int imageSizeY; + + @Override + protected int getLayoutId() { + return R.layout.tfe_ic_camera_connection_fragment; + } + + @Override + protected Size getDesiredPreviewFrameSize() { + return DESIRED_PREVIEW_SIZE; + } + + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + final float textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + borderedText = new BorderedText(textSizePx); + borderedText.setTypeface(Typeface.MONOSPACE); + + recreateClassifier(getModel(), getDevice(), getNumThreads()); + if (classifier == null) { + LOGGER.e("No classifier on preview!"); + return; + } + + previewWidth = size.getWidth(); + previewHeight = size.getHeight(); + + sensorOrientation = rotation - getScreenOrientation(); + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); + + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + } + + @Override + protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); + final int cropSize = Math.min(previewWidth, previewHeight); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + //final List results = + // classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + final List results = new ArrayList<>(); + + float[] img_array = classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + + + /* + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + */ + + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + //showResultsInBottomSheet(results); + showResultsInTexture(img_array, imageSizeX, imageSizeY); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(cropSize + "x" + cropSize); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); + } + + @Override + protected void onInferenceConfigurationChanged() { + if (rgbFrameBitmap == null) { + // Defer creation until we're getting camera frames. + return; + } + final Device device = getDevice(); + final Model model = getModel(); + final int numThreads = getNumThreads(); + runInBackground(() -> recreateClassifier(model, device, numThreads)); + } + + private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU + && (model == Model.QUANTIZED_MOBILENET || model == Model.QUANTIZED_EFFICIENTNET)) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, R.string.tfe_ic_gpu_quant_error, Toast.LENGTH_LONG).show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException | IllegalArgumentException e) { + LOGGER.e(e, "Failed to create classifier."); + runOnUiThread( + () -> { + Toast.makeText(this, e.getMessage(), Toast.LENGTH_LONG).show(); + }); + return; + } + + // Updates the input image size. + imageSizeX = classifier.getImageSizeX(); + imageSizeY = classifier.getImageSizeY(); + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..760fe90375450c7b1356603c83fb37a68548ca13 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java @@ -0,0 +1,203 @@ +package org.tensorflow.lite.examples.classification; + +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import android.annotation.SuppressLint; +import android.app.Fragment; +import android.graphics.SurfaceTexture; +import android.hardware.Camera; +import android.hardware.Camera.CameraInfo; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import java.io.IOException; +import java.util.List; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; + +public class LegacyCameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + private Camera camera; + private Camera.PreviewCallback imageListener; + private Size desiredSize; + /** The layout identifier to inflate for this Fragment. */ + private int layout; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + + int index = getCameraId(); + camera = Camera.open(index); + + try { + Camera.Parameters parameters = camera.getParameters(); + List focusModes = parameters.getSupportedFocusModes(); + if (focusModes != null + && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { + parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); + } + List cameraSizes = parameters.getSupportedPreviewSizes(); + Size[] sizes = new Size[cameraSizes.size()]; + int i = 0; + for (Camera.Size size : cameraSizes) { + sizes[i++] = new Size(size.width, size.height); + } + Size previewSize = + CameraConnectionFragment.chooseOptimalSize( + sizes, desiredSize.getWidth(), desiredSize.getHeight()); + parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight()); + camera.setDisplayOrientation(90); + camera.setParameters(parameters); + camera.setPreviewTexture(texture); + } catch (IOException exception) { + camera.release(); + } + + camera.setPreviewCallbackWithBuffer(imageListener); + Camera.Size s = camera.getParameters().getPreviewSize(); + camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]); + + textureView.setAspectRatio(s.height, s.width); + + camera.startPreview(); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) {} + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + @SuppressLint("ValidFragment") + public LegacyCameraConnectionFragment( + final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) { + this.imageListener = imageListener; + this.layout = layout; + this.desiredSize = desiredSize; + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + + if (textureView.isAvailable()) { + if (camera != null) { + camera.startPreview(); + } + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + stopCamera(); + stopBackgroundThread(); + super.onPause(); + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("CameraBackground"); + backgroundThread.start(); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + protected void stopCamera() { + if (camera != null) { + camera.stopPreview(); + camera.setPreviewCallback(null); + camera.release(); + camera = null; + } + } + + private int getCameraId() { + CameraInfo ci = new CameraInfo(); + for (int i = 0; i < Camera.getNumberOfCameras(); i++) { + Camera.getCameraInfo(i, ci); + if (ci.facing == CameraInfo.CAMERA_FACING_BACK) return i; + } + return -1; // No camera found + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java new file mode 100644 index 0000000000000000000000000000000000000000..62e99ae70c2a7c4c60a776e7490742c5339e85f3 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + private int ratioWidth = 0; + private int ratioHeight = 0; + + public AutoFitTextureView(final Context context) { + this(context, null); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(final int width, final int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + ratioWidth = width; + ratioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + final int width = MeasureSpec.getSize(widthMeasureSpec); + final int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == ratioWidth || 0 == ratioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * ratioWidth / ratioHeight) { + setMeasuredDimension(width, width * ratioHeight / ratioWidth); + } else { + setMeasuredDimension(height * ratioWidth / ratioHeight, height); + } + } + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java new file mode 100644 index 0000000000000000000000000000000000000000..dc302ac04f145c9a1673a2d7e630a94a05ab1b1a --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.util.AttributeSet; +import android.view.View; +import java.util.LinkedList; +import java.util.List; + +/** A simple View providing a render callback to other classes. */ +public class OverlayView extends View { + private final List callbacks = new LinkedList(); + + public OverlayView(final Context context, final AttributeSet attrs) { + super(context, attrs); + } + + public void addCallback(final DrawCallback callback) { + callbacks.add(callback); + } + + @Override + public synchronized void draw(final Canvas canvas) { + for (final DrawCallback callback : callbacks) { + callback.drawCallback(canvas); + } + } + + /** Interface defining the callback for client classes. */ + public interface DrawCallback { + public void drawCallback(final Canvas canvas); + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java new file mode 100644 index 0000000000000000000000000000000000000000..2c57f603f12200079c888793cfa40d9b10dabde3 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.graphics.Paint; +import android.util.AttributeSet; +import android.util.TypedValue; +import android.view.View; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public class RecognitionScoreView extends View implements ResultsView { + private static final float TEXT_SIZE_DIP = 16; + private final float textSizePx; + private final Paint fgPaint; + private final Paint bgPaint; + private List results; + + public RecognitionScoreView(final Context context, final AttributeSet set) { + super(context, set); + + textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + fgPaint = new Paint(); + fgPaint.setTextSize(textSizePx); + + bgPaint = new Paint(); + bgPaint.setColor(0xcc4285f4); + } + + @Override + public void setResults(final List results) { + this.results = results; + postInvalidate(); + } + + @Override + public void onDraw(final Canvas canvas) { + final int x = 10; + int y = (int) (fgPaint.getTextSize() * 1.5f); + + canvas.drawPaint(bgPaint); + + if (results != null) { + for (final Recognition recog : results) { + canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); + y += (int) (fgPaint.getTextSize() * 1.5f); + } + } + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java new file mode 100644 index 0000000000000000000000000000000000000000..d055eb5f161a57fc439716efe6d49b7e45ef3fc7 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public interface ResultsView { + public void setResults(final List results); +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 0000000000000000000000000000000000000000..b1517edf496ef5800b97d046b92012a9f94a34d0 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml new file mode 100644 index 0000000000000000000000000000000000000000..70f4b24e35039e6bfc35989bcbe570a4bdc2ae07 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml @@ -0,0 +1,9 @@ + + + + + + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml new file mode 100644 index 0000000000000000000000000000000000000000..757f4503314fb9e5837f68ac515f4487d9b5fc2c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml @@ -0,0 +1,9 @@ + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml new file mode 100644 index 0000000000000000000000000000000000000000..a64b853e79137f0fd95f9d5fa6e0552cc255c7ae --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml @@ -0,0 +1,9 @@ + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000000000000000000000000000000000000..d5fccc538c179838bfdce779c26eebb4fa0b5ce9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml new file mode 100644 index 0000000000000000000000000000000000000000..b8f5d3559c4e83072d5d73a3241d240aa68daccf --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml @@ -0,0 +1,13 @@ + + + + + + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..f0e1dae7afa15cf4a832de708f345482a6dfeff6 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml new file mode 100644 index 0000000000000000000000000000000000000000..97e5e7c6df25da48977f9064a888fd3735e4986f --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml @@ -0,0 +1,32 @@ + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml new file mode 100644 index 0000000000000000000000000000000000000000..77a348af90e2ed995ff106cd209cbf304c6b9153 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml @@ -0,0 +1,321 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..ed82bafb536474c6a88c996b439a2781f31f3d3e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml @@ -0,0 +1,8 @@ + + + #ffa800 + #ff6f00 + #425066 + + #66000000 + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml new file mode 100644 index 0000000000000000000000000000000000000000..5d3609029ca66b612c88b4f395e4e2e3cfc1f0e6 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml @@ -0,0 +1,5 @@ + + + 15dp + 8dp + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..7d763d85efc49879c8d3c0641484f5f472bfaca0 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml @@ -0,0 +1,21 @@ + + Midas + This device doesn\'t support Camera2 API. + GPU does not yet supported quantized models. + Model: + + Float_EfficientNet + + + + Device: + + GPU + CPU + NNAPI + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..ad09a13ec6b2de8920a7441c9992f3cc0eedcfda --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..14492756847191ca3beff4c2e012d378c4e44be6 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle @@ -0,0 +1,27 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:4.0.0' + classpath 'de.undercouch:gradle-download-task:4.0.2' + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..9592636c07d9d5e6f61c0cfce1311d3e1ffcf34d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties @@ -0,0 +1,15 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +android.useAndroidX=true +android.enableJetifier=true diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..f3d88b1c2faf2fc91d853cd5d4242b5547257070 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..1b16c34a71cf212ed0cfb883d14d1b8511903eb2 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.1.1-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew new file mode 100644 index 0000000000000000000000000000000000000000..2fe81a7d95e4f9ad2c9b2a046707d36ceb3980b3 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew @@ -0,0 +1,183 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..9618d8d9607cd91a0efb866bcac4810064ba6fac --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat @@ -0,0 +1,100 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..5d463975293264765a941795601cddb6cfc84f00 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite + implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true } + // Use local TensorFlow library + // implementation 'org.tensorflow:tensorflow-lite-local:0.0.0' +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..24ec573e7d184e7d64118a723d6645fd92d6e6d9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,376 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import android.view.TextureView; +import android.view.ViewStub; + +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.gpu.GpuDelegate; +import org.tensorflow.lite.nnapi.NnApiDelegate; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.TensorProcessor; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.image.ops.ResizeOp; +import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod; +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp; +import org.tensorflow.lite.support.image.ops.Rot90Op; +import org.tensorflow.lite.support.label.TensorLabel; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithSupport"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + QUANTIZED_EFFICIENTNET, + FLOAT_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** The loaded TensorFlow Lite model. */ + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + + /** Optional GPU delegate for accleration. */ + private GpuDelegate gpuDelegate = null; + + /** Optional NNAPI delegate for accleration. */ + private NnApiDelegate nnApiDelegate = null; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected Interpreter tflite; + + /** Options for configuring the Interpreter. */ + private final Interpreter.Options tfliteOptions = new Interpreter.Options(); + + /** Labels corresponding to the output of the vision model. */ + private final List labels; + + /** Input image TensorBuffer. */ + private TensorImage inputImageBuffer; + + /** Output probability TensorBuffer. */ + private final TensorBuffer outputProbabilityBuffer; + + /** Processer to apply post processing of the output probability. */ + private final TensorProcessor probabilityProcessor; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + + // Loads labels out from the label file. + labels = FileUtil.loadLabels(activity, getLabelPath()); + + // Reads type and shape of input and output tensors, respectively. + int imageTensorIndex = 0; + int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3} + if(imageShape[1] != imageShape[2]) { + imageSizeY = imageShape[2]; + imageSizeX = imageShape[3]; + } else { + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType(); + int probabilityTensorIndex = 0; + int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, NUM_CLASSES} + DataType probabilityDataType = tflite.getOutputTensor(probabilityTensorIndex).dataType(); + + // Creates the input tensor. + inputImageBuffer = new TensorImage(imageDataType); + + // Creates the output tensor and its processor. + outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + + // Creates the post processor for the output probability. + probabilityProcessor = new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); + + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Runs inference and returns the classification results. */ + //public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + public float[] recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + Trace.beginSection("loadImage"); + long startTimeForLoadImage = SystemClock.uptimeMillis(); + inputImageBuffer = loadImage(bitmap, sensorOrientation); + long endTimeForLoadImage = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage)); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind()); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + float[] img_array = outputProbabilityBuffer.getFloatArray(); + + // Gets the map of label and probability. + //Map labeledProbability = + // new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer)) + // .getMapWithFloatValue(); + Trace.endSection(); + + // Gets top-k results. + return img_array;//getTopKProbability(labeledProbability); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (tflite != null) { + tflite.close(); + tflite = null; + } + if (gpuDelegate != null) { + gpuDelegate.close(); + gpuDelegate = null; + } + if (nnApiDelegate != null) { + nnApiDelegate.close(); + nnApiDelegate = null; + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** Loads input image, and applies preprocessing. */ + private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + inputImageBuffer.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = min(bitmap.getWidth(), bitmap.getHeight()); + int numRotation = sensorOrientation / 90; + // TODO(b/143564309): Fuse ops inside ImageProcessor. + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // To get the same inference results as lib_task_api, which is built on top of the Task + // Library, use ResizeMethod.BILINEAR. + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.NEAREST_NEIGHBOR)) + //.add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR)) + .add(new Rot90Op(numRotation)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); + } + + /** Gets the top-k results. */ + private static List getTopKProbability(Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); + + /** Gets the name of the label file stored in Assets. */ + protected abstract String getLabelPath(); + + /** Gets the TensorOperator to nomalize the input image in preprocessing. */ + protected abstract TensorOperator getPreprocessNormalizeOp(); + + /** + * Gets the TensorOperator to dequantize the output probability in post processing. + * + *

For quantized model, we need de-quantize the prediction with NormalizeOp (as they are all + * essentially linear transformation). For float model, de-quantize is not required. But to + * uniform the API, de-quantize is added to float model too. Mean and std are set to 0.0f and + * 1.0f, respectively. + */ + protected abstract TensorOperator getPostprocessNormalizeOp(); +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..14dd027b26baefaedd979a8ac37f0bf984210ed4 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + private static final float IMAGE_MEAN = 115.0f; //127.0f; + private static final float IMAGE_STD = 58.0f; //128.0f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model_opt.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..40519de07cf5e887773250a4609a832b6060d684 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + + /** Float MobileNet requires additional normalization of the used input. */ + private static final float IMAGE_MEAN = 127.5f; + + private static final float IMAGE_STD = 127.5f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..d0d62f58d18333b6360ec30a4c85c9f1d38955ce --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..94b06e3df659005c287733a8a37672863fdadd71 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b5983986e3d56a77a41676b9195b0d0882b5fb96 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite Task Library + implementation('org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-metadata:0.0.0-nightly') { changing = true } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..45da52a0d0dfa203255e0f2d44901ee0618e739f --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.Rect; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.support.metadata.MetadataExtractor; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions.Orientation; +import org.tensorflow.lite.task.vision.classifier.Classifications; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier.ImageClassifierOptions; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithTaskApi"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + FLOAT_EFFICIENTNET, + QUANTIZED_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected final ImageClassifier imageClassifier; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + if (device != Device.CPU || numThreads != 1) { + throw new IllegalArgumentException( + "Manipulating the hardware accelerators and numbers of threads is not allowed in the Task" + + " library currently. Only CPU + single thread is allowed."); + } + + // Create the ImageClassifier instance. + ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); + imageClassifier = ImageClassifier.createFromFileAndOptions(activity, getModelPath(), options); + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + + // Get the input image size information of the underlying tflite model. + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + MetadataExtractor metadataExtractor = new MetadataExtractor(tfliteModel); + // Image shape is in the format of {1, height, width, 3}. + int[] imageShape = metadataExtractor.getInputTensorShape(/*inputIndex=*/ 0); + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + + /** Runs inference and returns the classification results. */ + public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + TensorImage inputImage = TensorImage.fromBitmap(bitmap); + int width = bitmap.getWidth(); + int height = bitmap.getHeight(); + int cropSize = min(width, height); + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // Task Library resize the images using bilinear interpolation, which is slightly different from + // the nearest neighbor sampling algorithm used in lib_support. See + // https://github.com/tensorflow/examples/blob/0ef3d93e2af95d325c70ef3bcbbd6844d0631e07/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java#L310. + ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + List results = imageClassifier.classify(inputImage, imageOptions); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + Trace.endSection(); + + return getRecognitions(results); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (imageClassifier != null) { + imageClassifier.close(); + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** + * Converts a list of {@link Classifications} objects into a list of {@link Recognition} objects + * to match the interface of other inference method, such as using the TFLite + * Support Library.. + */ + private static List getRecognitions(List classifications) { + + final ArrayList recognitions = new ArrayList<>(); + // All the demo models are single head models. Get the first Classifications in the results. + for (Category category : classifications.get(0).getCategories()) { + recognitions.add( + new Recognition( + "" + category.getLabel(), category.getLabel(), category.getScore(), null)); + } + return recognitions; + } + + /* Convert the camera orientation in degree into {@link ImageProcessingOptions#Orientation}.*/ + private static Orientation getOrientation(int cameraOrientation) { + switch (cameraOrientation / 90) { + case 3: + return Orientation.BOTTOM_LEFT; + case 2: + return Orientation.BOTTOM_RIGHT; + case 1: + return Orientation.TOP_RIGHT; + default: + return Orientation.TOP_LEFT; + } + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..250794cc12d0e603aa47502322dc646d50689848 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model.tflite"; + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..0707de98de41395eaf3ddcfd74d6e36229a63760 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224.tflite"; + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..05ca4fa6c409d0274a396c9b26c3c39ca8a8194e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "efficientnet-lite0-int8.tflite"; + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..978b08eeaf52a23eede437d61045db08d1dff163 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224_quant.tflite"; + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8d825707af20cbbead6c4599f075599148e3511c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle @@ -0,0 +1,40 @@ +apply plugin: 'com.android.library' +apply plugin: 'de.undercouch.download' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +apply from:'download.gradle' diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle new file mode 100644 index 0000000000000000000000000000000000000000..ce76974a2c3bc6f8214461028e0dfa6ebc25d588 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle @@ -0,0 +1,10 @@ +def modelFloatDownloadUrl = "https://github.com/isl-org/MiDaS/releases/download/v2_1/model_opt.tflite" +def modelFloatFile = "model_opt.tflite" + +task downloadModelFloat(type: Download) { + src "${modelFloatDownloadUrl}" + dest project.ext.ASSET_DIR + "/${modelFloatFile}" + overwrite false +} + +preBuild.dependsOn downloadModelFloat diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..42951a56497c5f947efe4aea6a07462019fb152c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt new file mode 100644 index 0000000000000000000000000000000000000000..f40829ed0fc318c673860fae4be6c48529da116e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt @@ -0,0 +1,1000 @@ +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8ebe235758d3d0f3d357c51ed54d78ac7eea8e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py @@ -0,0 +1,75 @@ +# Flex ops are included in the nightly build of the TensorFlow Python package. You can use TFLite models containing Flex ops by the same Python API as normal TFLite models. The nightly TensorFlow build can be installed with this command: +# Flex ops will be added to the TensorFlow Python package's and the tflite_runtime package from version 2.3 for Linux and 2.4 for other environments. +# https://www.tensorflow.org/lite/guide/ops_select#running_the_model + +# You must use: tf-nightly +# pip install tf-nightly + +import os +import glob +import cv2 +import numpy as np + +import tensorflow as tf + +width=256 +height=256 +model_name="model.tflite" +#model_name="model_quant.tflite" +image_name="dog.jpg" + +# input +img = cv2.imread(image_name) +img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + +mean=[0.485, 0.456, 0.406] +std=[0.229, 0.224, 0.225] +img = (img - mean) / std + +img_resized = tf.image.resize(img, [width,height], method='bicubic', preserve_aspect_ratio=False) +#img_resized = tf.transpose(img_resized, [2, 0, 1]) +img_input = img_resized.numpy() +reshape_img = img_input.reshape(1,width,height,3) +tensor = tf.convert_to_tensor(reshape_img, dtype=tf.float32) + +# load model +print("Load model...") +interpreter = tf.lite.Interpreter(model_path=model_name) +print("Allocate tensor...") +interpreter.allocate_tensors() +print("Get input/output details...") +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() +print("Get input shape...") +input_shape = input_details[0]['shape'] +print(input_shape) +print(input_details) +print(output_details) +#input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) +print("Set input tensor...") +interpreter.set_tensor(input_details[0]['index'], tensor) + +print("invoke()...") +interpreter.invoke() + +# The function `get_tensor()` returns a copy of the tensor data. +# Use `tensor()` in order to get a pointer to the tensor. +print("get output tensor...") +output = interpreter.get_tensor(output_details[0]['index']) +#output = np.squeeze(output) +output = output.reshape(width, height) +#print(output) +prediction = np.array(output) +print("reshape prediction...") +prediction = prediction.reshape(width, height) + +# output file +#prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) +print(" Write image to: output.png") +depth_min = prediction.min() +depth_max = prediction.max() +img_out = (255 * (prediction - depth_min) / (depth_max - depth_min)).astype("uint8") +print("save output image...") +cv2.imwrite("output.png", img_out) + +print("finished") \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e86d89d2483f92b7e778589011fad60fbba3a318 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle @@ -0,0 +1,2 @@ +rootProject.name = 'TFLite Image Classification Demo App' +include ':app', ':lib_support', ':lib_task_api', ':models' \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f1150e3379e4a38d31ca7bb46dc4f31d79f482c2 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore @@ -0,0 +1,2 @@ +# ignore model file +#*.tflite diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..4917371aa33a65fdfc66c02d914f05489c446430 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj @@ -0,0 +1,538 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = A1CE41C09920CCEC31985547 /* Pods_Midas.framework */; }; + 8402440123D9834600704ABD /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 8402440023D9834600704ABD /* README.md */; }; + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */ = {isa = PBXBuildFile; fileRef = 840ECB1F238BAA2300C7D88A /* InfoCell.swift */; }; + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */; }; + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDD002341DE380017ED42 /* Main.storyboard */; }; + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 842DDB6D2372A82000F6BB94 /* OverlayView.swift */; }; + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */; }; + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846BAF7523E7FE13006FC136 /* Constants.swift */; }; + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FEC82341D36E00377D34 /* PreviewView.swift */; }; + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FECA2341D39800377D34 /* CameraFeedManager.swift */; }; + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */; }; + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB82361874A0052C104 /* TFLiteExtension.swift */; }; + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CEE2326338300A11A08 /* AppDelegate.swift */; }; + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CF02326338300A11A08 /* ViewController.swift */; }; + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 84B67CF52326338400A11A08 /* Assets.xcassets */; }; + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */; }; + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 84F232D4254C831E0011862E /* model_opt.tflite */; }; + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */ = {isa = PBXBuildFile; fileRef = 84FCF5912387BD7900663812 /* tfl_logo.png */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 8402440023D9834600704ABD /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InfoCell.swift; sourceTree = ""; }; + 840EDCFC2341DDD30017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = "Base.lproj/Launch Screen.storyboard"; sourceTree = ""; }; + 840EDD012341DE380017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OverlayView.swift; sourceTree = ""; }; + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelDataHandler.swift; sourceTree = ""; }; + 846BAF7523E7FE13006FC136 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; + 8474FEC82341D36E00377D34 /* PreviewView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreviewView.swift; sourceTree = ""; }; + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CameraFeedManager.swift; sourceTree = ""; }; + 84884291236FF0A30043FC4C /* download_models.sh */ = {isa = PBXFileReference; lastKnownFileType = text.script.sh; path = download_models.sh; sourceTree = ""; }; + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CVPixelBufferExtension.swift; sourceTree = ""; }; + 84952CB82361874A0052C104 /* TFLiteExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TFLiteExtension.swift; sourceTree = ""; }; + 84B67CEB2326338300A11A08 /* Midas.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Midas.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 84B67CEE2326338300A11A08 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = ""; }; + 84B67CF02326338300A11A08 /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = ""; }; + 84B67CF52326338400A11A08 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 84B67CFA2326338400A11A08 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CGSizeExtension.swift; sourceTree = ""; }; + 84F232D4254C831E0011862E /* model_opt.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = model_opt.tflite; sourceTree = ""; }; + 84FCF5912387BD7900663812 /* tfl_logo.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; name = tfl_logo.png; path = Assets.xcassets/tfl_logo.png; sourceTree = ""; }; + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_Midas.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.release.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.release.xcconfig"; sourceTree = ""; }; + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.debug.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.debug.xcconfig"; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 84B67CE82326338300A11A08 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 840ECB1E238BAA0D00C7D88A /* Cells */ = { + isa = PBXGroup; + children = ( + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */, + ); + path = Cells; + sourceTree = ""; + }; + 842DDB6C2372A80E00F6BB94 /* Views */ = { + isa = PBXGroup; + children = ( + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */, + ); + path = Views; + sourceTree = ""; + }; + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */ = { + isa = PBXGroup; + children = ( + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */, + ); + path = ModelDataHandler; + sourceTree = ""; + }; + 8474FEC62341D2BE00377D34 /* ViewControllers */ = { + isa = PBXGroup; + children = ( + 84B67CF02326338300A11A08 /* ViewController.swift */, + ); + path = ViewControllers; + sourceTree = ""; + }; + 8474FEC72341D35800377D34 /* Camera Feed */ = { + isa = PBXGroup; + children = ( + 8474FEC82341D36E00377D34 /* PreviewView.swift */, + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */, + ); + path = "Camera Feed"; + sourceTree = ""; + }; + 84884290236FF07F0043FC4C /* RunScripts */ = { + isa = PBXGroup; + children = ( + 84884291236FF0A30043FC4C /* download_models.sh */, + ); + path = RunScripts; + sourceTree = ""; + }; + 848842A22370180C0043FC4C /* Model */ = { + isa = PBXGroup; + children = ( + 84F232D4254C831E0011862E /* model_opt.tflite */, + ); + path = Model; + sourceTree = ""; + }; + 84952CB3236186A20052C104 /* Extensions */ = { + isa = PBXGroup; + children = ( + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */, + 84952CB82361874A0052C104 /* TFLiteExtension.swift */, + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */, + ); + path = Extensions; + sourceTree = ""; + }; + 84B67CE22326338300A11A08 = { + isa = PBXGroup; + children = ( + 8402440023D9834600704ABD /* README.md */, + 84884290236FF07F0043FC4C /* RunScripts */, + 84B67CED2326338300A11A08 /* Midas */, + 84B67CEC2326338300A11A08 /* Products */, + B4DFDCC28443B641BC36251D /* Pods */, + A3DA804B8D3F6891E3A02852 /* Frameworks */, + ); + sourceTree = ""; + }; + 84B67CEC2326338300A11A08 /* Products */ = { + isa = PBXGroup; + children = ( + 84B67CEB2326338300A11A08 /* Midas.app */, + ); + name = Products; + sourceTree = ""; + }; + 84B67CED2326338300A11A08 /* Midas */ = { + isa = PBXGroup; + children = ( + 840ECB1E238BAA0D00C7D88A /* Cells */, + 842DDB6C2372A80E00F6BB94 /* Views */, + 848842A22370180C0043FC4C /* Model */, + 84952CB3236186A20052C104 /* Extensions */, + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */, + 8474FEC72341D35800377D34 /* Camera Feed */, + 8474FEC62341D2BE00377D34 /* ViewControllers */, + 84B67D002326339000A11A08 /* Storyboards */, + 84B67CEE2326338300A11A08 /* AppDelegate.swift */, + 846BAF7523E7FE13006FC136 /* Constants.swift */, + 84B67CF52326338400A11A08 /* Assets.xcassets */, + 84FCF5912387BD7900663812 /* tfl_logo.png */, + 84B67CFA2326338400A11A08 /* Info.plist */, + ); + path = Midas; + sourceTree = ""; + }; + 84B67D002326339000A11A08 /* Storyboards */ = { + isa = PBXGroup; + children = ( + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */, + 840EDD002341DE380017ED42 /* Main.storyboard */, + ); + path = Storyboards; + sourceTree = ""; + }; + A3DA804B8D3F6891E3A02852 /* Frameworks */ = { + isa = PBXGroup; + children = ( + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; + B4DFDCC28443B641BC36251D /* Pods */ = { + isa = PBXGroup; + children = ( + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */, + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */, + ); + path = Pods; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 84B67CEA2326338300A11A08 /* Midas */ = { + isa = PBXNativeTarget; + buildConfigurationList = 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */; + buildPhases = ( + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */, + 84884298237010B90043FC4C /* Download TensorFlow Lite model */, + 84B67CE72326338300A11A08 /* Sources */, + 84B67CE82326338300A11A08 /* Frameworks */, + 84B67CE92326338300A11A08 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = Midas; + productName = Midas; + productReference = 84B67CEB2326338300A11A08 /* Midas.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 84B67CE32326338300A11A08 /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 1030; + LastUpgradeCheck = 1030; + ORGANIZATIONNAME = tensorflow; + TargetAttributes = { + 84B67CEA2326338300A11A08 = { + CreatedOnToolsVersion = 10.3; + }; + }; + }; + buildConfigurationList = 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 84B67CE22326338300A11A08; + productRefGroup = 84B67CEC2326338300A11A08 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 84B67CEA2326338300A11A08 /* Midas */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 84B67CE92326338300A11A08 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 8402440123D9834600704ABD /* README.md in Resources */, + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */, + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */, + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */, + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */, + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-Midas-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; + 84884298237010B90043FC4C /* Download TensorFlow Lite model */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + ); + name = "Download TensorFlow Lite model"; + outputFileListPaths = ( + ); + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/bash; + shellScript = "\"$SRCROOT/RunScripts/download_models.sh\"\n"; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 84B67CE72326338300A11A08 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */, + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */, + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */, + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */, + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */, + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */, + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */, + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */, + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */, + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */, + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDCFC2341DDD30017ED42 /* Base */, + ); + name = "Launch Screen.storyboard"; + sourceTree = ""; + }; + 840EDD002341DE380017ED42 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDD012341DE380017ED42 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 84B67CFB2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + 84B67CFC2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 84B67CFE2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 84B67CFF2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFB2326338400A11A08 /* Debug */, + 84B67CFC2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFE2326338400A11A08 /* Debug */, + 84B67CFF2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 84B67CE32326338300A11A08 /* Project object */; +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000000000000000000000000000000000000..919434a6254f0e9651f402737811be6634a03e9c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000000000000000000000000000000000000..18d981003d68d0546c4804ac2ff47dd97c6e7921 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate new file mode 100644 index 0000000000000000000000000000000000000000..1d20756ee57b79e9f9f886453bdb7997ca2ee2d4 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist new file mode 100644 index 0000000000000000000000000000000000000000..6093f6160eedfdfc20e96396247a7dbc9247cc55 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist @@ -0,0 +1,14 @@ + + + + + SchemeUserState + + PoseNet.xcscheme_^#shared#^_ + + orderHint + 3 + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift new file mode 100644 index 0000000000000000000000000000000000000000..233f0291ab4f379067543bdad3cc198a2dc3ab0f --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift @@ -0,0 +1,41 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +@UIApplicationMain +class AppDelegate: UIResponder, UIApplicationDelegate { + + var window: UIWindow? + + func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?) -> Bool { + return true + } + + func applicationWillResignActive(_ application: UIApplication) { + } + + func applicationDidEnterBackground(_ application: UIApplication) { + } + + func applicationWillEnterForeground(_ application: UIApplication) { + } + + func applicationDidBecomeActive(_ application: UIApplication) { + } + + func applicationWillTerminate(_ application: UIApplication) { + } +} + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..65b74d7ef11fa59fafa829e681ac90906f3ac8b2 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1 @@ +{"images":[{"size":"60x60","expected-size":"180","filename":"180.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"40x40","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"60x60","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"57x57","expected-size":"57","filename":"57.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"87","filename":"87.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"57x57","expected-size":"114","filename":"114.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"60","filename":"60.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"1024x1024","filename":"1024.png","expected-size":"1024","idiom":"ios-marketing","folder":"Assets.xcassets/AppIcon.appiconset/","scale":"1x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"72x72","expected-size":"72","filename":"72.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"76x76","expected-size":"152","filename":"152.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"50x50","expected-size":"100","filename":"100.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"76x76","expected-size":"76","filename":"76.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"50x50","expected-size":"50","filename":"50.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"72x72","expected-size":"144","filename":"144.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"40x40","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"83.5x83.5","expected-size":"167","filename":"167.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"20x20","expected-size":"20","filename":"20.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"}]} \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift new file mode 100644 index 0000000000000000000000000000000000000000..48d65b88ee220e722fbad2570c8e879a431cd0f5 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift @@ -0,0 +1,316 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + +// MARK: - CameraFeedManagerDelegate Declaration +@objc protocol CameraFeedManagerDelegate: class { + /// This method delivers the pixel buffer of the current frame seen by the device's camera. + @objc optional func cameraFeedManager( + _ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer + ) + + /// This method initimates that a session runtime error occured. + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) + + /// This method initimates that the session was interrupted. + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) + + /// This method initimates that the session interruption has ended. + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) + + /// This method initimates that there was an error in video configurtion. + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) + + /// This method initimates that the camera permissions have been denied. + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) +} + +/// This enum holds the state of the camera initialization. +// MARK: - Camera Initialization State Enum +enum CameraConfiguration { + case success + case failed + case permissionDenied +} + +/// This class manages all camera related functionalities. +// MARK: - Camera Related Functionalies Manager +class CameraFeedManager: NSObject { + // MARK: Camera Related Instance Variables + private let session: AVCaptureSession = AVCaptureSession() + + private let previewView: PreviewView + private let sessionQueue = DispatchQueue(label: "sessionQueue") + private var cameraConfiguration: CameraConfiguration = .failed + private lazy var videoDataOutput = AVCaptureVideoDataOutput() + private var isSessionRunning = false + + // MARK: CameraFeedManagerDelegate + weak var delegate: CameraFeedManagerDelegate? + + // MARK: Initializer + init(previewView: PreviewView) { + self.previewView = previewView + super.init() + + // Initializes the session + session.sessionPreset = .high + self.previewView.session = session + self.previewView.previewLayer.connection?.videoOrientation = .portrait + self.previewView.previewLayer.videoGravity = .resizeAspectFill + self.attemptToConfigureSession() + } + + // MARK: Session Start and End methods + + /// This method starts an AVCaptureSession based on whether the camera configuration was successful. + func checkCameraConfigurationAndStartSession() { + sessionQueue.async { + switch self.cameraConfiguration { + case .success: + self.addObservers() + self.startSession() + case .failed: + DispatchQueue.main.async { + self.delegate?.presentVideoConfigurationErrorAlert(self) + } + case .permissionDenied: + DispatchQueue.main.async { + self.delegate?.presentCameraPermissionsDeniedAlert(self) + } + } + } + } + + /// This method stops a running an AVCaptureSession. + func stopSession() { + self.removeObservers() + sessionQueue.async { + if self.session.isRunning { + self.session.stopRunning() + self.isSessionRunning = self.session.isRunning + } + } + + } + + /// This method resumes an interrupted AVCaptureSession. + func resumeInterruptedSession(withCompletion completion: @escaping (Bool) -> Void) { + sessionQueue.async { + self.startSession() + + DispatchQueue.main.async { + completion(self.isSessionRunning) + } + } + } + + /// This method starts the AVCaptureSession + private func startSession() { + self.session.startRunning() + self.isSessionRunning = self.session.isRunning + } + + // MARK: Session Configuration Methods. + /// This method requests for camera permissions and handles the configuration of the session and stores the result of configuration. + private func attemptToConfigureSession() { + switch AVCaptureDevice.authorizationStatus(for: .video) { + case .authorized: + self.cameraConfiguration = .success + case .notDetermined: + self.sessionQueue.suspend() + self.requestCameraAccess(completion: { granted in + self.sessionQueue.resume() + }) + case .denied: + self.cameraConfiguration = .permissionDenied + default: + break + } + + self.sessionQueue.async { + self.configureSession() + } + } + + /// This method requests for camera permissions. + private func requestCameraAccess(completion: @escaping (Bool) -> Void) { + AVCaptureDevice.requestAccess(for: .video) { (granted) in + if !granted { + self.cameraConfiguration = .permissionDenied + } else { + self.cameraConfiguration = .success + } + completion(granted) + } + } + + /// This method handles all the steps to configure an AVCaptureSession. + private func configureSession() { + guard cameraConfiguration == .success else { + return + } + session.beginConfiguration() + + // Tries to add an AVCaptureDeviceInput. + guard addVideoDeviceInput() == true else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + // Tries to add an AVCaptureVideoDataOutput. + guard addVideoDataOutput() else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + session.commitConfiguration() + self.cameraConfiguration = .success + } + + /// This method tries to an AVCaptureDeviceInput to the current AVCaptureSession. + private func addVideoDeviceInput() -> Bool { + /// Tries to get the default back camera. + guard + let camera = AVCaptureDevice.default(.builtInWideAngleCamera, for: .video, position: .back) + else { + fatalError("Cannot find camera") + } + + do { + let videoDeviceInput = try AVCaptureDeviceInput(device: camera) + if session.canAddInput(videoDeviceInput) { + session.addInput(videoDeviceInput) + return true + } else { + return false + } + } catch { + fatalError("Cannot create video device input") + } + } + + /// This method tries to an AVCaptureVideoDataOutput to the current AVCaptureSession. + private func addVideoDataOutput() -> Bool { + let sampleBufferQueue = DispatchQueue(label: "sampleBufferQueue") + videoDataOutput.setSampleBufferDelegate(self, queue: sampleBufferQueue) + videoDataOutput.alwaysDiscardsLateVideoFrames = true + videoDataOutput.videoSettings = [ + String(kCVPixelBufferPixelFormatTypeKey): kCMPixelFormat_32BGRA + ] + + if session.canAddOutput(videoDataOutput) { + session.addOutput(videoDataOutput) + videoDataOutput.connection(with: .video)?.videoOrientation = .portrait + return true + } + return false + } + + // MARK: Notification Observer Handling + private func addObservers() { + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionRuntimeErrorOccured(notification:)), + name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionWasInterrupted(notification:)), + name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionInterruptionEnded), + name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + private func removeObservers() { + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + // MARK: Notification Observers + @objc func sessionWasInterrupted(notification: Notification) { + if let userInfoValue = notification.userInfo?[AVCaptureSessionInterruptionReasonKey] + as AnyObject?, + let reasonIntegerValue = userInfoValue.integerValue, + let reason = AVCaptureSession.InterruptionReason(rawValue: reasonIntegerValue) + { + os_log("Capture session was interrupted with reason: %s", type: .error, reason.rawValue) + + var canResumeManually = false + if reason == .videoDeviceInUseByAnotherClient { + canResumeManually = true + } else if reason == .videoDeviceNotAvailableWithMultipleForegroundApps { + canResumeManually = false + } + + delegate?.cameraFeedManager(self, sessionWasInterrupted: canResumeManually) + + } + } + + @objc func sessionInterruptionEnded(notification: Notification) { + delegate?.cameraFeedManagerDidEndSessionInterruption(self) + } + + @objc func sessionRuntimeErrorOccured(notification: Notification) { + guard let error = notification.userInfo?[AVCaptureSessionErrorKey] as? AVError else { + return + } + + os_log("Capture session runtime error: %s", type: .error, error.localizedDescription) + + if error.code == .mediaServicesWereReset { + sessionQueue.async { + if self.isSessionRunning { + self.startSession() + } else { + DispatchQueue.main.async { + self.delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } + } + } else { + delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } +} + +/// AVCaptureVideoDataOutputSampleBufferDelegate +extension CameraFeedManager: AVCaptureVideoDataOutputSampleBufferDelegate { + /// This method delegates the CVPixelBuffer of the frame seen by the camera currently. + func captureOutput( + _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, + from connection: AVCaptureConnection + ) { + + // Converts the CMSampleBuffer to a CVPixelBuffer. + let pixelBuffer: CVPixelBuffer? = CMSampleBufferGetImageBuffer(sampleBuffer) + + guard let imagePixelBuffer = pixelBuffer else { + return + } + + // Delegates the pixel buffer to the ViewController. + delegate?.cameraFeedManager?(self, didOutput: imagePixelBuffer) + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift new file mode 100644 index 0000000000000000000000000000000000000000..308c7ec54308af5c152ff6038670b26501a8e82c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift @@ -0,0 +1,39 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit +import AVFoundation + + /// The camera frame is displayed on this view. +class PreviewView: UIView { + var previewLayer: AVCaptureVideoPreviewLayer { + guard let layer = layer as? AVCaptureVideoPreviewLayer else { + fatalError("Layer expected is of type VideoPreviewLayer") + } + return layer + } + + var session: AVCaptureSession? { + get { + return previewLayer.session + } + set { + previewLayer.session = newValue + } + } + + override class var layerClass: AnyClass { + return AVCaptureVideoPreviewLayer.self + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift new file mode 100644 index 0000000000000000000000000000000000000000..c6be64af5678541ec09fc367b03c80155876f0ba --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift @@ -0,0 +1,21 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// Table cell for inference result in bottom view. +class InfoCell: UITableViewCell { + @IBOutlet weak var fieldNameLabel: UILabel! + @IBOutlet weak var infoLabel: UILabel! +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift new file mode 100644 index 0000000000000000000000000000000000000000..b0789ee58a1ea373d441f05333d8ce8914adadb7 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift @@ -0,0 +1,25 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +enum Constants { + // MARK: - Constants related to the image processing + static let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2) + static let rgbPixelChannels = 3 + static let maxRGBValue: Float32 = 255.0 + + // MARK: - Constants related to the model interperter + static let defaultThreadCount = 2 + static let defaultDelegate: Delegates = .CPU +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..031550ea0081963d18b5b83712854babaf7c0a34 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift @@ -0,0 +1,45 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CGSize { + /// Returns `CGAfineTransform` to resize `self` to fit in destination size, keeping aspect ratio + /// of `self`. `self` image is resized to be inscribe to destination size and located in center of + /// destination. + /// + /// - Parameter toFitIn: destination size to be filled. + /// - Returns: `CGAffineTransform` to transform `self` image to `dest` image. + func transformKeepAspect(toFitIn dest: CGSize) -> CGAffineTransform { + let sourceRatio = self.height / self.width + let destRatio = dest.height / dest.width + + // Calculates ratio `self` to `dest`. + var ratio: CGFloat + var x: CGFloat = 0 + var y: CGFloat = 0 + if sourceRatio > destRatio { + // Source size is taller than destination. Resized to fit in destination height, and find + // horizontal starting point to be centered. + ratio = dest.height / self.height + x = (dest.width - self.width * ratio) / 2 + } else { + ratio = dest.width / self.width + y = (dest.height - self.height * ratio) / 2 + } + return CGAffineTransform(a: ratio, b: 0, c: 0, d: ratio, tx: x, ty: y) + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..4899c76562a546c513736fbf4556629b08d2c929 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift @@ -0,0 +1,172 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CVPixelBuffer { + var size: CGSize { + return CGSize(width: CVPixelBufferGetWidth(self), height: CVPixelBufferGetHeight(self)) + } + + /// Returns a new `CVPixelBuffer` created by taking the self area and resizing it to the + /// specified target size. Aspect ratios of source image and destination image are expected to be + /// same. + /// + /// - Parameters: + /// - from: Source area of image to be cropped and resized. + /// - to: Size to scale the image to(i.e. image size used while training the model). + /// - Returns: The cropped and resized image of itself. + func resize(from source: CGRect, to size: CGSize) -> CVPixelBuffer? { + let rect = CGRect(origin: CGPoint(x: 0, y: 0), size: self.size) + guard rect.contains(source) else { + os_log("Resizing Error: source area is out of index", type: .error) + return nil + } + guard rect.size.width / rect.size.height - source.size.width / source.size.height < 1e-5 + else { + os_log( + "Resizing Error: source image ratio and destination image ratio is different", + type: .error) + return nil + } + + let inputImageRowBytes = CVPixelBufferGetBytesPerRow(self) + let imageChannels = 4 + + CVPixelBufferLockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) + defer { CVPixelBufferUnlockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) } + + // Finds the address of the upper leftmost pixel of the source area. + guard + let inputBaseAddress = CVPixelBufferGetBaseAddress(self)?.advanced( + by: Int(source.minY) * inputImageRowBytes + Int(source.minX) * imageChannels) + else { + return nil + } + + // Crops given area as vImage Buffer. + var croppedImage = vImage_Buffer( + data: inputBaseAddress, height: UInt(source.height), width: UInt(source.width), + rowBytes: inputImageRowBytes) + + let resultRowBytes = Int(size.width) * imageChannels + guard let resultAddress = malloc(Int(size.height) * resultRowBytes) else { + return nil + } + + // Allocates a vacant vImage buffer for resized image. + var resizedImage = vImage_Buffer( + data: resultAddress, + height: UInt(size.height), width: UInt(size.width), + rowBytes: resultRowBytes + ) + + // Performs the scale operation on cropped image and stores it in result image buffer. + guard vImageScale_ARGB8888(&croppedImage, &resizedImage, nil, vImage_Flags(0)) == kvImageNoError + else { + return nil + } + + let releaseCallBack: CVPixelBufferReleaseBytesCallback = { mutablePointer, pointer in + if let pointer = pointer { + free(UnsafeMutableRawPointer(mutating: pointer)) + } + } + + var result: CVPixelBuffer? + + // Converts the thumbnail vImage buffer to CVPixelBuffer + let conversionStatus = CVPixelBufferCreateWithBytes( + nil, + Int(size.width), Int(size.height), + CVPixelBufferGetPixelFormatType(self), + resultAddress, + resultRowBytes, + releaseCallBack, + nil, + nil, + &result + ) + + guard conversionStatus == kCVReturnSuccess else { + free(resultAddress) + return nil + } + + return result + } + + /// Returns the RGB `Data` representation of the given image buffer. + /// + /// - Parameters: + /// - isModelQuantized: Whether the model is quantized (i.e. fixed point values rather than + /// floating point values). + /// - Returns: The RGB data representation of the image buffer or `nil` if the buffer could not be + /// converted. + func rgbData( + isModelQuantized: Bool + ) -> Data? { + CVPixelBufferLockBaseAddress(self, .readOnly) + defer { CVPixelBufferUnlockBaseAddress(self, .readOnly) } + guard let sourceData = CVPixelBufferGetBaseAddress(self) else { + return nil + } + + let width = CVPixelBufferGetWidth(self) + let height = CVPixelBufferGetHeight(self) + let sourceBytesPerRow = CVPixelBufferGetBytesPerRow(self) + let destinationBytesPerRow = Constants.rgbPixelChannels * width + + // Assign input image to `sourceBuffer` to convert it. + var sourceBuffer = vImage_Buffer( + data: sourceData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: sourceBytesPerRow) + + // Make `destinationBuffer` and `destinationData` for its data to be assigned. + guard let destinationData = malloc(height * destinationBytesPerRow) else { + os_log("Error: out of memory", type: .error) + return nil + } + defer { free(destinationData) } + var destinationBuffer = vImage_Buffer( + data: destinationData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: destinationBytesPerRow) + + // Convert image type. + switch CVPixelBufferGetPixelFormatType(self) { + case kCVPixelFormatType_32BGRA: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + case kCVPixelFormatType_32ARGB: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + default: + os_log("The type of this image is not supported.", type: .error) + return nil + } + + // Make `Data` with converted image. + let imageByteData = Data( + bytes: destinationBuffer.data, count: destinationBuffer.rowBytes * height) + + if isModelQuantized { return imageByteData } + + let imageBytes = [UInt8](imageByteData) + return Data(copyingBufferOf: imageBytes.map { Float($0) / Constants.maxRGBValue }) + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..63f7ced786e2b550391c77af534d1d3c431522c6 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift @@ -0,0 +1,75 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite + +// MARK: - Data +extension Data { + /// Creates a new buffer by copying the buffer pointer of the given array. + /// + /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit + /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting + /// data from the resulting buffer has undefined behavior. + /// - Parameter array: An array with elements of type `T`. + init(copyingBufferOf array: [T]) { + self = array.withUnsafeBufferPointer(Data.init) + } + + /// Convert a Data instance to Array representation. + func toArray(type: T.Type) -> [T] where T: AdditiveArithmetic { + var array = [T](repeating: T.zero, count: self.count / MemoryLayout.stride) + _ = array.withUnsafeMutableBytes { self.copyBytes(to: $0) } + return array + } +} + +// MARK: - Wrappers +/// Struct for handling multidimension `Data` in flat `Array`. +struct FlatArray { + private var array: [Element] + var dimensions: [Int] + + init(tensor: Tensor) { + dimensions = tensor.shape.dimensions + array = tensor.data.toArray(type: Element.self) + } + + private func flatIndex(_ index: [Int]) -> Int { + guard index.count == dimensions.count else { + fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).") + } + + var result = 0 + for i in 0.. index[i] else { + fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])") + } + result = dimensions[i] * result + index[i] + } + return result + } + + subscript(_ index: Int...) -> Element { + get { + return array[flatIndex(index)] + } + set(newValue) { + array[flatIndex(index)] = newValue + } + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..4330d9b33f31010549802febc6f6f2bc9fd9b950 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist @@ -0,0 +1,42 @@ + + + + + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + NSCameraUsageDescription + This app will use camera to continuously estimate the depth map. + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift new file mode 100644 index 0000000000000000000000000000000000000000..144cfe1fa3a65af5adcb572237f2bf9718e570ae --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift @@ -0,0 +1,464 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite +import UIKit + +/// This class handles all data preprocessing and makes calls to run inference on a given frame +/// by invoking the `Interpreter`. It then formats the inferences obtained. +class ModelDataHandler { + // MARK: - Private Properties + + /// TensorFlow Lite `Interpreter` object for performing inference on a given model. + private var interpreter: Interpreter + + /// TensorFlow lite `Tensor` of model input and output. + private var inputTensor: Tensor + + //private var heatsTensor: Tensor + //private var offsetsTensor: Tensor + private var outputTensor: Tensor + // MARK: - Initialization + + /// A failable initializer for `ModelDataHandler`. A new instance is created if the model is + /// successfully loaded from the app's main bundle. Default `threadCount` is 2. + init( + threadCount: Int = Constants.defaultThreadCount, + delegate: Delegates = Constants.defaultDelegate + ) throws { + // Construct the path to the model file. + guard + let modelPath = Bundle.main.path( + forResource: Model.file.name, + ofType: Model.file.extension + ) + else { + fatalError("Failed to load the model file with name: \(Model.file.name).") + } + + // Specify the options for the `Interpreter`. + var options = Interpreter.Options() + options.threadCount = threadCount + + // Specify the delegates for the `Interpreter`. + var delegates: [Delegate]? + switch delegate { + case .Metal: + delegates = [MetalDelegate()] + case .CoreML: + if let coreMLDelegate = CoreMLDelegate() { + delegates = [coreMLDelegate] + } else { + delegates = nil + } + default: + delegates = nil + } + + // Create the `Interpreter`. + interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: delegates) + + // Initialize input and output `Tensor`s. + // Allocate memory for the model's input `Tensor`s. + try interpreter.allocateTensors() + + // Get allocated input and output `Tensor`s. + inputTensor = try interpreter.input(at: 0) + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + /* + // Check if input and output `Tensor`s are in the expected formats. + guard (inputTensor.dataType == .uInt8) == Model.isQuantized else { + fatalError("Unexpected Model: quantization is \(!Model.isQuantized)") + } + + guard inputTensor.shape.dimensions[0] == Model.input.batchSize, + inputTensor.shape.dimensions[1] == Model.input.height, + inputTensor.shape.dimensions[2] == Model.input.width, + inputTensor.shape.dimensions[3] == Model.input.channelSize + else { + fatalError("Unexpected Model: input shape") + } + + + guard heatsTensor.shape.dimensions[0] == Model.output.batchSize, + heatsTensor.shape.dimensions[1] == Model.output.height, + heatsTensor.shape.dimensions[2] == Model.output.width, + heatsTensor.shape.dimensions[3] == Model.output.keypointSize + else { + fatalError("Unexpected Model: heat tensor") + } + + guard offsetsTensor.shape.dimensions[0] == Model.output.batchSize, + offsetsTensor.shape.dimensions[1] == Model.output.height, + offsetsTensor.shape.dimensions[2] == Model.output.width, + offsetsTensor.shape.dimensions[3] == Model.output.offsetSize + else { + fatalError("Unexpected Model: offset tensor") + } + */ + + } + + /// Runs Midas model with given image with given source area to destination area. + /// + /// - Parameters: + /// - on: Input image to run the model. + /// - from: Range of input image to run the model. + /// - to: Size of view to render the result. + /// - Returns: Result of the inference and the times consumed in every steps. + func runMidas(on pixelbuffer: CVPixelBuffer, from source: CGRect, to dest: CGSize) + //-> (Result, Times)? + //-> (FlatArray, Times)? + -> ([Float], Int, Int, Times)? + { + // Start times of each process. + let preprocessingStartTime: Date + let inferenceStartTime: Date + let postprocessingStartTime: Date + + // Processing times in miliseconds. + let preprocessingTime: TimeInterval + let inferenceTime: TimeInterval + let postprocessingTime: TimeInterval + + preprocessingStartTime = Date() + guard let data = preprocess(of: pixelbuffer, from: source) else { + os_log("Preprocessing failed", type: .error) + return nil + } + preprocessingTime = Date().timeIntervalSince(preprocessingStartTime) * 1000 + + inferenceStartTime = Date() + inference(from: data) + inferenceTime = Date().timeIntervalSince(inferenceStartTime) * 1000 + + postprocessingStartTime = Date() + //guard let result = postprocess(to: dest) else { + // os_log("Postprocessing failed", type: .error) + // return nil + //} + postprocessingTime = Date().timeIntervalSince(postprocessingStartTime) * 1000 + + + let results: [Float] + switch outputTensor.dataType { + case .uInt8: + guard let quantization = outputTensor.quantizationParameters else { + print("No results returned because the quantization values for the output tensor are nil.") + return nil + } + let quantizedResults = [UInt8](outputTensor.data) + results = quantizedResults.map { + quantization.scale * Float(Int($0) - quantization.zeroPoint) + } + case .float32: + results = [Float32](unsafeData: outputTensor.data) ?? [] + default: + print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.") + return nil + } + + + let times = Times( + preprocessing: preprocessingTime, + inference: inferenceTime, + postprocessing: postprocessingTime) + + return (results, Model.input.width, Model.input.height, times) + } + + // MARK: - Private functions to run model + /// Preprocesses given rectangle image to be `Data` of disired size by croping and resizing it. + /// + /// - Parameters: + /// - of: Input image to crop and resize. + /// - from: Target area to be cropped and resized. + /// - Returns: The cropped and resized image. `nil` if it can not be processed. + private func preprocess(of pixelBuffer: CVPixelBuffer, from targetSquare: CGRect) -> Data? { + let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) + assert(sourcePixelFormat == kCVPixelFormatType_32BGRA) + + // Resize `targetSquare` of input image to `modelSize`. + let modelSize = CGSize(width: Model.input.width, height: Model.input.height) + guard let thumbnail = pixelBuffer.resize(from: targetSquare, to: modelSize) + else { + return nil + } + + // Remove the alpha component from the image buffer to get the initialized `Data`. + let byteCount = + Model.input.batchSize + * Model.input.height * Model.input.width + * Model.input.channelSize + guard + let inputData = thumbnail.rgbData( + isModelQuantized: Model.isQuantized + ) + else { + os_log("Failed to convert the image buffer to RGB data.", type: .error) + return nil + } + + return inputData + } + + + + /* + /// Postprocesses output `Tensor`s to `Result` with size of view to render the result. + /// + /// - Parameters: + /// - to: Size of view to be displaied. + /// - Returns: Postprocessed `Result`. `nil` if it can not be processed. + private func postprocess(to viewSize: CGSize) -> Result? { + // MARK: Formats output tensors + // Convert `Tensor` to `FlatArray`. As Midas is not quantized, convert them to Float type + // `FlatArray`. + let heats = FlatArray(tensor: heatsTensor) + let offsets = FlatArray(tensor: offsetsTensor) + + // MARK: Find position of each key point + // Finds the (row, col) locations of where the keypoints are most likely to be. The highest + // `heats[0, row, col, keypoint]` value, the more likely `keypoint` being located in (`row`, + // `col`). + let keypointPositions = (0.. (Int, Int) in + var maxValue = heats[0, 0, 0, keypoint] + var maxRow = 0 + var maxCol = 0 + for row in 0.. maxValue { + maxValue = heats[0, row, col, keypoint] + maxRow = row + maxCol = col + } + } + } + return (maxRow, maxCol) + } + + // MARK: Calculates total confidence score + // Calculates total confidence score of each key position. + let totalScoreSum = keypointPositions.enumerated().reduce(0.0) { accumulator, elem -> Float32 in + accumulator + sigmoid(heats[0, elem.element.0, elem.element.1, elem.offset]) + } + let totalScore = totalScoreSum / Float32(Model.output.keypointSize) + + // MARK: Calculate key point position on model input + // Calculates `KeyPoint` coordination model input image with `offsets` adjustment. + let coords = keypointPositions.enumerated().map { index, elem -> (y: Float32, x: Float32) in + let (y, x) = elem + let yCoord = + Float32(y) / Float32(Model.output.height - 1) * Float32(Model.input.height) + + offsets[0, y, x, index] + let xCoord = + Float32(x) / Float32(Model.output.width - 1) * Float32(Model.input.width) + + offsets[0, y, x, index + Model.output.keypointSize] + return (y: yCoord, x: xCoord) + } + + // MARK: Transform key point position and make lines + // Make `Result` from `keypointPosition'. Each point is adjusted to `ViewSize` to be drawn. + var result = Result(dots: [], lines: [], score: totalScore) + var bodyPartToDotMap = [BodyPart: CGPoint]() + for (index, part) in BodyPart.allCases.enumerated() { + let position = CGPoint( + x: CGFloat(coords[index].x) * viewSize.width / CGFloat(Model.input.width), + y: CGFloat(coords[index].y) * viewSize.height / CGFloat(Model.input.height) + ) + bodyPartToDotMap[part] = position + result.dots.append(position) + } + + do { + try result.lines = BodyPart.lines.map { map throws -> Line in + guard let from = bodyPartToDotMap[map.from] else { + throw PostprocessError.missingBodyPart(of: map.from) + } + guard let to = bodyPartToDotMap[map.to] else { + throw PostprocessError.missingBodyPart(of: map.to) + } + return Line(from: from, to: to) + } + } catch PostprocessError.missingBodyPart(let missingPart) { + os_log("Postprocessing error: %s is missing.", type: .error, missingPart.rawValue) + return nil + } catch { + os_log("Postprocessing error: %s", type: .error, error.localizedDescription) + return nil + } + + return result + } +*/ + + + + /// Run inference with given `Data` + /// + /// Parameter `from`: `Data` of input image to run model. + private func inference(from data: Data) { + // Copy the initialized `Data` to the input `Tensor`. + do { + try interpreter.copy(data, toInputAt: 0) + + // Run inference by invoking the `Interpreter`. + try interpreter.invoke() + + // Get the output `Tensor` to process the inference results. + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + + } catch let error { + os_log( + "Failed to invoke the interpreter with error: %s", type: .error, + error.localizedDescription) + return + } + } + + /// Returns value within [0,1]. + private func sigmoid(_ x: Float32) -> Float32 { + return (1.0 / (1.0 + exp(-x))) + } +} + +// MARK: - Data types for inference result +struct KeyPoint { + var bodyPart: BodyPart = BodyPart.NOSE + var position: CGPoint = CGPoint() + var score: Float = 0.0 +} + +struct Line { + let from: CGPoint + let to: CGPoint +} + +struct Times { + var preprocessing: Double + var inference: Double + var postprocessing: Double +} + +struct Result { + var dots: [CGPoint] + var lines: [Line] + var score: Float +} + +enum BodyPart: String, CaseIterable { + case NOSE = "nose" + case LEFT_EYE = "left eye" + case RIGHT_EYE = "right eye" + case LEFT_EAR = "left ear" + case RIGHT_EAR = "right ear" + case LEFT_SHOULDER = "left shoulder" + case RIGHT_SHOULDER = "right shoulder" + case LEFT_ELBOW = "left elbow" + case RIGHT_ELBOW = "right elbow" + case LEFT_WRIST = "left wrist" + case RIGHT_WRIST = "right wrist" + case LEFT_HIP = "left hip" + case RIGHT_HIP = "right hip" + case LEFT_KNEE = "left knee" + case RIGHT_KNEE = "right knee" + case LEFT_ANKLE = "left ankle" + case RIGHT_ANKLE = "right ankle" + + /// List of lines connecting each part. + static let lines = [ + (from: BodyPart.LEFT_WRIST, to: BodyPart.LEFT_ELBOW), + (from: BodyPart.LEFT_ELBOW, to: BodyPart.LEFT_SHOULDER), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.RIGHT_SHOULDER, to: BodyPart.RIGHT_ELBOW), + (from: BodyPart.RIGHT_ELBOW, to: BodyPart.RIGHT_WRIST), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.LEFT_HIP), + (from: BodyPart.LEFT_HIP, to: BodyPart.RIGHT_HIP), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.LEFT_HIP, to: BodyPart.LEFT_KNEE), + (from: BodyPart.LEFT_KNEE, to: BodyPart.LEFT_ANKLE), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_KNEE), + (from: BodyPart.RIGHT_KNEE, to: BodyPart.RIGHT_ANKLE), + ] +} + +// MARK: - Delegates Enum +enum Delegates: Int, CaseIterable { + case CPU + case Metal + case CoreML + + var description: String { + switch self { + case .CPU: + return "CPU" + case .Metal: + return "GPU" + case .CoreML: + return "NPU" + } + } +} + +// MARK: - Custom Errors +enum PostprocessError: Error { + case missingBodyPart(of: BodyPart) +} + +// MARK: - Information about the model file. +typealias FileInfo = (name: String, extension: String) + +enum Model { + static let file: FileInfo = ( + name: "model_opt", extension: "tflite" + ) + + static let input = (batchSize: 1, height: 256, width: 256, channelSize: 3) + static let output = (batchSize: 1, height: 256, width: 256, channelSize: 1) + static let isQuantized = false +} + + +extension Array { + /// Creates a new array from the bytes of the given unsafe data. + /// + /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit + /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in + /// the `unsafeData`'s buffer to a new array returns an unsafe copy. + /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of + /// `MemoryLayout.stride`. + /// - Parameter unsafeData: The data containing the bytes to turn into an array. + init?(unsafeData: Data) { + guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } + #if swift(>=5.0) + self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } + #else + self = unsafeData.withUnsafeBytes { + .init(UnsafeBufferPointer( + start: $0, + count: unsafeData.count / MemoryLayout.stride + )) + } + #endif // swift(>=5.0) + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..a04c79f554777863bd0dc8287bfd60704ce28bf2 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..5f5623794bd35b9bb75efd7b7e249fd7357fdfbd --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard @@ -0,0 +1,236 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift new file mode 100644 index 0000000000000000000000000000000000000000..fbb51b5a303412c0bbd158d76d025cf88fee6f8f --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift @@ -0,0 +1,489 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + + +public struct PixelData { + var a: UInt8 + var r: UInt8 + var g: UInt8 + var b: UInt8 +} + +extension UIImage { + convenience init?(pixels: [PixelData], width: Int, height: Int) { + guard width > 0 && height > 0, pixels.count == width * height else { return nil } + var data = pixels + guard let providerRef = CGDataProvider(data: Data(bytes: &data, count: data.count * MemoryLayout.size) as CFData) + else { return nil } + guard let cgim = CGImage( + width: width, + height: height, + bitsPerComponent: 8, + bitsPerPixel: 32, + bytesPerRow: width * MemoryLayout.size, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue), + provider: providerRef, + decode: nil, + shouldInterpolate: false, + intent: .defaultIntent) + else { return nil } + self.init(cgImage: cgim) + } +} + + +class ViewController: UIViewController { + // MARK: Storyboards Connections + @IBOutlet weak var previewView: PreviewView! + + //@IBOutlet weak var overlayView: OverlayView! + @IBOutlet weak var overlayView: UIImageView! + + private var imageView : UIImageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)) + + private var imageViewInitialized: Bool = false + + @IBOutlet weak var resumeButton: UIButton! + @IBOutlet weak var cameraUnavailableLabel: UILabel! + + @IBOutlet weak var tableView: UITableView! + + @IBOutlet weak var threadCountLabel: UILabel! + @IBOutlet weak var threadCountStepper: UIStepper! + + @IBOutlet weak var delegatesControl: UISegmentedControl! + + // MARK: ModelDataHandler traits + var threadCount: Int = Constants.defaultThreadCount + var delegate: Delegates = Constants.defaultDelegate + + // MARK: Result Variables + // Inferenced data to render. + private var inferencedData: InferencedData? + + // Minimum score to render the result. + private let minimumScore: Float = 0.5 + + private var avg_latency: Double = 0.0 + + // Relative location of `overlayView` to `previewView`. + private var overlayViewFrame: CGRect? + + private var previewViewFrame: CGRect? + + // MARK: Controllers that manage functionality + // Handles all the camera related functionality + private lazy var cameraCapture = CameraFeedManager(previewView: previewView) + + // Handles all data preprocessing and makes calls to run inference. + private var modelDataHandler: ModelDataHandler? + + // MARK: View Handling Methods + override func viewDidLoad() { + super.viewDidLoad() + + do { + modelDataHandler = try ModelDataHandler() + } catch let error { + fatalError(error.localizedDescription) + } + + cameraCapture.delegate = self + tableView.delegate = self + tableView.dataSource = self + + // MARK: UI Initialization + // Setup thread count stepper with white color. + // https://forums.developer.apple.com/thread/121495 + threadCountStepper.setDecrementImage( + threadCountStepper.decrementImage(for: .normal), for: .normal) + threadCountStepper.setIncrementImage( + threadCountStepper.incrementImage(for: .normal), for: .normal) + // Setup initial stepper value and its label. + threadCountStepper.value = Double(Constants.defaultThreadCount) + threadCountLabel.text = Constants.defaultThreadCount.description + + // Setup segmented controller's color. + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.lightGray], + for: .normal) + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.black], + for: .selected) + // Remove existing segments to initialize it with `Delegates` entries. + delegatesControl.removeAllSegments() + Delegates.allCases.forEach { delegate in + delegatesControl.insertSegment( + withTitle: delegate.description, + at: delegate.rawValue, + animated: false) + } + delegatesControl.selectedSegmentIndex = 0 + } + + override func viewWillAppear(_ animated: Bool) { + super.viewWillAppear(animated) + + cameraCapture.checkCameraConfigurationAndStartSession() + } + + override func viewWillDisappear(_ animated: Bool) { + cameraCapture.stopSession() + } + + override func viewDidLayoutSubviews() { + overlayViewFrame = overlayView.frame + previewViewFrame = previewView.frame + } + + // MARK: Button Actions + @IBAction func didChangeThreadCount(_ sender: UIStepper) { + let changedCount = Int(sender.value) + if threadCountLabel.text == changedCount.description { + return + } + + do { + modelDataHandler = try ModelDataHandler(threadCount: changedCount, delegate: delegate) + } catch let error { + fatalError(error.localizedDescription) + } + threadCount = changedCount + threadCountLabel.text = changedCount.description + os_log("Thread count is changed to: %d", threadCount) + } + + @IBAction func didChangeDelegate(_ sender: UISegmentedControl) { + guard let changedDelegate = Delegates(rawValue: delegatesControl.selectedSegmentIndex) else { + fatalError("Unexpected value from delegates segemented controller.") + } + do { + modelDataHandler = try ModelDataHandler(threadCount: threadCount, delegate: changedDelegate) + } catch let error { + fatalError(error.localizedDescription) + } + delegate = changedDelegate + os_log("Delegate is changed to: %s", delegate.description) + } + + @IBAction func didTapResumeButton(_ sender: Any) { + cameraCapture.resumeInterruptedSession { complete in + + if complete { + self.resumeButton.isHidden = true + self.cameraUnavailableLabel.isHidden = true + } else { + self.presentUnableToResumeSessionAlert() + } + } + } + + func presentUnableToResumeSessionAlert() { + let alert = UIAlertController( + title: "Unable to Resume Session", + message: "There was an error while attempting to resume session.", + preferredStyle: .alert + ) + alert.addAction(UIAlertAction(title: "OK", style: .default, handler: nil)) + + self.present(alert, animated: true) + } +} + +// MARK: - CameraFeedManagerDelegate Methods +extension ViewController: CameraFeedManagerDelegate { + func cameraFeedManager(_ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer) { + runModel(on: pixelBuffer) + } + + // MARK: Session Handling Alerts + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) { + // Handles session run time error by updating the UI and providing a button if session can be + // manually resumed. + self.resumeButton.isHidden = false + } + + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) { + // Updates the UI when session is interupted. + if canResumeManually { + self.resumeButton.isHidden = false + } else { + self.cameraUnavailableLabel.isHidden = false + } + } + + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) { + // Updates UI once session interruption has ended. + self.cameraUnavailableLabel.isHidden = true + self.resumeButton.isHidden = true + } + + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Confirguration Failed", message: "Configuration of camera has failed.", + preferredStyle: .alert) + let okAction = UIAlertAction(title: "OK", style: .cancel, handler: nil) + alertController.addAction(okAction) + + present(alertController, animated: true, completion: nil) + } + + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Camera Permissions Denied", + message: + "Camera permissions have been denied for this app. You can change this by going to Settings", + preferredStyle: .alert) + + let cancelAction = UIAlertAction(title: "Cancel", style: .cancel, handler: nil) + let settingsAction = UIAlertAction(title: "Settings", style: .default) { action in + if let url = URL.init(string: UIApplication.openSettingsURLString) { + UIApplication.shared.open(url, options: [:], completionHandler: nil) + } + } + + alertController.addAction(cancelAction) + alertController.addAction(settingsAction) + + present(alertController, animated: true, completion: nil) + } + + @objc func runModel(on pixelBuffer: CVPixelBuffer) { + guard let overlayViewFrame = overlayViewFrame, let previewViewFrame = previewViewFrame + else { + return + } + // To put `overlayView` area as model input, transform `overlayViewFrame` following transform + // from `previewView` to `pixelBuffer`. `previewView` area is transformed to fit in + // `pixelBuffer`, because `pixelBuffer` as a camera output is resized to fill `previewView`. + // https://developer.apple.com/documentation/avfoundation/avlayervideogravity/1385607-resizeaspectfill + let modelInputRange = overlayViewFrame.applying( + previewViewFrame.size.transformKeepAspect(toFitIn: pixelBuffer.size)) + + // Run Midas model. + guard + let (result, width, height, times) = self.modelDataHandler?.runMidas( + on: pixelBuffer, + from: modelInputRange, + to: overlayViewFrame.size) + else { + os_log("Cannot get inference result.", type: .error) + return + } + + if avg_latency == 0 { + avg_latency = times.inference + } else { + avg_latency = times.inference*0.1 + avg_latency*0.9 + } + + // Udpate `inferencedData` to render data in `tableView`. + inferencedData = InferencedData(score: Float(avg_latency), times: times) + + //let height = 256 + //let width = 256 + + let outputs = result + let outputs_size = width * height; + + var multiplier : Float = 1.0; + + let max_val : Float = outputs.max() ?? 0 + let min_val : Float = outputs.min() ?? 0 + + if((max_val - min_val) > 0) { + multiplier = 255 / (max_val - min_val); + } + + // Draw result. + DispatchQueue.main.async { + self.tableView.reloadData() + + var pixels: [PixelData] = .init(repeating: .init(a: 255, r: 0, g: 0, b: 0), count: width * height) + + for i in pixels.indices { + //if(i < 1000) + //{ + let val = UInt8((outputs[i] - min_val) * multiplier) + + pixels[i].r = val + pixels[i].g = val + pixels[i].b = val + //} + } + + + /* + pixels[i].a = 255 + pixels[i].r = .random(in: 0...255) + pixels[i].g = .random(in: 0...255) + pixels[i].b = .random(in: 0...255) + } + */ + + DispatchQueue.main.async { + let image = UIImage(pixels: pixels, width: width, height: height) + + self.imageView.image = image + + if (self.imageViewInitialized == false) { + self.imageViewInitialized = true + self.overlayView.addSubview(self.imageView) + self.overlayView.setNeedsDisplay() + } + } + + /* + let image = UIImage(pixels: pixels, width: width, height: height) + + var imageView : UIImageView + imageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)); + imageView.image = image + self.overlayView.addSubview(imageView) + self.overlayView.setNeedsDisplay() + */ + } + } +/* + func drawResult(of result: Result) { + self.overlayView.dots = result.dots + self.overlayView.lines = result.lines + self.overlayView.setNeedsDisplay() + } + + func clearResult() { + self.overlayView.clear() + self.overlayView.setNeedsDisplay() + } + */ + +} + + +// MARK: - TableViewDelegate, TableViewDataSource Methods +extension ViewController: UITableViewDelegate, UITableViewDataSource { + func numberOfSections(in tableView: UITableView) -> Int { + return InferenceSections.allCases.count + } + + func tableView(_ tableView: UITableView, numberOfRowsInSection section: Int) -> Int { + guard let section = InferenceSections(rawValue: section) else { + return 0 + } + + return section.subcaseCount + } + + func tableView(_ tableView: UITableView, cellForRowAt indexPath: IndexPath) -> UITableViewCell { + let cell = tableView.dequeueReusableCell(withIdentifier: "InfoCell") as! InfoCell + guard let section = InferenceSections(rawValue: indexPath.section) else { + return cell + } + guard let data = inferencedData else { return cell } + + var fieldName: String + var info: String + + switch section { + case .Score: + fieldName = section.description + info = String(format: "%.3f", data.score) + case .Time: + guard let row = ProcessingTimes(rawValue: indexPath.row) else { + return cell + } + var time: Double + switch row { + case .InferenceTime: + time = data.times.inference + } + fieldName = row.description + info = String(format: "%.2fms", time) + } + + cell.fieldNameLabel.text = fieldName + cell.infoLabel.text = info + + return cell + } + + func tableView(_ tableView: UITableView, heightForRowAt indexPath: IndexPath) -> CGFloat { + guard let section = InferenceSections(rawValue: indexPath.section) else { + return 0 + } + + var height = Traits.normalCellHeight + if indexPath.row == section.subcaseCount - 1 { + height = Traits.separatorCellHeight + Traits.bottomSpacing + } + return height + } + +} + +// MARK: - Private enums +/// UI coinstraint values +fileprivate enum Traits { + static let normalCellHeight: CGFloat = 35.0 + static let separatorCellHeight: CGFloat = 25.0 + static let bottomSpacing: CGFloat = 30.0 +} + +fileprivate struct InferencedData { + var score: Float + var times: Times +} + +/// Type of sections in Info Cell +fileprivate enum InferenceSections: Int, CaseIterable { + case Score + case Time + + var description: String { + switch self { + case .Score: + return "Average" + case .Time: + return "Processing Time" + } + } + + var subcaseCount: Int { + switch self { + case .Score: + return 1 + case .Time: + return ProcessingTimes.allCases.count + } + } +} + +/// Type of processing times in Time section in Info Cell +fileprivate enum ProcessingTimes: Int, CaseIterable { + case InferenceTime + + var description: String { + switch self { + case .InferenceTime: + return "Inference Time" + } + } +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift new file mode 100644 index 0000000000000000000000000000000000000000..3b53910b57563b6a195fd53321fa2a24ebaf3d3f --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift @@ -0,0 +1,63 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// UIView for rendering inference output. +class OverlayView: UIView { + + var dots = [CGPoint]() + var lines = [Line]() + + override func draw(_ rect: CGRect) { + for dot in dots { + drawDot(of: dot) + } + for line in lines { + drawLine(of: line) + } + } + + func drawDot(of dot: CGPoint) { + let dotRect = CGRect( + x: dot.x - Traits.dot.radius / 2, y: dot.y - Traits.dot.radius / 2, + width: Traits.dot.radius, height: Traits.dot.radius) + let dotPath = UIBezierPath(ovalIn: dotRect) + + Traits.dot.color.setFill() + dotPath.fill() + } + + func drawLine(of line: Line) { + let linePath = UIBezierPath() + linePath.move(to: CGPoint(x: line.from.x, y: line.from.y)) + linePath.addLine(to: CGPoint(x: line.to.x, y: line.to.y)) + linePath.close() + + linePath.lineWidth = Traits.line.width + Traits.line.color.setStroke() + + linePath.stroke() + } + + func clear() { + self.dots = [] + self.lines = [] + } +} + +private enum Traits { + static let dot = (radius: CGFloat(5), color: UIColor.orange) + static let line = (width: CGFloat(1.0), color: UIColor.orange) +} diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..5e9461fc96dbbe3c22ca6bbf2bfd7df3981b9462 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile @@ -0,0 +1,12 @@ +# Uncomment the next line to define a global platform for your project + platform :ios, '12.0' + +target 'Midas' do + # Comment the next line if you're not using Swift and don't want to use dynamic frameworks + use_frameworks! + + # Pods for Midas + pod 'TensorFlowLiteSwift', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/CoreML', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/Metal', '~> 0.0.1-nightly' +end diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7b8eb29feaa21e67814b035dbd5c5fb2c62a4151 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md @@ -0,0 +1,105 @@ +# Tensorflow Lite MiDaS iOS Example + +### Requirements + +- XCode 11.0 or above +- iOS 12.0 or above, [iOS 14 breaks the NPU Delegate](https://github.com/tensorflow/tensorflow/issues/43339) +- TensorFlow 2.4.0, TensorFlowLiteSwift -> 0.0.1-nightly + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. It uses TensorFlowLiteSwift / C++ libraries on iOS. The code is written in Swift. + +Paper: https://arxiv.org/abs/1907.01341 + +> Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +> René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +### Install TensorFlow + +Set default python version to python3: + +``` +echo 'export PATH=/usr/local/opt/python/libexec/bin:$PATH' >> ~/.zshenv +echo 'alias python=python3' >> ~/.zshenv +echo 'alias pip=pip3' >> ~/.zshenv +``` + +Install TensorFlow + +```shell +pip install tensorflow +``` + +### Install TensorFlowLiteSwift via Cocoapods + +Set required TensorFlowLiteSwift version in the file (`0.0.1-nightly` is recommended): https://github.com/isl-org/MiDaS/blob/master/mobile/ios/Podfile#L9 + +Install: brew, ruby, cocoapods + +``` +ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" +brew install mc rbenv ruby-build +sudo gem install cocoapods +``` + + +The TensorFlowLiteSwift library is available in [Cocoapods](https://cocoapods.org/), to integrate it to our project, we can run in the root directory of the project: + +```ruby +pod install +``` + +Now open the `Midas.xcworkspace` file in XCode, select your iPhone device (XCode->Product->Destination->iPhone) and launch it (cmd + R). If everything works well, you should see a real-time depth map from your camera. + +### Model + +The TensorFlow (TFlite) model `midas.tflite` is in the folder `/Midas/Model` + + +To use another model, you should convert it from TensorFlow saved-model to TFlite model (so that it can be deployed): + +```python +saved_model_export_dir = "./saved_model" +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_export_dir) +tflite_model = converter.convert() +open(model_tflite_name, "wb").write("model.tflite") +``` + +### Setup XCode + +* Open directory `.xcworkspace` from the XCode + +* Press on your ProjectName (left-top corner) -> change Bundle Identifier to `com.midas.tflite-npu` or something like this (it should be unique) + +* select your Developer Team (your should be signed-in by using your AppleID) + +* Connect your iPhone (if you want to run it on real device instead of simulator), select your iPhone device (XCode->Product->Destination->iPhone) + +* Click in the XCode: Product -> Run + +* On your iPhone device go to the: Settings -> General -> Device Management (or Profiles) -> Apple Development -> Trust Apple Development + +---- + +Original repository: https://github.com/isl-org/MiDaS + + +### Examples: + +| ![photo_2020-09-27_17-43-20](https://user-images.githubusercontent.com/4096485/94367804-9610de80-00e9-11eb-8a23-8b32a6f52d41.jpg) | ![photo_2020-09-27_17-49-22](https://user-images.githubusercontent.com/4096485/94367974-7201cd00-00ea-11eb-8e0a-68eb9ea10f63.jpg) | ![photo_2020-09-27_17-52-30](https://user-images.githubusercontent.com/4096485/94367976-729a6380-00ea-11eb-8ce0-39d3e26dd550.jpg) | ![photo_2020-09-27_17-43-21](https://user-images.githubusercontent.com/4096485/94367807-97420b80-00e9-11eb-9dcd-848ad9e89e03.jpg) | +|---|---|---|---| + +## LICENSE + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..d737b39d966278f5c6bc29802526ab86f8473de4 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Download TF Lite model from the internet if it does not exist. + +TFLITE_MODEL="model_opt.tflite" +TFLITE_FILE="Midas/Model/${TFLITE_MODEL}" +MODEL_SRC="https://github.com/isl-org/MiDaS/releases/download/v2/${TFLITE_MODEL}" + +if test -f "${TFLITE_FILE}"; then + echo "INFO: TF Lite model already exists. Skip downloading and use the local model." +else + curl --create-dirs -o "${TFLITE_FILE}" -LJO "${MODEL_SRC}" + echo "INFO: Downloaded TensorFlow Lite model to ${TFLITE_FILE}." +fi + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1d43c2606767798ee46b34292e0483197424ec23 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md @@ -0,0 +1,131 @@ +# MiDaS for ROS1 by using LibTorch in C++ + +### Requirements + +- Ubuntu 17.10 / 18.04 / 20.04, Debian Stretch +- ROS Melodic for Ubuntu (17.10 / 18.04) / Debian Stretch, ROS Noetic for Ubuntu 20.04 +- C++11 +- LibTorch >= 1.6 + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. + +* input from `image_topic`: `sensor_msgs/Image` - `RGB8` image with any shape +* output to `midas_topic`: `sensor_msgs/Image` - `TYPE_32FC1` inverse relative depth maps in range [0 - 255] with original size and channels=1 + +### Install Dependecies + +* install ROS Melodic for Ubuntu 17.10 / 18.04: +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_melodic_ubuntu_17_18.sh +./install_ros_melodic_ubuntu_17_18.sh +``` + +or Noetic for Ubuntu 20.04: + +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_noetic_ubuntu_20.sh +./install_ros_noetic_ubuntu_20.sh +``` + + +* install LibTorch 1.7 with CUDA 11.0: + +On **Jetson (ARM)**: +```bash +wget https://nvidia.box.com/shared/static/wa34qwrwtk9njtyarwt5nvo6imenfy26.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl +sudo apt-get install python3-pip libopenblas-base libopenmpi-dev +pip3 install Cython +pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl +``` +Or compile LibTorch from source: https://github.com/pytorch/pytorch#from-source + +On **Linux (x86_64)**: +```bash +cd ~/ +wget https://download.pytorch.org/libtorch/cu110/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu110.zip +unzip libtorch-cxx11-abi-shared-with-deps-1.7.0+cu110.zip +``` + +* create symlink for OpenCV: + +```bash +sudo ln -s /usr/include/opencv4 /usr/include/opencv +``` + +* download and install MiDaS: + +```bash +source ~/.bashrc +cd ~/ +mkdir catkin_ws +cd catkin_ws +git clone https://github.com/isl-org/MiDaS +mkdir src +cp -r MiDaS/ros/* src + +chmod +x src/additions/*.sh +chmod +x src/*.sh +chmod +x src/midas_cpp/scripts/*.py +cp src/additions/do_catkin_make.sh ./do_catkin_make.sh +./do_catkin_make.sh +./src/additions/downloads.sh +``` + +### Usage + +* run only `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + +#### Test + +* Test - capture video and show result in the window: + * place any `test.mp4` video file to the directory `~/catkin_ws/src/` + * run `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + * run test nodes in another terminal: `cd ~/catkin_ws/src && ./run_talker_listener_test.sh` and wait 30 seconds + + (to use Python 2, run command `sed -i 's/python3/python2/' ~/catkin_ws/src/midas_cpp/scripts/*.py` ) + +## Mobile version of MiDaS - Monocular Depth Estimation + +### Accuracy + +* MiDaS v2 small - ResNet50 default-decoder 384x384 +* MiDaS v2.1 small - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| MiDaS v2 small 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| MiDaS v2.1 small 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on nVidia GPU + +Inference speed excluding pre and post processing, batch=1, **Frames Per Second** (the higher - the better): + +| Model | Jetson Nano, FPS | RTX 2080Ti, FPS | +|---|---|---| +| MiDaS v2 small 384x384 | 1.6 | 117 | +| MiDaS v2.1 small 256x256 | 8.1 | 232 | +| SpeedUp, X times | **5x** | **2x** | + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d416fc00282aab146326bbba12a9274e1ba29b8 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh @@ -0,0 +1,5 @@ +mkdir src +catkin_make +source devel/setup.bash +echo $ROS_PACKAGE_PATH +chmod +x ./devel/setup.bash diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh new file mode 100644 index 0000000000000000000000000000000000000000..9c967d4e2dc7997da26399a063b5a54ecc314eb1 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh @@ -0,0 +1,5 @@ +mkdir ~/.ros +wget https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small-traced.pt +cp ./model-small-traced.pt ~/.ros/model-small-traced.pt + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh new file mode 100644 index 0000000000000000000000000000000000000000..b868112631e9d9bc7bccb601407dfc857b8a99d5 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh @@ -0,0 +1,34 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 +sudo apt-key adv --keyserver 'hkp://ha.pool.sks-keyservers.net:80' --recv-key 421C365BD9FF1F717815A3895523BAEEB01FA116 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-melodic-desktop-full + +printf "\nsource /opt/ros/melodic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python-rosinstall +sudo apt-get install python-catkin-tools +sudo apt-get install python-rospy +sudo apt-get install python-rosdep +sudo apt-get install python-roscd +sudo apt-get install python-pip \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh new file mode 100644 index 0000000000000000000000000000000000000000..d73ea1a3d92359819167d735a92d2a650b9bc245 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh @@ -0,0 +1,33 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-noetic-desktop-full + +printf "\nsource /opt/ros/noetic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python3-rosinstall +sudo apt-get install python3-catkin-tools +sudo apt-get install python3-rospy +sudo apt-get install python3-rosdep +sudo apt-get install python3-roscd +sudo apt-get install python3-pip \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0ef6073a9c9ce40744e1c81d557c1c68255b95e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh @@ -0,0 +1,16 @@ +cd ~/catkin_ws/src +catkin_create_pkg midas_cpp std_msgs roscpp cv_bridge sensor_msgs image_transport +cd ~/catkin_ws +catkin_make + +chmod +x ~/catkin_ws/devel/setup.bash +printf "\nsource ~/catkin_ws/devel/setup.bash" >> ~/.bashrc +source ~/catkin_ws/devel/setup.bash + + +sudo rosdep init +rosdep update +#rospack depends1 midas_cpp +roscd midas_cpp +#cat package.xml +#rospack depends midas_cpp \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a0d1583fffdc49216c625dfd07af2ae3b01a7a0 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh @@ -0,0 +1,2 @@ +source ~/catkin_ws/devel/setup.bash +roslaunch midas_cpp midas_cpp.launch model_name:="model-small-traced.pt" input_topic:="image_topic" output_topic:="midas_topic" out_orig_size:="true" \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..885341691d217f9c4c8fcb1e4ff568d87788c7b8 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt @@ -0,0 +1,189 @@ +cmake_minimum_required(VERSION 3.0.2) +project(midas_cpp) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED COMPONENTS + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs +) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + +list(APPEND CMAKE_PREFIX_PATH "~/libtorch") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python3.6/dist-packages/torch/lib") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python2.7/dist-packages/torch/lib") + +if(NOT EXISTS "~/libtorch") + if (EXISTS "/usr/local/lib/python3.6/dist-packages/torch") + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python3.6/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python3.6/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python3.6/dist-packages/torch) + + elseif (EXISTS "/usr/local/lib/python2.7/dist-packages/torch") + + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python2.7/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python2.7/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python2.7/dist-packages/torch) + endif() +endif() + + + +find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED) +include_directories( ${OpenCV_INCLUDE_DIRS} ) + +add_executable(midas_cpp src/main.cpp) +target_link_libraries(midas_cpp "${TORCH_LIBRARIES}" "${OpenCV_LIBS} ${catkin_LIBRARIES}") +set_property(TARGET midas_cpp PROPERTY CXX_STANDARD 14) + + + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES midas_cpp +# CATKIN_DEPENDS cv_bridge image_transport roscpp sensor_msgs std_msgs +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include + ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/midas_cpp.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/midas_cpp_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# scripts/my_python_script +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_midas_cpp.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +add_custom_command( + TARGET midas_cpp POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_BINARY_DIR}/midas_cpp + ${CMAKE_SOURCE_DIR}/midas_cpp +) \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch new file mode 100644 index 0000000000000000000000000000000000000000..88e86f42f668e76ad4976ec6794a8cb0f20cac65 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch new file mode 100644 index 0000000000000000000000000000000000000000..8817a4f4933c56986fe0edc0886b2fded3d3406d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml new file mode 100644 index 0000000000000000000000000000000000000000..9cac90eba75409bd170f73531c54c83c52ff047a --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml @@ -0,0 +1,77 @@ + + + midas_cpp + 0.1.0 + The midas_cpp package + + Alexey Bochkovskiy + MIT + https://github.com/isl-org/MiDaS/tree/master/ros + + + + + + + TODO + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + + + + + + + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py new file mode 100644 index 0000000000000000000000000000000000000000..6927ea7a83ac9309e5f883ee974a5dcfa8a2aa3b --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("midas_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py new file mode 100644 index 0000000000000000000000000000000000000000..20e235f6958d644b89383752ab18e9e2275f55e5 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener original - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("image_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener_original: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show_orig", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener_original', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py new file mode 100644 index 0000000000000000000000000000000000000000..8219cc8632484a2efd02984347c615efad6b78b2 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + + +def talker(): + rospy.init_node('talker', anonymous=True) + + use_camera = rospy.get_param('~use_camera', False) + input_video_file = rospy.get_param('~input_video_file','test.mp4') + # rospy.loginfo(f"Talker - params: use_camera={use_camera}, input_video_file={input_video_file}") + + # rospy.loginfo("Talker: Trying to open a video stream") + if use_camera == True: + cap = cv2.VideoCapture(0) + else: + cap = cv2.VideoCapture(input_video_file) + + pub = rospy.Publisher('image_topic', Image, queue_size=1) + rate = rospy.Rate(30) # 30hz + bridge = CvBridge() + + while not rospy.is_shutdown(): + ret, cv_image = cap.read() + if ret==False: + print("Talker: Video is over") + rospy.loginfo("Video is over") + return + + try: + image = bridge.cv2_to_imgmsg(cv_image, "bgr8") + except CvBridgeError as e: + rospy.logerr("Talker: cv2image conversion failed: ", e) + print(e) + continue + + rospy.loginfo("Talker: Publishing frame") + pub.publish(image) + rate.sleep() + +if __name__ == '__main__': + try: + talker() + except rospy.ROSInterruptException: + pass diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4fc72c6955f66af71c9cb1fc7a7b1f643129685 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include + +#include + +#include // One-stop header. + +#include +#include +#include +#include + +#include +#include + +// includes for OpenCV >= 3.x +#ifndef CV_VERSION_EPOCH +#include +#include +#include +#endif + +// OpenCV includes for OpenCV 2.x +#ifdef CV_VERSION_EPOCH +#include +#include +#include +#include +#endif + +static const std::string OPENCV_WINDOW = "Image window"; + +class Midas +{ + ros::NodeHandle nh_; + image_transport::ImageTransport it_; + image_transport::Subscriber image_sub_; + image_transport::Publisher image_pub_; + + torch::jit::script::Module module; + torch::Device device; + + auto ToTensor(cv::Mat img, bool show_output = false, bool unsqueeze = false, int unsqueeze_dim = 0) + { + //std::cout << "image shape: " << img.size() << std::endl; + at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte); + + if (unsqueeze) + { + tensor_image.unsqueeze_(unsqueeze_dim); + //std::cout << "tensors new shape: " << tensor_image.sizes() << std::endl; + } + + if (show_output) + { + std::cout << tensor_image.slice(2, 0, 1) << std::endl; + } + //std::cout << "tenor shape: " << tensor_image.sizes() << std::endl; + return tensor_image; + } + + auto ToInput(at::Tensor tensor_image) + { + // Create a vector of inputs. + return std::vector{tensor_image}; + } + + auto ToCvImage(at::Tensor tensor, int cv_type = CV_8UC3) + { + int width = tensor.sizes()[0]; + int height = tensor.sizes()[1]; + try + { + cv::Mat output_mat; + if (cv_type == CV_8UC4 || cv_type == CV_8UC3 || cv_type == CV_8UC2 || cv_type == CV_8UC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_32FC4 || cv_type == CV_32FC3 || cv_type == CV_32FC2 || cv_type == CV_32FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_64FC4 || cv_type == CV_64FC3 || cv_type == CV_64FC2 || cv_type == CV_64FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + + //show_image(output_mat, "converted image from tensor"); + return output_mat.clone(); + } + catch (const c10::Error& e) + { + std::cout << "an error has occured : " << e.msg() << std::endl; + } + return cv::Mat(height, width, CV_8UC3); + } + + std::string input_topic, output_topic, model_name; + bool out_orig_size; + int net_width, net_height; + torch::NoGradGuard guard; + at::Tensor mean, std; + at::Tensor output, tensor; + +public: + Midas() + : nh_(), it_(nh_), device(torch::Device(torch::kCPU)) + { + ros::param::param("~input_topic", input_topic, "image_topic"); + ros::param::param("~output_topic", output_topic, "midas_topic"); + ros::param::param("~model_name", model_name, "model-small-traced.pt"); + ros::param::param("~out_orig_size", out_orig_size, true); + ros::param::param("~net_width", net_width, 256); + ros::param::param("~net_height", net_height, 256); + + std::cout << ", input_topic = " << input_topic << + ", output_topic = " << output_topic << + ", model_name = " << model_name << + ", out_orig_size = " << out_orig_size << + ", net_width = " << net_width << + ", net_height = " << net_height << + std::endl; + + // Subscrive to input video feed and publish output video feed + image_sub_ = it_.subscribe(input_topic, 1, &Midas::imageCb, this); + image_pub_ = it_.advertise(output_topic, 1); + + std::cout << "Try to load torchscript model \n"; + + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + module = torch::jit::load(model_name); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + exit(0); + } + + std::cout << "ok\n"; + + try { + module.eval(); + torch::jit::getProfilingMode() = false; + torch::jit::setGraphExecutorOptimize(true); + + mean = torch::tensor({ 0.485, 0.456, 0.406 }); + std = torch::tensor({ 0.229, 0.224, 0.225 }); + + if (torch::hasCUDA()) { + std::cout << "cuda is available" << std::endl; + at::globalContext().setBenchmarkCuDNN(true); + device = torch::Device(torch::kCUDA); + module.to(device); + mean = mean.to(device); + std = std.to(device); + } + } + catch (const c10::Error& e) + { + std::cerr << " module initialization: " << e.msg() << std::endl; + } + } + + ~Midas() + { + } + + void imageCb(const sensor_msgs::ImageConstPtr& msg) + { + cv_bridge::CvImagePtr cv_ptr; + try + { + // sensor_msgs::Image to cv::Mat + cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::RGB8); + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // pre-processing + auto tensor_cpu = ToTensor(cv_ptr->image); // OpenCV-image -> Libtorch-tensor + + try { + tensor = tensor_cpu.to(device); // move to device (CPU or GPU) + + tensor = tensor.toType(c10::kFloat); + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor = tensor.unsqueeze(0); + tensor = at::upsample_bilinear2d(tensor, { net_height, net_width }, true); // resize + tensor = tensor.squeeze(0); + tensor = tensor.permute({ 1, 2, 0 }); // CHW -> HWC + + tensor = tensor.div(255).sub(mean).div(std); // normalization + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor.unsqueeze_(0); // CHW -> NCHW + } + catch (const c10::Error& e) + { + std::cerr << " pre-processing exception: " << e.msg() << std::endl; + return; + } + + auto input_to_net = ToInput(tensor); // input to the network + + // inference + output; + try { + output = module.forward(input_to_net).toTensor(); // run inference + } + catch (const c10::Error& e) + { + std::cerr << " module.forward() exception: " << e.msg() << std::endl; + return; + } + + output = output.detach().to(torch::kF32); + + // move to CPU temporary + at::Tensor output_tmp = output; + output_tmp = output_tmp.to(torch::kCPU); + + // normalization + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::min(); + + for (int i = 0; i < net_width * net_height; ++i) { + float val = output_tmp.data_ptr()[i]; + if (min_val > val) min_val = val; + if (max_val < val) max_val = val; + } + float range_val = max_val - min_val; + + output = output.sub(min_val).div(range_val).mul(255.0F).clamp(0, 255).to(torch::kF32); // .to(torch::kU8); + + // resize to the original size if required + if (out_orig_size) { + try { + output = at::upsample_bilinear2d(output.unsqueeze(0), { cv_ptr->image.size().height, cv_ptr->image.size().width }, true); + output = output.squeeze(0); + } + catch (const c10::Error& e) + { + std::cout << " upsample_bilinear2d() exception: " << e.msg() << std::endl; + return; + } + } + output = output.permute({ 1, 2, 0 }).to(torch::kCPU); + + int cv_type = CV_32FC1; // CV_8UC1; + auto cv_img = ToCvImage(output, cv_type); + + sensor_msgs::Image img_msg; + + try { + // cv::Mat -> sensor_msgs::Image + std_msgs::Header header; // empty header + header.seq = 0; // user defined counter + header.stamp = ros::Time::now();// time + //cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::MONO8, cv_img); + cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::TYPE_32FC1, cv_img); + + img_bridge.toImageMsg(img_msg); // cv_bridge -> sensor_msgs::Image + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // Output modified video stream + image_pub_.publish(img_msg); + } +}; + +int main(int argc, char** argv) +{ + ros::init(argc, argv, "midas", ros::init_options::AnonymousName); + Midas ic; + ros::spin(); + return 0; +} \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..a997c4261072d0d627598fe06a723fcc7522d347 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh @@ -0,0 +1,16 @@ +# place any test.mp4 file near with this file + +# roscore +# rosnode kill -a + +source ~/catkin_ws/devel/setup.bash + +roscore & +P1=$! +rosrun midas_cpp talker.py & +P2=$! +rosrun midas_cpp listener_original.py & +P3=$! +rosrun midas_cpp listener.py & +P4=$! +wait $P1 $P2 $P3 $P4 \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py new file mode 100644 index 0000000000000000000000000000000000000000..5696ef0547af093713ea416d18edd77d11879d0a --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py @@ -0,0 +1,277 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import torch +import utils +import cv2 +import argparse +import time + +import numpy as np + +from imutils.video import VideoStream +from midas.model_loader import default_models, load_model + +first_execution = True +def process(device, model, model_type, image, input_size, target_size, optimize, use_camera): + """ + Run the inference and interpolate. + + Args: + device (torch.device): the torch device used + model: the model used for inference + model_type: the type of the model + image: the image fed into the neural network + input_size: the size (width, height) of the neural network input (for OpenVINO) + target_size: the size (width, height) the neural network output is interpolated to + optimize: optimize the model to half-floats on CUDA? + use_camera: is the camera used? + + Returns: + the prediction + """ + global first_execution + + if "openvino" in model_type: + if first_execution or not use_camera: + print(f" Input resized to {input_size[0]}x{input_size[1]} before entering the encoder") + first_execution = False + + sample = [np.reshape(image, (1, 3, *input_size))] + prediction = model(sample)[model.output(0)][0] + prediction = cv2.resize(prediction, dsize=target_size, + interpolation=cv2.INTER_CUBIC) + else: + sample = torch.from_numpy(image).to(device).unsqueeze(0) + + if optimize and device == torch.device("cuda"): + if first_execution: + print(" Optimization to half-floats activated. Use with caution, because models like Swin require\n" + " float precision to work properly and may yield non-finite depth values to some extent for\n" + " half-floats.") + sample = sample.to(memory_format=torch.channels_last) + sample = sample.half() + + if first_execution or not use_camera: + height, width = sample.shape[2:] + print(f" Input resized to {width}x{height} before entering the encoder") + first_execution = False + + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=target_size[::-1], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + return prediction + + +def create_side_by_side(image, depth, grayscale): + """ + Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map + for better visibility. + + Args: + image: the RGB image + depth: the depth map + grayscale: use a grayscale colormap? + + Returns: + the image and depth map place side by side + """ + depth_min = depth.min() + depth_max = depth.max() + normalized_depth = 255 * (depth - depth_min) / (depth_max - depth_min) + normalized_depth *= 3 + + right_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3 + if not grayscale: + right_side = cv2.applyColorMap(np.uint8(right_side), cv2.COLORMAP_INFERNO) + + if image is None: + return right_side + else: + return np.concatenate((image, right_side), axis=1) + + +def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", optimize=False, side=False, height=None, + square=False, grayscale=False): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + model_type (str): the model type + optimize (bool): optimize the model to half-floats on CUDA? + side (bool): RGB and depth side by side in output images? + height (int): inference encoder image height + square (bool): resize to a square resolution? + grayscale (bool): use a grayscale colormap? + """ + print("Initialize") + + # select device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device: %s" % device) + + model, transform, net_w, net_h = load_model(device, model_path, model_type, optimize, height, square) + + # get input + if input_path is not None: + image_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(image_names) + else: + print("No input path specified. Grabbing images from camera.") + + # create output folder + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + + print("Start processing") + + if input_path is not None: + if output_path is None: + print("Warning: No output path specified. Images will be processed but not shown or stored anywhere.") + for index, image_name in enumerate(image_names): + + print(" Processing {} ({}/{})".format(image_name, index + 1, num_images)) + + # input + original_image_rgb = utils.read_image(image_name) # in [0, 1] + image = transform({"image": original_image_rgb})["image"] + + # compute + with torch.no_grad(): + prediction = process(device, model, model_type, image, (net_w, net_h), original_image_rgb.shape[1::-1], + optimize, False) + + # output + if output_path is not None: + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(image_name))[0] + '-' + model_type + ) + if not side: + utils.write_depth(filename, prediction, grayscale, bits=2) + else: + original_image_bgr = np.flip(original_image_rgb, 2) + content = create_side_by_side(original_image_bgr*255, prediction, grayscale) + cv2.imwrite(filename + ".png", content) + utils.write_pfm(filename + ".pfm", prediction.astype(np.float32)) + + else: + with torch.no_grad(): + fps = 1 + video = VideoStream(0).start() + time_start = time.time() + frame_index = 0 + while True: + frame = video.read() + if frame is not None: + original_image_rgb = np.flip(frame, 2) # in [0, 255] (flip required to get RGB) + image = transform({"image": original_image_rgb/255})["image"] + + prediction = process(device, model, model_type, image, (net_w, net_h), + original_image_rgb.shape[1::-1], optimize, True) + + original_image_bgr = np.flip(original_image_rgb, 2) if side else None + content = create_side_by_side(original_image_bgr, prediction, grayscale) + cv2.imshow('MiDaS Depth Estimation - Press Escape to close window ', content/255) + + if output_path is not None: + filename = os.path.join(output_path, 'Camera' + '-' + model_type + '_' + str(frame_index)) + cv2.imwrite(filename + ".png", content) + + alpha = 0.1 + if time.time()-time_start > 0: + fps = (1 - alpha) * fps + alpha * 1 / (time.time()-time_start) # exponential moving average + time_start = time.time() + print(f"\rFPS: {round(fps,2)}", end="") + + if cv2.waitKey(1) == 27: # Escape key + break + + frame_index += 1 + print() + + print("Finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default=None, + help='Folder with input images (if no input path is specified, images are tried to be grabbed ' + 'from camera)' + ) + + parser.add_argument('-o', '--output_path', + default=None, + help='Folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default=None, + help='Path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='dpt_beit_large_512', + help='Model type: ' + 'dpt_beit_large_512, dpt_beit_large_384, dpt_beit_base_384, dpt_swin2_large_384, ' + 'dpt_swin2_base_384, dpt_swin2_tiny_256, dpt_swin_large_384, dpt_next_vit_large_384, ' + 'dpt_levit_224, dpt_large_384, dpt_hybrid_384, midas_v21_384, midas_v21_small_256 or ' + 'openvino_midas_v21_small_256' + ) + + parser.add_argument('-s', '--side', + action='store_true', + help='Output images contain RGB and depth images side by side' + ) + + parser.add_argument('--optimize', dest='optimize', action='store_true', help='Use half-float optimization') + parser.set_defaults(optimize=False) + + parser.add_argument('--height', + type=int, default=None, + help='Preferred height of images feed into the encoder during inference. Note that the ' + 'preferred height may differ from the actual height, because an alignment to multiples of ' + '32 takes place. Many models support only the height chosen during training, which is ' + 'used automatically if this parameter is not set.' + ) + parser.add_argument('--square', + action='store_true', + help='Option to resize images to a square resolution by changing their widths when images are ' + 'fed into the encoder during inference. If this parameter is not set, the aspect ratio of ' + 'images is tried to be preserved if supported by the model.' + ) + parser.add_argument('--grayscale', + action='store_true', + help='Use a grayscale colormap instead of the inferno one. Although the inferno colormap, ' + 'which is used by default, is better for visibility, it does not allow storing 16-bit ' + 'depth values in PNGs but only 8-bit ones due to the precision limitation of this ' + 'colormap.' + ) + + args = parser.parse_args() + + + if args.model_weights is None: + args.model_weights = default_models[args.model_type] + + # set torch options + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, args.side, args.height, + args.square, args.grayscale) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5b5fe0e63668eab45a55b140826cb3762862b17c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md @@ -0,0 +1,147 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +### TensorFlow inference using `.pb` and `.onnx` models + +1. [Run inference on TensorFlow-model by using TensorFlow](#run-inference-on-tensorflow-model-by-using-tensorFlow) + +2. [Run inference on ONNX-model by using TensorFlow](#run-inference-on-onnx-model-by-using-tensorflow) + +3. [Make ONNX model from downloaded Pytorch model file](#make-onnx-model-from-downloaded-pytorch-model-file) + + +### Run inference on TensorFlow-model by using TensorFlow + +1) Download the model weights [model-f6b98070.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pb) +and [model-small.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.pb) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_pb.py + ``` + + Or run the small model: + + ```shell + python tf/run_pb.py --model_weights model-small.pb --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + +### Run inference on ONNX-model by using ONNX-Runtime + +1) Download the model weights [model-f6b98070.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.onnx) +and [model-small.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.onnx) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX Runtime +pip install onnxruntime==1.5.2 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_onnx.py + ``` + + Or run the small model: + + ```shell + python tf/run_onnx.py --model_weights model-small.onnx --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + + +### Make ONNX model from downloaded Pytorch model file + +1) Download the model weights [model-f6b98070.pt](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pt) and place the +file in the root folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install PyTorch TorchVision +pip install -I torch==1.7.0 torchvision==0.8.0 + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX-TensorFlow +git clone https://github.com/onnx/onnx-tensorflow.git +cd onnx-tensorflow +git checkout 095b51b88e35c4001d70f15f80f31014b592b81e +pip install -e . +``` + +#### Usage + +1) Run the converter: + + ```shell + python tf/make_onnx_model.py + ``` + +2) The resulting `model-f6b98070.onnx` file is written to the `/tf/` folder. + + +### Requirements + + The code was tested with Python 3.6.9, PyTorch 1.5.1, TensorFlow 2.2.0, TensorFlow-addons 0.8.3, ONNX 1.7.0, ONNX-TensorFlow (GitHub-master-17.07.2020) and OpenCV 4.3.0. + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2019, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + +### License + +MIT License + + diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d14b0e4e1d2ea70fa315fd7ca7dfd72440a19376 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py @@ -0,0 +1,112 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import ntpath +import glob +import torch +import utils +import cv2 +import numpy as np +from torchvision.transforms import Compose, Normalize +from torchvision import transforms + +from shutil import copyfile +import fileinput +import sys +sys.path.append(os.getcwd() + '/..') + +def modify_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename, modify_filename+'.bak') + + with open(modify_filename, 'r') as file : + filedata = file.read() + + filedata = filedata.replace('align_corners=True', 'align_corners=False') + filedata = filedata.replace('import torch.nn as nn', 'import torch.nn as nn\nimport torchvision.models as models') + filedata = filedata.replace('torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")', 'models.resnext101_32x8d()') + + with open(modify_filename, 'w') as file: + file.write(filedata) + +def restore_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename+'.bak', modify_filename) + +modify_file() + +from midas.midas_net import MidasNet +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +restore_file() + + +class MidasNet_preprocessing(MidasNet): + """Network for monocular depth estimation. + """ + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + return MidasNet.forward(self, x) + + +def run(model_path): + """Run MonoDepthNN to compute depth maps. + + Args: + model_path (str): path to saved model + """ + print("initialize") + + # select device + + # load network + #model = MidasNet(model_path, non_negative=True) + model = MidasNet_preprocessing(model_path, non_negative=True) + + model.eval() + + print("start processing") + + # input + img_input = np.zeros((3, 384, 384), np.float32) + + # compute + with torch.no_grad(): + sample = torch.from_numpy(img_input).unsqueeze(0) + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img_input.shape[:2], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + torch.onnx.export(model, sample, ntpath.basename(model_path).rsplit('.', 1)[0]+'.onnx', opset_version=9) + + print("finished") + + +if __name__ == "__main__": + # set paths + # MODEL_PATH = "model.pt" + MODEL_PATH = "../model-f6b98070.pt" + + # compute depth maps + run(MODEL_PATH) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..7107b99969a127f951814f743d5c562a436b2430 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py @@ -0,0 +1,119 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import sys +import numpy as np +import argparse + +import onnx +import onnxruntime as rt + +from transforms import Resize, NormalizeImage, PrepareForNet + + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # select device + device = "CUDA:0" + #device = "CPU" + print("device: %s" % device) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + print("loading model...") + model = rt.InferenceSession(model_path) + input_name = model.get_inputs()[0].name + output_name = model.get_outputs()[0].name + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + output = model.run([output_name], {input_name: img_input.reshape(1, 3, net_h, net_w).astype(np.float32)})[0] + prediction = np.array(output).reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.onnx', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py new file mode 100644 index 0000000000000000000000000000000000000000..e46254f7b37f72e7d87672d70fd4b2f393ad7658 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py @@ -0,0 +1,135 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import argparse + +import tensorflow as tf + +from transforms import Resize, NormalizeImage, PrepareForNet + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # the runtime initialization will not allocate all memory on the device to avoid out of GPU memory + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + try: + for gpu in gpus: + #tf.config.experimental.set_memory_growth(gpu, True) + tf.config.experimental.set_virtual_device_configuration(gpu, + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4000)]) + except RuntimeError as e: + print(e) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + graph_def = tf.compat.v1.GraphDef() + with tf.io.gfile.GFile(model_path, 'rb') as f: + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + + model_operations = tf.compat.v1.get_default_graph().get_operations() + input_node = '0:0' + output_layer = model_operations[len(model_operations) - 1].name + ':0' + print("Last layer name: ", output_layer) + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + with tf.compat.v1.Session() as sess: + try: + # load images + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + prob_tensor = sess.graph.get_tensor_by_name(output_layer) + prediction, = sess.run(prob_tensor, {input_node: [img_input] }) + prediction = prediction.reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + except KeyError: + print ("Couldn't find input node: ' + input_node + ' or output layer: " + output_layer + ".") + exit(-1) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.pb', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9a54bd55f5e31a90fad21242efbfda5a6cc1a7 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py @@ -0,0 +1,82 @@ +import numpy as np +import sys +import cv2 + + +def write_pfm(path, image, scale=1): + """Write pfm file. + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + +def read_image(path): + """Read image and output RGB image (0-1). + Args: + path (str): path to file + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = 0 + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3976fd97dfe6a9dc7d4fa144be8fcb0b18b2db --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py @@ -0,0 +1,199 @@ +"""Utils for monoDepth. +""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, grayscale, bits=1): + """Write depth map to png file. + + Args: + path (str): filepath without extension + depth (array): depth + grayscale (bool): use a grayscale colormap? + """ + if not grayscale: + bits = 1 + + if not np.isfinite(depth).all(): + depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0) + print("WARNING: Non-finite depth values present") + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.dtype) + + if not grayscale: + out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/third_party/flux/annotator/zoe/zoedepth/models/builder.py b/third_party/flux/annotator/zoe/zoedepth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0818311b642561712a03a66655c638ce09a04cca --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/builder.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module +from .depth_model import DepthModel + +def build_model(config) -> DepthModel: + """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. + This function should be used to construct models for training and evaluation. + + Args: + config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. + + Returns: + torch.nn.Module: Model corresponding to name and version as specified in config + """ + module_name = f"zoedepth.models.{config.model}" + try: + module = import_module(module_name) + except ModuleNotFoundError as e: + # print the original error message + print(e) + raise ValueError( + f"Model {config.model} not found. Refer above error for details.") from e + try: + get_version = getattr(module, "get_version") + except AttributeError as e: + raise ValueError( + f"Model {config.model} has no get_version function.") from e + return get_version(config.version_name).build_from_config(config) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/depth_model.py b/third_party/flux/annotator/zoe/zoedepth/models/depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/depth_model.py @@ -0,0 +1,152 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +import PIL.Image +from PIL import Image +from typing import Union + + +class DepthModel(nn.Module): + def __init__(self): + super().__init__() + self.device = 'cpu' + + def to(self, device) -> nn.Module: + self.device = device + return super().to(device) + + def forward(self, x, *args, **kwargs): + raise NotImplementedError + + def _infer(self, x: torch.Tensor): + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + return self(x)['metric_depth'] + + def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: + """ + Inference interface for the model with padding augmentation + Padding augmentation fixes the boundary artifacts in the output depth map. + Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. + This augmentation pads the input image and crops the prediction back to the original size / view. + + Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to pad the input or not. Defaults to True. + fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. + fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. + upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. + padding_mode (str, optional): padding mode. Defaults to "reflect". + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # assert x is nchw and c = 3 + assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) + assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) + + if pad_input: + assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" + pad_h = int(np.sqrt(x.shape[2]/2) * fh) + pad_w = int(np.sqrt(x.shape[3]/2) * fw) + padding = [pad_w, pad_w] + if pad_h > 0: + padding += [pad_h, pad_h] + + x = F.pad(x, padding, mode=padding_mode, **kwargs) + out = self._infer(x) + if out.shape[-2:] != x.shape[-2:]: + out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) + if pad_input: + # crop to the original size, handling the case where pad_h and pad_w is 0 + if pad_h > 0: + out = out[:, :, pad_h:-pad_h,:] + if pad_w > 0: + out = out[:, :, :, pad_w:-pad_w] + return out + + def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model with horizontal flip augmentation + Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # infer with horizontal flip and average + out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) + out = (out + torch.flip(out_flip, dims=[3])) / 2 + return out + + def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + if with_flip_aug: + return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) + else: + return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + + @torch.no_grad() + def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: + """ + Inference interface for the model for PIL image + Args: + pil_img (PIL.Image.Image): input PIL image + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". + """ + x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) + out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) + if output_type == "numpy": + return out_tensor.squeeze().cpu().numpy() + elif output_type == "pil": + # uint16 is required for depth pil image + out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) + return Image.fromarray(out_16bit_numpy) + elif output_type == "tensor": + return out_tensor.squeeze().cpu() + else: + raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/attractor.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/attractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b50e806232fae8bbbbe2e93b1d3f67d79f783d61 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/attractor.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/dist_layers.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/dist_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adb22d9b60f29eb972ba73b23961f07b38b5a613 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/dist_layers.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/localbins_layers.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/localbins_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..529251a9e4bc9ea3c337a0c07b409c758e4de2f9 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/layers/__pycache__/localbins_layers.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/layers/attractor.py b/third_party/flux/annotator/zoe/zoedepth/models/layers/attractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/layers/attractor.py @@ -0,0 +1,208 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +@torch.jit.script +def exp_attractor(dx, alpha: float = 300, gamma: int = 2): + """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) + + +@torch.jit.script +def inv_attractor(dx, alpha: float = 300, gamma: int = 2): + """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center + This is the default one according to the accompanying paper. + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return dx.div(1+alpha*dx.pow(gamma)) + + +class AttractorLayer(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm + nn.ReLU(inplace=True) + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + eps = 1e-3 + A = A + eps + n, c, h, w = A.shape + A = A.view(n, self.n_attractors, 2, h, w) + A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w + A_normed = A[:, :, 0, ...] # n, na, h, w + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func(dist(A_normed.unsqueeze( + 2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + # .shape N, nbins, h, w + delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = (self.max_depth - self.min_depth) * \ + b_new_centers + self.min_depth + B_centers, _ = torch.sort(B_centers, dim=1) + B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) + return b_new_centers, B_centers + + +class AttractorLayerUnnormed(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are unbounded + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + n, c, h, w = A.shape + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func( + dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + delta_c += dist(A[:, i, ...].unsqueeze(1) - + b_centers) # .shape N, nbins, h, w + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = b_new_centers + + return b_new_centers, B_centers diff --git a/third_party/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py b/third_party/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/layers/dist_layers.py @@ -0,0 +1,121 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +def log_binom(n, k, eps=1e-7): + """ log(nCk) using stirling approximation """ + n = n + eps + k = k + eps + return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) + + +class LogBinomial(nn.Module): + def __init__(self, n_classes=256, act=torch.softmax): + """Compute log binomial distribution for n_classes + + Args: + n_classes (int, optional): number of output classes. Defaults to 256. + """ + super().__init__() + self.K = n_classes + self.act = act + self.register_buffer('k_idx', torch.arange( + 0, n_classes).view(1, -1, 1, 1)) + self.register_buffer('K_minus_1', torch.Tensor( + [self.K-1]).view(1, -1, 1, 1)) + + def forward(self, x, t=1., eps=1e-4): + """Compute log binomial distribution for x + + Args: + x (torch.Tensor - NCHW): probabilities + t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. + eps (float, optional): Small number for numerical stability. Defaults to 1e-4. + + Returns: + torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) + """ + if x.ndim == 3: + x = x.unsqueeze(1) # make it nchw + + one_minus_x = torch.clamp(1 - x, eps, 1) + x = torch.clamp(x, eps, 1) + y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ + torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) + return self.act(y/t, dim=1) + + +class ConditionalLogBinomial(nn.Module): + def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): + """Conditional Log Binomial distribution + + Args: + in_features (int): number of input channels in main feature + condition_dim (int): number of input channels in condition feature + n_classes (int, optional): Number of classes. Defaults to 256. + bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. + p_eps (float, optional): small eps value. Defaults to 1e-4. + max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. + min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. + """ + super().__init__() + self.p_eps = p_eps + self.max_temp = max_temp + self.min_temp = min_temp + self.log_binomial_transform = LogBinomial(n_classes, act=act) + bottleneck = (in_features + condition_dim) // bottleneck_factor + self.mlp = nn.Sequential( + nn.Conv2d(in_features + condition_dim, bottleneck, + kernel_size=1, stride=1, padding=0), + nn.GELU(), + # 2 for p linear norm, 2 for t linear norm + nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), + nn.Softplus() + ) + + def forward(self, x, cond): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Main feature + cond (torch.Tensor - NCHW): condition feature + + Returns: + torch.Tensor: Output log binomial distribution + """ + pt = self.mlp(torch.concat((x, cond), dim=1)) + p, t = pt[:, :2, ...], pt[:, 2:, ...] + + p = p + self.p_eps + p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) + + t = t + self.p_eps + t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) + t = t.unsqueeze(1) + t = (self.max_temp - self.min_temp) * t + self.min_temp + + return self.log_binomial_transform(p, t) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py b/third_party/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/layers/localbins_layers.py @@ -0,0 +1,169 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class SeedBinRegressor(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Min depth value. Defaults to 1e-3. + max_depth (float, optional): Max depth value. Defaults to 10. + """ + super().__init__() + self.version = "1_1" + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B = self._net(x) + eps = 1e-3 + B = B + eps + B_widths_normed = B / B.sum(dim=1, keepdim=True) + B_widths = (self.max_depth - self.min_depth) * \ + B_widths_normed # .shape NCHW + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad( + B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) + return B_widths_normed, B_centers + + +class SeedBinRegressorUnnormed(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are unbounded + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + """ + super().__init__() + self.version = "1_1" + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B_centers = self._net(x) + return B_centers, B_centers + + +class Projector(nn.Module): + def __init__(self, in_features, out_features, mlp_dim=128): + """Projector MLP + + Args: + in_features (int): input channels + out_features (int): output channels + mlp_dim (int, optional): hidden dimension. Defaults to 128. + """ + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, out_features, 1, 1, 0), + ) + + def forward(self, x): + return self._net(x) + + + +class LinearSplitter(nn.Module): + def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): + super().__init__() + + self.prev_nbins = prev_nbins + self.split_factor = split_factor + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.GELU(), + nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), + nn.ReLU() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + x : feature block; shape - n, c, h, w + b_prev : previous bin widths normed; shape - n, prev_nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + S = self._net(x) + eps = 1e-3 + S = S + eps + n, c, h, w = S.shape + S = S.view(n, self.prev_nbins, self.split_factor, h, w) + S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits + + b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) + + + b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees + # print(b_prev.shape, S_normed.shape) + # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat? + b = b_prev.unsqueeze(2) * S_normed + b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w + + # calculate bin centers for loss calculation + B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) + return b, B_centers \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py b/third_party/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/layers/patch_transformer.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): + """ViT-like transformer block + + Args: + in_channels (int): Input channels + patch_size (int, optional): patch size. Defaults to 10. + embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. + num_heads (int, optional): number of attention heads. Defaults to 4. + use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. + """ + super(PatchTransformerEncoder, self).__init__() + self.use_class_token = use_class_token + encoder_layers = nn.TransformerEncoderLayer( + embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): + """Generate positional encodings + + Args: + sequence_length (int): Sequence length + embedding_dim (int): Embedding dimension + + Returns: + torch.Tensor SBE: Positional encodings + """ + position = torch.arange( + 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) + index = torch.arange( + 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) + div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) + pos_encoding = position * div_term + pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) + pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) + return pos_encoding + + + def forward(self, x): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Input feature tensor + + Returns: + torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim + """ + embeddings = self.embedding_convPxP(x).flatten( + 2) # .shape = n,c,s = n, embedding_dim, s + if self.use_class_token: + # extra special token at start ? + embeddings = nn.functional.pad(embeddings, (1, 0)) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + S, N, E = embeddings.shape + embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x diff --git a/third_party/flux/annotator/zoe/zoedepth/models/model_io.py b/third_party/flux/annotator/zoe/zoedepth/models/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..78b6579631dd847ac76651238cb5a948b5a66286 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/model_io.py @@ -0,0 +1,92 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch + +def load_state_dict(model, state_dict): + """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. + + DataParallel prefixes state_dict keys with 'module.' when saving. + If the model is not a DataParallel model but the state_dict is, then prefixes are removed. + If the model is a DataParallel model but the state_dict is not, then prefixes are added. + """ + state_dict = state_dict.get('model', state_dict) + # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' + + do_prefix = isinstance( + model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) + state = {} + for k, v in state_dict.items(): + if k.startswith('module.') and not do_prefix: + k = k[7:] + + if not k.startswith('module.') and do_prefix: + k = 'module.' + k + + state[k] = v + + model.load_state_dict(state) + print("Loaded successfully") + return model + + +def load_wts(model, checkpoint_path): + ckpt = torch.load(checkpoint_path, map_location='cpu') + return load_state_dict(model, ckpt) + + +def load_state_dict_from_url(model, url, **kwargs): + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) + return load_state_dict(model, state_dict) + + +def load_state_from_resource(model, resource: str): + """Loads weights to the model from a given resource. A resource can be of following types: + 1. URL. Prefixed with "url::" + e.g. url::http(s)://url.resource.com/ckpt.pt + + 2. Local path. Prefixed with "local::" + e.g. local::/path/to/ckpt.pt + + + Args: + model (torch.nn.Module): Model + resource (str): resource string + + Returns: + torch.nn.Module: Model with loaded weights + """ + print(f"Using pretrained resource {resource}") + + if resource.startswith('url::'): + url = resource.split('url::')[1] + return load_state_dict_from_url(model, url, progress=True) + + elif resource.startswith('local::'): + path = resource.split('local::')[1] + return load_wts(model, path) + + else: + raise ValueError("Invalid resource type, only url:: and local:: are supported") + \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bf3ad8449fd3045264fe8638af4af3ff5118afc Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__pycache__/zoedepth_v1.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__pycache__/zoedepth_v1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28cae54a66fc178eea881fd7ff5caa52084e8e37 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/__pycache__/zoedepth_v1.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json new file mode 100644 index 0000000000000000000000000000000000000000..3112ed78c89f00e1d13f5d6e5be87cd3216b6dc7 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json @@ -0,0 +1,58 @@ +{ + "model": { + "name": "ZoeDepth", + "version_name": "v1", + "n_bins": 64, + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "midas_model_type" : "DPT_BEiT_L_384", + "min_temp": 0.0212, + "max_temp": 50.0, + "output_distribution": "logbinomial", + "memory_efficient": true, + "inverse_midas": false, + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 0.2, + "w_reg": 0, + "w_grad": 0, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "midas_lr_factor": 1, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10, + "freeze_midas_bn": true + + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null, + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null + } +} \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json new file mode 100644 index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json @@ -0,0 +1,22 @@ +{ + "model": { + "bin_centers_type": "normed", + "img_size": [384, 768] + }, + + "train": { + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" + } +} \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..bc931b059d6165c84e8ff4f09d5c62d19930cee9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py @@ -0,0 +1,250 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + b, c, h, w = x.shape + # print("input shape ", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_nk_v1 import ZoeDepthNK + +all_versions = { + "v1": ZoeDepthNK, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json new file mode 100644 index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json @@ -0,0 +1,67 @@ +{ + "model": { + "name": "ZoeDepthNK", + "version_name": "v1", + "bin_conf" : [ + { + "name": "nyu", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 10.0 + }, + { + "name": "kitti", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 80.0 + } + ], + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "min_temp": 0.0212, + "max_temp": 50.0, + "memory_efficient": true, + "midas_model_type" : "DPT_BEiT_L_384", + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth_nk", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 100, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "w_grad": 0, + "w_reg": 0, + "midas_lr_factor": 10, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10 + }, + + "infer": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false, + "force_keep_ar": true + }, + + "eval": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false + } +} \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..7368ae8031188a9f946d9d3f29633c96e791e68e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py @@ -0,0 +1,333 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn + +from zoedepth.models.depth_model import DepthModel +from zoedepth.models.base_models.midas import MidasCore +from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed +from zoedepth.models.layers.dist_layers import ConditionalLogBinomial +from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder +from zoedepth.models.model_io import load_state_from_resource + + +class ZoeDepthNK(DepthModel): + def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', + min_temp=5, max_temp=50, + memory_efficient=False, train_midas=True, + is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + + bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: + "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) + + The length of this list determines the number of metric heads. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + + memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. + + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + + """ + + super().__init__() + + self.core = core + self.bin_conf = bin_conf + self.min_temp = min_temp + self.max_temp = max_temp + self.memory_efficient = memory_efficient + self.train_midas = train_midas + self.is_midas_pretrained = is_midas_pretrained + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.inverse_midas = inverse_midas + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + # self.scales = [16, 8, 4, 2] # spatial scale factors + + self.conv2 = nn.Conv2d( + btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) + + # Transformer classifier on the bottleneck + self.patch_transformer = PatchTransformerEncoder( + btlnck_features, 1, 128, use_class_token=True) + self.mlp_classifier = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 2) + ) + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + self.bin_centers_type = bin_centers_type + # We have bins for each bin conf. + # Create a map (ModuleDict) of 'name' -> seed_bin_regressor + self.seed_bin_regressors = nn.ModuleDict( + {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for conf in bin_conf} + ) + + self.seed_projector = Projector( + btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + for num_out in num_out_features + ]) + + # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) + self.attractors = nn.ModuleDict( + {conf['name']: nn.ModuleList([ + Attractor(bin_embedding_dim, n_attractors[i], + mlp_dim=bin_embedding_dim, alpha=attractor_alpha, + gamma=attractor_gamma, kind=attractor_kind, + attractor_type=attractor_type, memory_efficient=memory_efficient, + min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for i in range(len(n_attractors)) + ]) + for conf in bin_conf} + ) + + last_in = N_MIDAS_OUT + # conditional log binomial for each bin conf + self.conditional_log_binomial = nn.ModuleDict( + {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) + for conf in bin_conf} + ) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. + return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. Defaults to False. + return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. + + Returns: + dict: Dictionary of outputs with keys: + - "rel_depth": Relative depth map of shape (B, 1, H, W) + - "metric_depth": Metric depth map of shape (B, 1, H, W) + - "domain_logits": Domain logits of shape (B, 2) + - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True + - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True + """ + b, c, h, w = x.shape + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + + # Predict which path to take + embedding = self.patch_transformer(x)[0] # N, E + domain_logits = self.mlp_classifier(embedding) # N, 2 + domain_vote = torch.softmax(domain_logits.sum( + dim=0, keepdim=True), dim=-1) # 1, 2 + + # Get the path + bin_conf_name = ["nyu", "kitti"][torch.argmax( + domain_vote, dim=-1).squeeze().item()] + + try: + conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] + except IndexError: + raise ValueError( + f"bin_conf_name {bin_conf_name} not found in bin_confs") + + min_depth = conf['min_depth'] + max_depth = conf['max_depth'] + + seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] + _, seed_b_centers = seed_bin_regressor(x) + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) + else: + b_prev = seed_b_centers + prev_b_embedding = self.seed_projector(x) + + attractors = self.attractors[bin_conf_name] + for projector, attractor, x in zip(self.projectors, attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b + prev_b_embedding = b_embedding + + last = outconv_activation + + b_centers = nn.functional.interpolate( + b_centers, last.shape[-2:], mode='bilinear', align_corners=True) + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + + clb = self.conditional_log_binomial[bin_conf_name] + x = clb(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + output = dict(domain_logits=domain_logits, metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + def get_rel_pos_params(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + encoder_params = get_enc_params_except_rel_pos() + rel_pos_params = get_rel_pos_params() + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 + param_conf.extend([ + {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, + {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, + {'params': midas_params, 'lr': lr / midas_lr_factor} + ]) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + param_conf.append({'params': remaining_params, 'lr': lr}) + return param_conf + + def get_conf_parameters(self, conf_name): + """ + Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + params = [] + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + if bin_conf_name == conf_name: + params += list(module.parameters()) + return params + + def freeze_conf(self, conf_name): + """ + Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = False + + def unfreeze_conf(self, conf_name): + """ + Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = True + + def freeze_all_confs(self): + """ + Freezes all the parameters of all the ModuleDicts children + """ + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + for p in module.parameters(): + p.requires_grad = False + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepthNK(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepthNK.build(**config) diff --git a/third_party/flux/annotator/zoe/zoedepth/trainers/base_trainer.py b/third_party/flux/annotator/zoe/zoedepth/trainers/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..33fbbea3a7d49efe11b005adb5127f441eabfaf6 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/trainers/base_trainer.py @@ -0,0 +1,326 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os +import uuid +import warnings +from datetime import datetime as dt +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +import wandb +from tqdm import tqdm + +from zoedepth.utils.config import flatten +from zoedepth.utils.misc import RunningAverageDict, colorize, colors + + +def is_rank_zero(args): + return args.rank == 0 + + +class BaseTrainer: + def __init__(self, config, model, train_loader, test_loader=None, device=None): + """ Base Trainer class for training a model.""" + + self.config = config + self.metric_criterion = "abs_rel" + if device is None: + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + self.device = device + self.model = model + self.train_loader = train_loader + self.test_loader = test_loader + self.optimizer = self.init_optimizer() + self.scheduler = self.init_scheduler() + + def resize_to_target(self, prediction, target): + if prediction.shape[2:] != target.shape[-2:]: + prediction = nn.functional.interpolate( + prediction, size=target.shape[-2:], mode="bilinear", align_corners=True + ) + return prediction + + def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"): + import glob + import os + + from zoedepth.models.model_io import load_wts + + if hasattr(self.config, "checkpoint"): + checkpoint = self.config.checkpoint + elif hasattr(self.config, "ckpt_pattern"): + pattern = self.config.ckpt_pattern + matches = glob.glob(os.path.join( + checkpoint_dir, f"*{pattern}*{ckpt_type}*")) + if not (len(matches) > 0): + raise ValueError(f"No matches found for the pattern {pattern}") + checkpoint = matches[0] + else: + return + model = load_wts(self.model, checkpoint) + # TODO : Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it. + print("Loaded weights from {0}".format(checkpoint)) + warnings.warn( + "Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.") + self.model = model + + def init_optimizer(self): + m = self.model.module if self.config.multigpu else self.model + + if self.config.same_lr: + print("Using same LR") + if hasattr(m, 'core'): + m.core.unfreeze() + params = self.model.parameters() + else: + print("Using diff LR") + if not hasattr(m, 'get_lr_params'): + raise NotImplementedError( + f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.") + + params = m.get_lr_params(self.config.lr) + + return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd) + + def init_scheduler(self): + lrs = [l['lr'] for l in self.optimizer.param_groups] + return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader), + cycle_momentum=self.config.cycle_momentum, + base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase) + + def train_on_batch(self, batch, train_step): + raise NotImplementedError + + def validate_on_batch(self, batch, val_step): + raise NotImplementedError + + def raise_if_nan(self, losses): + for key, value in losses.items(): + if torch.isnan(value): + raise ValueError(f"{key} is NaN, Stopping training") + + @property + def iters_per_epoch(self): + return len(self.train_loader) + + @property + def total_iters(self): + return self.config.epochs * self.iters_per_epoch + + def should_early_stop(self): + if self.config.get('early_stop', False) and self.step > self.config.early_stop: + return True + + def train(self): + print(f"Training {self.config.name}") + if self.config.uid is None: + self.config.uid = str(uuid.uuid4()).split('-')[-1] + run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}" + self.config.run_id = run_id + self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}" + self.should_write = ((not self.config.distributed) + or self.config.rank == 0) + self.should_log = self.should_write # and logging + if self.should_log: + tags = self.config.tags.split( + ',') if self.config.tags != '' else None + wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root, + tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork")) + + self.model.train() + self.step = 0 + best_loss = np.inf + validate_every = int(self.config.validate_every * self.iters_per_epoch) + + + if self.config.prefetch: + + for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader): + pass + + losses = {} + def stringify_losses(L): return "; ".join(map( + lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items())) + for epoch in range(self.config.epochs): + if self.should_early_stop(): + break + + self.epoch = epoch + ################################# Train loop ########################################################## + if self.should_log: + wandb.log({"Epoch": epoch}, step=self.step) + pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader) + for i, batch in pbar: + if self.should_early_stop(): + print("Early stopping") + break + # print(f"Batch {self.step+1} on rank {self.config.rank}") + losses = self.train_on_batch(batch, i) + # print(f"trained batch {self.step+1} on rank {self.config.rank}") + + self.raise_if_nan(losses) + if is_rank_zero(self.config) and self.config.print_losses: + pbar.set_description( + f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}") + self.scheduler.step() + + if self.should_log and self.step % 50 == 0: + wandb.log({f"Train/{name}": loss.item() + for name, loss in losses.items()}, step=self.step) + + self.step += 1 + + ######################################################################################################## + + if self.test_loader: + if (self.step % validate_every) == 0: + self.model.eval() + if self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_latest.pt") + + ################################# Validation loop ################################################## + # validate on the entire validation set in every process but save only from rank 0, I know, inefficient, but avoids divergence of processes + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log( + {f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step) + + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + if self.config.distributed: + dist.barrier() + # print(f"Validated: {metrics} on device {self.config.rank}") + + # print(f"Finished step {self.step} on device {self.config.rank}") + ################################################################################################# + + # Save / validate at the end + self.step += 1 # log as final point + self.model.eval() + self.save_checkpoint(f"{self.config.experiment_id}_latest.pt") + if self.test_loader: + + ################################# Validation loop ################################################## + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log({f"Test/{name}": tloss for name, + tloss in test_losses.items()}, step=self.step) + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + def validate(self): + with torch.no_grad(): + losses_avg = RunningAverageDict() + metrics_avg = RunningAverageDict() + for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)): + metrics, losses = self.validate_on_batch(batch, val_step=i) + + if losses: + losses_avg.update(losses) + if metrics: + metrics_avg.update(metrics) + + return metrics_avg.get_value(), losses_avg.get_value() + + def save_checkpoint(self, filename): + if not self.should_write: + return + root = self.config.save_dir + if not os.path.isdir(root): + os.makedirs(root) + + fpath = os.path.join(root, filename) + m = self.model.module if self.config.multigpu else self.model + torch.save( + { + "model": m.state_dict(), + "optimizer": None, # TODO : Change to self.optimizer.state_dict() if resume support is needed, currently None to reduce file size + "epoch": self.epoch + }, fpath) + + def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None): + if not self.should_log: + return + + if min_depth is None: + try: + min_depth = self.config.min_depth + max_depth = self.config.max_depth + except AttributeError: + min_depth = None + max_depth = None + + depth = {k: colorize(v, vmin=min_depth, vmax=max_depth) + for k, v in depth.items()} + scalar_field = {k: colorize( + v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()} + images = {**rgb, **depth, **scalar_field} + wimages = { + prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]} + wandb.log(wimages, step=self.step) + + def log_line_plot(self, data): + if not self.should_log: + return + + plt.plot(data) + plt.ylabel("Scale factors") + wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step) + plt.close() + + def log_bar_plot(self, title, labels, values): + if not self.should_log: + return + + data = [[label, val] for (label, val) in zip(labels, values)] + table = wandb.Table(data=data, columns=["label", "value"]) + wandb.log({title: wandb.plot.bar(table, "label", + "value", title=title)}, step=self.step) diff --git a/third_party/flux/annotator/zoe/zoedepth/trainers/builder.py b/third_party/flux/annotator/zoe/zoedepth/trainers/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a663541b08912ebedce21a68c7599ce4c06e85d0 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/trainers/builder.py @@ -0,0 +1,48 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module + + +def get_trainer(config): + """Builds and returns a trainer based on the config. + + Args: + config (dict): the config dict (typically constructed using utils.config.get_config) + config.trainer (str): the name of the trainer to use. The module named "{config.trainer}_trainer" must exist in trainers root module + + Raises: + ValueError: If the specified trainer does not exist under trainers/ folder + + Returns: + Trainer (inherited from zoedepth.trainers.BaseTrainer): The Trainer object + """ + assert "trainer" in config and config.trainer is not None and config.trainer != '', "Trainer not specified. Config: {0}".format( + config) + try: + Trainer = getattr(import_module( + f"zoedepth.trainers.{config.trainer}_trainer"), 'Trainer') + except ModuleNotFoundError as e: + raise ValueError(f"Trainer {config.trainer}_trainer not found.") from e + return Trainer diff --git a/third_party/flux/annotator/zoe/zoedepth/trainers/loss.py b/third_party/flux/annotator/zoe/zoedepth/trainers/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5a1c15cdf5628c1474c566fdc6e58159d7f5ab --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/trainers/loss.py @@ -0,0 +1,316 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.amp as amp +import numpy as np + + +KEY_OUTPUT = 'metric_depth' + + +def extract_key(prediction, key): + if isinstance(prediction, dict): + return prediction[key] + return prediction + + +# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) +class SILogLoss(nn.Module): + """SILog loss (pixel-wise)""" + def __init__(self, beta=0.15): + super(SILogLoss, self).__init__() + self.name = 'SILog' + self.beta = beta + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + if target.ndim == 3: + target = target.unsqueeze(1) + + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + input = input[mask] + target = target[mask] + + with amp.autocast(enabled=False): # amp causes NaNs in this loss function + alpha = 1e-7 + g = torch.log(input + alpha) - torch.log(target + alpha) + + # n, c, h, w = g.shape + # norm = 1/(h*w) + # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 + + Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) + + loss = 10 * torch.sqrt(Dg) + + if torch.isnan(loss): + print("Nan SILog loss") + print("input:", input.shape) + print("target:", target.shape) + print("G", torch.sum(torch.isnan(g))) + print("Input min max", torch.min(input), torch.max(input)) + print("Target min max", torch.min(target), torch.max(target)) + print("Dg", torch.isnan(Dg)) + print("loss", torch.isnan(loss)) + + if not return_interpolated: + return loss + + return loss, intr_input + + +def grad(x): + # x.shape : n, c, h, w + diff_x = x[..., 1:, 1:] - x[..., 1:, :-1] + diff_y = x[..., 1:, 1:] - x[..., :-1, 1:] + mag = diff_x**2 + diff_y**2 + # angle_ratio + angle = torch.atan(diff_y / (diff_x + 1e-10)) + return mag, angle + + +def grad_mask(mask): + return mask[..., 1:, 1:] & mask[..., 1:, :-1] & mask[..., :-1, 1:] + + +class GradL1Loss(nn.Module): + """Gradient loss""" + def __init__(self): + super(GradL1Loss, self).__init__() + self.name = 'GradL1' + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + grad_gt = grad(target) + grad_pred = grad(input) + mask_g = grad_mask(mask) + + loss = nn.functional.l1_loss(grad_pred[0][mask_g], grad_gt[0][mask_g]) + loss = loss + \ + nn.functional.l1_loss(grad_pred[1][mask_g], grad_gt[1][mask_g]) + if not return_interpolated: + return loss + return loss, intr_input + + +class OrdinalRegressionLoss(object): + + def __init__(self, ord_num, beta, discretization="SID"): + self.ord_num = ord_num + self.beta = beta + self.discretization = discretization + + def _create_ord_label(self, gt): + N,one, H, W = gt.shape + # print("gt shape:", gt.shape) + + ord_c0 = torch.ones(N, self.ord_num, H, W).to(gt.device) + if self.discretization == "SID": + label = self.ord_num * torch.log(gt) / np.log(self.beta) + else: + label = self.ord_num * (gt - 1.0) / (self.beta - 1.0) + label = label.long() + mask = torch.linspace(0, self.ord_num - 1, self.ord_num, requires_grad=False) \ + .view(1, self.ord_num, 1, 1).to(gt.device) + mask = mask.repeat(N, 1, H, W).contiguous().long() + mask = (mask > label) + ord_c0[mask] = 0 + ord_c1 = 1 - ord_c0 + # implementation according to the paper. + # ord_label = torch.ones(N, self.ord_num * 2, H, W).to(gt.device) + # ord_label[:, 0::2, :, :] = ord_c0 + # ord_label[:, 1::2, :, :] = ord_c1 + # reimplementation for fast speed. + ord_label = torch.cat((ord_c0, ord_c1), dim=1) + return ord_label, mask + + def __call__(self, prob, gt): + """ + :param prob: ordinal regression probability, N x 2*Ord Num x H x W, torch.Tensor + :param gt: depth ground truth, NXHxW, torch.Tensor + :return: loss: loss value, torch.float + """ + # N, C, H, W = prob.shape + valid_mask = gt > 0. + ord_label, mask = self._create_ord_label(gt) + # print("prob shape: {}, ord label shape: {}".format(prob.shape, ord_label.shape)) + entropy = -prob * ord_label + loss = torch.sum(entropy, dim=1)[valid_mask.squeeze(1)] + return loss.mean() + + +class DiscreteNLLLoss(nn.Module): + """Cross entropy loss""" + def __init__(self, min_depth=1e-3, max_depth=10, depth_bins=64): + super(DiscreteNLLLoss, self).__init__() + self.name = 'CrossEntropy' + self.ignore_index = -(depth_bins + 1) + # self._loss_func = nn.NLLLoss(ignore_index=self.ignore_index) + self._loss_func = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + self.min_depth = min_depth + self.max_depth = max_depth + self.depth_bins = depth_bins + self.alpha = 1 + self.zeta = 1 - min_depth + self.beta = max_depth + self.zeta + + def quantize_depth(self, depth): + # depth : N1HW + # output : NCHW + + # Quantize depth log-uniformly on [1, self.beta] into self.depth_bins bins + depth = torch.log(depth / self.alpha) / np.log(self.beta / self.alpha) + depth = depth * (self.depth_bins - 1) + depth = torch.round(depth) + depth = depth.long() + return depth + + + + def _dequantize_depth(self, depth): + """ + Inverse of quantization + depth : NCHW -> N1HW + """ + # Get the center of the bin + + + + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + # assert torch.all(input <= 0), "Input should be negative" + + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + # assert torch.all(input)<=1) + if target.ndim == 3: + target = target.unsqueeze(1) + + target = self.quantize_depth(target) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + # Set the mask to ignore_index + mask = mask.long() + input = input * mask + (1 - mask) * self.ignore_index + target = target * mask + (1 - mask) * self.ignore_index + + + + input = input.flatten(2) # N, nbins, H*W + target = target.flatten(1) # N, H*W + loss = self._loss_func(input, target) + + if not return_interpolated: + return loss + return loss, intr_input + + + + +def compute_scale_and_shift(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + # A needs to be a positive definite matrix. + valid = det > 0 + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + + return x_0, x_1 +class ScaleAndShiftInvariantLoss(nn.Module): + def __init__(self): + super().__init__() + self.name = "SSILoss" + + def forward(self, prediction, target, mask, interpolate=True, return_interpolated=False): + + if prediction.shape[-1] != target.shape[-1] and interpolate: + prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = prediction + else: + intr_input = prediction + + + prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze() + assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}." + + scale, shift = compute_scale_and_shift(prediction, target, mask) + + scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) + + loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask]) + if not return_interpolated: + return loss + return loss, intr_input + + + + +if __name__ == '__main__': + # Tests for DiscreteNLLLoss + celoss = DiscreteNLLLoss() + print(celoss(torch.rand(4, 64, 26, 32)*10, torch.rand(4, 1, 26, 32)*10, )) + + d = torch.Tensor([6.59, 3.8, 10.0]) + print(celoss.dequantize_depth(celoss.quantize_depth(d))) diff --git a/third_party/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py b/third_party/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d528ae126f1c51b2f25fd31f94a39591ceb2f43a --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py @@ -0,0 +1,143 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics + +from .base_trainer import BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.domain_classifier_loss = nn.CrossEntropyLoss() + + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + + Assumes all images in a batch are from the same dataset + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + # batch['dataset'] is a tensor strings all valued either 'nyu' or 'kitti'. labels nyu -> 0, kitti -> 1 + dataset = batch['dataset'][0] + # Convert to 0s or 1s + domain_labels = torch.Tensor([dataset == 'kitti' for _ in range( + images.size(0))]).to(torch.long).to(self.device) + + # m = self.model.module if self.config.multigpu else self.model + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + output = self.model(images) + pred_depths = output['metric_depth'] + domain_logits = output['domain_logits'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + if self.config.w_domain > 0: + l_domain = self.domain_classifier_loss( + domain_logits, domain_labels) + loss = loss + self.config.w_domain * l_domain + losses["DomainLoss"] = l_domain + else: + l_domain = torch.Tensor([0.]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and self.step > 1 and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + return losses + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(images)["metric_depth"] + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + mask = torch.logical_and( + depths_gt > self.config.min_depth, depths_gt < self.config.max_depth) + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/third_party/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py b/third_party/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac1c24c0512c1c1b191670a7c24abb4fca47ba1 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py @@ -0,0 +1,177 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics +from zoedepth.data.preprocess import get_black_border + +from .base_trainer import BaseTrainer +from torchvision import transforms +from PIL import Image +import numpy as np + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + dataset = batch['dataset'][0] + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + + output = self.model(images) + pred_depths = output['metric_depth'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + # -99 is treated as invalid depth in the log_images function and is colored grey. + depths_gt[torch.logical_not(mask)] = -99 + + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + if self.config.get("log_rel", False): + self.log_images( + scalar_field={"RelPred": output["relative_depth"][0]}, prefix="TrainRel") + + self.scaler.update() + self.optimizer.zero_grad() + + return losses + + @torch.no_grad() + def eval_infer(self, x): + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(x)['metric_depth'] + return pred_depths + + @torch.no_grad() + def crop_aware_infer(self, x): + # if we are not avoiding the black border, we can just use the normal inference + if not self.config.get("avoid_boundary", False): + return self.eval_infer(x) + + # otherwise, we need to crop the image to avoid the black border + # For now, this may be a bit slow due to converting to numpy and back + # We assume no normalization is done on the input image + + # get the black border + assert x.shape[0] == 1, "Only batch size 1 is supported for now" + x_pil = transforms.ToPILImage()(x[0].cpu()) + x_np = np.array(x_pil, dtype=np.uint8) + black_border_params = get_black_border(x_np) + top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right + x_np_cropped = x_np[top:bottom, left:right, :] + x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped)) + + # run inference on the cropped image + pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device)) + + # resize the prediction to x_np_cropped's size + pred_depths_cropped = nn.functional.interpolate( + pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False) + + + # pad the prediction back to the original size + pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype) + pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped + + return pred_depths + + + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + mask = batch["mask"].to(self.device) + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + mask = mask.squeeze().unsqueeze(0).unsqueeze(0) + if dataset == 'nyu': + pred_depths = self.crop_aware_infer(images) + else: + pred_depths = self.eval_infer(images) + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/__init__.py b/third_party/flux/annotator/zoe/zoedepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebda26a0c03f07b9ef46d3b64cbc5f3aa1eecadb Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8426b8b2e363309f73556059765db163fff28492 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d19de6ac31783618a5aa1386b3b1a57d298ad586 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/arg_utils.py b/third_party/flux/annotator/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/config.py b/third_party/flux/annotator/zoe/zoedepth/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py b/third_party/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc b/third_party/flux/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c28f57320be0e23bbad110abd6d419e4cfcdb6e0 Binary files /dev/null and b/third_party/flux/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/geometry.py b/third_party/flux/annotator/zoe/zoedepth/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e3da8c75b5a8e39b4b58a4dcd827b84d79b9115c --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/utils/geometry.py @@ -0,0 +1,98 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np + +def get_intrinsics(H,W): + """ + Intrinsics for a pinhole camera model. + Assume fov of 55 degrees and central principal point. + """ + f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0) + cx = 0.5 * W + cy = 0.5 * H + return np.array([[f, 0, cx], + [0, f, cy], + [0, 0, 1]]) + +def depth_to_points(depth, R=None, t=None): + + K = get_intrinsics(depth.shape[1], depth.shape[2]) + Kinv = np.linalg.inv(K) + if R is None: + R = np.eye(3) + if t is None: + t = np.zeros(3) + + # M converts from your coordinate to PyTorch3D's coordinate system + M = np.eye(3) + M[0, 0] = -1.0 + M[1, 1] = -1.0 + + height, width = depth.shape[1:3] + + x = np.arange(width) + y = np.arange(height) + coord = np.stack(np.meshgrid(x, y), -1) + coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1 + coord = coord.astype(np.float32) + # coord = torch.as_tensor(coord, dtype=torch.float32, device=device) + coord = coord[None] # bs, h, w, 3 + + D = depth[:, :, :, None, None] + # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape ) + pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None] + # pts3D_1 live in your coordinate system. Convert them to Py3D's + pts3D_1 = M[None, None, None, ...] @ pts3D_1 + # from reference to targe tviewpoint + pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None] + # pts3D_2 = pts3D_1 + # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w + return pts3D_2[:, :, :, :3, 0][0] + + +def create_triangles(h, w, mask=None): + """ + Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68 + Creates mesh triangle indices from a given pixel grid size. + This function is not and need not be differentiable as triangle indices are + fixed. + Args: + h: (int) denoting the height of the image. + w: (int) denoting the width of the image. + Returns: + triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3) + """ + x, y = np.meshgrid(range(w - 1), range(h - 1)) + tl = y * w + x + tr = y * w + x + 1 + bl = (y + 1) * w + x + br = (y + 1) * w + x + 1 + triangles = np.array([tl, bl, tr, br, tr, bl]) + triangles = np.transpose(triangles, (1, 2, 0)).reshape( + ((w - 1) * (h - 1) * 2, 3)) + if mask is not None: + mask = mask.reshape(-1) + triangles = triangles[mask[triangles].all(1)] + return triangles diff --git a/third_party/flux/annotator/zoe/zoedepth/utils/misc.py b/third_party/flux/annotator/zoe/zoedepth/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbe403d3669829eecdf658458c76aa5e87e2b33 --- /dev/null +++ b/third_party/flux/annotator/zoe/zoedepth/utils/misc.py @@ -0,0 +1,368 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +"""Miscellaneous utility functions.""" + +from scipy import ndimage + +import base64 +import math +import re +from io import BytesIO + +import matplotlib +import matplotlib.cm +import numpy as np +import requests +import torch +import torch.distributed as dist +import torch.nn +import torch.nn as nn +import torch.utils.data.distributed +from PIL import Image +from torchvision.transforms import ToTensor + + +class RunningAverage: + def __init__(self): + self.avg = 0 + self.count = 0 + + def append(self, value): + self.avg = (value + self.count * self.avg) / (self.count + 1) + self.count += 1 + + def get_value(self): + return self.avg + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + + +class RunningAverageDict: + """A dictionary of running averages.""" + def __init__(self): + self._dict = None + + def update(self, new_dict): + if new_dict is None: + return + + if self._dict is None: + self._dict = dict() + for key, value in new_dict.items(): + self._dict[key] = RunningAverage() + + for key, value in new_dict.items(): + self._dict[key].append(value) + + def get_value(self): + if self._dict is None: + return None + return {key: value.get_value() for key, value in self._dict.items()} + + +def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None): + """Converts a depth map to a color image. + + Args: + value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed + vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. + vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. + cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. + invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. + invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. + background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). + gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. + value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. + + Returns: + numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) + """ + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + + value = value.squeeze() + if invalid_mask is None: + invalid_mask = value == invalid_val + mask = np.logical_not(invalid_mask) + + # normalize + vmin = np.percentile(value[mask],2) if vmin is None else vmin + vmax = np.percentile(value[mask],85) if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + + # squeeze last dim if it exists + # grey out the invalid values + + value[invalid_mask] = np.nan + cmapper = matplotlib.cm.get_cmap(cmap) + if value_transform: + value = value_transform(value) + # value = value / value.max() + value = cmapper(value, bytes=True) # (nxmx4) + + # img = value[:, :, :] + img = value[...] + img[invalid_mask] = background_color + + # return img.transpose((2, 0, 1)) + if gamma_corrected: + # gamma correction + img = img / 255 + img = np.power(img, 2.2) + img = img * 255 + img = img.astype(np.uint8) + return img + + +def count_parameters(model, include_all=False): + return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all) + + +def compute_errors(gt, pred): + """Compute metrics for 'pred' compared to 'gt' + + Args: + gt (numpy.ndarray): Ground truth values + pred (numpy.ndarray): Predicted values + + gt.shape should be equal to pred.shape + + Returns: + dict: Dictionary containing the following metrics: + 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25 + 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2 + 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3 + 'abs_rel': Absolute relative error + 'rmse': Root mean squared error + 'log_10': Absolute log10 error + 'sq_rel': Squared relative error + 'rmse_log': Root mean squared error on the log scale + 'silog': Scale invariant log error + """ + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + abs_rel = np.mean(np.abs(gt - pred) / gt) + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + err = np.log(pred) - np.log(gt) + silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 + + log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() + return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, + silog=silog, sq_rel=sq_rel) + + +def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs): + """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics. + """ + if 'config' in kwargs: + config = kwargs['config'] + garg_crop = config.garg_crop + eigen_crop = config.eigen_crop + min_depth_eval = config.min_depth_eval + max_depth_eval = config.max_depth_eval + + if gt.shape[-2:] != pred.shape[-2:] and interpolate: + pred = nn.functional.interpolate( + pred, gt.shape[-2:], mode='bilinear', align_corners=True) + + pred = pred.squeeze().cpu().numpy() + pred[pred < min_depth_eval] = min_depth_eval + pred[pred > max_depth_eval] = max_depth_eval + pred[np.isinf(pred)] = max_depth_eval + pred[np.isnan(pred)] = min_depth_eval + + gt_depth = gt.squeeze().cpu().numpy() + valid_mask = np.logical_and( + gt_depth > min_depth_eval, gt_depth < max_depth_eval) + + if garg_crop or eigen_crop: + gt_height, gt_width = gt_depth.shape + eval_mask = np.zeros(valid_mask.shape) + + if garg_crop: + eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), + int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 + + elif eigen_crop: + # print("-"*10, " EIGEN CROP ", "-"*10) + if dataset == 'kitti': + eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), + int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 + else: + # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images" + eval_mask[45:471, 41:601] = 1 + else: + eval_mask = np.ones(valid_mask.shape) + valid_mask = np.logical_and(valid_mask, eval_mask) + return compute_errors(gt_depth[valid_mask], pred[valid_mask]) + + +#################################### Model uilts ################################################ + + +def parallelize(config, model, find_unused_parameters=True): + + if config.gpu is not None: + torch.cuda.set_device(config.gpu) + model = model.cuda(config.gpu) + + config.multigpu = False + if config.distributed: + # Use DDP + config.multigpu = True + config.rank = config.rank * config.ngpus_per_node + config.gpu + dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, + world_size=config.world_size, rank=config.rank) + config.batch_size = int(config.batch_size / config.ngpus_per_node) + # config.batch_size = 8 + config.workers = int( + (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node) + print("Device", config.gpu, "Rank", config.rank, "batch size", + config.batch_size, "Workers", config.workers) + torch.cuda.set_device(config.gpu) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = model.cuda(config.gpu) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu, + find_unused_parameters=find_unused_parameters) + + elif config.gpu is None: + # Use DP + config.multigpu = True + model = model.cuda() + model = torch.nn.DataParallel(model) + + return model + + +################################################################################################# + + +##################################################################################################### + + +class colors: + '''Colors class: + Reset all colors with colors.reset + Two subclasses fg for foreground and bg for background. + Use as colors.subclass.colorname. + i.e. colors.fg.red or colors.bg.green + Also, the generic bold, disable, underline, reverse, strikethrough, + and invisible work with the main class + i.e. colors.bold + ''' + reset = '\033[0m' + bold = '\033[01m' + disable = '\033[02m' + underline = '\033[04m' + reverse = '\033[07m' + strikethrough = '\033[09m' + invisible = '\033[08m' + + class fg: + black = '\033[30m' + red = '\033[31m' + green = '\033[32m' + orange = '\033[33m' + blue = '\033[34m' + purple = '\033[35m' + cyan = '\033[36m' + lightgrey = '\033[37m' + darkgrey = '\033[90m' + lightred = '\033[91m' + lightgreen = '\033[92m' + yellow = '\033[93m' + lightblue = '\033[94m' + pink = '\033[95m' + lightcyan = '\033[96m' + + class bg: + black = '\033[40m' + red = '\033[41m' + green = '\033[42m' + orange = '\033[43m' + blue = '\033[44m' + purple = '\033[45m' + cyan = '\033[46m' + lightgrey = '\033[47m' + + +def printc(text, color): + print(f"{color}{text}{colors.reset}") + +############################################ + +def get_image_from_url(url): + response = requests.get(url) + img = Image.open(BytesIO(response.content)).convert("RGB") + return img + +def url_to_torch(url, size=(384, 384)): + img = get_image_from_url(url) + img = img.resize(size, Image.ANTIALIAS) + img = torch.from_numpy(np.asarray(img)).float() + img = img.permute(2, 0, 1) + img.div_(255) + return img + +def pil_to_batched_tensor(img): + return ToTensor()(img).unsqueeze(0) + +def save_raw_16bit(depth, fpath="raw.png"): + if isinstance(depth, torch.Tensor): + depth = depth.squeeze().cpu().numpy() + + assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array" + assert depth.ndim == 2, "Depth must be 2D" + depth = depth * 256 # scale for 16-bit png + depth = depth.astype(np.uint16) + depth = Image.fromarray(depth) + depth.save(fpath) + print("Saved raw depth to", fpath) \ No newline at end of file diff --git a/third_party/flux/api.py b/third_party/flux/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b08202adb35d2ffae320bb9b47f567e538837836 --- /dev/null +++ b/third_party/flux/api.py @@ -0,0 +1,194 @@ +import io +import os +import time +from pathlib import Path + +import requests +from PIL import Image + +API_ENDPOINT = "https://api.bfl.ml" + + +class ApiException(Exception): + def __init__(self, status_code: int, detail: str | list[dict] | None = None): + super().__init__() + self.detail = detail + self.status_code = status_code + + def __str__(self) -> str: + return self.__repr__() + + def __repr__(self) -> str: + if self.detail is None: + message = None + elif isinstance(self.detail, str): + message = self.detail + else: + message = "[" + ",".join(d["msg"] for d in self.detail) + "]" + return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" + + +class ImageRequest: + def __init__( + self, + prompt: str, + width: int = 1024, + height: int = 1024, + name: str = "flux.1-pro", + num_steps: int = 50, + prompt_upsampling: bool = False, + seed: int | None = None, + validate: bool = True, + launch: bool = True, + api_key: str | None = None, + ): + """ + Manages an image generation request to the API. + + Args: + prompt: Prompt to sample + width: Width of the image in pixel + height: Height of the image in pixel + name: Name of the model + num_steps: Number of network evaluations + prompt_upsampling: Use prompt upsampling + seed: Fix the generation seed + validate: Run input validation + launch: Directly launches request + api_key: Your API key if not provided by the environment + + Raises: + ValueError: For invalid input + ApiException: For errors raised from the API + """ + if validate: + if name not in ["flux.1-pro"]: + raise ValueError(f"Invalid model {name}") + elif width % 32 != 0: + raise ValueError(f"width must be divisible by 32, got {width}") + elif not (256 <= width <= 1440): + raise ValueError(f"width must be between 256 and 1440, got {width}") + elif height % 32 != 0: + raise ValueError(f"height must be divisible by 32, got {height}") + elif not (256 <= height <= 1440): + raise ValueError(f"height must be between 256 and 1440, got {height}") + elif not (1 <= num_steps <= 50): + raise ValueError(f"steps must be between 1 and 50, got {num_steps}") + + self.request_json = { + "prompt": prompt, + "width": width, + "height": height, + "variant": name, + "steps": num_steps, + "prompt_upsampling": prompt_upsampling, + } + if seed is not None: + self.request_json["seed"] = seed + + self.request_id: str | None = None + self.result: dict | None = None + self._image_bytes: bytes | None = None + self._url: str | None = None + if api_key is None: + self.api_key = os.environ.get("BFL_API_KEY") + else: + self.api_key = api_key + + if launch: + self.request() + + def request(self): + """ + Request to generate the image. + """ + if self.request_id is not None: + return + response = requests.post( + f"{API_ENDPOINT}/v1/image", + headers={ + "accept": "application/json", + "x-key": self.api_key, + "Content-Type": "application/json", + }, + json=self.request_json, + ) + result = response.json() + if response.status_code != 200: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + self.request_id = response.json()["id"] + + def retrieve(self) -> dict: + """ + Wait for the generation to finish and retrieve response. + """ + if self.request_id is None: + self.request() + while self.result is None: + response = requests.get( + f"{API_ENDPOINT}/v1/get_result", + headers={ + "accept": "application/json", + "x-key": self.api_key, + }, + params={ + "id": self.request_id, + }, + ) + result = response.json() + if "status" not in result: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + elif result["status"] == "Ready": + self.result = result["result"] + elif result["status"] == "Pending": + time.sleep(0.5) + else: + raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") + return self.result + + @property + def bytes(self) -> bytes: + """ + Generated image as bytes. + """ + if self._image_bytes is None: + response = requests.get(self.url) + if response.status_code == 200: + self._image_bytes = response.content + else: + raise ApiException(status_code=response.status_code) + return self._image_bytes + + @property + def url(self) -> str: + """ + Public url to retrieve the image from + """ + if self._url is None: + result = self.retrieve() + self._url = result["sample"] + return self._url + + @property + def image(self) -> Image.Image: + """ + Load the image as a PIL Image + """ + return Image.open(io.BytesIO(self.bytes)) + + def save(self, path: str): + """ + Save the generated image to a local path + """ + suffix = Path(self.url).suffix + if not path.endswith(suffix): + path = path + suffix + Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as file: + file.write(self.bytes) + + +if __name__ == "__main__": + from fire import Fire + + Fire(ImageRequest) diff --git a/third_party/flux/cli.py b/third_party/flux/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..f3624bc6c387f359162e68f46995b12ce341970a --- /dev/null +++ b/third_party/flux/cli.py @@ -0,0 +1,254 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from einops import rearrange +from fire import Fire +from PIL import ExifTags, Image + +from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from flux.util import (configs, embed_watermark, load_ae, load_clip, + load_flow_model, load_t5) +from transformers import pipeline + +NSFW_THRESHOLD = 0.85 + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting seed to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +@torch.inference_mode() +def main( + name: str = "flux-schnell", + width: int = 1360, + height: int = 768, + seed: int | None = None, + prompt: str = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ), + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + guidance: float = 3.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the model to load + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 50 + + # allow for packing and conversion to latent space + height = 16 * (height // 16) + width = 16 * (width // 16) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if loop: + opts = parse_prompt(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + inp = prepare(t5, clip, x, prompt=opts.prompt) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + t1 = time.perf_counter() + + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s. Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + + if nsfw_score < NSFW_THRESHOLD: + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(fn, exif=exif_data, quality=95, subsampling=0) + idx += 1 + else: + print("Your generated image may contain NSFW content.") + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + else: + opts = None + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() diff --git a/third_party/flux/controlnet.py b/third_party/flux/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a04cc0234b2b726a550cbe62d027943f6bbcbb --- /dev/null +++ b/third_party/flux/controlnet.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return controlnet_block_res_samples diff --git a/third_party/flux/math.py b/third_party/flux/math.py new file mode 100644 index 0000000000000000000000000000000000000000..0156bb6a205dec340e029f0c87cf70ae8709ae12 --- /dev/null +++ b/third_party/flux/math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/third_party/flux/model.py b/third_party/flux/model.py new file mode 100644 index 0000000000000000000000000000000000000000..51531c114babcea3b7a365ca44ee458bfce9a673 --- /dev/null +++ b/third_party/flux/model.py @@ -0,0 +1,228 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + image_proj: Tensor | None = None, + ip_scale: Tensor | float = 1.0, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + image_proj, + ip_scale, + ) + else: + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + image_proj=image_proj, + ip_scale=ip_scale, + ) + # controlnet residual + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[index_block % 2] + + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + vec, + pe, + ) + else: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/third_party/flux/modules/__pycache__/autoencoder.cpython-310.pyc b/third_party/flux/modules/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b5f72492b788deb8d6f0c35a507dab6cde23319 Binary files /dev/null and b/third_party/flux/modules/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/third_party/flux/modules/__pycache__/conditioner.cpython-310.pyc b/third_party/flux/modules/__pycache__/conditioner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f652ccc04ff0ff4044e20ac39d75c64a9df50a51 Binary files /dev/null and b/third_party/flux/modules/__pycache__/conditioner.cpython-310.pyc differ diff --git a/third_party/flux/modules/__pycache__/layers.cpython-310.pyc b/third_party/flux/modules/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d26945b76a26719da75d50ad79c13db1608e8406 Binary files /dev/null and b/third_party/flux/modules/__pycache__/layers.cpython-310.pyc differ diff --git a/third_party/flux/modules/autoencoder.py b/third_party/flux/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..75159f711f65f064107a1a1b9be6f09fc9872028 --- /dev/null +++ b/third_party/flux/modules/autoencoder.py @@ -0,0 +1,312 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/third_party/flux/modules/conditioner.py b/third_party/flux/modules/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdd881878ace848745da7d723c60f03392916ab --- /dev/null +++ b/third_party/flux/modules/conditioner.py @@ -0,0 +1,38 @@ +from torch import Tensor, nn +from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, + T5Tokenizer) + + +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/third_party/flux/modules/layers.py b/third_party/flux/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c5489698671c6ed32dcb790a2f83d682d898b872 --- /dev/null +++ b/third_party/flux/modules/layers.py @@ -0,0 +1,595 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from ..math import attention, rope +import torch.nn.functional as F + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + +class FLuxSelfAttnProcessor: + def __call__(self, attn, x, pe, **attention_kwargs): + print('2' * 30) + + qkv = attn.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + return x + +class LoraFluxAttnProcessor(nn.Module): + + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + + def __call__(self, attn, x, pe, **attention_kwargs): + qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + self.proj_lora(x) * self.lora_weight + print('1' * 30) + print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm') + return x + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + def forward(): + pass + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + +class DoubleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class IPDoubleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with double stream block.""" + + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) + self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) + + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) + + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) + + def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs): + + # Prepare image for attention + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:] + + # print(f"txt_attn shape: {txt_attn.size()}") + # print(f"img_attn shape: {img_attn.size()}") + + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + + + # IP-adapter processing + ip_query = img_q # latent sample query + ip_key = self.ip_adapter_double_stream_k_proj(image_proj) + ip_value = self.ip_adapter_double_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value, + dropout_p=0.0, + is_causal=False + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim) + + img = img + ip_scale * ip_attention + + return img, txt + +class DoubleStreamBlockProcessor: + def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_heads + + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + processor = DoubleStreamBlockProcessor() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor = None, + ip_scale: float =1.0, + ) -> tuple[Tensor, Tensor]: + if image_proj is None: + return self.processor(self, img, txt, vec, pe) + else: + return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) + +class IPSingleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with single stream block.""" + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) + self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) + + nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight) + + def __call__( + self, + attn: nn.Module, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # IP-adapter processing + ip_query = q + ip_key = self.ip_adapter_single_stream_k_proj(image_proj) + ip_value = self.ip_adapter_single_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)") + + attn_out = attn_1 + ip_scale * ip_attention + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2)) + out = x + mod.gate * output + + return out + + +class SingleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight + output = x + mod.gate * output + return output + + +class SingleStreamBlockProcessor: + def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = x + mod.gate * output + return output + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(self.head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + processor = SingleStreamBlockProcessor() + self.set_processor(processor) + + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + if image_proj is None: + return self.processor(self, x, vec, pe) + else: + return self.processor(self, x, vec, pe, image_proj, ip_scale) + + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + +class ImageProjModel(torch.nn.Module): + """Projection Model + https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 + """ + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + diff --git a/third_party/flux/sampling.py b/third_party/flux/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a97f7d8e5a2586f8c9849f0d6470af9c201041 --- /dev/null +++ b/third_party/flux/sampling.py @@ -0,0 +1,242 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from .model import Flux +from .modules.conditioner import HFEmbedder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1.0, + neg_ip_scale: Tensor | float = 1.0 +): + i = 0 + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=image_proj, + ip_scale=ip_scale, + ) + if i >= timestep_to_start_cfg: + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + img = img + (t_prev - t_curr) * pred + i += 1 + return img + +def denoise_controlnet( + model: Flux, + controlnet:None, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + controlnet_cond, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + controlnet_gs=0.7, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1, + neg_ip_scale: Tensor | float = 1, +): + # this is ignored for schnell + i = 0 + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples], + image_proj=image_proj, + ip_scale=ip_scale, + ) + if i >= timestep_to_start_cfg: + neg_block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples], + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + + img = img + (t_prev - t_curr) * pred + + i += 1 + return img + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/third_party/flux/util.py b/third_party/flux/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd496213f0acc9cd05ddc621d504dd9e87373e9 --- /dev/null +++ b/third_party/flux/util.py @@ -0,0 +1,433 @@ +import os +from dataclasses import dataclass + +import torch +import json +import cv2 +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file as load_sft + +from optimum.quanto import requantize + +from .model import Flux, FluxParams +from .controlnet import ControlNetFlux +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder +from .annotator.dwpose import DWposeDetector +from .annotator.mlsd import MLSDdetector +from .annotator.canny import CannyDetector +from .annotator.midas import MidasDetector +from .annotator.hed import HEDdetector +from .annotator.tile import TileDetector +from .annotator.zoe import ZoeDetector + + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + +def get_lora_rank(checkpoint): + for k in checkpoint.keys(): + if k.endswith(".down.weight"): + return checkpoint[k].shape[0] + +def load_checkpoint(local_path, repo_id, name): + if local_path is not None: + if '.safetensors' in local_path: + print(f"Loading .safetensors checkpoint from {local_path}") + checkpoint = load_safetensors(local_path) + else: + print(f"Loading checkpoint from {local_path}") + checkpoint = torch.load(local_path, map_location='cpu') + elif repo_id is not None and name is not None: + print(f"Loading checkpoint {name} from repo id {repo_id}") + checkpoint = load_from_repo_id(repo_id, name) + else: + raise ValueError( + "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" + ) + return checkpoint + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + +#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 +#Added upscale_method, mode params +def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode='edge'): + if skip_hwc3: + img = input_image + else: + img = HWC3(input_image) + H_raw, W_raw, _ = img.shape + if resolution == 0: + return img, lambda x: x + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(np.round(float(H_raw) * k)) + W_target = int(np.round(float(W_raw) * k)) + img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA) + H_pad, W_pad = pad64(H_target), pad64(W_target) + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) + + def remove_pad(x): + return safer_memory(x[:H_target, :W_target, ...]) + + return safer_memory(img_padded), remove_pad + +class Annotator: + def __init__(self, name: str, device: str): + if name == "canny": + processor = CannyDetector() + elif name == "openpose": + processor = DWposeDetector(device) + elif name == "depth": + processor = MidasDetector() + elif name == "hed": + processor = HEDdetector() + elif name == "hough": + processor = MLSDdetector() + elif name == "tile": + processor = TileDetector() + elif name == "zoe": + processor = ZoeDetector() + self.name = name + self.processor = processor + + def __call__(self, image: Image, width: int, height: int): + image = np.array(image) + detect_resolution = max(width, height) + image, remove_pad = resize_image_with_pad(image, detect_resolution) + + image = np.array(image) + if self.name == "canny": + result = self.processor(image, low_threshold=100, high_threshold=200) + elif self.name == "hough": + result = self.processor(image, thr_v=0.05, thr_d=5) + elif self.name == "depth": + result = self.processor(image) + result, _ = result + else: + result = self.processor(image) + + result = HWC3(remove_pad(result)) + result = cv2.resize(result, (width, height)) + return result + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + repo_id_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fp8": ModelSpec( + repo_id="XLabs-AI/flux-dev-fp8", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux-dev-fp8.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FP8"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + +def load_from_repo_id(repo_id, checkpoint_name): + ckpt_path = hf_hub_download(repo_id, checkpoint_name) + sd = load_sft(ckpt_path, device='cpu') + return sd + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') + + + model = Flux(configs[name].params).to(torch.bfloat16) + + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device='cpu') + with open(json_path, "r") as f: + quantization_map = json.load(f) + print("Start a quantization process...") + requantize(model, sd, quantization_map, device=device) + print("Model is quantized!") + return model + +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = ControlNetFlux(configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was choosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] diff --git a/third_party/flux/xflux_pipeline.py b/third_party/flux/xflux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..50707544b3531a6611a3d186445f3532cc23dace --- /dev/null +++ b/third_party/flux/xflux_pipeline.py @@ -0,0 +1,364 @@ +from PIL import Image, ExifTags +import numpy as np +import torch +from torch import Tensor + +from einops import rearrange +import uuid +import os + +from .modules.layers import ( + SingleStreamBlockProcessor, + DoubleStreamBlockProcessor, + SingleStreamBlockLoraProcessor, + DoubleStreamBlockLoraProcessor, + IPDoubleStreamBlockProcessor, + ImageProjModel, +) +from .sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack +from .util import ( + load_ae, + load_clip, + load_flow_model, + load_t5, + load_controlnet, + load_flow_model_quintized, + Annotator, + get_lora_rank, + load_checkpoint +) + +from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor + +class XFluxPipeline: + def __init__(self, model_type, device, offload: bool = False): + self.device = torch.device(device) + self.offload = offload + self.model_type = model_type + + self.clip = load_clip(self.device) + self.t5 = load_t5(self.device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device) + if "fp8" in model_type: + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + else: + self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + + self.image_encoder_path = "openai/clip-vit-large-patch14" + self.hf_lora_collection = "XLabs-AI/flux-lora-collection" + self.lora_types_to_names = { + "realism": "lora.safetensors", + } + self.controlnet_loaded = False + self.ip_loaded = False + + def set_ip(self, local_path: str = None, repo_id = None, name: str = None, token: int = 4): + self.model.to(self.device) + + # unpack checkpoint + checkpoint = load_checkpoint(local_path, repo_id, name) + prefix = "double_blocks." + blocks = {} + proj = {} + + for key, value in checkpoint.items(): + if key.startswith(prefix): + blocks[key[len(prefix):].replace('.processor.', '.')] = value + if key.startswith("ip_adapter_proj_model"): + proj[key[len("ip_adapter_proj_model."):]] = value + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + + # setup image embedding projection model + self.improj = ImageProjModel(4096, 768, token) + self.improj.load_state_dict(proj) + self.improj = self.improj.to(self.device, dtype=torch.bfloat16) + + ip_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + ip_state_dict = {} + for k in checkpoint.keys(): + if name in k: + ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] + if ip_state_dict: + ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) + ip_attn_procs[name].load_state_dict(ip_state_dict) + ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) + else: + ip_attn_procs[name] = self.model.attn_processors[name] + + self.model.set_attn_processor(ip_attn_procs) + self.ip_loaded = True + + def set_lora(self, local_path: str = None, repo_id: str = None, + name: str = None, lora_weight: int = 0.7): + checkpoint = load_checkpoint(local_path, repo_id, name) + self.update_model_with_lora(checkpoint, lora_weight) + + def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): + checkpoint = load_checkpoint( + None, self.hf_lora_collection, self.lora_types_to_names[lora_type] + ) + self.update_model_with_lora(checkpoint, lora_weight) + + def update_model_with_lora(self, checkpoint, lora_weight): + rank = get_lora_rank(checkpoint) + lora_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight + + if len(lora_state_dict): + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) + else: + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(self.device) + else: + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockProcessor() + else: + lora_attn_procs[name] = DoubleStreamBlockProcessor() + + self.model.set_attn_processor(lora_attn_procs) + + def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): + self.model.to(self.device) + self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + + checkpoint = load_checkpoint(local_path, repo_id, name) + self.controlnet.load_state_dict(checkpoint, strict=False) + self.annotator = Annotator(control_type, self.device) + self.controlnet_loaded = True + self.control_type = control_type + + def get_image_proj( + self, + image_prompt: Tensor, + ): + # encode image-prompt embeds + image_prompt = self.clip_image_processor( + images=image_prompt, + return_tensors="pt" + ).pixel_values + image_prompt = image_prompt.to(self.image_encoder.device) + image_prompt_embeds = self.image_encoder( + image_prompt + ).image_embeds.to( + device=self.device, dtype=torch.bfloat16, + ) + # encode image + image_proj = self.improj(image_prompt_embeds) + return image_proj + + def __call__(self, + prompt: str, + image_prompt: Image = None, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + true_gs: float = 3, + control_weight: float = 0.9, + ip_scale: float = 1.0, + neg_ip_scale: float = 1.0, + neg_prompt: str = '', + neg_image_prompt: Image = None, + timestep_to_start_cfg: int = 0, + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + image_proj = None + neg_image_proj = None + if not (image_prompt is None and neg_image_prompt is None) : + assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' + + if image_prompt is None: + image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + if neg_image_prompt is None: + neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + + image_proj = self.get_image_proj(image_prompt) + neg_image_proj = self.get_image_proj(neg_image_prompt) + + if self.controlnet_loaded: + controlnet_image = self.annotator(controlnet_image, width, height) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute( + 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward( + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + true_gs=true_gs, + control_weight=control_weight, + neg_prompt=neg_prompt, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + @torch.inference_mode() + def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, + lora_weight, local_path, lora_local_path, ip_local_path): + if controlnet_image is not None: + controlnet_image = Image.fromarray(controlnet_image) + if ((self.controlnet_loaded and control_type != self.control_type) + or not self.controlnet_loaded): + if local_path is not None: + self.set_controlnet(control_type, local_path=local_path) + else: + self.set_controlnet(control_type, local_path=None, + repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", + name=f"flux-{control_type}-controlnet-v3.safetensors") + if lora_local_path is not None: + self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) + if image_prompt is not None: + image_prompt = Image.fromarray(image_prompt) + if neg_image_prompt is not None: + neg_image_prompt = Image.fromarray(neg_image_prompt) + if not self.ip_loaded: + if ip_local_path is not None: + self.set_ip(local_path=ip_local_path) + else: + self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", + name="flux-ip-adapter.safetensors") + seed = int(seed) + if seed == -1: + seed = torch.Generator(device="cpu").seed() + + img = self(prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg) + + filename = f"output/gradio/{uuid.uuid4()}.jpg" + os.makedirs(os.path.dirname(filename), exist_ok=True) + exif_data = Image.Exif() + exif_data[ExifTags.Base.Make] = "XLabs AI" + exif_data[ExifTags.Base.Model] = self.model_type + img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) + return img, filename + + def forward( + self, + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image = None, + timestep_to_start_cfg = 0, + true_gs = 3.5, + control_weight = 0.9, + neg_prompt="", + image_proj=None, + neg_image_proj=None, + ip_scale=1.0, + neg_ip_scale=1.0, + ): + x = get_noise( + 1, height, width, device=self.device, + dtype=torch.bfloat16, seed=seed + ) + timesteps = get_schedule( + num_steps, + (width // 8) * (height // 8) // (16 * 16), + shift=True, + ) + torch.manual_seed(seed) + with torch.no_grad(): + if self.offload: + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) + neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) + + if self.offload: + self.offload_model_to_cpu(self.t5, self.clip) + self.model = self.model.to(self.device) + if self.controlnet_loaded: + x = denoise_controlnet( + self.model, + **inp_cond, + controlnet=self.controlnet, + timesteps=timesteps, + guidance=guidance, + controlnet_cond=controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + controlnet_gs=control_weight, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + else: + x = denoise( + self.model, + **inp_cond, + timesteps=timesteps, + guidance=guidance, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + if self.offload: + self.offload_model_to_cpu(self.model) + self.ae.decoder.to(x.device) + x = unpack(x.float(), height, width) + x = self.ae.decode(x) + self.offload_model_to_cpu(self.ae.decoder) + + x1 = x.clamp(-1, 1) + x1 = rearrange(x1[-1], "c h w -> h w c") + output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) + return output_img + + def offload_model_to_cpu(self, *models): + if not self.offload: return + for model in models: + model.cpu() + torch.cuda.empty_cache() + + +class XFluxSampler(XFluxPipeline): + def __init__(self, clip, t5, ae, model, device): + self.clip = clip + self.t5 = t5 + self.ae = ae + self.model = model + self.model.eval() + self.device = device + self.controlnet_loaded = False + self.ip_loaded = False + self.offload = False diff --git a/third_party/ip_adapter/__init__.py b/third_party/ip_adapter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5657c48fb9f6808b1fd1fce4de2cbd8ad91a23cd --- /dev/null +++ b/third_party/ip_adapter/__init__.py @@ -0,0 +1,12 @@ +from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull, IPAdapterPlusXLVLM +from .ip_adapter_fp16 import IPAdapterPlusXL as IPAdapterPlusXL_fp16 + +__all__ = [ + "IPAdapter", + "IPAdapterPlus", + "IPAdapterPlusXL", + "IPAdapterXL", + "IPAdapterFull", + "IPAdapterPlusXLVLM", + "IPAdapterPlusXL_fp16", +] diff --git a/third_party/ip_adapter/__pycache__/__init__.cpython-310.pyc b/third_party/ip_adapter/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c39394c7f3f6af5e3510c863f52ea738fdd88d50 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/ip_adapter/__pycache__/__init__.cpython-38.pyc b/third_party/ip_adapter/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fe06db2e41a571e5260051dda0adff874af2055 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/__init__.cpython-38.pyc differ diff --git a/third_party/ip_adapter/__pycache__/attention_processor.cpython-310.pyc b/third_party/ip_adapter/__pycache__/attention_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9039489e7dc3260fd78872de218be9245c0d6862 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/attention_processor.cpython-310.pyc differ diff --git a/third_party/ip_adapter/__pycache__/attention_processor.cpython-38.pyc b/third_party/ip_adapter/__pycache__/attention_processor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b057b32babc4f2aa4017a063bf82b81e9e50af2 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/attention_processor.cpython-38.pyc differ diff --git a/third_party/ip_adapter/__pycache__/ip_adapter.cpython-310.pyc b/third_party/ip_adapter/__pycache__/ip_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..434a737f782955fc2f196d8886038ef266bcb41d Binary files /dev/null and b/third_party/ip_adapter/__pycache__/ip_adapter.cpython-310.pyc differ diff --git a/third_party/ip_adapter/__pycache__/ip_adapter.cpython-38.pyc b/third_party/ip_adapter/__pycache__/ip_adapter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7904f8c62bd2a11ca8f64b652a4efb120ad14adc Binary files /dev/null and b/third_party/ip_adapter/__pycache__/ip_adapter.cpython-38.pyc differ diff --git a/third_party/ip_adapter/__pycache__/ip_adapter_fp16.cpython-310.pyc b/third_party/ip_adapter/__pycache__/ip_adapter_fp16.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3ee04f027d52c13b94a99e17168b9f6092dc4a0 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/ip_adapter_fp16.cpython-310.pyc differ diff --git a/third_party/ip_adapter/__pycache__/resampler.cpython-310.pyc b/third_party/ip_adapter/__pycache__/resampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c61d43ee70f2b4f557c80f801e39e129c4bfa1f Binary files /dev/null and b/third_party/ip_adapter/__pycache__/resampler.cpython-310.pyc differ diff --git a/third_party/ip_adapter/__pycache__/resampler.cpython-38.pyc b/third_party/ip_adapter/__pycache__/resampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cd039f5a933b9658ef1872f108f635b1d2fe256 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/resampler.cpython-38.pyc differ diff --git a/third_party/ip_adapter/__pycache__/utils.cpython-310.pyc b/third_party/ip_adapter/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..042adda1a3ac7a0fe849c4ee81af398877c13294 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/utils.cpython-310.pyc differ diff --git a/third_party/ip_adapter/__pycache__/utils.cpython-38.pyc b/third_party/ip_adapter/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..979e21962bdccf0caa1b5cd0bf465234bb5fe0b9 Binary files /dev/null and b/third_party/ip_adapter/__pycache__/utils.cpython-38.pyc differ diff --git a/third_party/ip_adapter/attention_processor.py b/third_party/ip_adapter/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..93592bb1a7e7b3329fc8a400c51920dad519a42f --- /dev/null +++ b/third_party/ip_adapter/attention_processor.py @@ -0,0 +1,568 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + #print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +## for controlnet +class CNAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __init__(self, num_tokens=4): + self.num_tokens = num_tokens + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CNAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, num_tokens=4): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.num_tokens = num_tokens + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/third_party/ip_adapter/attention_processor_faceid.py b/third_party/ip_adapter/attention_processor_faceid.py new file mode 100644 index 0000000000000000000000000000000000000000..ca1546b24dfb2647f9fe529d31293d85a5c58c43 --- /dev/null +++ b/third_party/ip_adapter/attention_processor_faceid.py @@ -0,0 +1,433 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models.lora import LoRALinearLayer + + +class LoRAAttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAAttnProcessor2_0(nn.Module): + + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + rank=4, + network_alpha=None, + lora_scale=1.0, + ): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAIPAttnProcessor2_0(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, lora_scale=1.0, scale=1.0, num_tokens=4): + super().__init__() + + self.rank = rank + self.lora_scale = lora_scale + self.num_tokens = num_tokens + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None, *args, **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) + #query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + # for text + key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/third_party/ip_adapter/check_default_atten.py b/third_party/ip_adapter/check_default_atten.py new file mode 100644 index 0000000000000000000000000000000000000000..19ebaad2d47fdb05751b60c42d56000ea574b35a --- /dev/null +++ b/third_party/ip_adapter/check_default_atten.py @@ -0,0 +1,73 @@ +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/third_party/ip_adapter/check_default_attenip.py b/third_party/ip_adapter/check_default_attenip.py new file mode 100644 index 0000000000000000000000000000000000000000..024490ea013541e1e0b87cd97e0db17f09562d87 --- /dev/null +++ b/third_party/ip_adapter/check_default_attenip.py @@ -0,0 +1,107 @@ +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + *args, + **kwargs, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/third_party/ip_adapter/custom_pipelines.py b/third_party/ip_adapter/custom_pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..7d43d2c34db9b83f6148fac53425a9fd4c60fc93 --- /dev/null +++ b/third_party/ip_adapter/custom_pipelines.py @@ -0,0 +1,394 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from diffusers import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg + +from .utils import is_torch2_available + +if is_torch2_available(): + from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor +else: + from .attention_processor import IPAttnProcessor + + +class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline): + def set_scale(self, scale): + for attn_processor in self.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + @torch.no_grad() + def __call__( # noqa: C901 + self, + prompt: Optional[Union[str, List[str]]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # get init conditioning scale + for attn_processor in self.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + conditioning_scale = attn_processor.scale + break + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end): + self.set_scale(0.0) + else: + self.set_scale(conditioning_scale) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if output_type != "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/third_party/ip_adapter/ip_adapter.py b/third_party/ip_adapter/ip_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdeaf6004605743391c38b01a44391a84080457 --- /dev/null +++ b/third_party/ip_adapter/ip_adapter.py @@ -0,0 +1,745 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +if is_torch2_available(): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +from .resampler import Resampler + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class IPAdapter: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.bfloat16 + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.bfloat16) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.bfloat16) + unet.set_attn_processor(attn_procs) + if hasattr(self.pipe, "controlnet"): + if isinstance(self.pipe.controlnet, MultiControlNetModel): + for controlnet in self.pipe.controlnet.nets: + controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + else: + self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds + else: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image=None, + clip_image_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if pil_image is not None: + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + else: + num_prompts = clip_image_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, clip_image_embeds=clip_image_embeds + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.bfloat16) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.bfloat16) + return image_proj_model + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.bfloat16) + print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + # if isinstance(pil_image, Image.Image): + # pil_image = [pil_image] + # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + # clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + # print('clip_image shape', clip_image.shape) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros([clip_image_embeds.shape[0], 3, 224, 224]).to(device=clip_image_embeds.device, dtype=clip_image_embeds.dtype), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + # print(uncond_clip_image_embeds.shape, uncond_image_prompt_embeds.shape) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + def generate_from_feat( + self, + feat, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + # num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + num_prompts = len(feat) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_feat(clip_image_embeds=feat) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + # print(prompt_embeds.shape, image_prompt_embeds.shape, uncond_image_prompt_embeds.shape, negative_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + + def generate_dual( + self, + pil_image, + pil_image1, + improve=1.0, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + image_prompt_embeds1, uncond_image_prompt_embeds1 = self.get_image_embeds(pil_image1) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + bs_embed, seq_len, _ = image_prompt_embeds1.shape + image_prompt_embeds1 = image_prompt_embeds1.repeat(1, num_samples, 1) + image_prompt_embeds1 = image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.repeat(1, num_samples, 1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + # if improve < 1.0: + # pass + # else: + image_prompt_embeds = image_prompt_embeds1 + improve * (image_prompt_embeds - image_prompt_embeds1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + +class IPAdapterPlusXLVLM(IPAdapter): + """SDXL""" + def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=16): + self.device = device + # self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + # self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + # self.device, dtype=torch.bfloat16 + # ) + # self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + # dim=768, + depth=4, + dim_head=64, + heads=20, + # heads=12, + num_queries=self.num_tokens, + # num_queries=257, + embedding_dim=5120, + output_dim=2048, + # output_dim=768, + ff_mult=4, + ).to(self.device, dtype=torch.bfloat16) + # Resampler( + # dim=1280, + # depth=4, + # dim_head=64, + # heads=20, + # num_queries=self.num_tokens, + # embedding_dim=self.image_encoder.config.hidden_size, + # output_dim=self.pipe.unet.config.cross_attention_dim, + # ff_mult=4, + # ).to(self.device, dtype=torch.bfloat16) + # print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + + def generate_from_vlm( + self, + image_prompt_embeds, + uncond_image_prompt_embeds, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 + # if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + diff --git a/third_party/ip_adapter/ip_adapter_faceid.py b/third_party/ip_adapter/ip_adapter_faceid.py new file mode 100644 index 0000000000000000000000000000000000000000..fe98ad540648d429a3a227733c8607626fe0784e --- /dev/null +++ b/third_party/ip_adapter/ip_adapter_faceid.py @@ -0,0 +1,542 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor +from .utils import is_torch2_available, get_generator + +USE_DAFAULT_ATTN = False # should be True for visualization_attnmap +if is_torch2_available() and (not USE_DAFAULT_ATTN): + from .attention_processor_faceid import ( + LoRAAttnProcessor2_0 as LoRAAttnProcessor, + ) + from .attention_processor_faceid import ( + LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor, + ) +else: + from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor +from .resampler import PerceiverAttention, FeedForward + + +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + + +class ProjPlusModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): + + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = x + scale * out + return out + + +class IPAdapterFaceID: + def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): + self.device = device + self.ip_ckpt = ip_ckpt + self.lora_rank = lora_rank + self.num_tokens = num_tokens + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ).to(self.device, dtype=self.torch_dtype) + else: + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds): + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, LoRAIPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlus: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.lora_rank = lora_rank + self.num_tokens = num_tokens + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=self.torch_dtype + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ProjPlusModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, + ).to(self.device, dtype=self.torch_dtype) + else: + attn_procs[name] = LoRAIPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut): + if isinstance(face_image, Image.Image): + pil_image = [face_image] + clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=self.torch_dtype) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, LoRAIPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=False, + **kwargs, + ): + self.set_scale(scale) + + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDXL(IPAdapterFaceID): + """SDXL""" + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): + """SDXL""" + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=True, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + **kwargs, + ).images + + return images diff --git a/third_party/ip_adapter/ip_adapter_faceid_separate.py b/third_party/ip_adapter/ip_adapter_faceid_separate.py new file mode 100644 index 0000000000000000000000000000000000000000..7c34e7c3bfd009e69823e1a76b91fe423230cfb6 --- /dev/null +++ b/third_party/ip_adapter/ip_adapter_faceid_separate.py @@ -0,0 +1,556 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +USE_DAFAULT_ATTN = False # should be True for visualization_attnmap +if is_torch2_available() and (not USE_DAFAULT_ATTN): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, IPAttnProcessor +from .resampler import PerceiverAttention, FeedForward + + +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + + +class ProjPlusModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0): + + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = x + scale * out + return out + + +class IPAdapterFaceID: + def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, n_cond=1, torch_dtype=torch.float16): + self.device = device + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + self.n_cond = n_cond + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens*self.n_cond, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds): + + multi_face = False + if faceid_embeds.dim() == 3: + multi_face = True + b, n, c = faceid_embeds.shape + faceid_embeds = faceid_embeds.reshape(b*n, c) + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds)) + if multi_face: + c = image_prompt_embeds.size(-1) + image_prompt_embeds = image_prompt_embeds.reshape(b, -1, c) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.reshape(b, -1, c) + + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + else: + faceid_embeds = faceid_embeds.repeat(num_samples, 1, 1) + num_samples = 1 + + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + num_images_per_prompt=num_samples, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlus: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + self.torch_dtype = torch_dtype + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=self.torch_dtype + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ProjPlusModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens, + ).to(self.device, dtype=self.torch_dtype) + unet.set_attn_processor(attn_procs) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut): + if isinstance(face_image, Image.Image): + pil_image = [face_image] + clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=self.torch_dtype) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + + faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) + image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, LoRAIPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=False, + **kwargs, + ): + self.set_scale(scale) + + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDXL(IPAdapterFaceID): + """SDXL""" + + def generate( + self, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + else: + faceid_embeds = faceid_embeds.repeat(num_samples, 1, 1) + num_samples = 1 + + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + num_images_per_prompt=num_samples, + **kwargs, + ).images + + return images + + +class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus): + """SDXL""" + + def generate( + self, + face_image=None, + faceid_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + s_scale=1.0, + shortcut=True, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = faceid_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + **kwargs, + ).images + + return images diff --git a/third_party/ip_adapter/ip_adapter_fp16.py b/third_party/ip_adapter/ip_adapter_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..32285e93192ba4062b965b0ad930f24268d991b5 --- /dev/null +++ b/third_party/ip_adapter/ip_adapter_fp16.py @@ -0,0 +1,747 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +if is_torch2_available(): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +from .resampler import Resampler + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class IPAdapter: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + print('fp16 ipadapter loaded') + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + unet.set_attn_processor(attn_procs) + if hasattr(self.pipe, "controlnet"): + if isinstance(self.pipe.controlnet, MultiControlNetModel): + for controlnet in self.pipe.controlnet.nets: + controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + else: + self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds + else: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image=None, + clip_image_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if pil_image is not None: + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + else: + num_prompts = clip_image_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, clip_image_embeds=clip_image_embeds + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + # if isinstance(pil_image, Image.Image): + # pil_image = [pil_image] + # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + # clip_image = clip_image.to(self.device, dtype=torch.float16) + # print('clip_image shape', clip_image.shape) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros([clip_image_embeds.shape[0], 3, 224, 224]).to(device=clip_image_embeds.device, dtype=clip_image_embeds.dtype), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + # print(uncond_clip_image_embeds.shape, uncond_image_prompt_embeds.shape) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + # negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + # generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=torch.Generator(self.device).manual_seed(seed), + **kwargs, + ).images + + return images + + def generate_from_feat( + self, + feat, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + # num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + num_prompts = len(feat) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_feat(clip_image_embeds=feat) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + # image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + # uncond_image_prompt_embeds = image_prompt_embeds + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + # negative_prompt=negative_prompt, + ) + # print(prompt_embeds.shape, image_prompt_embeds.shape, uncond_image_prompt_embeds.shape, negative_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + # generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=torch.Generator().manual_seed(seed), + **kwargs, + ).images + + return images + + + def generate_dual( + self, + pil_image, + pil_image1, + improve=1.0, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + image_prompt_embeds1, uncond_image_prompt_embeds1 = self.get_image_embeds(pil_image1) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + bs_embed, seq_len, _ = image_prompt_embeds1.shape + image_prompt_embeds1 = image_prompt_embeds1.repeat(1, num_samples, 1) + image_prompt_embeds1 = image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.repeat(1, num_samples, 1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + # if improve < 1.0: + # pass + # else: + image_prompt_embeds = image_prompt_embeds1 + improve * (image_prompt_embeds - image_prompt_embeds1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + +class IPAdapterPlusXLVLM(IPAdapter): + """SDXL""" + def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=16): + self.device = device + # self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + # self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + # self.device, dtype=torch.float16 + # ) + # self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + # dim=768, + depth=4, + dim_head=64, + heads=20, + # heads=12, + num_queries=self.num_tokens, + # num_queries=257, + embedding_dim=5120, + output_dim=2048, + # output_dim=768, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + # Resampler( + # dim=1280, + # depth=4, + # dim_head=64, + # heads=20, + # num_queries=self.num_tokens, + # embedding_dim=self.image_encoder.config.hidden_size, + # output_dim=self.pipe.unet.config.cross_attention_dim, + # ff_mult=4, + # ).to(self.device, dtype=torch.float16) + # print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + + def generate_from_vlm( + self, + image_prompt_embeds, + uncond_image_prompt_embeds, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 + # if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + diff --git a/third_party/ip_adapter/ip_adapter_fp16_backup.py b/third_party/ip_adapter/ip_adapter_fp16_backup.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd4849db0d79c60d78e6dbeca70a7394f0dbd87 --- /dev/null +++ b/third_party/ip_adapter/ip_adapter_fp16_backup.py @@ -0,0 +1,745 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +if is_torch2_available(): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +from .resampler import Resampler + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class IPAdapter: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float16) + unet.set_attn_processor(attn_procs) + if hasattr(self.pipe, "controlnet"): + if isinstance(self.pipe.controlnet, MultiControlNetModel): + for controlnet in self.pipe.controlnet.nets: + controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + else: + self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds + else: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image=None, + clip_image_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if pil_image is not None: + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + else: + num_prompts = clip_image_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, clip_image_embeds=clip_image_embeds + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.float16) + return image_proj_model + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + # if isinstance(pil_image, Image.Image): + # pil_image = [pil_image] + # clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + # clip_image = clip_image.to(self.device, dtype=torch.float16) + # print('clip_image shape', clip_image.shape) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros([clip_image_embeds.shape[0], 3, 224, 224]).to(device=clip_image_embeds.device, dtype=clip_image_embeds.dtype), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + # print(uncond_clip_image_embeds.shape, uncond_image_prompt_embeds.shape) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + def generate_from_feat( + self, + feat, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + # num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + num_prompts = len(feat) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_feat(clip_image_embeds=feat) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + # print(prompt_embeds.shape, image_prompt_embeds.shape, uncond_image_prompt_embeds.shape, negative_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + + def generate_dual( + self, + pil_image, + pil_image1, + improve=1.0, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + image_prompt_embeds1, uncond_image_prompt_embeds1 = self.get_image_embeds(pil_image1) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + bs_embed, seq_len, _ = image_prompt_embeds1.shape + image_prompt_embeds1 = image_prompt_embeds1.repeat(1, num_samples, 1) + image_prompt_embeds1 = image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.repeat(1, num_samples, 1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + # if improve < 1.0: + # pass + # else: + image_prompt_embeds = image_prompt_embeds1 + improve * (image_prompt_embeds - image_prompt_embeds1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + +class IPAdapterPlusXLVLM(IPAdapter): + """SDXL""" + def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=16): + self.device = device + # self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + # self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + # self.device, dtype=torch.float16 + # ) + # self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + # dim=768, + depth=4, + dim_head=64, + heads=20, + # heads=12, + num_queries=self.num_tokens, + # num_queries=257, + embedding_dim=5120, + output_dim=2048, + # output_dim=768, + ff_mult=4, + ).to(self.device, dtype=torch.float16) + # Resampler( + # dim=1280, + # depth=4, + # dim_head=64, + # heads=20, + # num_queries=self.num_tokens, + # embedding_dim=self.image_encoder.config.hidden_size, + # output_dim=self.pipe.unet.config.cross_attention_dim, + # ff_mult=4, + # ).to(self.device, dtype=torch.float16) + # print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + + def generate_from_vlm( + self, + image_prompt_embeds, + uncond_image_prompt_embeds, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 + # if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + diff --git a/third_party/ip_adapter/ip_adapter_old.py b/third_party/ip_adapter/ip_adapter_old.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1d602a87feff5f66e5949396ffbf4f0bde39e5 --- /dev/null +++ b/third_party/ip_adapter/ip_adapter_old.py @@ -0,0 +1,813 @@ +import os +from typing import List + +import torch +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.controlnet import MultiControlNetModel +from PIL import Image +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from .utils import is_torch2_available, get_generator + +if is_torch2_available(): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +from .resampler import Resampler + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class IPAdapter: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.bfloat16 + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.bfloat16) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.bfloat16) + unet.set_attn_processor(attn_procs) + if hasattr(self.pipe, "controlnet"): + if isinstance(self.pipe.controlnet, MultiControlNetModel): + for controlnet in self.pipe.controlnet.nets: + controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + else: + self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds + else: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image=None, + clip_image_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if pil_image is not None: + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + else: + num_prompts = clip_image_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, clip_image_embeds=clip_image_embeds + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + self.generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=self.generator, + **kwargs, + ).images + + return images + + +class IPAdapterPlus(IPAdapter): + """IP-Adapter with fine-grained features""" + + def init_proj(self): + image_proj_model = Resampler( + dim=self.pipe.unet.config.cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.bfloat16) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.bfloat16) + return image_proj_model + + +class IPAdapterPlusXL(IPAdapter): + """SDXL""" + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=self.num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=self.pipe.unet.config.cross_attention_dim, + ff_mult=4, + ).to(self.device, dtype=torch.bfloat16) + print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + def generate( + self, + pil_image, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + def generate_from_feat( + self, + pil_image, + feat, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + # num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + num_prompts = len(feat) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_feat(pil_image=pil_image, clip_image_embeds=feat) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + def generate_from_feat_and_image( + self, + pil_image, + feat, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + source_image=None, + strength=0.5, + **kwargs, + ): + self.set_scale(scale) + + # num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + num_prompts = len(feat) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_feat(pil_image=pil_image, clip_image_embeds=feat) + + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + image=source_image, + strength=strength, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + + def generate_dual( + self, + pil_image, + pil_image1, + improve=1.0, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + image_prompt_embeds1, uncond_image_prompt_embeds1 = self.get_image_embeds(pil_image1) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + bs_embed, seq_len, _ = image_prompt_embeds1.shape + image_prompt_embeds1 = image_prompt_embeds1.repeat(1, num_samples, 1) + image_prompt_embeds1 = image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.repeat(1, num_samples, 1) + # uncond_image_prompt_embeds1 = uncond_image_prompt_embeds1.view(bs_embed * num_samples, seq_len, -1) + # image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + # if improve < 1.0: + # pass + # else: + image_prompt_embeds = image_prompt_embeds1 + improve * (image_prompt_embeds - image_prompt_embeds1) + + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + +class IPAdapterPlusXLVLM(IPAdapter): + """SDXL""" + def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=16): + self.device = device + # self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.set_ip_adapter() + + # load image encoder + # self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + # self.device, dtype=torch.bfloat16 + # ) + # self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + + def init_proj(self): + image_proj_model = Resampler( + dim=1280, + # dim=768, + depth=4, + dim_head=64, + heads=20, + # heads=12, + num_queries=self.num_tokens, + # num_queries=257, + embedding_dim=5120, + output_dim=2048, + # output_dim=768, + ff_mult=4, + ).to(self.device, dtype=torch.bfloat16) + # Resampler( + # dim=1280, + # depth=4, + # dim_head=64, + # heads=20, + # num_queries=self.num_tokens, + # embedding_dim=self.image_encoder.config.hidden_size, + # output_dim=self.pipe.unet.config.cross_attention_dim, + # ff_mult=4, + # ).to(self.device, dtype=torch.bfloat16) + # print(self.num_tokens, self.image_encoder.config.hidden_size, self.pipe.unet.config.cross_attention_dim) + return image_proj_model + + def save_proj(self, path): + torch.save(self.image_proj_model.state_dict(), path) + + @torch.inference_mode() + def get_image_embeds(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + @torch.inference_mode() + def get_image_embeds_feat(self, pil_image=None, clip_image_embeds=None): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.bfloat16) + # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + image_prompt_embeds = clip_image_embeds + # self.image_proj_model(clip_image_embeds) + uncond_clip_image_embeds = self.image_encoder( + torch.zeros_like(clip_image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) + return image_prompt_embeds, uncond_image_prompt_embeds + + + def generate_from_vlm( + self, + image_prompt_embeds, + uncond_image_prompt_embeds, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + num_prompts = 1 + # if isinstance(pil_image, Image.Image) else len(pil_image) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + # image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + image_prompt_embeds = image_prompt_embeds[: , torch.randperm(image_prompt_embeds.shape[1]) ,:] + with torch.inference_mode(): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + print(prompt_embeds.shape, image_prompt_embeds.shape) + prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + num_inference_steps=num_inference_steps, + generator=generator, + **kwargs, + ).images + + return images + + diff --git a/third_party/ip_adapter/resampler.py b/third_party/ip_adapter/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..24266671d02092438ae6576336a59659fef9c054 --- /dev/null +++ b/third_party/ip_adapter/resampler.py @@ -0,0 +1,158 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py + +import math + +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) diff --git a/third_party/ip_adapter/sd3_attention_processor.py b/third_party/ip_adapter/sd3_attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..32b8ff7b6a215f4611f3dbdafaba6fd7e55e782a --- /dev/null +++ b/third_party/ip_adapter/sd3_attention_processor.py @@ -0,0 +1,179 @@ +from typing import Callable, List, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.models.attention_processor import Attention + + +class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + +class IPJointAttnProcessor2_0(torch.nn.Module): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, context_dim, hidden_dim, scale=1.0): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + super().__init__() + self.scale = scale + + self.add_k_proj_ip = nn.Linear(context_dim, hidden_dim) + self.add_v_proj_ip = nn.Linear(context_dim, hidden_dim) + + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + ip_hidden_states: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + sample_query = query # latent query + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + # attention + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + + # for ip-adapter + ip_key = self.add_k_proj_ip(ip_hidden_states) + ip_value = self.add_v_proj_ip(ip_hidden_states) + ip_query = sample_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + ip_hidden_states = F.scaled_dot_product_attention(ip_query, ip_key, ip_value, dropout_p=0.0, is_causal=False) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(ip_query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + diff --git a/third_party/ip_adapter/test_resampler.py b/third_party/ip_adapter/test_resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8978c8e19c0f6326fc849930086253db53a8a17b --- /dev/null +++ b/third_party/ip_adapter/test_resampler.py @@ -0,0 +1,44 @@ +import torch +from resampler import Resampler +from transformers import CLIPVisionModel + +BATCH_SIZE = 2 +OUTPUT_DIM = 1280 +NUM_QUERIES = 8 +NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior) +APPLY_POS_EMB = True # False for no positional embeddings (previous behavior) +IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + + +def main(): + image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) + embedding_dim = image_encoder.config.hidden_size + print(f"image_encoder hidden size: ", embedding_dim) + + image_proj_model = Resampler( + dim=1024, + depth=2, + dim_head=64, + heads=16, + num_queries=NUM_QUERIES, + embedding_dim=embedding_dim, + output_dim=OUTPUT_DIM, + ff_mult=2, + max_seq_len=257, + apply_pos_emb=APPLY_POS_EMB, + num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, + ) + + dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) + with torch.no_grad(): + image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] + print("image_embds shape: ", image_embeds.shape) + + with torch.no_grad(): + ip_tokens = image_proj_model(image_embeds) + print("ip_tokens shape:", ip_tokens.shape) + assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) + + +if __name__ == "__main__": + main() diff --git a/third_party/ip_adapter/utils.py b/third_party/ip_adapter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a273358585962fdf383d0bb7a0e1c654b4999b8 --- /dev/null +++ b/third_party/ip_adapter/utils.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image + +attn_maps = {} +def hook_fn(name): + def forward_hook(module, input, output): + if hasattr(module.processor, "attn_map"): + attn_maps[name] = module.processor.attn_map + del module.processor.attn_map + + return forward_hook + +def register_cross_attention_hook(unet): + for name, module in unet.named_modules(): + if name.split('.')[-1].startswith('attn2'): + module.register_forward_hook(hook_fn(name)) + + return unet + +def upscale(attn_map, target_size): + attn_map = torch.mean(attn_map, dim=0) + attn_map = attn_map.permute(1,0) + temp_size = None + + for i in range(0,5): + scale = 2 ** i + if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64: + temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8)) + break + + assert temp_size is not None, "temp_size cannot is None" + + attn_map = attn_map.view(attn_map.shape[0], *temp_size) + + attn_map = F.interpolate( + attn_map.unsqueeze(0).to(dtype=torch.float32), + size=target_size, + mode='bilinear', + align_corners=False + )[0] + + attn_map = torch.softmax(attn_map, dim=0) + return attn_map +def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True): + + idx = 0 if instance_or_negative else 1 + net_attn_maps = [] + + for name, attn_map in attn_maps.items(): + attn_map = attn_map.cpu() if detach else attn_map + attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze() + attn_map = upscale(attn_map, image_size) + net_attn_maps.append(attn_map) + + net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0) + + return net_attn_maps + +def attnmaps2images(net_attn_maps): + + #total_attn_scores = 0 + images = [] + + for attn_map in net_attn_maps: + attn_map = attn_map.cpu().numpy() + #total_attn_scores += attn_map.mean().item() + + normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255 + normalized_attn_map = normalized_attn_map.astype(np.uint8) + #print("norm: ", normalized_attn_map.shape) + image = Image.fromarray(normalized_attn_map) + + #image = fix_save_attn_map(attn_map) + images.append(image) + + #print(total_attn_scores) + return images +def is_torch2_available(): + return hasattr(F, "scaled_dot_product_attention") + +def get_generator(seed, device): + + if seed is not None: + if isinstance(seed, list): + generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed] + else: + generator = torch.Generator(device).manual_seed(seed) + else: + generator = None + + return generator \ No newline at end of file diff --git a/third_party/open_clip/__init__.py b/third_party/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c328ed24f54803a32e10f712a540fff59ef50175 --- /dev/null +++ b/third_party/open_clip/__init__.py @@ -0,0 +1,14 @@ +from .coca_model import CoCa +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss +from .factory import list_models, add_model_config, get_model_config, load_checkpoint +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ + convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ + get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained +from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub +from .tokenizer import SimpleTokenizer, tokenize, decode +from .transform import image_transform, AugmentationCfg +from .utils import freeze_batch_norm_2d diff --git a/third_party/open_clip/__pycache__/__init__.cpython-310.pyc b/third_party/open_clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d24fecd1ce1ccb0debbc60f120f7a9eb5e09a0a9 Binary files /dev/null and b/third_party/open_clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/__init__.cpython-38.pyc b/third_party/open_clip/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5b97d1bac8e74c7a9b58c9621204252793f91c5 Binary files /dev/null and b/third_party/open_clip/__pycache__/__init__.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/__init__.cpython-39.pyc b/third_party/open_clip/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8af96bdc26440c65338813769d8b5ae65b692b5 Binary files /dev/null and b/third_party/open_clip/__pycache__/__init__.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/coca_model.cpython-310.pyc b/third_party/open_clip/__pycache__/coca_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93f118af6eaf22070b70e8fa03f4645ee04407bc Binary files /dev/null and b/third_party/open_clip/__pycache__/coca_model.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/coca_model.cpython-38.pyc b/third_party/open_clip/__pycache__/coca_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e9922d5e747ce3e2b29983953f00357bbbc9ea5 Binary files /dev/null and b/third_party/open_clip/__pycache__/coca_model.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/coca_model.cpython-39.pyc b/third_party/open_clip/__pycache__/coca_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59f7b895897ce7c0b11304262a9f925b4398bec0 Binary files /dev/null and b/third_party/open_clip/__pycache__/coca_model.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/constants.cpython-310.pyc b/third_party/open_clip/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..715b3943ccd116efa137b7134e2d067e0a368dff Binary files /dev/null and b/third_party/open_clip/__pycache__/constants.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/constants.cpython-38.pyc b/third_party/open_clip/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..291a24fd997b6c2af8c29782de8351939ac32bff Binary files /dev/null and b/third_party/open_clip/__pycache__/constants.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/constants.cpython-39.pyc b/third_party/open_clip/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f6af69b3be21c451a2dbbe69deefb2bd3da555c Binary files /dev/null and b/third_party/open_clip/__pycache__/constants.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/factory.cpython-310.pyc b/third_party/open_clip/__pycache__/factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0f1c581290d153138ad21adf0174fcc6321abd8 Binary files /dev/null and b/third_party/open_clip/__pycache__/factory.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/factory.cpython-38.pyc b/third_party/open_clip/__pycache__/factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55576583820f1a2cd6b83328a633e8a04ae5aa5c Binary files /dev/null and b/third_party/open_clip/__pycache__/factory.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/factory.cpython-39.pyc b/third_party/open_clip/__pycache__/factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93c7883a9ae21b85152ffc883bbc307ab59e6db6 Binary files /dev/null and b/third_party/open_clip/__pycache__/factory.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/hf_configs.cpython-310.pyc b/third_party/open_clip/__pycache__/hf_configs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c85eb6120611ab306968e62d68cac095368233 Binary files /dev/null and b/third_party/open_clip/__pycache__/hf_configs.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/hf_configs.cpython-38.pyc b/third_party/open_clip/__pycache__/hf_configs.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a2412e692afb63cd03cdb1b1fe3ea9e196a82d1 Binary files /dev/null and b/third_party/open_clip/__pycache__/hf_configs.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/hf_configs.cpython-39.pyc b/third_party/open_clip/__pycache__/hf_configs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b769dd7b68d9d293d0045111b1effbd58e972f1 Binary files /dev/null and b/third_party/open_clip/__pycache__/hf_configs.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/hf_model.cpython-310.pyc b/third_party/open_clip/__pycache__/hf_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33fc7751cee8b6bc2a1279fbf94313f9e1291793 Binary files /dev/null and b/third_party/open_clip/__pycache__/hf_model.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/hf_model.cpython-38.pyc b/third_party/open_clip/__pycache__/hf_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dd540de284ae4ed7818945b953732707288e307 Binary files /dev/null and b/third_party/open_clip/__pycache__/hf_model.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/hf_model.cpython-39.pyc b/third_party/open_clip/__pycache__/hf_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8cc28d69acf45a4f826ba1d606ab07d76734c0e Binary files /dev/null and b/third_party/open_clip/__pycache__/hf_model.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/loss.cpython-310.pyc b/third_party/open_clip/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebecae0f3c9a3eba9de584003cf7f04822a92ced Binary files /dev/null and b/third_party/open_clip/__pycache__/loss.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/loss.cpython-38.pyc b/third_party/open_clip/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6779c3cbc1f263c6e093e0530da0a164f5b1b6bb Binary files /dev/null and b/third_party/open_clip/__pycache__/loss.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/loss.cpython-39.pyc b/third_party/open_clip/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1aaec0ecb1800d1a62c4115da0b550a10980b5c Binary files /dev/null and b/third_party/open_clip/__pycache__/loss.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/model.cpython-310.pyc b/third_party/open_clip/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e283455e4fa119d14e6158d60d77a3303d1b062e Binary files /dev/null and b/third_party/open_clip/__pycache__/model.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/model.cpython-38.pyc b/third_party/open_clip/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aadde3139c403d2c4668d6ab2752c17628a3dd9 Binary files /dev/null and b/third_party/open_clip/__pycache__/model.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/model.cpython-39.pyc b/third_party/open_clip/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e95cf754179067db3e71a57be0dfd9b49d513e50 Binary files /dev/null and b/third_party/open_clip/__pycache__/model.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/modified_resnet.cpython-310.pyc b/third_party/open_clip/__pycache__/modified_resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a226fbca9cffae3e8c16bd58446f477cbe90fb94 Binary files /dev/null and b/third_party/open_clip/__pycache__/modified_resnet.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/modified_resnet.cpython-38.pyc b/third_party/open_clip/__pycache__/modified_resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05d8541fd805e07c0566548f2d0e903fc8e27082 Binary files /dev/null and b/third_party/open_clip/__pycache__/modified_resnet.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/modified_resnet.cpython-39.pyc b/third_party/open_clip/__pycache__/modified_resnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffedf95fcde8dbe04c9314899ad532a96cd6841a Binary files /dev/null and b/third_party/open_clip/__pycache__/modified_resnet.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/openai.cpython-310.pyc b/third_party/open_clip/__pycache__/openai.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..217965522eabef20f1bae857e9fe13ae67cc41b9 Binary files /dev/null and b/third_party/open_clip/__pycache__/openai.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/openai.cpython-38.pyc b/third_party/open_clip/__pycache__/openai.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3042913d80ea4eecc14a6ac9be97ca460ef669e3 Binary files /dev/null and b/third_party/open_clip/__pycache__/openai.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/openai.cpython-39.pyc b/third_party/open_clip/__pycache__/openai.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84d4b1e2ad5ee6e60f60b6f3e16f919b0f569c94 Binary files /dev/null and b/third_party/open_clip/__pycache__/openai.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/pretrained.cpython-310.pyc b/third_party/open_clip/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..768ac72ae1b8c1b4466cbf32e9ff4c8e224fe743 Binary files /dev/null and b/third_party/open_clip/__pycache__/pretrained.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/pretrained.cpython-38.pyc b/third_party/open_clip/__pycache__/pretrained.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab33706d23c81f5d7005f9e0967865bde33654ed Binary files /dev/null and b/third_party/open_clip/__pycache__/pretrained.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/pretrained.cpython-39.pyc b/third_party/open_clip/__pycache__/pretrained.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b24cc462921dbdcbd17f93d21fc5414dbe33a6fd Binary files /dev/null and b/third_party/open_clip/__pycache__/pretrained.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/push_to_hf_hub.cpython-310.pyc b/third_party/open_clip/__pycache__/push_to_hf_hub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a161c86fe8b79f0db27def277017c08922283c44 Binary files /dev/null and b/third_party/open_clip/__pycache__/push_to_hf_hub.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/timm_model.cpython-310.pyc b/third_party/open_clip/__pycache__/timm_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d85afb39359ff4f2d361dbee4a05ac89f51c910d Binary files /dev/null and b/third_party/open_clip/__pycache__/timm_model.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/timm_model.cpython-38.pyc b/third_party/open_clip/__pycache__/timm_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b16665175d3eb77cbe4b603460bab491a99d724 Binary files /dev/null and b/third_party/open_clip/__pycache__/timm_model.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/timm_model.cpython-39.pyc b/third_party/open_clip/__pycache__/timm_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..031eb6be61bd38964fa3314edee9fc1f266c499a Binary files /dev/null and b/third_party/open_clip/__pycache__/timm_model.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/tokenizer.cpython-310.pyc b/third_party/open_clip/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b314c12fce3b966919ceebdcbce7914bd2cbca5a Binary files /dev/null and b/third_party/open_clip/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/tokenizer.cpython-38.pyc b/third_party/open_clip/__pycache__/tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c384d6d535039c531610904f6c09f30fa1d56d63 Binary files /dev/null and b/third_party/open_clip/__pycache__/tokenizer.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/tokenizer.cpython-39.pyc b/third_party/open_clip/__pycache__/tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..755f97405da4a37e5b52d51b578426dca130ac63 Binary files /dev/null and b/third_party/open_clip/__pycache__/tokenizer.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/transform.cpython-310.pyc b/third_party/open_clip/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b17b5ba224505df1407392870d5724e6d42c1ba Binary files /dev/null and b/third_party/open_clip/__pycache__/transform.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/transform.cpython-38.pyc b/third_party/open_clip/__pycache__/transform.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e38df8a6fb7c5fc7c2ce91f59c0fc9150584ab5a Binary files /dev/null and b/third_party/open_clip/__pycache__/transform.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/transform.cpython-39.pyc b/third_party/open_clip/__pycache__/transform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dceb09e7e62bd1d883de71453880baf8aa7e4845 Binary files /dev/null and b/third_party/open_clip/__pycache__/transform.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/transformer.cpython-310.pyc b/third_party/open_clip/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..804876f5136f70b160d488e302c65f855e63f083 Binary files /dev/null and b/third_party/open_clip/__pycache__/transformer.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/transformer.cpython-38.pyc b/third_party/open_clip/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a03c3e2cac1fe75d71d2468be1bb158ce887ffa8 Binary files /dev/null and b/third_party/open_clip/__pycache__/transformer.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/transformer.cpython-39.pyc b/third_party/open_clip/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3c13b42043a4d5ef60ee275e68a61c0e18ad632 Binary files /dev/null and b/third_party/open_clip/__pycache__/transformer.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/utils.cpython-310.pyc b/third_party/open_clip/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e387ad397e7c7c34d86bc16c7d7aa3ac6197583c Binary files /dev/null and b/third_party/open_clip/__pycache__/utils.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/utils.cpython-38.pyc b/third_party/open_clip/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4670b2339b32e5cd98596e550823b8b95ac42021 Binary files /dev/null and b/third_party/open_clip/__pycache__/utils.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/utils.cpython-39.pyc b/third_party/open_clip/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1606dcb91cdf1d2b7e9108e7afeb7ea966caaa58 Binary files /dev/null and b/third_party/open_clip/__pycache__/utils.cpython-39.pyc differ diff --git a/third_party/open_clip/__pycache__/version.cpython-310.pyc b/third_party/open_clip/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f81e4b2a61668ab63f6c27e881a959fef92ac232 Binary files /dev/null and b/third_party/open_clip/__pycache__/version.cpython-310.pyc differ diff --git a/third_party/open_clip/__pycache__/version.cpython-38.pyc b/third_party/open_clip/__pycache__/version.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86e013f1efefcd1b61993cecc2e0b59560ea848a Binary files /dev/null and b/third_party/open_clip/__pycache__/version.cpython-38.pyc differ diff --git a/third_party/open_clip/__pycache__/version.cpython-39.pyc b/third_party/open_clip/__pycache__/version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..537e94be8e754d72e22bcfb8870682acc4f6b70c Binary files /dev/null and b/third_party/open_clip/__pycache__/version.cpython-39.pyc differ diff --git a/third_party/open_clip/bpe_simple_vocab_16e6.txt.gz b/third_party/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/third_party/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/third_party/open_clip/coca_model.py b/third_party/open_clip/coca_model.py new file mode 100644 index 0000000000000000000000000000000000000000..039453af70d1c865dd7cc6016f732aff2f7dc3d2 --- /dev/null +++ b/third_party/open_clip/coca_model.py @@ -0,0 +1,458 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np +from dataclasses import dataclass + +from .transformer import ( + LayerNormFp32, + LayerNorm, + QuickGELU, + MultimodalTransformer, +) +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize=True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize=True, embed_cls=True): + text = text[:, :-1] if embed_cls else text # make space for CLS token + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize=True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize=True, embed_cls=True): + text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) + return text_latent + + def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + # TODO: add assertion to avoid bugs? + labels = text[:, -token_embs.shape[1]:] + + logits = self.text_decoder(image_embs, token_embs) + return { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "labels": labels, + "logit_scale": self.logit_scale.exp() + } + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + embed_cls=False, + image_latent=image_latent, + image_embs=image_embs + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/third_party/open_clip/constants.py b/third_party/open_clip/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee --- /dev/null +++ b/third_party/open_clip/constants.py @@ -0,0 +1,2 @@ +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) diff --git a/third_party/open_clip/factory.py b/third_party/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..00f0bb440de3af56a1a0bf3a6537f832780a91bb --- /dev/null +++ b/third_party/open_clip/factory.py @@ -0,0 +1,433 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from turtle import forward +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ + resize_pos_embed, get_cast_dtype +from .coca_model import CoCa +from .loss import ClipLoss, DistillClipLoss, CoCaLoss +from .openai import load_openai_model +from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf +from .transform import image_transform, AugmentationCfg +from .tokenizer import HFTokenizer, tokenize + + +HF_HUB_PREFIX = 'hf-hub:' +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """ add model config path or file and update registry """ + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def get_model_config(model_name): + if model_name in _MODEL_CONFIGS: + return deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + + +def get_tokenizer(model_name): + if model_name.startswith(HF_HUB_PREFIX): + tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) + else: + config = get_model_config(model_name) + tokenizer = HFTokenizer( + config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize + return tokenizer + + +def load_state_dict(checkpoint_path: str, map_location='cpu'): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def load_checkpoint(model, checkpoint_path, strict=True): + state_dict = load_state_dict(checkpoint_path) + # detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, +): + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) + + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + pretrained_cfg = config['preprocess_cfg'] + model_cfg = config['model_cfg'] + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + pretrained_cfg = {} + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + if pretrained and pretrained.lower() == 'openai': + logging.info(f'Loading pretrained {model_name} from OpenAI.') + model = load_openai_model( + model_name, + precision=precision, + device=device, + jit=jit, + cache_dir=cache_dir, + ) + + # to always output dict even if it is clip + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + else: + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + if pretrained_image: + if 'timm_model_name' in model_cfg.get('vision_cfg', {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + if custom_text: + if is_hf_model: + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + if "coca" in model_name: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + model.to(device=device) + if precision in ("fp16", "bf16"): + convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) + + # set image / mean metadata from pretrained_cfg if available, or use default + model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN + model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD + + # to always output dict even if it is clip + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + return model + + +def create_loss(args): + if args.distill: + return DistillClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + elif "coca" in args.model.lower(): + return CoCaLoss( + caption_loss_weight=args.coca_caption_loss_weight, + clip_loss_weight=args.coca_contrastive_loss_weight, + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + return ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + ) + +class MLP(torch.nn.Module): + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.layers = torch.nn.Sequential( + torch.nn.Linear(self.input_size, 1024), + torch.nn.Dropout(0.2), + torch.nn.Linear(1024, 128), + torch.nn.Dropout(0.2), + torch.nn.Linear(128, 64), + torch.nn.Dropout(0.1), + torch.nn.Linear(64, 16), + torch.nn.Linear(16, 1) + ) + + def forward(self, x): + return self.layers(x) + +# class semantic_head(torch.nn.Module): +# def __init__(self, input_size): +# super().__init__() +# self.input_size = input_size # for ViT-L-14 is 1024 +# self.seg_head = torch.nn.Sequential( +# torch.nn.Linear(input_size, 128), +# torch.nn.Dropout(0.2), +# torch.nn.Linear(128, 64), +# torch.nn.Dropout(0.1), +# torch.nn.Linear(64, 16), +# torch.nn.Linear(16, 1), +# ) +# self.sigmoid = torch.nn.Sigmoid() + +# def forward(self, x): +# return self.sigmoid(self.seg_head(x)) + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + cache_dir: Optional[str] = None, + light_augmentation = False, + output_dict: Optional[bool] = None, + with_score_predictor: bool = False, + with_region_predictor: bool = False +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + ) + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + + if with_score_predictor: + model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype) + + if with_region_predictor: + # model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype) + model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype) + # preprocess_train = image_transform_region( + # model.visual.image_size, + # is_train=True, + # mean=image_mean, + # std=image_std + # ) + # preprocess_val = image_transform_region( + # model.visual.image_size, + # is_train=False, + # mean=image_mean, + # std=image_std + # ) + + if light_augmentation: + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + resize_longest_max=True, + ) + preprocess_train = preprocess_val + else: + preprocess_train = image_transform( + model.visual.image_size, + is_train=True, + mean=image_mean, + std=image_std + ) + preprocess_val = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std + ) + + return model, preprocess_train, preprocess_val + + +def create_model_from_pretrained( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + return_transform: bool = True, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + cache_dir: Optional[str] = None, +): + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_image_size=force_image_size, + cache_dir=cache_dir, + require_pretrained=True, + ) + + if not return_transform: + return model + + image_mean = image_mean or getattr(model.visual, 'image_mean', None) + image_std = image_std or getattr(model.visual, 'image_std', None) + preprocess = image_transform( + model.visual.image_size, + is_train=False, + mean=image_mean, + std=image_std, + ) + + return model, preprocess diff --git a/third_party/open_clip/generation_utils.py b/third_party/open_clip/generation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/open_clip/hf_configs.py b/third_party/open_clip/hf_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..e236222bafce0358445ea16953ca0b2d5a84758a --- /dev/null +++ b/third_party/open_clip/hf_configs.py @@ -0,0 +1,45 @@ +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, +} diff --git a/third_party/open_clip/hf_model.py b/third_party/open_clip/hf_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fbccc812757bf10b122ff14096980e0e38d1d221 --- /dev/null +++ b/third_party/open_clip/hf_model.py @@ -0,0 +1,176 @@ +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" + +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +from .hf_configs import arch_dict + + +# utils +def _camel2snake(s): + return re.sub(r'(? torch.Tensor: + # calculated ground-truth and cache if enabled + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + return labels + + def get_logits(self, image_features, text_features, logit_scale): + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, text_features, + self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = logit_scale * all_image_features @ all_text_features.T + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + return logits_per_image, logits_per_text + + def forward(self, image_features, text_features, logit_scale, output_dict=False): + device = image_features.device + logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) + + labels = self.get_ground_truth(device, logits_per_image.shape[0]) + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + return total_loss + +class PreferenceLoss(nn.Module): + + def forward(self, logits_per_image, num_images, labels): + + paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))] + paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999) + + ce_loss = F.cross_entropy(paired_logits, labels) + return ce_loss + +class HPSLoss(nn.Module): + + def forward(self, text_logits, labels): + + device = text_logits.device + text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1) + label_0, label_1 = labels.chunk(2, dim=-1) + + index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long) + text_0_logits = text_0_logits[index, index] + text_1_logits = text_1_logits[index, index] + text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1) + text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long) + text_1_labels = text_0_labels + 1 + + text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none") + text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none") + + text_loss = label_0 * text_0_loss + label_1 * text_1_loss + + # absolute_example_weight = 1 / num_per_prompt + # denominator = absolute_example_weight.sum() + # weight_per_example = absolute_example_weight / denominator + # text_loss *= weight_per_example + + text_loss = text_loss.sum() + return text_loss + +class RankingLoss(nn.Module): + + def forward(self, logits_per_image, num_images, labels, margin = 1.0): + paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))] + label_list = [label for label in labels.split(num_images.tolist())] + # ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)] + + paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1) + padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10) + + # regulized_logits = torch.log(torch.sigmoid(paired_logits)) + + diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2) + # diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2) + # diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1) + diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2)) + mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach() + + loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean() + return loss + +class CoCaLoss(ClipLoss): + def __init__( + self, + caption_loss_weight, + clip_loss_weight, + pad_id=0, # pad_token for open_clip custom tokenizer + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__( + local_loss=local_loss, + gather_with_grad=gather_with_grad, + cache_labels=cache_labels, + rank=rank, + world_size=world_size, + use_horovod=use_horovod + ) + + self.clip_loss_weight = clip_loss_weight + self.caption_loss_weight = caption_loss_weight + self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) + + def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): + clip_loss = super().forward(image_features, text_features, logit_scale) + clip_loss = self.clip_loss_weight * clip_loss + + caption_loss = self.caption_loss( + logits.permute(0, 2, 1), + labels, + ) + caption_loss = caption_loss * self.caption_loss_weight + + if output_dict: + return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} + + return clip_loss, caption_loss + + +class DistillClipLoss(ClipLoss): + + def dist_loss(self, teacher_logits, student_logits): + return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) + + def forward( + self, + image_features, + text_features, + logit_scale, + dist_image_features, + dist_text_features, + dist_logit_scale, + output_dict=False, + ): + logits_per_image, logits_per_text = \ + self.get_logits(image_features, text_features, logit_scale) + + dist_logits_per_image, dist_logits_per_text = \ + self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) + + labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) + + contrastive_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + distill_loss = ( + self.dist_loss(dist_logits_per_image, logits_per_image) + + self.dist_loss(dist_logits_per_text, logits_per_text) + ) / 2 + + if output_dict: + return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} + + return contrastive_loss, distill_loss diff --git a/third_party/open_clip/model.py b/third_party/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e347c42fc8df6464ca28e59adadba61e53a38add --- /dev/null +++ b/third_party/open_clip/model.py @@ -0,0 +1,461 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +import logging +import math +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from .hf_model import HFTextEncoder +from .modified_resnet import ModifiedResNet +from .timm_model import TimmModel +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer +from .utils import to_2tuple + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + output_tokens: bool = False + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj=text_cfg.proj, + pooler_type=text_cfg.pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + locked_layers = [] + locked_layers.append(self.token_embedding) + self.positional_embedding.requires_grad = False + if unlocked_layers > 0: + locked_layers.append(self.transformer.resblocks[:-unlocked_layers]) + else: + locked_layers.append(self.transformer) + locked_layers.append(self.ln_final) + self.text_projection.requires_grad = False + + # freeze layers + for module in locked_layers: + for n, p in module.named_parameters(): + p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def forward(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed diff --git a/third_party/open_clip/model_configs/RN101-quickgelu.json b/third_party/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/third_party/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/RN101.json b/third_party/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/third_party/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/RN50-quickgelu.json b/third_party/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/third_party/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/third_party/open_clip/model_configs/RN50.json b/third_party/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/third_party/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/RN50x16.json b/third_party/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/third_party/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/RN50x4.json b/third_party/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/third_party/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/RN50x64.json b/third_party/open_clip/model_configs/RN50x64.json new file mode 100644 index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7 --- /dev/null +++ b/third_party/open_clip/model_configs/RN50x64.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-B-16-plus-240.json b/third_party/open_clip/model_configs/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-B-16-plus.json b/third_party/open_clip/model_configs/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-B-16.json b/third_party/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-B-32-plus-256.json b/third_party/open_clip/model_configs/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-B-32-quickgelu.json b/third_party/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-B-32.json b/third_party/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-H-14.json b/third_party/open_clip/model_configs/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-H-16.json b/third_party/open_clip/model_configs/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-L-14-280.json b/third_party/open_clip/model_configs/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-L-14-336.json b/third_party/open_clip/model_configs/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-L-14.json b/third_party/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-L-16-320.json b/third_party/open_clip/model_configs/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-L-16.json b/third_party/open_clip/model_configs/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-M-16-alt.json b/third_party/open_clip/model_configs/ViT-M-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-M-16-alt.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16, + "ls_init_value": 1e-4 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-M-16.json b/third_party/open_clip/model_configs/ViT-M-16.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-M-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-M-32-alt.json b/third_party/open_clip/model_configs/ViT-M-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-M-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-M-32.json b/third_party/open_clip/model_configs/ViT-M-32.json new file mode 100644 index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-M-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 512, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-S-16-alt.json b/third_party/open_clip/model_configs/ViT-S-16-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-S-16-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-S-16.json b/third_party/open_clip/model_configs/ViT-S-16.json new file mode 100644 index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-S-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-S-32-alt.json b/third_party/open_clip/model_configs/ViT-S-32-alt.json new file mode 100644 index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-S-32-alt.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 256, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 256, + "heads": 4, + "layers": 10 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-S-32.json b/third_party/open_clip/model_configs/ViT-S-32.json new file mode 100644 index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-S-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 384, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 384, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 384, + "heads": 6, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-bigG-14.json b/third_party/open_clip/model_configs/ViT-bigG-14.json new file mode 100644 index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-bigG-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 48, + "width": 1664, + "head_width": 104, + "mlp_ratio": 4.9231, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 32 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-e-14.json b/third_party/open_clip/model_configs/ViT-e-14.json new file mode 100644 index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-e-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1280, + "vision_cfg": { + "image_size": 224, + "layers": 56, + "width": 1792, + "head_width": 112, + "mlp_ratio": 8.5715, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1280, + "heads": 20, + "layers": 36 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/ViT-g-14.json b/third_party/open_clip/model_configs/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091 --- /dev/null +++ b/third_party/open_clip/model_configs/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/coca_ViT-B-32.json b/third_party/open_clip/model_configs/coca_ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..7e7eb520a6a0096e5602d509ecd6186e278f4725 --- /dev/null +++ b/third_party/open_clip/model_configs/coca_ViT-B-32.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "attn_pooler_heads": 8 + }, + "custom_text": true +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/coca_ViT-L-14.json b/third_party/open_clip/model_configs/coca_ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..3d5ca4ca2338540f06852df5ff35ea6277e64555 --- /dev/null +++ b/third_party/open_clip/model_configs/coca_ViT-L-14.json @@ -0,0 +1,30 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "attentional_pool": true, + "attn_pooler_heads": 8, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "embed_cls": true, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "attn_pooler_heads": 12 + }, + "custom_text": true +} diff --git a/third_party/open_clip/model_configs/coca_base.json b/third_party/open_clip/model_configs/coca_base.json new file mode 100644 index 0000000000000000000000000000000000000000..cf8c6cecb78a49d7e7140145a0307cbd561077c2 --- /dev/null +++ b/third_party/open_clip/model_configs/coca_base.json @@ -0,0 +1,31 @@ +{ + "embed_dim": 512, + "multimodal_cfg": { + "width": 768, + "context_length": 76, + "vocab_size": 64000, + "mlp_ratio": 4, + "layers": 12, + "dim_head": 64, + "heads": 12, + "n_queries": 256, + "attn_pooler_heads": 8 + }, + "vision_cfg": { + "image_size": 288, + "layers": 12, + "width": 768, + "patch_size": 18, + "output_tokens": true + }, + "text_cfg": { + "context_length": 76, + "vocab_size": 64000, + "layers": 12, + "heads": 12, + "width": 768, + "embed_cls": true, + "output_tokens": true + }, + "custom_text": true +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/coca_roberta-ViT-B-32.json b/third_party/open_clip/model_configs/coca_roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..fb46354b95a17a46d7fcfd9d504e917ee6c1608c --- /dev/null +++ b/third_party/open_clip/model_configs/coca_roberta-ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "linear", + "width": 768, + "output_tokens": true + }, + "multimodal_cfg": { + "context_length": 76, + "width": 768, + "heads": 8, + "layers": 12 + }, + "custom_text": true +} diff --git a/third_party/open_clip/model_configs/convnext_base.json b/third_party/open_clip/model_configs/convnext_base.json new file mode 100644 index 0000000000000000000000000000000000000000..bb6dba181d950ea5081155c90d47e72c94816b80 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_base.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_base_w.json b/third_party/open_clip/model_configs/convnext_base_w.json new file mode 100644 index 0000000000000000000000000000000000000000..82ea7ae3659e5514f37ff982f0ab1141dff4bd18 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_base_w.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_base_w_320.json b/third_party/open_clip/model_configs/convnext_base_w_320.json new file mode 100644 index 0000000000000000000000000000000000000000..0a07c4e16abaa4015ecc5f82ec845de16e1f9d88 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_base_w_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "convnext_base", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_large.json b/third_party/open_clip/model_configs/convnext_large.json new file mode 100644 index 0000000000000000000000000000000000000000..c4a1fea73dbead71c218a0e74b9b15f9b252e3ef --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_large.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_large_d.json b/third_party/open_clip/model_configs/convnext_large_d.json new file mode 100644 index 0000000000000000000000000000000000000000..ae8fed21b58e1a6a411daf8b792ee50f0ab42346 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_large_d.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_large_d_320.json b/third_party/open_clip/model_configs/convnext_large_d_320.json new file mode 100644 index 0000000000000000000000000000000000000000..54c3df36a6f56ace0b12ada24c13058de96feed8 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_large_d_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "convnext_large", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "mlp", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 16 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_small.json b/third_party/open_clip/model_configs/convnext_small.json new file mode 100644 index 0000000000000000000000000000000000000000..3592c2a5cd21aae8d2544931773cf7603f67ea28 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_small.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "convnext_small", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_tiny.json b/third_party/open_clip/model_configs/convnext_tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..ad11470f5ec40ffec771096971ce58d3d5b9249b --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_tiny.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_tiny", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_xlarge.json b/third_party/open_clip/model_configs/convnext_xlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..2a909965932eef994177c829fefc2bdc1c219b3f --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_xlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 20 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_xxlarge.json b/third_party/open_clip/model_configs/convnext_xxlarge.json new file mode 100644 index 0000000000000000000000000000000000000000..23a55a681c346d1a315d8a163c1cb6ad495e6a91 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_xxlarge.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/convnext_xxlarge_320.json b/third_party/open_clip/model_configs/convnext_xxlarge_320.json new file mode 100644 index 0000000000000000000000000000000000000000..ac5134ca12cbaa97772cde059270d345386a74c7 --- /dev/null +++ b/third_party/open_clip/model_configs/convnext_xxlarge_320.json @@ -0,0 +1,19 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "convnext_xxlarge", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "timm_drop": 0.0, + "timm_drop_path": 0.1, + "image_size": 320 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/mt5-base-ViT-B-32.json b/third_party/open_clip/model_configs/mt5-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017 --- /dev/null +++ b/third_party/open_clip/model_configs/mt5-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "google/mt5-base", + "hf_tokenizer_name": "google/mt5-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/third_party/open_clip/model_configs/mt5-xl-ViT-H-14.json b/third_party/open_clip/model_configs/mt5-xl-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255 --- /dev/null +++ b/third_party/open_clip/model_configs/mt5-xl-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "google/mt5-xl", + "hf_tokenizer_name": "google/mt5-xl", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/third_party/open_clip/model_configs/roberta-ViT-B-32.json b/third_party/open_clip/model_configs/roberta-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260 --- /dev/null +++ b/third_party/open_clip/model_configs/roberta-ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "roberta-base", + "hf_tokenizer_name": "roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/third_party/open_clip/model_configs/swin_base_patch4_window7_224.json b/third_party/open_clip/model_configs/swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a --- /dev/null +++ b/third_party/open_clip/model_configs/swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/vit_medium_patch16_gap_256.json b/third_party/open_clip/model_configs/vit_medium_patch16_gap_256.json new file mode 100644 index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65 --- /dev/null +++ b/third_party/open_clip/model_configs/vit_medium_patch16_gap_256.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_medium_patch16_gap_256", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 256 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/third_party/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json new file mode 100644 index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1 --- /dev/null +++ b/third_party/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_relpos_medium_patch16_cls_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/third_party/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/third_party/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9 --- /dev/null +++ b/third_party/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json @@ -0,0 +1,15 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-base", + "hf_tokenizer_name": "xlm-roberta-base", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/third_party/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/third_party/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a --- /dev/null +++ b/third_party/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "hf_model_name": "xlm-roberta-large", + "hf_tokenizer_name": "xlm-roberta-large", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} diff --git a/third_party/open_clip/modified_resnet.py b/third_party/open_clip/modified_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8d3aeda91ecb394303becbbfccc8acd8cddcd9 --- /dev/null +++ b/third_party/open_clip/modified_resnet.py @@ -0,0 +1,181 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + +from .utils import freeze_batch_norm_2d + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x diff --git a/third_party/open_clip/openai.py b/third_party/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356 --- /dev/null +++ b/third_party/open_clip/openai.py @@ -0,0 +1,144 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import List, Optional, Union + +import torch + +from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype +from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_models_by_tag('openai') + + +def load_openai_model( + name: str, + precision: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + jit: bool = True, + cache_dir: Optional[str] = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + precision: str + Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + cache_dir : Optional[str] + The directory to cache the downloaded model weights + + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if precision is None: + precision = 'fp32' if device == 'cpu' else 'fp16' + + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + # Build a non-jit model from the OpenAI jitted model state dict + cast_dtype = get_cast_dtype(precision) + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) + + # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use + model = model.to(device) + if precision.startswith('amp') or precision == 'fp32': + model.float() + elif precision == 'bf16': + convert_weights_to_lp(model, dtype=torch.bfloat16) + + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 (typically for CPU) + if precision == 'fp32': + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model diff --git a/third_party/open_clip/pretrained.py b/third_party/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..87e7e527497d643fdf6ac931ac73b6e887a90d0d --- /dev/null +++ b/third_party/open_clip/pretrained.py @@ -0,0 +1,376 @@ +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Union + +from tqdm import tqdm + +from .version import __version__ + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', mean=None, std=None): + return dict( + url=url, + hf_hub=hf_hub, + mean=mean, + std=std, + ) + + +_RN50 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN50_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + cc12m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), +) + +_RN101 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN101_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + yfcc15m=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), +) + +_RN50x4 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), +) + +_RN50x16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), +) + +_RN50x64 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), +) + +_VITB32 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + laion2b_e16=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') +) + +_VITB32_quickgelu = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), +) + +_VITB16 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + # laion400m_32k=_pcfg( + # url="", + # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # laion400m_64k=_pcfg( + # url="", + # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), +) + +_VITL14 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + laion400m_e31=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + laion400m_e32=_pcfg( + "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), +) + +_VITL14_336 = dict( + openai=_pcfg( + "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), +) + +_VITH14 = dict( + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, +} + + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Union[str, None] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def download_pretrained_from_hf( + model_id: str, + filename: str = 'open_clip_pytorch_model.bin', + revision=None, + cache_dir: Union[str, None] = None, +): + has_hf_hub(True) + cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) + return cached_file + + +def download_pretrained( + cfg: Dict, + force_hf_hub: bool = False, + cache_dir: Union[str, None] = None, +): + target = '' + if not cfg: + return target + + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if download_hf_hub and force_hf_hub: + # use HF hub even if url exists + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target diff --git a/third_party/open_clip/push_to_hf_hub.py b/third_party/open_clip/push_to_hf_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..23c0631c81dcb43829b7374fac09406ecefcb436 --- /dev/null +++ b/third_party/open_clip/push_to_hf_hub.py @@ -0,0 +1,243 @@ +import argparse +import json +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple + +import torch + +try: + from huggingface_hub import ( + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + ) + from huggingface_hub.utils import EntryNotFoundError + _has_hf_hub = True +except ImportError: + _has_hf_hub = False + +from .factory import create_model_from_pretrained, get_model_config, get_tokenizer +from .tokenizer import HFTokenizer + + +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] +): + preprocess_cfg = { + 'mean': model.visual.image_mean, + 'std': model.visual.image_std, + } + hf_config = { + 'model_cfg': model_config, + 'preprocess_cfg': preprocess_cfg, + } + + with config_path.open('w') as f: + json.dump(hf_config, f, indent=2) + + +def save_for_hf( + model, + tokenizer: HFTokenizer, + model_config: dict, + save_directory: str, + weights_filename='open_clip_pytorch_model.bin', + config_filename='open_clip_config.json', +): + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / weights_filename + torch.save(model.state_dict(), weights_path) + + tokenizer.save_pretrained(save_directory) + + config_path = save_directory / config_filename + save_config_for_hf(model, config_path, model_config=model_config) + + +def push_to_hf_hub( + model, + tokenizer, + model_config: Optional[dict], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + if not isinstance(tokenizer, HFTokenizer): + # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 + tokenizer = HFTokenizer('openai/clip-vit-large-patch14') + + # Create repo if it doesn't exist yet + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) + + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if README file already exist in repo + try: + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf( + model, + tokenizer=tokenizer, + model_config=model_config, + save_directory=tmpdir, + ) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def push_pretrained_to_hf_hub( + model_name, + pretrained: str, + repo_id: str, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_card: Optional[dict] = None, +): + model, preprocess_eval = create_model_from_pretrained( + model_name, + pretrained=pretrained, + image_mean=image_mean, + image_std=image_std, + ) + + model_config = get_model_config(model_name) + assert model_config + + tokenizer = get_tokenizer(model_name) + + push_to_hf_hub( + model=model, + tokenizer=tokenizer, + model_config=model_config, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + ) + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" + readme_text += "library_tag: open_clip\n" + readme_text += f"license: {model_card.get('license', 'mit')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + + return readme_text + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") + parser.add_argument( + "--model", type=str, help="Name of the model to use.", + ) + parser.add_argument( + "--pretrained", type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--repo-id", type=str, + help="Destination HF Hub repo-id ie 'organization/model_id'.", + ) + parser.add_argument( + '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override default image mean value of dataset') + parser.add_argument( + '--image-std', type=float, nargs='+', default=None, metavar='STD', + help='Override default image std deviation of of dataset') + args = parser.parse_args() + + print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') + + # FIXME add support to pass model_card json / template from file via cmd line + + push_pretrained_to_hf_hub( + args.model, + args.pretrained, + args.repo_id, + image_mean=args.image_mean, # override image mean/std if trained w/ non defaults + image_std=args.image_std, + ) + + print(f'{args.model} saved.') diff --git a/third_party/open_clip/timm_model.py b/third_party/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dc71a693f9a42ec01fd88d307661bc382b4d05bc --- /dev/null +++ b/third_party/open_clip/timm_model.py @@ -0,0 +1,127 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + try: + # old timm imports < 0.8.1 + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d + except ImportError: + # new timm imports >= 0.8.1 + from timm.layers import RotAttentionPool2d + from timm.layers import AttentionPool2d as AbsAttentionPool2d +except ImportError: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + proj_bias=False, + drop=0., + drop_path=None, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + timm_kwargs = {} + if drop_path is not None: + timm_kwargs['drop_path_rate'] = drop_path + self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if pool in ('abs_attn', 'rot_attn'): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, 'projection layer needed if non-attention pooling is used.' + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + try: + self.trunk.set_grad_checkpointing(enable) + except Exception as e: + logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/third_party/open_clip/tokenizer.py b/third_party/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/third_party/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/third_party/open_clip/transform.py b/third_party/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4e21fa5b515f2412049f9274bd06fbe77fb9b9 --- /dev/null +++ b/third_party/open_clip/transform.py @@ -0,0 +1,216 @@ +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +from functools import partial +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + +from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None + interpolation: Optional[str] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + +class ResizeMaxSize(nn.Module): + + def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == 'min' else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[1:] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) + return img + + +def _convert_to_rgb_or_rgba(image): + if image.mode == 'RGBA': + return image + else: + return image.convert('RGB') + +# def transform_and_split(merged, transform_fn, normalize_fn): +# transformed = transform_fn(merged) +# crop_img, crop_label = torch.split(transformed, [3,1], dim=0) + +# # crop_img = _convert_to_rgb(crop_img) +# crop_img = normalize_fn(ToTensor()(crop_img)) +# return crop_img, crop_label + +class MaskAwareNormalize(nn.Module): + def __init__(self, mean, std): + super().__init__() + self.normalize = Normalize(mean=mean, std=std) + + def forward(self, tensor): + if tensor.shape[0] == 4: + return torch.cat([self.normalize(tensor[:3]), tensor[3:]], dim=0) + else: + return self.normalize(tensor) + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + normalize = MaskAwareNormalize(mean=mean, std=std) + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + assert False, "not tested for augmentation with mask" + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time + aug_cfg_dict.setdefault('interpolation', 'random') + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + **aug_cfg_dict, + ) + else: + train_transform = Compose([ + _convert_to_rgb_or_rgba, + ToTensor(), + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + normalize, + ]) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + transforms = [ + _convert_to_rgb_or_rgba, + ToTensor(), + ] + if resize_longest_max: + transforms.extend([ + ResizeMaxSize(image_size, fill=fill_color) + ]) + else: + transforms.extend([ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ]) + transforms.extend([ + normalize, + ]) + return Compose(transforms) + + +# def image_transform_region( +# image_size: int, +# is_train: bool, +# mean: Optional[Tuple[float, ...]] = None, +# std: Optional[Tuple[float, ...]] = None, +# resize_longest_max: bool = False, +# fill_color: int = 0, +# aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +# ): +# mean = mean or OPENAI_DATASET_MEAN +# if not isinstance(mean, (list, tuple)): +# mean = (mean,) * 3 + +# std = std or OPENAI_DATASET_STD +# if not isinstance(std, (list, tuple)): +# std = (std,) * 3 + +# if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: +# # for square size, pass size as int so that Resize() uses aspect preserving shortest edge +# image_size = image_size[0] + +# if isinstance(aug_cfg, dict): +# aug_cfg = AugmentationCfg(**aug_cfg) +# else: +# aug_cfg = aug_cfg or AugmentationCfg() +# normalize = Normalize(mean=mean, std=std) +# if is_train: +# aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + +# transform = Compose([ +# RandomResizedCrop( +# image_size, +# scale=aug_cfg_dict.pop('scale'), +# interpolation=InterpolationMode.BICUBIC, +# ), +# ]) +# train_transform = Compose([ +# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize) +# ]) +# return train_transform +# else: +# if resize_longest_max: +# transform = [ +# ResizeMaxSize(image_size, fill=fill_color) +# ] +# val_transform = Compose([ +# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize), +# ]) +# else: +# transform = [ +# Resize(image_size, interpolation=InterpolationMode.BICUBIC), +# CenterCrop(image_size), +# ] +# val_transform = Compose([ +# partial(transform_and_split, transform_fn=transform,normalize_fn=normalize), +# ]) +# return val_transform \ No newline at end of file diff --git a/third_party/open_clip/transformer.py b/third_party/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7465c1b20bf388a17e0f4f80f7b8eee3b564af92 --- /dev/null +++ b/third_party/open_clip/transformer.py @@ -0,0 +1,727 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +from .utils import to_2tuple + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor, skip_pool: bool = False): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if skip_pool: + return x + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/third_party/open_clip/utils.py b/third_party/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673 --- /dev/null +++ b/third_party/open_clip/utils.py @@ -0,0 +1,60 @@ +from itertools import repeat +import collections.abc + +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) diff --git a/third_party/open_clip/version.py b/third_party/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..48aa744fb053599044caf0253b889b5cfe5b78e7 --- /dev/null +++ b/third_party/open_clip/version.py @@ -0,0 +1 @@ +__version__ = '2.16.0' diff --git a/third_party/score_models/HPS_v2_compressed.pt b/third_party/score_models/HPS_v2_compressed.pt new file mode 100644 index 0000000000000000000000000000000000000000..5939b86baafdad30c8a8972d372d368b9520f4b7 --- /dev/null +++ b/third_party/score_models/HPS_v2_compressed.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07cb0d20626232204ca3a6ec53ff5984e7e53212a8a7b0e0ec2bac1bdaf125a1 +size 1972484501 diff --git a/third_party/score_models/sac+logos+ava1-l14-linearMSE.pth b/third_party/score_models/sac+logos+ava1-l14-linearMSE.pth new file mode 100644 index 0000000000000000000000000000000000000000..aae5780851125baf1a30834c3a715d3866858a4d --- /dev/null +++ b/third_party/score_models/sac+logos+ava1-l14-linearMSE.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21dd590f3ccdc646f0d53120778b296013b096a035a2718c9cb0d511bff0f1e0 +size 3714759 diff --git a/third_party/src/flux_ch/__init__.py b/third_party/src/flux_ch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43c365a49d6980e88acba10ef3069f110a59644a --- /dev/null +++ b/third_party/src/flux_ch/__init__.py @@ -0,0 +1,11 @@ +try: + from ._version import version as __version__ # type: ignore + from ._version import version_tuple +except ImportError: + __version__ = "unknown (no version information available)" + version_tuple = (0, 0, "unknown", "noinfo") + +from pathlib import Path + +PACKAGE = __package__.replace("_", "-") +PACKAGE_ROOT = Path(__file__).parent diff --git a/third_party/src/flux_ch/__main__.py b/third_party/src/flux_ch/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cf0fd2444d4cda4053fa74dad3371556b886e5 --- /dev/null +++ b/third_party/src/flux_ch/__main__.py @@ -0,0 +1,4 @@ +from .cli import app + +if __name__ == "__main__": + app() diff --git a/third_party/src/flux_ch/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f8ec9e59923d28eb3123e9ee57aac4f536bccbc Binary files /dev/null and b/third_party/src/flux_ch/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/__pycache__/controlnet.cpython-310.pyc b/third_party/src/flux_ch/__pycache__/controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42fb692ae30c36766b585bb9e55ceee06b05e94e Binary files /dev/null and b/third_party/src/flux_ch/__pycache__/controlnet.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/__pycache__/math.cpython-310.pyc b/third_party/src/flux_ch/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8118ee1f0846a931117c935f46b8a43775107fa Binary files /dev/null and b/third_party/src/flux_ch/__pycache__/math.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/__pycache__/model.cpython-310.pyc b/third_party/src/flux_ch/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5526dc9727c00946981c87d7f66b1b0dd9cef27a Binary files /dev/null and b/third_party/src/flux_ch/__pycache__/model.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/__pycache__/sampling.cpython-310.pyc b/third_party/src/flux_ch/__pycache__/sampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb1c8617cc9af05bbf4b4d01d2c84d06d95e49b7 Binary files /dev/null and b/third_party/src/flux_ch/__pycache__/sampling.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/__pycache__/util.cpython-310.pyc b/third_party/src/flux_ch/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b48206b669032eaebe11c6c80432fd02a9cc87aa Binary files /dev/null and b/third_party/src/flux_ch/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/__pycache__/xflux_pipeline.cpython-310.pyc b/third_party/src/flux_ch/__pycache__/xflux_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a271bfeae8a8d70a2fae760822bf3831c55eb1be Binary files /dev/null and b/third_party/src/flux_ch/__pycache__/xflux_pipeline.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/__pycache__/util.cpython-310.pyc b/third_party/src/flux_ch/annotator/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..492b9259de00df19be781cd252f7586cff99c30d Binary files /dev/null and b/third_party/src/flux_ch/annotator/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/canny/__init__.py b/third_party/src/flux_ch/annotator/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b --- /dev/null +++ b/third_party/src/flux_ch/annotator/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/third_party/src/flux_ch/annotator/canny/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/canny/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3114fb9f40b5ee2802cb79a37abc045b8af7b12 Binary files /dev/null and b/third_party/src/flux_ch/annotator/canny/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/ckpts/ckpts.txt b/third_party/src/flux_ch/annotator/ckpts/ckpts.txt new file mode 100644 index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b --- /dev/null +++ b/third_party/src/flux_ch/annotator/ckpts/ckpts.txt @@ -0,0 +1 @@ +Weights here. \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/dwpose/__init__.py b/third_party/src/flux_ch/annotator/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6e172d05c9de3f1cdd61e330ad8d6dde46dfdd --- /dev/null +++ b/third_party/src/flux_ch/annotator/dwpose/__init__.py @@ -0,0 +1,68 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .wholebody import Wholebody + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + canvas = util.draw_bodypose(canvas, candidate, subset) + + canvas = util.draw_handpose(canvas, hands) + + canvas = util.draw_facepose(canvas, faces) + + return canvas + + +class DWposeDetector: + def __init__(self, device): + + self.pose_estimation = Wholebody(device) + + def __call__(self, oriImg): + oriImg = oriImg.copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.pose_estimation(oriImg) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + return draw_pose(pose, H, W) diff --git a/third_party/src/flux_ch/annotator/dwpose/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/dwpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dec435d916837442480ddd3f8dc1865ba5f4d22b Binary files /dev/null and b/third_party/src/flux_ch/annotator/dwpose/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc b/third_party/src/flux_ch/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85ab8070e80878df80428d1c06f9c46b6fcf4c14 Binary files /dev/null and b/third_party/src/flux_ch/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc b/third_party/src/flux_ch/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e3059ca8d4e27e050f9524218f3ddb266cab05e Binary files /dev/null and b/third_party/src/flux_ch/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/dwpose/__pycache__/util.cpython-310.pyc b/third_party/src/flux_ch/annotator/dwpose/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c64fa6bbb8a3b155fa9345f7c349d5fa0070cc0 Binary files /dev/null and b/third_party/src/flux_ch/annotator/dwpose/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc b/third_party/src/flux_ch/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76cb4c55129df85d673572b9ed15900b415365e1 Binary files /dev/null and b/third_party/src/flux_ch/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/dwpose/onnxdet.py b/third_party/src/flux_ch/annotator/dwpose/onnxdet.py new file mode 100644 index 0000000000000000000000000000000000000000..e0411c96a5eef41e981bde5481ef7786b242f1fa --- /dev/null +++ b/third_party/src/flux_ch/annotator/dwpose/onnxdet.py @@ -0,0 +1,125 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + output = session.run(None, ort_inputs) + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/third_party/src/flux_ch/annotator/dwpose/onnxpose.py b/third_party/src/flux_ch/annotator/dwpose/onnxpose.py new file mode 100644 index 0000000000000000000000000000000000000000..79cd4a06241123af81ea22446a4ca8816716443f --- /dev/null +++ b/third_party/src/flux_ch/annotator/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/dwpose/util.py b/third_party/src/flux_ch/annotator/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..73d7d0153b38d143eb8090e07a9784a274b619ed --- /dev/null +++ b/third_party/src/flux_ch/annotator/dwpose/util.py @@ -0,0 +1,297 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/third_party/src/flux_ch/annotator/dwpose/wholebody.py b/third_party/src/flux_ch/annotator/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..d73f19d61c238c47cf7de98d01385b2150a5361f --- /dev/null +++ b/third_party/src/flux_ch/annotator/dwpose/wholebody.py @@ -0,0 +1,48 @@ +import cv2 +import numpy as np + +import onnxruntime as ort +from huggingface_hub import hf_hub_download +from .onnxdet import inference_detector +from .onnxpose import inference_pose + + +class Wholebody: + def __init__(self, device="cuda:0"): + providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] + onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx") + onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx") + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + diff --git a/third_party/src/flux_ch/annotator/hed/__init__.py b/third_party/src/flux_ch/annotator/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70d11c9e62133149d38091a597a1b6691ff8f1b6 --- /dev/null +++ b/third_party/src/flux_ch/annotator/hed/__init__.py @@ -0,0 +1,95 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import cv2 +import torch +import numpy as np + +from huggingface_hub import hf_hub_download +from einops import rearrange +from ...annotator.util import annotator_ckpts_path + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + + +class HEDdetector: + def __init__(self): + modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") + if not os.path.exists(modelpath): + modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth") + self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() + self.netNetwork.load_state_dict(torch.load(modelpath)) + + def __call__(self, input_image): + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image.copy()).float().cuda() + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + return edge + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z diff --git a/third_party/src/flux_ch/annotator/hed/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/hed/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56a8ab644ce9e74dfcffa2a72ff6ed4b4e4f0d91 Binary files /dev/null and b/third_party/src/flux_ch/annotator/hed/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/LICENSE b/third_party/src/flux_ch/annotator/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/src/flux_ch/annotator/midas/__init__.py b/third_party/src/flux_ch/annotator/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36789767f35bcc169c2cbf096e2747539df4f14d --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/__init__.py @@ -0,0 +1,42 @@ +# Midas Depth Estimation +# From https://github.com/isl-org/MiDaS +# MIT LICENSE + +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self): + self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image, normal_image diff --git a/third_party/src/flux_ch/annotator/midas/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..861a0bc2a88ac7f5732325675985e2066a687469 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/__pycache__/api.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22de81bcaeb7c36fa28d6047ca2672a648710c5f Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/__pycache__/api.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/api.py b/third_party/src/flux_ch/annotator/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..6226a39d80de978162a7238cec1c4d4a64bacbe9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/api.py @@ -0,0 +1,168 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from huggingface_hub import hf_hub_download + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from ...annotator.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt") + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/third_party/src/flux_ch/annotator/midas/midas/__init__.py b/third_party/src/flux_ch/annotator/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..372e21559c6810f7b63eca8416b28c220404fa25 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22b4720fd56b5921fa5855d92f6e02d2f92533c2 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/base_model.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ad239636a93f8b0aaf7e867f2593b9a0d783dc5 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/blocks.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e14e31c0f8b0fc3806f8447fad505d39588a39 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/dpt_depth.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f63b6661494dfba5ffc8ad82885a2eb611199879 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/midas_net.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c12459fd2444965733c0c6f6f315eef815671027 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/midas_net_custom.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52e3d48221cd0ab1baedf21e14cfad755394e20f Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/transforms.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/__pycache__/vit.cpython-310.pyc b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..451f5f59fbf2612069b45d4849f56edbc1e12b07 Binary files /dev/null and b/third_party/src/flux_ch/annotator/midas/midas/__pycache__/vit.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/midas/midas/base_model.py b/third_party/src/flux_ch/annotator/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/third_party/src/flux_ch/annotator/midas/midas/blocks.py b/third_party/src/flux_ch/annotator/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/third_party/src/flux_ch/annotator/midas/midas/dpt_depth.py b/third_party/src/flux_ch/annotator/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/third_party/src/flux_ch/annotator/midas/midas/midas_net.py b/third_party/src/flux_ch/annotator/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/third_party/src/flux_ch/annotator/midas/midas/midas_net_custom.py b/third_party/src/flux_ch/annotator/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/midas/midas/transforms.py b/third_party/src/flux_ch/annotator/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/src/flux_ch/annotator/midas/midas/vit.py b/third_party/src/flux_ch/annotator/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/third_party/src/flux_ch/annotator/midas/utils.py b/third_party/src/flux_ch/annotator/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/third_party/src/flux_ch/annotator/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/third_party/src/flux_ch/annotator/mlsd/LICENSE b/third_party/src/flux_ch/annotator/mlsd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363 --- /dev/null +++ b/third_party/src/flux_ch/annotator/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/mlsd/__init__.py b/third_party/src/flux_ch/annotator/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5028aef051a1e67caae1f0d23a0b0dbca883a7f8 --- /dev/null +++ b/third_party/src/flux_ch/annotator/mlsd/__init__.py @@ -0,0 +1,40 @@ +# MLSD Line Detection +# From https://github.com/navervision/mlsd +# Apache-2.0 license + +import cv2 +import numpy as np +import torch +import os + +from einops import rearrange +from huggingface_hub import hf_hub_download +from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + +from ...annotator.util import annotator_ckpts_path + + +class MLSDdetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "mlsd_large_512_fp32.pth") + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + self.model = model.cuda().eval() + + def __call__(self, input_image, thr_v, thr_d): + assert input_image.ndim == 3 + img = input_image + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + return img_output[:, :, 0] diff --git a/third_party/src/flux_ch/annotator/mlsd/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/mlsd/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e56d3f15e53f62e62def1fa1b54c7e3ba55ff269 Binary files /dev/null and b/third_party/src/flux_ch/annotator/mlsd/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/mlsd/__pycache__/utils.cpython-310.pyc b/third_party/src/flux_ch/annotator/mlsd/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f133e374ca390d0e2304c0b709bcecf5bf7f3ab8 Binary files /dev/null and b/third_party/src/flux_ch/annotator/mlsd/__pycache__/utils.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/mlsd/models/__pycache__/mbv2_mlsd_large.cpython-310.pyc b/third_party/src/flux_ch/annotator/mlsd/models/__pycache__/mbv2_mlsd_large.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd7589678a59463e19dd671bf87eaa99671903ca Binary files /dev/null and b/third_party/src/flux_ch/annotator/mlsd/models/__pycache__/mbv2_mlsd_large.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/mlsd/models/__pycache__/mbv2_mlsd_tiny.cpython-310.pyc b/third_party/src/flux_ch/annotator/mlsd/models/__pycache__/mbv2_mlsd_tiny.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c764f896674a004698e12679ad401b6527ab5d4 Binary files /dev/null and b/third_party/src/flux_ch/annotator/mlsd/models/__pycache__/mbv2_mlsd_tiny.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/mlsd/models/mbv2_mlsd_large.py b/third_party/src/flux_ch/annotator/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603 --- /dev/null +++ b/third_party/src/flux_ch/annotator/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/mlsd/models/mbv2_mlsd_tiny.py b/third_party/src/flux_ch/annotator/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83 --- /dev/null +++ b/third_party/src/flux_ch/annotator/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/mlsd/utils.py b/third_party/src/flux_ch/annotator/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..848a9fd7f9ff9d909f18c5d3ff55786c5a4b547a --- /dev/null +++ b/third_party/src/flux_ch/annotator/mlsd/utils.py @@ -0,0 +1,580 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().to("cuda:4") + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().cuda() + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/third_party/src/flux_ch/annotator/tile/__init__.py b/third_party/src/flux_ch/annotator/tile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c96899289d6e04796140cd7eff9c08e5f693af02 --- /dev/null +++ b/third_party/src/flux_ch/annotator/tile/__init__.py @@ -0,0 +1,26 @@ +import random +import cv2 +from .guided_filter import FastGuidedFilter + + +class TileDetector: + # https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0 + def __init__(self): + pass + + def __call__(self, image): + blur_strength = random.sample([i / 10. for i in range(10, 201, 2)], k=1)[0] + radius = random.sample([i for i in range(1, 40, 2)], k=1)[0] + eps = random.sample([i / 1000. for i in range(1, 101, 2)], k=1)[0] + scale_factor = random.sample([i / 10. for i in range(10, 181, 5)], k=1)[0] + + ksize = int(blur_strength) + if ksize % 2 == 0: + ksize += 1 + + if random.random() > 0.5: + image = cv2.GaussianBlur(image, (ksize, ksize), blur_strength / 2) + if random.random() > 0.5: + filter = FastGuidedFilter(image, radius, eps, scale_factor) + image = filter.filter(image) + return image diff --git a/third_party/src/flux_ch/annotator/tile/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/tile/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfbb2d0000009dc2810a8743c6c4f9cf45b1f2a3 Binary files /dev/null and b/third_party/src/flux_ch/annotator/tile/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/tile/__pycache__/guided_filter.cpython-310.pyc b/third_party/src/flux_ch/annotator/tile/__pycache__/guided_filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26086595060bcc7b6c4888f77d6207c53dc444cb Binary files /dev/null and b/third_party/src/flux_ch/annotator/tile/__pycache__/guided_filter.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/tile/guided_filter.py b/third_party/src/flux_ch/annotator/tile/guided_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7172a5e144672eea26551ef75f70b90a2f96d6 --- /dev/null +++ b/third_party/src/flux_ch/annotator/tile/guided_filter.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +## @package guided_filter.core.filters +# +# Implementation of guided filter. +# * GuidedFilter: Original guided filter. +# * FastGuidedFilter: Fast version of the guided filter. +# @author tody +# @date 2015/08/26 + +import numpy as np +import cv2 + +## Convert image into float32 type. +def to32F(img): + if img.dtype == np.float32: + return img + return (1.0 / 255.0) * np.float32(img) + +## Convert image into uint8 type. +def to8U(img): + if img.dtype == np.uint8: + return img + return np.clip(np.uint8(255.0 * img), 0, 255) + +## Return if the input image is gray or not. +def _isGray(I): + return len(I.shape) == 2 + + +## Return down sampled image. +# @param scale (w/s, h/s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _downSample(I, scale=4, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST) + + +## Return up sampled image. +# @param scale (w*s, h*s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _upSample(I, scale=2, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) + +## Fast guide filter. +class FastGuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + # @param scale Down sampled scale. + def __init__(self, I, radius=5, epsilon=0.4, scale=4): + I_32F = to32F(I) + self._I = I_32F + h, w = I.shape[:2] + + I_sub = _downSample(I_32F, scale) + + self._I_sub = I_sub + radius = int(radius / scale) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + shape_original = p.shape[:2] + + p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2]) + + if _isGray(p_sub): + return self._filterGray(p_sub, shape_original) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original) + return to8U(q) + + def _filterGray(self, p_sub, shape_original): + ab_sub = self._guided_filter._computeCoefficients(p_sub) + ab = [_upSample(abi, shape=shape_original) for abi in ab_sub] + return self._guided_filter._computeOutput(ab, self._I) + + +## Guide filter. +class GuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + I_32F = to32F(I) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return to8U(self._guided_filter.filter(p)) + + +## Common parts of guided filter. +# +# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor. +# Based on guided_filter._computeCoefficients, guided_filter._computeOutput, +# GuidedFilterCommon.filter computes filtered image for color and gray. +class GuidedFilterCommon: + def __init__(self, guided_filter): + self._guided_filter = guided_filter + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + if _isGray(p_32F): + return self._filterGray(p_32F) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_32F[:, :, ci]) + return q + + def _filterGray(self, p): + ab = self._guided_filter._computeCoefficients(p) + return self._guided_filter._computeOutput(ab, self._guided_filter._I) + + +## Guided filter for gray guidance image. +class GuidedFilterGray: + # @param I Input gray guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + self._I_mean = cv2.blur(I, (r, r)) + I_mean_sq = cv2.blur(I ** 2, (r, r)) + self._I_var = I_mean_sq - self._I_mean ** 2 + + def _computeCoefficients(self, p): + r = self._radius + p_mean = cv2.blur(p, (r, r)) + p_cov = p_mean - self._I_mean * p_mean + a = p_cov / (self._I_var + self._epsilon) + b = p_mean - a * self._I_mean + a_mean = cv2.blur(a, (r, r)) + b_mean = cv2.blur(b, (r, r)) + return a_mean, b_mean + + def _computeOutput(self, ab, I): + a_mean, b_mean = ab + return a_mean * I + b_mean + + +## Guided filter for color guidance image. +class GuidedFilterColor: + # @param I Input color guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.2): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + eps = self._epsilon + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + self._Ir_mean = cv2.blur(Ir, (r, r)) + self._Ig_mean = cv2.blur(Ig, (r, r)) + self._Ib_mean = cv2.blur(Ib, (r, r)) + + Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps + Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean + Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean + Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps + Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean + Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps + + Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var + Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var + Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var + Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var + Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var + Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var + + I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var + Irr_inv /= I_cov + Irg_inv /= I_cov + Irb_inv /= I_cov + Igg_inv /= I_cov + Igb_inv /= I_cov + Ibb_inv /= I_cov + + self._Irr_inv = Irr_inv + self._Irg_inv = Irg_inv + self._Irb_inv = Irb_inv + self._Igg_inv = Igg_inv + self._Igb_inv = Igb_inv + self._Ibb_inv = Ibb_inv + + def _computeCoefficients(self, p): + r = self._radius + I = self._I + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + p_mean = cv2.blur(p, (r, r)) + + Ipr_mean = cv2.blur(Ir * p, (r, r)) + Ipg_mean = cv2.blur(Ig * p, (r, r)) + Ipb_mean = cv2.blur(Ib * p, (r, r)) + + Ipr_cov = Ipr_mean - self._Ir_mean * p_mean + Ipg_cov = Ipg_mean - self._Ig_mean * p_mean + Ipb_cov = Ipb_mean - self._Ib_mean * p_mean + + ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov + ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov + ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov + b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean + + ar_mean = cv2.blur(ar, (r, r)) + ag_mean = cv2.blur(ag, (r, r)) + ab_mean = cv2.blur(ab, (r, r)) + b_mean = cv2.blur(b, (r, r)) + + return ar_mean, ag_mean, ab_mean, b_mean + + def _computeOutput(self, ab, I): + ar_mean, ag_mean, ab_mean, b_mean = ab + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + q = (ar_mean * Ir + + ag_mean * Ig + + ab_mean * Ib + + b_mean) + + return q diff --git a/third_party/src/flux_ch/annotator/util.py b/third_party/src/flux_ch/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05 --- /dev/null +++ b/third_party/src/flux_ch/annotator/util.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/third_party/src/flux_ch/annotator/zoe/LICENSE b/third_party/src/flux_ch/annotator/zoe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/__init__.py b/third_party/src/flux_ch/annotator/zoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7628090932e35bdd71d041069ae62f6a731f60d4 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/__init__.py @@ -0,0 +1,48 @@ +# ZoeDepth +# https://github.com/isl-org/ZoeDepth + +import os +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.utils.config import get_config +from ...annotator.util import annotator_ckpts_path +from huggingface_hub import hf_hub_download + + +class ZoeDetector: + def __init__(self): + model_path = os.path.join(annotator_ckpts_path, "ZoeD_M12_N.pt") + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "ZoeD_M12_N.pt") + conf = get_config("zoedepth", "infer") + model = ZoeDepth.build_from_config(conf) + model.load_state_dict(torch.load(model_path)['model'], strict=False) + model = model.cuda() + model.device = 'cuda' + model.eval() + self.model = model + + def __call__(self, input_image): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + return depth_image diff --git a/third_party/src/flux_ch/annotator/zoe/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1251c0eb0602a71aa11f49dc6650e1c27ea2d5fa Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/__init__.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/data_mono.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/data_mono.py new file mode 100644 index 0000000000000000000000000000000000000000..80a8486f239a35331df553f490e213f9bf71e735 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/data_mono.py @@ -0,0 +1,573 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee + +import itertools +import os +import random + +import numpy as np +import cv2 +import torch +import torch.nn as nn +import torch.utils.data.distributed +from zoedepth.utils.easydict import EasyDict as edict +from PIL import Image, ImageOps +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + +from zoedepth.utils.config import change_dataset + +from .ddad import get_ddad_loader +from .diml_indoor_test import get_diml_indoor_loader +from .diml_outdoor_test import get_diml_outdoor_loader +from .diode import get_diode_loader +from .hypersim import get_hypersim_loader +from .ibims import get_ibims_loader +from .sun_rgbd_loader import get_sunrgbd_loader +from .vkitti import get_vkitti_loader +from .vkitti2 import get_vkitti2_loader + +from .preprocess import CropParams, get_white_border, get_black_border + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) + + +def preprocessing_transforms(mode, **kwargs): + return transforms.Compose([ + ToTensor(mode=mode, **kwargs) + ]) + + +class DepthDataLoader(object): + def __init__(self, config, mode, device='cpu', transform=None, **kwargs): + """ + Data loader for depth datasets + + Args: + config (dict): Config dictionary. Refer to utils/config.py + mode (str): "train" or "online_eval" + device (str, optional): Device to load the data on. Defaults to 'cpu'. + transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None. + """ + + self.config = config + + if config.dataset == 'ibims': + self.data = get_ibims_loader(config, batch_size=1, num_workers=1) + return + + if config.dataset == 'sunrgbd': + self.data = get_sunrgbd_loader( + data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_indoor': + self.data = get_diml_indoor_loader( + data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'diml_outdoor': + self.data = get_diml_outdoor_loader( + data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1) + return + + if "diode" in config.dataset: + self.data = get_diode_loader( + config[config.dataset+"_root"], batch_size=1, num_workers=1) + return + + if config.dataset == 'hypersim_test': + self.data = get_hypersim_loader( + config.hypersim_test_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti': + self.data = get_vkitti_loader( + config.vkitti_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'vkitti2': + self.data = get_vkitti2_loader( + config.vkitti2_root, batch_size=1, num_workers=1) + return + + if config.dataset == 'ddad': + self.data = get_ddad_loader(config.ddad_root, resize_shape=( + 352, 1216), batch_size=1, num_workers=1) + return + + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + + if transform is None: + transform = preprocessing_transforms(mode, size=img_size) + + if mode == 'train': + + Dataset = DataLoadPreprocess + self.training_samples = Dataset( + config, mode, transform=transform, device=device) + + if config.distributed: + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_samples) + else: + self.train_sampler = None + + self.data = DataLoader(self.training_samples, + batch_size=config.batch_size, + shuffle=(self.train_sampler is None), + num_workers=config.workers, + pin_memory=True, + persistent_workers=True, + # prefetch_factor=2, + sampler=self.train_sampler) + + elif mode == 'online_eval': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + if config.distributed: # redundant. here only for readability and to be more explicit + # Give whole test set to all processes (and report evaluation only on one) regardless + self.eval_sampler = None + else: + self.eval_sampler = None + self.data = DataLoader(self.testing_samples, 1, + shuffle=kwargs.get("shuffle_test", False), + num_workers=1, + pin_memory=False, + sampler=self.eval_sampler) + + elif mode == 'test': + self.testing_samples = DataLoadPreprocess( + config, mode, transform=transform) + self.data = DataLoader(self.testing_samples, + 1, shuffle=False, num_workers=1) + + else: + print( + 'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) + + +def repetitive_roundrobin(*iterables): + """ + cycles through iterables but sample wise + first yield first sample from first iterable then first sample from second iterable and so on + then second sample from first iterable then second sample from second iterable and so on + + If one iterable is shorter than the others, it is repeated until all iterables are exhausted + repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E + """ + # Repetitive roundrobin + iterables_ = [iter(it) for it in iterables] + exhausted = [False] * len(iterables) + while not all(exhausted): + for i, it in enumerate(iterables_): + try: + yield next(it) + except StopIteration: + exhausted[i] = True + iterables_[i] = itertools.cycle(iterables[i]) + # First elements may get repeated if one iterable is shorter than the others + yield next(iterables_[i]) + + +class RepetitiveRoundRobinDataLoader(object): + def __init__(self, *dataloaders): + self.dataloaders = dataloaders + + def __iter__(self): + return repetitive_roundrobin(*self.dataloaders) + + def __len__(self): + # First samples get repeated, thats why the plus one + return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1) + + +class MixedNYUKITTI(object): + def __init__(self, config, mode, device='cpu', **kwargs): + config = edict(config) + config.workers = config.workers // 2 + self.config = config + nyu_conf = change_dataset(edict(config), 'nyu') + kitti_conf = change_dataset(edict(config), 'kitti') + + # make nyu default for testing + self.config = config = nyu_conf + img_size = self.config.get("img_size", None) + img_size = img_size if self.config.get( + "do_input_resize", False) else None + if mode == 'train': + nyu_loader = DepthDataLoader( + nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + kitti_loader = DepthDataLoader( + kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data + # It has been changed to repetitive roundrobin + self.data = RepetitiveRoundRobinDataLoader( + nyu_loader, kitti_loader) + else: + self.data = DepthDataLoader(nyu_conf, mode, device=device).data + + +def remove_leading_slash(s): + if s[0] == '/' or s[0] == '\\': + return s[1:] + return s + + +class CachedReader: + def __init__(self, shared_dict=None): + if shared_dict: + self._cache = shared_dict + else: + self._cache = {} + + def open(self, fpath): + im = self._cache.get(fpath, None) + if im is None: + im = self._cache[fpath] = Image.open(fpath) + return im + + +class ImReader: + def __init__(self): + pass + + # @cache + def open(self, fpath): + return Image.open(fpath) + + +class DataLoadPreprocess(Dataset): + def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs): + self.config = config + if mode == 'online_eval': + with open(config.filenames_file_eval, 'r') as f: + self.filenames = f.readlines() + else: + with open(config.filenames_file, 'r') as f: + self.filenames = f.readlines() + + self.mode = mode + self.transform = transform + self.to_tensor = ToTensor(mode) + self.is_for_online_eval = is_for_online_eval + if config.use_shared_dict: + self.reader = CachedReader(config.shared_dict) + else: + self.reader = ImReader() + + def postprocess(self, sample): + return sample + + def __getitem__(self, idx): + sample_path = self.filenames[idx] + focal = float(sample_path.split()[2]) + sample = {} + + if self.mode == 'train': + if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[3])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[4])) + else: + image_path = os.path.join( + self.config.data_path, remove_leading_slash(sample_path.split()[0])) + depth_path = os.path.join( + self.config.gt_path, remove_leading_slash(sample_path.split()[1])) + + image = self.reader.open(image_path) + depth_gt = self.reader.open(depth_path) + w, h = image.size + + if self.config.do_kb_crop: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth_gt = depth_gt.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + + # Avoid blank boundaries due to pixel registration? + # Train images have white border. Test images have black border. + if self.config.dataset == 'nyu' and self.config.avoid_boundary: + # print("Avoiding Blank Boundaries!") + # We just crop and pad again with reflect padding to original size + # original_size = image.size + crop_params = get_white_border(np.array(image, dtype=np.uint8)) + image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) + + # Use reflect padding to fill the blank + image = np.array(image) + image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect') + image = Image.fromarray(image) + + depth_gt = np.array(depth_gt) + depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0) + depth_gt = Image.fromarray(depth_gt) + + + if self.config.do_random_rotate and (self.config.aug): + random_angle = (random.random() - 0.5) * 2 * self.config.degree + image = self.rotate_image(image, random_angle) + depth_gt = self.rotate_image( + depth_gt, random_angle, flag=Image.NEAREST) + + image = np.asarray(image, dtype=np.float32) / 255.0 + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + if self.config.aug and (self.config.random_crop): + image, depth_gt = self.random_crop( + image, depth_gt, self.config.input_height, self.config.input_width) + + if self.config.aug and self.config.random_translate: + # print("Random Translation!") + image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation) + + image, depth_gt = self.train_preprocess(image, depth_gt) + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample = {'image': image, 'depth': depth_gt, 'focal': focal, + 'mask': mask, **sample} + + else: + if self.mode == 'online_eval': + data_path = self.config.data_path_eval + else: + data_path = self.config.data_path + + image_path = os.path.join( + data_path, remove_leading_slash(sample_path.split()[0])) + image = np.asarray(self.reader.open(image_path), + dtype=np.float32) / 255.0 + + if self.mode == 'online_eval': + gt_path = self.config.gt_path_eval + depth_path = os.path.join( + gt_path, remove_leading_slash(sample_path.split()[1])) + has_valid_depth = False + try: + depth_gt = self.reader.open(depth_path) + has_valid_depth = True + except IOError: + depth_gt = False + # print('Missing gt for {}'.format(image_path)) + + if has_valid_depth: + depth_gt = np.asarray(depth_gt, dtype=np.float32) + depth_gt = np.expand_dims(depth_gt, axis=2) + if self.config.dataset == 'nyu': + depth_gt = depth_gt / 1000.0 + else: + depth_gt = depth_gt / 256.0 + + mask = np.logical_and( + depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...] + else: + mask = False + + if self.config.do_kb_crop: + height = image.shape[0] + width = image.shape[1] + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + image = image[top_margin:top_margin + 352, + left_margin:left_margin + 1216, :] + if self.mode == 'online_eval' and has_valid_depth: + depth_gt = depth_gt[top_margin:top_margin + + 352, left_margin:left_margin + 1216, :] + + if self.mode == 'online_eval': + sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1], + 'mask': mask} + else: + sample = {'image': image, 'focal': focal} + + if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']): + mask = np.logical_and(depth_gt > self.config.min_depth, + depth_gt < self.config.max_depth).squeeze()[None, ...] + sample['mask'] = mask + + if self.transform: + sample = self.transform(sample) + + sample = self.postprocess(sample) + sample['dataset'] = self.config.dataset + sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]} + + return sample + + def rotate_image(self, image, angle, flag=Image.BILINEAR): + result = image.rotate(angle, resample=flag) + return result + + def random_crop(self, img, depth, height, width): + assert img.shape[0] >= height + assert img.shape[1] >= width + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + x = random.randint(0, img.shape[1] - width) + y = random.randint(0, img.shape[0] - height) + img = img[y:y + height, x:x + width, :] + depth = depth[y:y + height, x:x + width, :] + + return img, depth + + def random_translate(self, img, depth, max_t=20): + assert img.shape[0] == depth.shape[0] + assert img.shape[1] == depth.shape[1] + p = self.config.translate_prob + do_translate = random.random() + if do_translate > p: + return img, depth + x = random.randint(-max_t, max_t) + y = random.randint(-max_t, max_t) + M = np.float32([[1, 0, x], [0, 1, y]]) + # print(img.shape, depth.shape) + img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) + depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0])) + depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it + # print("after", img.shape, depth.shape) + return img, depth + + def train_preprocess(self, image, depth_gt): + if self.config.aug: + # Random flipping + do_flip = random.random() + if do_flip > 0.5: + image = (image[:, ::-1, :]).copy() + depth_gt = (depth_gt[:, ::-1, :]).copy() + + # Random gamma, brightness, color augmentation + do_augment = random.random() + if do_augment > 0.5: + image = self.augment_image(image) + + return image, depth_gt + + def augment_image(self, image): + # gamma augmentation + gamma = random.uniform(0.9, 1.1) + image_aug = image ** gamma + + # brightness augmentation + if self.config.dataset == 'nyu': + brightness = random.uniform(0.75, 1.25) + else: + brightness = random.uniform(0.9, 1.1) + image_aug = image_aug * brightness + + # color augmentation + colors = np.random.uniform(0.9, 1.1, size=3) + white = np.ones((image.shape[0], image.shape[1])) + color_image = np.stack([white * colors[i] for i in range(3)], axis=2) + image_aug *= color_image + image_aug = np.clip(image_aug, 0, 1) + + return image_aug + + def __len__(self): + return len(self.filenames) + + +class ToTensor(object): + def __init__(self, mode, do_normalize=False, size=None): + self.mode = mode + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() + self.size = size + if size is not None: + self.resize = transforms.Resize(size=size) + else: + self.resize = nn.Identity() + + def __call__(self, sample): + image, focal = sample['image'], sample['focal'] + image = self.to_tensor(image) + image = self.normalize(image) + image = self.resize(image) + + if self.mode == 'test': + return {'image': image, 'focal': focal} + + depth = sample['depth'] + if self.mode == 'train': + depth = self.to_tensor(depth) + return {**sample, 'image': image, 'depth': depth, 'focal': focal} + else: + has_valid_depth = sample['has_valid_depth'] + image = self.resize(image) + return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth, + 'image_path': sample['image_path'], 'depth_path': sample['depth_path']} + + def to_tensor(self, pic): + if not (_is_pil_image(pic) or _is_numpy_image(pic)): + raise TypeError( + 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/ddad.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/ddad.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd0492bdec767685d3a21992b4a26e62d002d97 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/ddad.py @@ -0,0 +1,117 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self, resize_shape): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(resize_shape) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "ddad"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DDAD(Dataset): + def __init__(self, data_dir_root, resize_shape): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join(data_dir_root, '*.png')) + self.depth_files = [r.replace("_rgb.png", "_depth.npy") + for r in self.image_files] + self.transform = ToTensor(resize_shape) + + def __getitem__(self, idx): + + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_ddad_loader(data_dir_root, resize_shape, batch_size=1, **kwargs): + dataset = DDAD(data_dir_root, resize_shape) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diml_indoor_test.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diml_indoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f720ad9aefaee78ef4ec363dfef0f82ace850a6d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diml_indoor_test.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diml_indoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Indoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "LR", '*', 'color', '*.png')) + self.depth_files = [r.replace("color", "depth_filled").replace( + "_c.png", "_depth_filled.png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diml_indoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Indoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/HR") +# get_diml_indoor_loader(data_dir_root="datasets/diml/indoor/test/LR") diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diml_outdoor_test.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diml_outdoor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8670b48f5febafb819dac22848ad79ccb5dd5ae4 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diml_outdoor_test.py @@ -0,0 +1,114 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "diml_outdoor"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIML_Outdoor(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /{outleft, depthmap}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "*", 'outleft', '*.png')) + self.depth_files = [r.replace("outleft", "depthmap") + for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype='uint16') / 1000.0 # mm to meters + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth, dataset="diml_outdoor") + + # return sample + return self.transform(sample) + + def __len__(self): + return len(self.image_files) + + +def get_diml_outdoor_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIML_Outdoor(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/HR") +# get_diml_outdoor_loader(data_dir_root="datasets/diml/outdoor/test/LR") diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diode.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diode.py new file mode 100644 index 0000000000000000000000000000000000000000..1510c87116b8f70ce2e1428873a8e4da042bee23 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/diode.py @@ -0,0 +1,125 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + self.resize = transforms.Resize(480) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "diode"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class DIODE(Dataset): + def __init__(self, data_dir_root): + import glob + + # image paths are of the form /scene_#/scan_#/*.png + self.image_files = glob.glob( + os.path.join(data_dir_root, '*', '*', '*.png')) + self.depth_files = [r.replace(".png", "_depth.npy") + for r in self.image_files] + self.depth_mask_files = [ + r.replace(".png", "_depth_mask.npy") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + depth_mask_path = self.depth_mask_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.load(depth_path) # in meters + valid = np.load(depth_mask_path) # binary + + # depth[depth > 8] = -1 + # depth = depth[..., None] + + sample = dict(image=image, depth=depth, valid=valid) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_diode_loader(data_dir_root, batch_size=1, **kwargs): + dataset = DIODE(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + +# get_diode_loader(data_dir_root="datasets/diode/val/outdoor") diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/hypersim.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/hypersim.py new file mode 100644 index 0000000000000000000000000000000000000000..4334198971830200f72ea2910d03f4c7d6a43334 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/hypersim.py @@ -0,0 +1,138 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import glob +import os + +import h5py +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +def hypersim_distance_to_depth(npyDistance): + intWidth, intHeight, fltFocal = 1024, 768, 886.81 + + npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape( + 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None] + npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5, + intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None] + npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32) + npyImageplane = np.concatenate( + [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2) + + npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal + return npyDepth + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + self.resize = transforms.Resize((480, 640)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "hypersim"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class HyperSim(Dataset): + def __init__(self, data_dir_root): + # image paths are of the form //images/scene_cam_#_final_preview/*.tonemap.jpg + # depth paths are of the form //images/scene_cam_#_final_preview/*.depth_meters.hdf5 + self.image_files = glob.glob(os.path.join( + data_dir_root, '*', 'images', 'scene_cam_*_final_preview', '*.tonemap.jpg')) + self.depth_files = [r.replace("_final_preview", "_geometry_hdf5").replace( + ".tonemap.jpg", ".depth_meters.hdf5") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + + # depth from hdf5 + depth_fd = h5py.File(depth_path, "r") + # in meters (Euclidean distance) + distance_meters = np.array(depth_fd['dataset']) + depth = hypersim_distance_to_depth( + distance_meters) # in meters (planar depth) + + # depth[depth > 8] = -1 + depth = depth[..., None] + + sample = dict(image=image, depth=depth) + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_hypersim_loader(data_dir_root, batch_size=1, **kwargs): + dataset = HyperSim(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/ibims.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/ibims.py new file mode 100644 index 0000000000000000000000000000000000000000..b66abfabcf4cfc617d4a60ec818780c3548d9920 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/ibims.py @@ -0,0 +1,81 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms as T + + +class iBims(Dataset): + def __init__(self, config): + root_folder = config.ibims_root + with open(os.path.join(root_folder, "imagelist.txt"), 'r') as f: + imglist = f.read().split() + + samples = [] + for basename in imglist: + img_path = os.path.join(root_folder, 'rgb', basename + ".png") + depth_path = os.path.join(root_folder, 'depth', basename + ".png") + valid_mask_path = os.path.join( + root_folder, 'mask_invalid', basename+".png") + transp_mask_path = os.path.join( + root_folder, 'mask_transp', basename+".png") + + samples.append( + (img_path, depth_path, valid_mask_path, transp_mask_path)) + + self.samples = samples + # self.normalize = T.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __getitem__(self, idx): + img_path, depth_path, valid_mask_path, transp_mask_path = self.samples[idx] + + img = np.asarray(Image.open(img_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), + dtype=np.uint16).astype('float')*50.0/65535 + + mask_valid = np.asarray(Image.open(valid_mask_path)) + mask_transp = np.asarray(Image.open(transp_mask_path)) + + # depth = depth * mask_valid * mask_transp + depth = np.where(mask_valid * mask_transp, depth, -1) + + img = torch.from_numpy(img).permute(2, 0, 1) + img = self.normalize(img) + depth = torch.from_numpy(depth).unsqueeze(0) + return dict(image=img, depth=depth, image_path=img_path, depth_path=depth_path, dataset='ibims') + + def __len__(self): + return len(self.samples) + + +def get_ibims_loader(config, batch_size=1, **kwargs): + dataloader = DataLoader(iBims(config), batch_size=batch_size, **kwargs) + return dataloader diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/preprocess.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..e08cc309dc823ae6efd7cda8db9eb37130dc5499 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/preprocess.py @@ -0,0 +1,154 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +from dataclasses import dataclass +from typing import Tuple, List + +# dataclass to store the crop parameters +@dataclass +class CropParams: + top: int + bottom: int + left: int + right: int + + + +def get_border_params(rgb_image, tolerance=0.1, cut_off=20, value=0, level_diff_threshold=5, channel_axis=-1, min_border=5) -> CropParams: + gray_image = np.mean(rgb_image, axis=channel_axis) + h, w = gray_image.shape + + + def num_value_pixels(arr): + return np.sum(np.abs(arr - value) < level_diff_threshold) + + def is_above_tolerance(arr, total_pixels): + return (num_value_pixels(arr) / total_pixels) > tolerance + + # Crop top border until number of value pixels become below tolerance + top = min_border + while is_above_tolerance(gray_image[top, :], w) and top < h-1: + top += 1 + if top > cut_off: + break + + # Crop bottom border until number of value pixels become below tolerance + bottom = h - min_border + while is_above_tolerance(gray_image[bottom, :], w) and bottom > 0: + bottom -= 1 + if h - bottom > cut_off: + break + + # Crop left border until number of value pixels become below tolerance + left = min_border + while is_above_tolerance(gray_image[:, left], h) and left < w-1: + left += 1 + if left > cut_off: + break + + # Crop right border until number of value pixels become below tolerance + right = w - min_border + while is_above_tolerance(gray_image[:, right], h) and right > 0: + right -= 1 + if w - right > cut_off: + break + + + return CropParams(top, bottom, left, right) + + +def get_white_border(rgb_image, value=255, **kwargs) -> CropParams: + """Crops the white border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + Returns: + Crop parameters. + """ + if value == 255: + # assert range of values in rgb image is [0, 255] + assert np.max(rgb_image) <= 255 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 255]." + assert rgb_image.max() > 1, "RGB image values are not in range [0, 255]." + elif value == 1: + # assert range of values in rgb image is [0, 1] + assert np.max(rgb_image) <= 1 and np.min(rgb_image) >= 0, "RGB image values are not in range [0, 1]." + + return get_border_params(rgb_image, value=value, **kwargs) + +def get_black_border(rgb_image, **kwargs) -> CropParams: + """Crops the black border of the RGB. + + Args: + rgb: RGB image, shape (H, W, 3). + + Returns: + Crop parameters. + """ + + return get_border_params(rgb_image, value=0, **kwargs) + +def crop_image(image: np.ndarray, crop_params: CropParams) -> np.ndarray: + """Crops the image according to the crop parameters. + + Args: + image: RGB or depth image, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped image. + """ + return image[crop_params.top:crop_params.bottom, crop_params.left:crop_params.right] + +def crop_images(*images: np.ndarray, crop_params: CropParams) -> Tuple[np.ndarray]: + """Crops the images according to the crop parameters. + + Args: + images: RGB or depth images, shape (H, W, 3) or (H, W). + crop_params: Crop parameters. + + Returns: + Cropped images. + """ + return tuple(crop_image(image, crop_params) for image in images) + +def crop_black_or_white_border(rgb_image, *other_images: np.ndarray, tolerance=0.1, cut_off=20, level_diff_threshold=5) -> Tuple[np.ndarray]: + """Crops the white and black border of the RGB and depth images. + + Args: + rgb: RGB image, shape (H, W, 3). This image is used to determine the border. + other_images: The other images to crop according to the border of the RGB image. + Returns: + Cropped RGB and other images. + """ + # crop black border + crop_params = get_black_border(rgb_image, tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(rgb_image, *other_images, crop_params=crop_params) + + # crop white border + crop_params = get_white_border(cropped_images[0], tolerance=tolerance, cut_off=cut_off, level_diff_threshold=level_diff_threshold) + cropped_images = crop_images(*cropped_images, crop_params=crop_params) + + return cropped_images + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/sun_rgbd_loader.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/sun_rgbd_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2bdb9aefe68ca4439f41eff3bba722c49fb976 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/sun_rgbd_loader.py @@ -0,0 +1,106 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x : x + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + return {'image': image, 'depth': depth, 'dataset': "sunrgbd"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class SunRGBD(Dataset): + def __init__(self, data_dir_root): + # test_file_dirs = loadmat(train_test_file)['alltest'].squeeze() + # all_test = [t[0].replace("/n/fs/sun3d/data/", "") for t in test_file_dirs] + # self.all_test = [os.path.join(data_dir_root, t) for t in all_test] + import glob + self.image_files = glob.glob( + os.path.join(data_dir_root, 'rgb', 'rgb', '*')) + self.depth_files = [ + r.replace("rgb/rgb", "gt/gt").replace("jpg", "png") for r in self.image_files] + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 + depth = np.asarray(Image.open(depth_path), dtype='uint16') / 1000.0 + depth[depth > 8] = -1 + depth = depth[..., None] + return self.transform(dict(image=image, depth=depth)) + + def __len__(self): + return len(self.image_files) + + +def get_sunrgbd_loader(data_dir_root, batch_size=1, **kwargs): + dataset = SunRGBD(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/transforms.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..374416dff24fb4fd55598f3946d6d6b091ddefc9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/transforms.py @@ -0,0 +1,481 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import math +import random + +import cv2 +import numpy as np + + +class RandomFliplr(object): + """Horizontal flip of the sample with given probability. + """ + + def __init__(self, probability=0.5): + """Init. + + Args: + probability (float, optional): Flip probability. Defaults to 0.5. + """ + self.__probability = probability + + def __call__(self, sample): + prob = random.random() + + if prob < self.__probability: + for k, v in sample.items(): + if len(v.shape) >= 2: + sample[k] = np.fliplr(v).copy() + + return sample + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class RandomCrop(object): + """Get a random crop of the sample with the given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_if_needed=False, + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): output width + height (int): output height + resize_if_needed (bool, optional): If True, sample might be upsampled to ensure + that a crop of size (width, height) is possbile. Defaults to False. + """ + self.__size = (height, width) + self.__resize_if_needed = resize_if_needed + self.__image_interpolation_method = image_interpolation_method + + def __call__(self, sample): + + shape = sample["disparity"].shape + + if self.__size[0] > shape[0] or self.__size[1] > shape[1]: + if self.__resize_if_needed: + shape = apply_min_size( + sample, self.__size, self.__image_interpolation_method + ) + else: + raise Exception( + "Output size {} bigger than input size {}.".format( + self.__size, shape + ) + ) + + offset = ( + np.random.randint(shape[0] - self.__size[0] + 1), + np.random.randint(shape[1] - self.__size[1] + 1), + ) + + for k, v in sample.items(): + if k == "code" or k == "basis": + continue + + if len(sample[k].shape) >= 2: + sample[k] = v[ + offset[0]: offset[0] + self.__size[0], + offset[1]: offset[1] + self.__size[1], + ] + + return sample + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + letter_box=False, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + self.__letter_box = letter_box + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def make_letter_box(self, sample): + top = bottom = (self.__height - sample.shape[0]) // 2 + left = right = (self.__width - sample.shape[1]) // 2 + sample = cv2.copyMakeBorder( + sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0) + return sample + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__letter_box: + sample["image"] = self.make_letter_box(sample["image"]) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["disparity"] = self.make_letter_box( + sample["disparity"]) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, + height), interpolation=cv2.INTER_NEAREST + ) + + if self.__letter_box: + sample["depth"] = self.make_letter_box(sample["depth"]) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if self.__letter_box: + sample["mask"] = self.make_letter_box(sample["mask"]) + + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class ResizeFixed(object): + def __init__(self, size): + self.__size = size + + def __call__(self, sample): + sample["image"] = cv2.resize( + sample["image"], self.__size[::-1], interpolation=cv2.INTER_LINEAR + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], self.__size[::- + 1], interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + self.__size[::-1], + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class Rescale(object): + """Rescale target values to the interval [0, max_val]. + If input is constant, values are set to max_val / 2. + """ + + def __init__(self, max_val=1.0, use_mask=True): + """Init. + + Args: + max_val (float, optional): Max output value. Defaults to 1.0. + use_mask (bool, optional): Only operate on valid pixels (mask == True). Defaults to True. + """ + self.__max_val = max_val + self.__use_mask = use_mask + + def __call__(self, sample): + disp = sample["disparity"] + + if self.__use_mask: + mask = sample["mask"] + else: + mask = np.ones_like(disp, dtype=np.bool) + + if np.sum(mask) == 0: + return sample + + min_val = np.min(disp[mask]) + max_val = np.max(disp[mask]) + + if max_val > min_val: + sample["disparity"][mask] = ( + (disp[mask] - min_val) / (max_val - min_val) * self.__max_val + ) + else: + sample["disparity"][mask] = np.ones_like( + disp[mask]) * self.__max_val / 2.0 + + return sample + + +# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class DepthToDisparity(object): + """Convert depth to disparity. Removes depth from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "depth" in sample + + sample["mask"][sample["depth"] < self.__eps] = False + + sample["disparity"] = np.zeros_like(sample["depth"]) + sample["disparity"][sample["depth"] >= self.__eps] = ( + 1.0 / sample["depth"][sample["depth"] >= self.__eps] + ) + + del sample["depth"] + + return sample + + +class DisparityToDepth(object): + """Convert disparity to depth. Removes disparity from sample. + """ + + def __init__(self, eps=1e-4): + self.__eps = eps + + def __call__(self, sample): + assert "disparity" in sample + + disp = np.abs(sample["disparity"]) + sample["mask"][disp < self.__eps] = False + + # print(sample["disparity"]) + # print(sample["mask"].sum()) + # exit() + + sample["depth"] = np.zeros_like(disp) + sample["depth"][disp >= self.__eps] = ( + 1.0 / disp[disp >= self.__eps] + ) + + del sample["disparity"] + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/vkitti.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/vkitti.py new file mode 100644 index 0000000000000000000000000000000000000000..72a2e5a8346f6e630ede0e28d6959725af8d7e72 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/vkitti.py @@ -0,0 +1,151 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import os + +from PIL import Image +import numpy as np +import cv2 + + +class ToTensor(object): + def __init__(self): + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True): + import glob + # image paths are of the form /{HR, LR}//{color, depth_filled}/*.png + self.image_files = glob.glob(os.path.join( + data_dir_root, "test_color", '*.png')) + self.depth_files = [r.replace("test_color", "test_depth") + for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) + print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + # depth[depth > 8] = -1 + + if self.do_kb_crop and False: + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti_test") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/data/vkitti2.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/vkitti2.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcfb0414b7f3f21859f30ae34bd71689516a3e7 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/data/vkitti2.py @@ -0,0 +1,187 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class ToTensor(object): + def __init__(self): + # self.normalize = transforms.Normalize( + # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.normalize = lambda x: x + # self.resize = transforms.Resize((375, 1242)) + + def __call__(self, sample): + image, depth = sample['image'], sample['depth'] + + image = self.to_tensor(image) + image = self.normalize(image) + depth = self.to_tensor(depth) + + # image = self.resize(image) + + return {'image': image, 'depth': depth, 'dataset': "vkitti"} + + def to_tensor(self, pic): + + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic.transpose((2, 0, 1))) + return img + + # # handle PIL Image + if pic.mode == 'I': + img = torch.from_numpy(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = torch.from_numpy(np.array(pic, np.int16, copy=False)) + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + elif pic.mode == 'I;16': + nchannel = 1 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + + img = img.transpose(0, 1).transpose(0, 2).contiguous() + if isinstance(img, torch.ByteTensor): + return img.float() + else: + return img + + +class VKITTI2(Dataset): + def __init__(self, data_dir_root, do_kb_crop=True, split="test"): + import glob + + # image paths are of the form /rgb///frames//Camera<0,1>/rgb_{}.jpg + self.image_files = glob.glob(os.path.join( + data_dir_root, "rgb", "**", "frames", "rgb", "Camera_0", '*.jpg'), recursive=True) + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + self.do_kb_crop = True + self.transform = ToTensor() + + # If train test split is not created, then create one. + # Split is such that 8% of the frames from each scene are used for testing. + if not os.path.exists(os.path.join(data_dir_root, "train.txt")): + import random + scenes = set([os.path.basename(os.path.dirname( + os.path.dirname(os.path.dirname(f)))) for f in self.image_files]) + train_files = [] + test_files = [] + for scene in scenes: + scene_files = [f for f in self.image_files if os.path.basename( + os.path.dirname(os.path.dirname(os.path.dirname(f)))) == scene] + random.shuffle(scene_files) + train_files.extend(scene_files[:int(len(scene_files) * 0.92)]) + test_files.extend(scene_files[int(len(scene_files) * 0.92):]) + with open(os.path.join(data_dir_root, "train.txt"), "w") as f: + f.write("\n".join(train_files)) + with open(os.path.join(data_dir_root, "test.txt"), "w") as f: + f.write("\n".join(test_files)) + + if split == "train": + with open(os.path.join(data_dir_root, "train.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + elif split == "test": + with open(os.path.join(data_dir_root, "test.txt"), "r") as f: + self.image_files = f.read().splitlines() + self.depth_files = [r.replace("/rgb/", "/depth/").replace( + "rgb_", "depth_").replace(".jpg", ".png") for r in self.image_files] + + def __getitem__(self, idx): + image_path = self.image_files[idx] + depth_path = self.depth_files[idx] + + image = Image.open(image_path) + # depth = Image.open(depth_path) + depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | + cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m + depth = Image.fromarray(depth) + # print("dpeth min max", depth.min(), depth.max()) + + # print(np.shape(image)) + # print(np.shape(depth)) + + if self.do_kb_crop: + if idx == 0: + print("Using KB input crop") + height = image.height + width = image.width + top_margin = int(height - 352) + left_margin = int((width - 1216) / 2) + depth = depth.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + image = image.crop( + (left_margin, top_margin, left_margin + 1216, top_margin + 352)) + # uv = uv[:, top_margin:top_margin + 352, left_margin:left_margin + 1216] + + image = np.asarray(image, dtype=np.float32) / 255.0 + # depth = np.asarray(depth, dtype=np.uint16) /1. + depth = np.asarray(depth, dtype=np.float32) / 1. + depth[depth > 80] = -1 + + depth = depth[..., None] + sample = dict(image=image, depth=depth) + + # return sample + sample = self.transform(sample) + + if idx == 0: + print(sample["image"].shape) + + return sample + + def __len__(self): + return len(self.image_files) + + +def get_vkitti2_loader(data_dir_root, batch_size=1, **kwargs): + dataset = VKITTI2(data_dir_root) + return DataLoader(dataset, batch_size, **kwargs) + + +if __name__ == "__main__": + loader = get_vkitti2_loader( + data_dir_root="/home/bhatsf/shortcuts/datasets/vkitti2") + print("Total files", len(loader.dataset)) + for i, sample in enumerate(loader): + print(sample["image"].shape) + print(sample["depth"].shape) + print(sample["dataset"]) + print(sample['depth'].min(), sample['depth'].max()) + if i > 5: + break diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__init__.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a3165c45c0aad14dcba9760e58c971e867fa26b Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/depth_model.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/depth_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2591da4fa93411b85fb311dc38143e8193958857 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/depth_model.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/model_io.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/model_io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da4e5e40340b0c9ae868956d48be318e1630508 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/__pycache__/model_io.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__init__.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a83d40b9a0605cabfe31cdc2220a6e6ab36e958 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__pycache__/midas.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__pycache__/midas.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3cfe0f0f2dfdc356e3ed2cb1fe0d0a5a6c80f48 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/__pycache__/midas.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas.py new file mode 100644 index 0000000000000000000000000000000000000000..ee660bc93d44c28efe8d8c674e715ea2ecb4c183 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas.py @@ -0,0 +1,379 @@ +# MIT License +import os + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + print("Params passed to Resize transform:") + print("\twidth: ", width) + print("\theight: ", height) + print("\tresize_target: ", resize_target) + print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + print("\tensure_multiple_of: ", ensure_multiple_of) + print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (height, width), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + self.normalization = Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class MidasCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + with torch.no_grad(): + if denorm: + x = denormalize(x) + x = self.prep(x) + # print("Shape after prep: ", x.shape) + + with torch.set_grad_enabled(self.trainable): + + # print("Input size to Midascore", x.shape) + rel_depth = self.core(x) + # print("Output from midas shape", rel_depth.shape) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.scratch.output_conv.children())[ + 3].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self, model_type): + self.output_channels = MIDAS_SETTINGS[model_type] + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if midas_model_type not in MIDAS_SETTINGS: + raise ValueError( + f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") + if "img_size" in kwargs: + kwargs = MidasCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + print("img_size", img_size) + midas_path = os.path.join(os.path.dirname(__file__), 'midas_repo') + midas = torch.hub.load(midas_path, midas_model_type, + pretrained=use_pretrained_midas, force_reload=force_reload, source='local') + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + midas_core.set_output_channels(midas_model_type) + return midas_core + + @staticmethod + def build_from_config(config): + return MidasCore.build(**config) + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a13c80028de3d297de4a3f09cee1b20759acc006 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/.gitignore @@ -0,0 +1,110 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +*.png +*.pfm +*.jpg +*.jpeg +*.pt \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..466bc94ba3128ea9cbe4bde82bd2fd1fc9daa8af --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/Dockerfile @@ -0,0 +1,29 @@ +# enables cuda support in docker +FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 + +# install python 3.6, pip and requirements for opencv-python +# (see https://github.com/NVIDIA/nvidia-docker/issues/864) +RUN apt-get update && apt-get -y install \ + python3 \ + python3-pip \ + libsm6 \ + libxext6 \ + libxrender-dev \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# install python dependencies +RUN pip3 install --upgrade pip +RUN pip3 install torch~=1.8 torchvision opencv-python-headless~=3.4 timm + +# copy inference code +WORKDIR /opt/MiDaS +COPY ./midas ./midas +COPY ./*.py ./ + +# download model weights so the docker image can be used offline +RUN cd weights && {curl -OL https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt; cd -; } +RUN python3 run.py --model_type dpt_hybrid; exit 0 + +# entrypoint (dont forget to mount input and output directories) +CMD python3 run.py --model_type dpt_hybrid diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9568ea71c755b6938ee5482ba9f09be722e75943 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/README.md @@ -0,0 +1,259 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + + +and our [preprint](https://arxiv.org/abs/2103.13413): + +> Vision Transformers for Dense Prediction +> René Ranftl, Alexey Bochkovskiy, Vladlen Koltun + + +MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with +multi-objective optimization. +The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2). +The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters. + +![](figures/Improvement_vs_FPS.png) + +### Setup + +1) Pick one or more models and download the corresponding weights to the `weights` folder: + +MiDaS 3.1 +- For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) +- For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt) +- For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt) +- For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin) + +MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) + +MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) + +1) Set up dependencies: + + ```shell + conda env create -f environment.yaml + conda activate midas-py310 + ``` + +#### optional + +For the Next-ViT model, execute + +```shell +git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit +``` + +For the OpenVINO model, install + +```shell +pip install openvino +``` + +### Usage + +1) Place one or more input images in the folder `input`. + +2) Run the model with + + ```shell + python run.py --model_type --input_path input --output_path output + ``` + where `````` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type), + [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type), + [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type), + [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type), + [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). + +3) The resulting depth maps are written to the `output` folder. + +#### optional + +1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This + size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single + inference height but a range of different heights. Feel free to explore different heights by appending the extra + command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may + decrease the model accuracy. +2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is + supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, + disregarding the aspect ratio while preserving the height, use the command line argument `--square`. + +#### via Camera + + If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths + away and choose a model type as shown above: + + ```shell + python run.py --model_type --side + ``` + + The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown + side-by-side for comparison. + +#### via Docker + +1) Make sure you have installed Docker and the + [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)). + +2) Build the Docker image: + + ```shell + docker build -t midas . + ``` + +3) Run inference: + + ```shell + docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas + ``` + + This command passes through all of your NVIDIA GPUs to the container, mounts the + `input` and `output` directories and then runs the inference. + +#### via PyTorch Hub + +The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/) + +#### via TensorFlow or ONNX + +See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory. + +Currently only supports MiDaS v2.1. + + +#### via Mobile (iOS / Android) + +See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory. + +#### via ROS1 (Robot Operating System) + +See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory. + +Currently only supports MiDaS v2.1. DPT-based models to be added. + + +### Accuracy + +We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets +(see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**. +$\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to +MiDaS 3.0 DPTL-384. The models are grouped by the height used for inference, whereas the square training resolution is given by +the numbers in the model names. The table also shows the **number of parameters** (in millions) and the +**frames per second** for inference at the training resolution (for GPU RTX 3090): + +| MiDaS Model | DIW
WHDR | Eth3d
AbsRel | Sintel
AbsRel | TUM
δ1 | KITTI
δ1 | NYUv2
δ1 | $\color{green}{\textsf{Imp.}}$
% | Par.
M | FPS
  | +|-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:| +| **Inference height 512** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** | +| | | | | | | | | | | +| **Inference height 384** | | | | | | | | | | +| [v3.1 BEiTL-512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 | +| [v3.1 Swin2L-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 | +| [v3.1 Swin2B-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 | +| [v3.1 SwinL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 | +| [v3.1 BEiTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 | +| [v3.1 Next-ViTL-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 | +| [v3.1 BEiTB-384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 | +| [v3.0 DPTL-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** | +| [v3.0 DPTH-384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 | +| [v2.1 Large384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 | +| | | | | | | | | | | +| **Inference height 256** | | | | | | | | | | +| [v3.1 Swin2T-256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 | +| [v2.1 Small256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** | +| | | | | | | | | | | +| **Inference height 224** | | | | | | | | | | +| [v3.1 LeViT224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** | + +* No zero-shot error, because models are also trained on KITTI and NYU Depth V2\ +$\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model +does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other +validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the +improvement, because these quantities are averages over the pixels of an image and do not take into account the +advantage of more details due to a higher resolution.\ +Best values per column and same validation height in bold + +#### Improvement + +The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0 +DPTL-384 and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then +the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%. + +Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the +improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large384 +and v2.0 Large384 respectively instead of v3.0 DPTL-384. + +### Depth map comparison + +Zoom in for better visibility +![](figures/Comparison.png) + +### Speed on Camera Feed + +Test configuration +- Windows 10 +- 11th Gen Intel Core i7-1185G7 3.00GHz +- 16GB RAM +- Camera resolution 640x480 +- openvino_midas_v21_small_256 + +Speed: 22 FPS + +### Changelog + +* [Dec 2022] Released MiDaS v3.1: + - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf)) + - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split + - Best model, BEiTLarge 512, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0 + - Integrated live depth estimation from camera feed +* [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large). +* [Apr 2021] Released MiDaS v3.0: + - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1 + - Additional models can be found [here](https://github.com/isl-org/DPT) +* [Nov 2020] Released MiDaS v2.1: + - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2) + - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms. + - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android) + - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots +* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/). +* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust +* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1)) + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@ARTICLE {Ranftl2022, + author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun", + title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer", + journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence", + year = "2022", + volume = "44", + number = "3" +} +``` + +If you use a DPT-based model, please also cite: + +``` +@article{Ranftl2021, + author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun}, + title = {Vision Transformers for Dense Prediction}, + journal = {ICCV}, + year = {2021}, +} +``` + +### Acknowledgements + +Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT). +We'd like to thank the authors for making these libraries available. + +### License + +MIT License diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9abe5693b9e0de56b7d20728f4d0e6333c5822d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/environment.yaml @@ -0,0 +1,16 @@ +name: midas-py310 +channels: + - pytorch + - defaults +dependencies: + - nvidia::cudatoolkit=11.7 + - python=3.10.8 + - pytorch::pytorch=1.13.0 + - torchvision=0.14.0 + - pip=22.3.1 + - numpy=1.23.4 + - pip: + - opencv-python==4.6.0.66 + - imutils==0.5.4 + - timm==0.6.12 + - einops==0.6.0 \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..0d638be5151c4e305daff0c47d1ea3fc8066377d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/hubconf.py @@ -0,0 +1,435 @@ +dependencies = ["torch"] + +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small + +def DPT_BEiT_L_512(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_512 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_512", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_BEiT_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitb16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2l24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2b24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_SwinV2_T_256(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_T_256 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2t16_256", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Swin_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Swin_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swinl12_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Next_ViT_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="next_vit_large_6m", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_LeViT_224(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_LeViT_224 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Large(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Large model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitl16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def DPT_Hybrid(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT-Hybrid model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="vitb_rn50_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet() + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + +def MiDaS_small(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices + pretrained (bool): load pretrained weights into model + """ + + model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True}) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + + +def transforms(): + import cv2 + from torchvision.transforms import Compose + from midas.transforms import Resize, NormalizeImage, PrepareForNet + from midas import transforms + + transforms.default_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.small_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.dpt_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.beit512_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 512, + 512, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin384_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin256_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.levit_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 224, + 224, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + return transforms diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..e616dfd4026f448f9e22d35c6ad8b0028732acb9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/beit.py @@ -0,0 +1,196 @@ +import timm +import torch +import types + +import numpy as np +import torch.nn.functional as F + +from .utils import forward_adapted_unflatten, make_backbone_default +from timm.models.beit import gen_relative_position_index +from torch.utils.checkpoint import checkpoint +from typing import Optional + + +def forward_beit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_features") + + +def patch_embed_forward(self, x): + """ + Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes. + """ + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +def _get_rel_pos_bias(self, window_size): + """ + Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes. + """ + old_height = 2 * self.window_size[0] - 1 + old_width = 2 * self.window_size[1] - 1 + + new_height = 2 * window_size[0] - 1 + new_width = 2 * window_size[1] - 1 + + old_relative_position_bias_table = self.relative_position_bias_table + + old_num_relative_distance = self.num_relative_distance + new_num_relative_distance = new_height * new_width + 3 + + old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3] + + old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2) + new_sub_table = F.interpolate(old_sub_table, size=(new_height, new_width), mode="bilinear") + new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1) + + new_relative_position_bias_table = torch.cat( + [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]]) + + key = str(window_size[1]) + "," + str(window_size[0]) + if key not in self.relative_position_indices.keys(): + self.relative_position_indices[key] = gen_relative_position_index(window_size) + + relative_position_bias = new_relative_position_bias_table[ + self.relative_position_indices[key].view(-1)].view( + window_size[0] * window_size[1] + 1, + window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + +def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes. + """ + B, N, C = x.shape + + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + window_size = tuple(np.array(resolution) // 16) + attn = attn + self._get_rel_pos_bias(window_size) + if shared_rel_pos_bias is not None: + attn = attn + shared_rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None): + """ + Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes. + """ + if self.gamma_1 is None: + x = x + self.drop_path1(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), resolution, + shared_rel_pos_bias=shared_rel_pos_bias)) + x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +def beit_forward_features(self, x): + """ + Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes. + """ + resolution = x.shape[2:] + + x = self.patch_embed(x) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias) + else: + x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias) + x = self.norm(x) + return x + + +def _make_beit_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[0, 4, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed) + backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model) + + for block in backbone.model.blocks: + attn = block.attn + attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn) + attn.forward = types.MethodType(attention_forward, attn) + attn.relative_position_indices = {} + + block.forward = types.MethodType(block_forward, block) + + return backbone + + +def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_512", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + + features = [256, 512, 1024, 1024] + + return _make_beit_backbone( + model, + features=features, + size=[512, 512], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("beit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks is None else hooks + return _make_beit_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + ) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d023a98702a0451806d26f33f8bccf931814f10 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/levit.py @@ -0,0 +1,106 @@ +import timm +import torch +import torch.nn as nn +import numpy as np + +from .utils import activations, get_activation, Transpose + + +def forward_levit(pretrained, x): + pretrained.model.forward_features(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + + layer_1 = pretrained.act_postprocess1(layer_1) + layer_2 = pretrained.act_postprocess2(layer_2) + layer_3 = pretrained.act_postprocess3(layer_3) + + return layer_1, layer_2, layer_3 + + +def _make_levit_backbone( + model, + hooks=[3, 11, 21], + patch_grid=[14, 14] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + + pretrained.activations = activations + + patch_grid_size = np.array(patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) + ) + + return pretrained + + +class ConvTransposeNorm(nn.Sequential): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm + such that ConvTranspose2d is used instead of Conv2d. + """ + + def __init__( + self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', + nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_chs)) + + nn.init.constant_(self.bn.weight, bn_weight_init) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.ConvTranspose2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b4_transpose(in_chs, out_chs, activation): + """ + Modification of + https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 + such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. + """ + return nn.Sequential( + ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), + activation(), + ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), + activation()) + + +def _make_pretrained_levit_384(pretrained, hooks=None): + model = timm.create_model("levit_384", pretrained=pretrained) + + hooks = [3, 11, 21] if hooks == None else hooks + return _make_levit_backbone( + model, + hooks=hooks + ) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8afdd8b743b5ab023a359dc3b721e601b1a40d11 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/next_vit.py @@ -0,0 +1,39 @@ +import timm + +import torch.nn as nn + +from pathlib import Path +from .utils import activations, forward_default, get_activation + +from ..external.next_vit.classification.nextvit import * + + +def forward_next_vit(pretrained, x): + return forward_default(pretrained, x, "forward") + + +def _make_next_vit_backbone( + model, + hooks=[2, 6, 36, 39], +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + return pretrained + + +def _make_pretrained_next_vit_large_6m(hooks=None): + model = timm.create_model("nextvit_large") + + hooks = [2, 6, 36, 39] if hooks == None else hooks + return _make_next_vit_backbone( + model, + hooks=hooks, + ) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..f8c71367e3e78b087f80b2ab3e2f495a9c372f1a --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin.py @@ -0,0 +1,13 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swinl12_384(pretrained, hooks=None): + model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4c8f1d6fc1807a207dc6b9a261c6f7b14a87a3 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin2.py @@ -0,0 +1,34 @@ +import timm + +from .swin_common import _make_swin_backbone + + +def _make_pretrained_swin2l24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2b24_384(pretrained, hooks=None): + model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) + + hooks = [1, 1, 17, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks + ) + + +def _make_pretrained_swin2t16_256(pretrained, hooks=None): + model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) + + hooks = [1, 1, 5, 1] if hooks == None else hooks + return _make_swin_backbone( + model, + hooks=hooks, + patch_grid=[64, 64] + ) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py new file mode 100644 index 0000000000000000000000000000000000000000..94d63d408f18511179d90b3ac6f697385d1e556d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/swin_common.py @@ -0,0 +1,52 @@ +import torch + +import torch.nn as nn +import numpy as np + +from .utils import activations, forward_default, get_activation, Transpose + + +def forward_swin(pretrained, x): + return forward_default(pretrained, x) + + +def _make_swin_backbone( + model, + hooks=[1, 1, 17, 1], + patch_grid=[96, 96] +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + if hasattr(model, "patch_grid"): + used_patch_grid = model.patch_grid + else: + used_patch_grid = patch_grid + + patch_grid_size = np.array(used_patch_grid, dtype=int) + + pretrained.act_postprocess1 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) + ) + pretrained.act_postprocess2 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) + ) + pretrained.act_postprocess3 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) + ) + pretrained.act_postprocess4 = nn.Sequential( + Transpose(1, 2), + nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) + ) + + return pretrained diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0558899dddcfccec5f01a764d4f21738eb612149 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/utils.py @@ -0,0 +1,249 @@ +import torch + +import torch.nn as nn + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index:] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def forward_default(pretrained, x, function_name="forward_features"): + exec(f"pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + if hasattr(pretrained, "act_postprocess1"): + layer_1 = pretrained.act_postprocess1(layer_1) + if hasattr(pretrained, "act_postprocess2"): + layer_2 = pretrained.act_postprocess2(layer_2) + if hasattr(pretrained, "act_postprocess3"): + layer_3 = pretrained.act_postprocess3(layer_3) + if hasattr(pretrained, "act_postprocess4"): + layer_4 = pretrained.act_postprocess4(layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def forward_adapted_unflatten(pretrained, x, function_name="forward_features"): + b, c, h, w = x.shape + + exec(f"glob = pretrained.model.{function_name}(x)") + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def make_backbone_default( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + return pretrained diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..413f9693bd4548342280e329c9128c1a52cea920 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/backbones/vit.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + +from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper, + make_backbone_default, Transpose) + + +def forward_vit(pretrained, x): + return forward_adapted_unflatten(pretrained, x, "forward_flex") + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index:], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + if self.no_embed_class: + x = x + pos_embed + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if not self.no_embed_class: + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, + start_index_readout=1, +): + pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index, + start_index_readout) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + patch_size=[16, 16], + number_stages=2, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + used_number_stages = 0 if use_vit_only else number_stages + for s in range(used_number_stages): + pretrained.model.patch_embed.backbone.stages[s].register_forward_hook( + get_activation(str(s + 1)) + ) + for s in range(used_number_stages, 4): + pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1))) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + for s in range(used_number_stages): + value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + exec(f"pretrained.act_postprocess{s + 1}=value") + for s in range(used_number_stages, 4): + if s < number_stages: + final_layer = nn.ConvTranspose2d( + in_channels=features[s], + out_channels=features[s], + kernel_size=4 // (2 ** s), + stride=4 // (2 ** s), + padding=0, + bias=True, + dilation=1, + groups=1, + ) + elif s > number_stages: + final_layer = nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ) + else: + final_layer = None + + layers = [ + readout_oper[s], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[s], + kernel_size=1, + stride=1, + padding=0, + ), + ] + if final_layer is not None: + layers.append(final_layer) + + value = nn.Sequential(*layers) + exec(f"pretrained.act_postprocess{s + 1}=value") + + pretrained.model.start_index = start_index + pretrained.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..6d87a00680bb6ed9a6d7c3043ea30a1e90361794 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/blocks.py @@ -0,0 +1,439 @@ +import torch +import torch.nn as nn + +from .backbones.beit import ( + _make_pretrained_beitl16_512, + _make_pretrained_beitl16_384, + _make_pretrained_beitb16_384, + forward_beit, +) +from .backbones.swin_common import ( + forward_swin, +) +from .backbones.swin2 import ( + _make_pretrained_swin2l24_384, + _make_pretrained_swin2b24_384, + _make_pretrained_swin2t16_256, +) +from .backbones.swin import ( + _make_pretrained_swinl12_384, +) +from .backbones.levit import ( + _make_pretrained_levit_384, + forward_levit, +) +from .backbones.vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, + use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]): + if backbone == "beitl16_512": + pretrained = _make_pretrained_beitl16_512( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_512-L (backbone) + elif backbone == "beitl16_384": + pretrained = _make_pretrained_beitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # BEiT_384-L (backbone) + elif backbone == "beitb16_384": + pretrained = _make_pretrained_beitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # BEiT_384-B (backbone) + elif backbone == "swin2l24_384": + pretrained = _make_pretrained_swin2l24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin2-L/12to24 (backbone) + elif backbone == "swin2b24_384": + pretrained = _make_pretrained_swin2b24_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [128, 256, 512, 1024], features, groups=groups, expand=expand + ) # Swin2-B/12to24 (backbone) + elif backbone == "swin2t16_256": + pretrained = _make_pretrained_swin2t16_256( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # Swin2-T/16 (backbone) + elif backbone == "swinl12_384": + pretrained = _make_pretrained_swinl12_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [192, 384, 768, 1536], features, groups=groups, expand=expand + ) # Swin-L/12 (backbone) + elif backbone == "next_vit_large_6m": + from .backbones.next_vit import _make_pretrained_next_vit_large_6m + pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks) + scratch = _make_scratch( + in_features, features, groups=groups, expand=expand + ) # Next-ViT-L on ImageNet-1K-6M (backbone) + elif backbone == "levit_384": + pretrained = _make_pretrained_levit_384( + use_pretrained, hooks=hooks + ) + scratch = _make_scratch( + [384, 512, 768], features, groups=groups, expand=expand + ) # LeViT 384 (backbone) + elif backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..3129d09cb43a7c79b23916236991fabbedb78f55 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/dpt_depth.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_beit, + forward_swin, + forward_levit, + forward_vit, +) +from .backbones.levit import stem_b4_transpose +from timm.models.layers import get_act_layer + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + **kwargs + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the + # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. + hooks = { + "beitl16_512": [5, 11, 17, 23], + "beitl16_384": [5, 11, 17, 23], + "beitb16_384": [2, 5, 8, 11], + "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] + "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] + "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] + "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + }[backbone] + + if "next_vit" in backbone: + in_features = { + "next_vit_large_6m": [96, 256, 512, 1024], + }[backbone] + else: + in_features = None + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks, + use_readout=readout, + in_features=in_features, + ) + + self.number_layers = len(hooks) if hooks is not None else 4 + size_refinenet3 = None + self.scratch.stem_transpose = None + + if "beit" in backbone: + self.forward_transformer = forward_beit + elif "swin" in backbone: + self.forward_transformer = forward_swin + elif "next_vit" in backbone: + from .backbones.next_vit import forward_next_vit + self.forward_transformer = forward_next_vit + elif "levit" in backbone: + self.forward_transformer = forward_levit + size_refinenet3 = 7 + self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) + else: + self.forward_transformer = forward_vit + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) + if self.number_layers >= 4: + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layers = self.forward_transformer(self.pretrained, x) + if self.number_layers == 3: + layer_1, layer_2, layer_3 = layers + else: + layer_1, layer_2, layer_3, layer_4 = layers + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + if self.number_layers >= 4: + layer_4_rn = self.scratch.layer4_rn(layer_4) + + if self.number_layers == 3: + path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) + else: + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + if self.scratch.stem_transpose is not None: + path_1 = self.scratch.stem_transpose(path_1) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features + head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 + kwargs.pop("head_features_1", None) + kwargs.pop("head_features_2", None) + + head = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f1cd1f2d43054bfd3d650587c7b2ed35f1347c9e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/model_loader.py @@ -0,0 +1,242 @@ +import cv2 +import torch + +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +from torchvision.transforms import Compose + +default_models = { + "dpt_beit_large_512": "weights/dpt_beit_large_512.pt", + "dpt_beit_large_384": "weights/dpt_beit_large_384.pt", + "dpt_beit_base_384": "weights/dpt_beit_base_384.pt", + "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt", + "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt", + "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt", + "dpt_swin_large_384": "weights/dpt_swin_large_384.pt", + "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt", + "dpt_levit_224": "weights/dpt_levit_224.pt", + "dpt_large_384": "weights/dpt_large_384.pt", + "dpt_hybrid_384": "weights/dpt_hybrid_384.pt", + "midas_v21_384": "weights/midas_v21_384.pt", + "midas_v21_small_256": "weights/midas_v21_small_256.pt", + "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml", +} + + +def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False): + """Load the specified network. + + Args: + device (device): the torch device used + model_path (str): path to saved model + model_type (str): the type of the model to be loaded + optimize (bool): optimize the model to half-integer on CUDA? + height (int): inference encoder image height + square (bool): resize to a square resolution? + + Returns: + The loaded network, the transform which prepares images as input to the network and the dimensions of the + network input + """ + if "openvino" in model_type: + from openvino.runtime import Core + + keep_aspect_ratio = not square + + if model_type == "dpt_beit_large_512": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_512", + non_negative=True, + ) + net_w, net_h = 512, 512 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="beitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_beit_base_384": + model = DPTDepthModel( + path=model_path, + backbone="beitb16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2l24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_base_384": + model = DPTDepthModel( + path=model_path, + backbone="swin2b24_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin2_tiny_256": + model = DPTDepthModel( + path=model_path, + backbone="swin2t16_256", + non_negative=True, + ) + net_w, net_h = 256, 256 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_swin_large_384": + model = DPTDepthModel( + path=model_path, + backbone="swinl12_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_next_vit_large_384": + model = DPTDepthModel( + path=model_path, + backbone="next_vit_large_6m", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers + # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of + # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py + # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e) + elif model_type == "dpt_levit_224": + model = DPTDepthModel( + path=model_path, + backbone="levit_384", + non_negative=True, + head_features_1=64, + head_features_2=8, + ) + net_w, net_h = 224, 224 + keep_aspect_ratio = False + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_large_384": + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid_384": + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21_384": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small_256": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "openvino_midas_v21_small_256": + ie = Core() + uncompiled_model = ie.read_model(model=model_path) + model = ie.compile_model(uncompiled_model, "CPU") + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + if not "openvino" in model_type: + print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) + else: + print("Model loaded, optimized with OpenVINO") + + if "openvino" in model_type: + keep_aspect_ratio = False + + if height is not None: + net_w, net_h = height, height + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + if not "openvino" in model_type: + model.eval() + + if optimize and (device == torch.device("cuda")): + if not "openvino" in model_type: + model = model.to(memory_format=torch.channels_last) + model = model.half() + else: + print("Error: OpenVINO models are already optimized. No optimization to half-float possible.") + exit() + + if not "openvino" in model_type: + model.to(device) + + return model, transform, net_w, net_h diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md new file mode 100644 index 0000000000000000000000000000000000000000..45c18f7f0bfe40c0db373e8a94716867705f5827 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/README.md @@ -0,0 +1,70 @@ +## Mobile version of MiDaS for iOS / Android - Monocular Depth Estimation + +### Accuracy + +* Old small model - ResNet50 default-decoder 384x384 +* New small model - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| Old small model 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| New small model 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on iOS / Android + +**Frames Per Second** (the higher - the better): + +| Model | iPhone CPU | iPhone GPU | iPhone NPU | OnePlus8 CPU | OnePlus8 GPU | OnePlus8 NNAPI | +|---|---|---|---|---|---|---| +| Old small model 384x384 | 0.6 | N/A | N/A | 0.45 | 0.50 | 0.50 | +| New small model 256x256 | 8 | 22 | **30** | 6 | **22** | 4 | +| SpeedUp, X times | **12.8x** | - | - | **13.2x** | **44x** | **8x** | + +N/A - run-time error (no data available) + + +#### Models: + +* Old small model - ResNet50 default-decoder 1x384x384x3, batch=1 FP32 (converters: Pytorch -> ONNX - [onnx_tf](https://github.com/onnx/onnx-tensorflow) -> (saved model) PB -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor) + +* New small model - EfficientNet-Lite3 small-decoder 1x256x256x3, batch=1 FP32 (custom converter: Pytorch -> TFlite) + + (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor, HRWSI, IRS, TartanAir, BlendedMVS, ApolloScape) + +#### Frameworks for training and conversions: +``` +pip install torch==1.6.0 torchvision==0.7.0 +pip install tf-nightly-gpu==2.5.0.dev20201031 tensorflow-addons==0.11.2 numpy==1.18.0 +git clone --depth 1 --branch v1.6.0 https://github.com/onnx/onnx-tensorflow +``` + +#### SoC - OS - Library: + +* iPhone 11 (A13 Bionic) - iOS 13.7 - TensorFlowLiteSwift 0.0.1-nightly +* OnePlus 8 (Snapdragon 865) - Andoird 10 - org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly + + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2fbe357549c64ae2966d5c3013a9179427b7b396 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/.gitignore @@ -0,0 +1,13 @@ +*.iml +.gradle +/local.properties +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +.DS_Store +/build +/captures +.externalNativeBuild + +/.gradle/ +/.idea/ \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md new file mode 100644 index 0000000000000000000000000000000000000000..72014bdfa2cd701a6453debbc8e53fcc15c0a5dc --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/EXPLORE_THE_CODE.md @@ -0,0 +1,414 @@ +# TensorFlow Lite Android image classification example + +This document walks through the code of a simple Android mobile application that +demonstrates +[image classification](https://www.tensorflow.org/lite/models/image_classification/overview) +using the device camera. + +## Explore the code + +We're now going to walk through the most important parts of the sample code. + +### Get camera input + +This mobile application gets the camera input using the functions defined in the +file +[`CameraActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java). +This file depends on +[`AndroidManifest.xml`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/AndroidManifest.xml) +to set the camera orientation. + +`CameraActivity` also contains code to capture user preferences from the UI and +make them available to other classes via convenience methods. + +```java +model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); +device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); +numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); +``` + +### Classifier + +This Image Classification Android reference app demonstrates two implementation +solutions, +[`lib_task_api`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api) +that leverages the out-of-box API from the +[TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier), +and +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support) +that creates the custom inference pipleline using the +[TensorFlow Lite Support Library](https://www.tensorflow.org/lite/inference_with_metadata/lite_support). + +Both solutions implement the file `Classifier.java` (see +[the one in lib_task_api](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java) +and +[the one in lib_support](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java)) +that contains most of the complex logic for processing the camera input and +running inference. + +Two subclasses of the `Classifier` exist, as in `ClassifierFloatMobileNet.java` +and `ClassifierQuantizedMobileNet.java`, which contain settings for both +floating point and +[quantized](https://www.tensorflow.org/lite/performance/post_training_quantization) +models. + +The `Classifier` class implements a static method, `create`, which is used to +instantiate the appropriate subclass based on the supplied model type (quantized +vs floating point). + +#### Using the TensorFlow Lite Task Library + +Inference can be done using just a few lines of code with the +[`ImageClassifier`](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier) +in the TensorFlow Lite Task Library. + +##### Load model and create ImageClassifier + +`ImageClassifier` expects a model populated with the +[model metadata](https://www.tensorflow.org/lite/convert/metadata) and the label +file. See the +[model compatibility requirements](https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements) +for more details. + +`ImageClassifierOptions` allows manipulation on various inference options, such +as setting the maximum number of top scored results to return using +`setMaxResults(MAX_RESULTS)`, and setting the score threshold using +`setScoreThreshold(scoreThreshold)`. + +```java +// Create the ImageClassifier instance. +ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); +imageClassifier = ImageClassifier.createFromFileAndOptions(activity, + getModelPath(), options); +``` + +`ImageClassifier` currently does not support configuring delegates and +multithread, but those are on our roadmap. Please stay tuned! + +##### Run inference + +`ImageClassifier` contains builtin logic to preprocess the input image, such as +rotating and resizing an image. Processing options can be configured through +`ImageProcessingOptions`. In the following example, input images are rotated to +the up-right angle and cropped to the center as the model expects a square input +(`224x224`). See the +[Java doc of `ImageClassifier`](https://github.com/tensorflow/tflite-support/blob/195b574f0aa9856c618b3f1ad87bd185cddeb657/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java#L22) +for more details about how the underlying image processing is performed. + +```java +TensorImage inputImage = TensorImage.fromBitmap(bitmap); +int width = bitmap.getWidth(); +int height = bitmap.getHeight(); +int cropSize = min(width, height); +ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + +List results = imageClassifier.classify(inputImage, + imageOptions); +``` + +The output of `ImageClassifier` is a list of `Classifications` instance, where +each `Classifications` element is a single head classification result. All the +demo models are single head models, therefore, `results` only contains one +`Classifications` object. Use `Classifications.getCategories()` to get a list of +top-k categories as specified with `MAX_RESULTS`. Each `Category` object +contains the srting label and the score of that category. + +To match the implementation of +[`lib_support`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/lib_support), +`results` is converted into `List` in the method, +`getRecognitions`. + +#### Using the TensorFlow Lite Support Library + +##### Load model and create interpreter + +To perform inference, we need to load a model file and instantiate an +`Interpreter`. This happens in the constructor of the `Classifier` class, along +with loading the list of class labels. Information about the device type and +number of threads is used to configure the `Interpreter` via the +`Interpreter.Options` instance passed into its constructor. Note that if a GPU, +DSP (Digital Signal Processor) or NPU (Neural Processing Unit) is available, a +[`Delegate`](https://www.tensorflow.org/lite/performance/delegates) can be used +to take full advantage of these hardware. + +Please note that there are performance edge cases and developers are adviced to +test with a representative set of devices prior to production. + +```java +protected Classifier(Activity activity, Device device, int numThreads) throws + IOException { + tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + labels = FileUtil.loadLabels(activity, getLabelPath()); +... +``` + +For Android devices, we recommend pre-loading and memory mapping the model file +to offer faster load times and reduce the dirty pages in memory. The method +`FileUtil.loadMappedFile` does this, returning a `MappedByteBuffer` containing +the model. + +The `MappedByteBuffer` is passed into the `Interpreter` constructor, along with +an `Interpreter.Options` object. This object can be used to configure the +interpreter, for example by setting the number of threads (`.setNumThreads(1)`) +or enabling [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks) +(`.addDelegate(nnApiDelegate)`). + +##### Pre-process bitmap image + +Next in the `Classifier` constructor, we take the input camera bitmap image, +convert it to a `TensorImage` format for efficient processing and pre-process +it. The steps are shown in the private 'loadImage' method: + +```java +/** Loads input image, and applys preprocessing. */ +private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + image.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight()); + int numRoration = sensorOrientation / 90; + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.BILINEAR)) + .add(new Rot90Op(numRoration)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); +} +``` + +The pre-processing is largely the same for quantized and float models with one +exception: Normalization. + +In `ClassifierFloatMobileNet`, the normalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 127.5f; +private static final float IMAGE_STD = 127.5f; +``` + +In `ClassifierQuantizedMobileNet`, normalization is not required. Thus the +nomalization parameters are defined as: + +```java +private static final float IMAGE_MEAN = 0.0f; +private static final float IMAGE_STD = 1.0f; +``` + +##### Allocate output object + +Initiate the output `TensorBuffer` for the output of the model. + +```java +/** Output probability TensorBuffer. */ +private final TensorBuffer outputProbabilityBuffer; + +//... +// Get the array size for the output buffer from the TensorFlow Lite model file +int probabilityTensorIndex = 0; +int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, 1001} +DataType probabilityDataType = + tflite.getOutputTensor(probabilityTensorIndex).dataType(); + +// Creates the output tensor and its processor. +outputProbabilityBuffer = + TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + +// Creates the post processor for the output probability. +probabilityProcessor = + new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); +``` + +For quantized models, we need to de-quantize the prediction with the NormalizeOp +(as they are all essentially linear transformation). For float model, +de-quantize is not required. But to uniform the API, de-quantize is added to +float model too. Mean and std are set to 0.0f and 1.0f, respectively. To be more +specific, + +In `ClassifierQuantizedMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 255.0f; +``` + +In `ClassifierFloatMobileNet`, the normalized parameters are defined as: + +```java +private static final float PROBABILITY_MEAN = 0.0f; +private static final float PROBABILITY_STD = 1.0f; +``` + +##### Run inference + +Inference is performed using the following in `Classifier` class: + +```java +tflite.run(inputImageBuffer.getBuffer(), + outputProbabilityBuffer.getBuffer().rewind()); +``` + +##### Recognize image + +Rather than call `run` directly, the method `recognizeImage` is used. It accepts +a bitmap and sensor orientation, runs inference, and returns a sorted `List` of +`Recognition` instances, each corresponding to a label. The method will return a +number of results bounded by `MAX_RESULTS`, which is 3 by default. + +`Recognition` is a simple class that contains information about a specific +recognition result, including its `title` and `confidence`. Using the +post-processing normalization method specified, the confidence is converted to +between 0 and 1 of a given class being represented by the image. + +```java +/** Gets the label to probability map. */ +Map labeledProbability = + new TensorLabel(labels, + probabilityProcessor.process(outputProbabilityBuffer)) + .getMapWithFloatValue(); +``` + +A `PriorityQueue` is used for sorting. + +```java +/** Gets the top-k results. */ +private static List getTopKProbability( + Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of + // the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), + entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = Math.min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; +} +``` + +### Display results + +The classifier is invoked and inference results are displayed by the +`processImage()` function in +[`ClassifierActivity.java`](https://github.com/tensorflow/examples/tree/master/lite/examples/image_classification/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java). + +`ClassifierActivity` is a subclass of `CameraActivity` that contains method +implementations that render the camera image, run classification, and display +the results. The method `processImage()` runs classification on a background +thread as fast as possible, rendering information on the UI thread to avoid +blocking inference and creating latency. + +```java +@Override +protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, + previewHeight); + final int imageSizeX = classifier.getImageSizeX(); + final int imageSizeY = classifier.getImageSizeY(); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + final List results = + classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + showResultsInBottomSheet(results); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(imageSizeX + "x" + imageSizeY); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); +} +``` + +Another important role of `ClassifierActivity` is to determine user preferences +(by interrogating `CameraActivity`), and instantiate the appropriately +configured `Classifier` subclass. This happens when the video feed begins (via +`onPreviewSizeChosen()`) and when options are changed in the UI (via +`onInferenceConfigurationChanged()`). + +```java +private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU && model == Model.QUANTIZED) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, "GPU does not yet supported quantized models.", + Toast.LENGTH_LONG) + .show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, + device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException e) { + LOGGER.e(e, "Failed to create classifier."); + } +} +``` diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md new file mode 100644 index 0000000000000000000000000000000000000000..faf415eb27ccc1a62357718d1e0a9b8c746de4e8 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/README.md @@ -0,0 +1,21 @@ +# MiDaS on Android smartphone by using TensorFlow-lite (TFLite) + + +* Either use Android Studio for compilation. + +* Or use ready to install apk-file: + * Or use URL: https://i.diawi.com/CVb8a9 + * Or use QR-code: + +Scan QR-code or open URL -> Press `Install application` -> Press `Download` and wait for download -> Open -> Install -> Open -> Press: Allow MiDaS to take photo and video from the camera While using the APP + +![CVb8a9](https://user-images.githubusercontent.com/4096485/97727213-38552500-1ae1-11eb-8b76-4ea11216f76d.png) + +---- + +To use another model, you should convert it to `model_opt.tflite` and place it to the directory: `models\src\main\assets` + + +---- + +Original repository: https://github.com/isl-org/MiDaS diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1ae74c6780c277d75fedfb7511ff51f69941b48b --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/.gitignore @@ -0,0 +1,3 @@ +/build + +/build/ \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..94e9886a55c7d54f71b424bb246c849dd6bd795d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/build.gradle @@ -0,0 +1,56 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 28 + defaultConfig { + applicationId "org.tensorflow.lite.examples.classification" + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + lintOptions { + abortOnError false + } + flavorDimensions "tfliteInference" + productFlavors { + // The TFLite inference is built using the TFLite Support library. + support { + dimension "tfliteInference" + } + // The TFLite inference is built using the TFLite Task library. + taskApi { + dimension "tfliteInference" + } + } + +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + supportImplementation project(":lib_support") + taskApiImplementation project(":lib_task_api") + implementation 'androidx.appcompat:appcompat:1.0.0' + implementation 'androidx.coordinatorlayout:coordinatorlayout:1.0.0' + implementation 'com.google.android.material:material:1.0.0' + + androidTestImplementation 'androidx.test.ext:junit:1.1.1' + androidTestImplementation 'com.google.truth:truth:1.0.1' + androidTestImplementation 'androidx.test:runner:1.2.0' + androidTestImplementation 'androidx.test:rules:1.1.0' +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdfad31f9b3e694817025d8b8f2ca0b40aa436bb --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt @@ -0,0 +1,3 @@ +red_fox 0.79403335 +kit_fox 0.16753247 +grey_fox 0.03619214 diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt new file mode 100644 index 0000000000000000000000000000000000000000..3668ce54df0d1e57e31c58281d6085b83928f991 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt @@ -0,0 +1,3 @@ +red_fox 0.85 +kit_fox 0.13 +grey_fox 0.02 diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..3653d8799092492ebbb16c7c956eb50e3d404aa4 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/AndroidManifest.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java new file mode 100644 index 0000000000000000000000000000000000000000..0194132890aae659c2a70d33106306ed665b22e8 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.util.Log; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.rule.ActivityTestRule; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Scanner; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +/** Golden test for Image Classification Reference app. */ +@RunWith(AndroidJUnit4.class) +public class ClassifierTest { + + @Rule + public ActivityTestRule rule = + new ActivityTestRule<>(ClassifierActivity.class); + + private static final String[] INPUTS = {"fox.jpg"}; + private static final String[] GOLDEN_OUTPUTS_SUPPORT = {"fox-mobilenet_v1_1.0_224_support.txt"}; + private static final String[] GOLDEN_OUTPUTS_TASK = {"fox-mobilenet_v1_1.0_224_task_api.txt"}; + + @Test + public void classificationResultsShouldNotChange() throws IOException { + ClassifierActivity activity = rule.getActivity(); + Classifier classifier = Classifier.create(activity, Model.FLOAT_MOBILENET, Device.CPU, 1); + for (int i = 0; i < INPUTS.length; i++) { + String imageFileName = INPUTS[i]; + String goldenOutputFileName; + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // This is a temporary workaround to set different golden rest results as the preprocessing + // of lib_support and lib_task_api are different. Will merge them once the above TODO is + // resolved. + if (Classifier.TAG.equals("ClassifierWithSupport")) { + goldenOutputFileName = GOLDEN_OUTPUTS_SUPPORT[i]; + } else { + goldenOutputFileName = GOLDEN_OUTPUTS_TASK[i]; + } + Bitmap input = loadImage(imageFileName); + List goldenOutput = loadRecognitions(goldenOutputFileName); + + List result = classifier.recognizeImage(input, 0); + Iterator goldenOutputIterator = goldenOutput.iterator(); + + for (Recognition actual : result) { + Assert.assertTrue(goldenOutputIterator.hasNext()); + Recognition expected = goldenOutputIterator.next(); + assertThat(actual.getTitle()).isEqualTo(expected.getTitle()); + assertThat(actual.getConfidence()).isWithin(0.01f).of(expected.getConfidence()); + } + } + } + + private static Bitmap loadImage(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load image from assets"); + } + return BitmapFactory.decodeStream(inputStream); + } + + private static List loadRecognitions(String fileName) { + AssetManager assetManager = + InstrumentationRegistry.getInstrumentation().getContext().getAssets(); + InputStream inputStream = null; + try { + inputStream = assetManager.open(fileName); + } catch (IOException e) { + Log.e("Test", "Cannot load probability results from assets"); + } + Scanner scanner = new Scanner(inputStream); + List result = new ArrayList<>(); + while (scanner.hasNext()) { + String category = scanner.next(); + category = category.replace('_', ' '); + if (!scanner.hasNextFloat()) { + break; + } + float probability = scanner.nextFloat(); + Recognition recognition = new Recognition(null, category, probability, null); + result.add(recognition); + } + return result; + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..7a414d5176a117262dce56c2220e6b71791287de --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/AndroidManifest.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..d1eb26c862c04bf573ecc4eb127e7460f0b100fc --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraActivity.java @@ -0,0 +1,717 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.Manifest; +import android.app.Fragment; +import android.content.Context; +import android.content.pm.PackageManager; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.RectF; +import android.hardware.Camera; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Build; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Trace; +import androidx.annotation.NonNull; +import androidx.annotation.UiThread; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Size; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewTreeObserver; +import android.view.WindowManager; +import android.widget.AdapterView; +import android.widget.ImageView; +import android.widget.LinearLayout; +import android.widget.Spinner; +import android.widget.TextView; +import android.widget.Toast; +import com.google.android.material.bottomsheet.BottomSheetBehavior; +import java.nio.ByteBuffer; +import java.util.List; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public abstract class CameraActivity extends AppCompatActivity + implements OnImageAvailableListener, + Camera.PreviewCallback, + View.OnClickListener, + AdapterView.OnItemSelectedListener { + private static final Logger LOGGER = new Logger(); + + private static final int PERMISSIONS_REQUEST = 1; + + private static final String PERMISSION_CAMERA = Manifest.permission.CAMERA; + protected int previewWidth = 0; + protected int previewHeight = 0; + private Handler handler; + private HandlerThread handlerThread; + private boolean useCamera2API; + private boolean isProcessingFrame = false; + private byte[][] yuvBytes = new byte[3][]; + private int[] rgbBytes = null; + private int yRowStride; + private Runnable postInferenceCallback; + private Runnable imageConverter; + private LinearLayout bottomSheetLayout; + private LinearLayout gestureLayout; + private BottomSheetBehavior sheetBehavior; + protected TextView recognitionTextView, + recognition1TextView, + recognition2TextView, + recognitionValueTextView, + recognition1ValueTextView, + recognition2ValueTextView; + protected TextView frameValueTextView, + cropValueTextView, + cameraResolutionTextView, + rotationTextView, + inferenceTimeTextView; + protected ImageView bottomSheetArrowImageView; + private ImageView plusImageView, minusImageView; + private Spinner modelSpinner; + private Spinner deviceSpinner; + private TextView threadsTextView; + + //private Model model = Model.QUANTIZED_EFFICIENTNET; + //private Device device = Device.CPU; + private Model model = Model.FLOAT_EFFICIENTNET; + private Device device = Device.GPU; + private int numThreads = -1; + + @Override + protected void onCreate(final Bundle savedInstanceState) { + LOGGER.d("onCreate " + this); + super.onCreate(null); + getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); + + setContentView(R.layout.tfe_ic_activity_camera); + + if (hasPermission()) { + setFragment(); + } else { + requestPermission(); + } + + threadsTextView = findViewById(R.id.threads); + plusImageView = findViewById(R.id.plus); + minusImageView = findViewById(R.id.minus); + modelSpinner = findViewById(R.id.model_spinner); + deviceSpinner = findViewById(R.id.device_spinner); + bottomSheetLayout = findViewById(R.id.bottom_sheet_layout); + gestureLayout = findViewById(R.id.gesture_layout); + sheetBehavior = BottomSheetBehavior.from(bottomSheetLayout); + bottomSheetArrowImageView = findViewById(R.id.bottom_sheet_arrow); + + ViewTreeObserver vto = gestureLayout.getViewTreeObserver(); + vto.addOnGlobalLayoutListener( + new ViewTreeObserver.OnGlobalLayoutListener() { + @Override + public void onGlobalLayout() { + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.JELLY_BEAN) { + gestureLayout.getViewTreeObserver().removeGlobalOnLayoutListener(this); + } else { + gestureLayout.getViewTreeObserver().removeOnGlobalLayoutListener(this); + } + // int width = bottomSheetLayout.getMeasuredWidth(); + int height = gestureLayout.getMeasuredHeight(); + + sheetBehavior.setPeekHeight(height); + } + }); + sheetBehavior.setHideable(false); + + sheetBehavior.setBottomSheetCallback( + new BottomSheetBehavior.BottomSheetCallback() { + @Override + public void onStateChanged(@NonNull View bottomSheet, int newState) { + switch (newState) { + case BottomSheetBehavior.STATE_HIDDEN: + break; + case BottomSheetBehavior.STATE_EXPANDED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_down); + } + break; + case BottomSheetBehavior.STATE_COLLAPSED: + { + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + } + break; + case BottomSheetBehavior.STATE_DRAGGING: + break; + case BottomSheetBehavior.STATE_SETTLING: + bottomSheetArrowImageView.setImageResource(R.drawable.icn_chevron_up); + break; + } + } + + @Override + public void onSlide(@NonNull View bottomSheet, float slideOffset) {} + }); + + recognitionTextView = findViewById(R.id.detected_item); + recognitionValueTextView = findViewById(R.id.detected_item_value); + recognition1TextView = findViewById(R.id.detected_item1); + recognition1ValueTextView = findViewById(R.id.detected_item1_value); + recognition2TextView = findViewById(R.id.detected_item2); + recognition2ValueTextView = findViewById(R.id.detected_item2_value); + + frameValueTextView = findViewById(R.id.frame_info); + cropValueTextView = findViewById(R.id.crop_info); + cameraResolutionTextView = findViewById(R.id.view_info); + rotationTextView = findViewById(R.id.rotation_info); + inferenceTimeTextView = findViewById(R.id.inference_info); + + modelSpinner.setOnItemSelectedListener(this); + deviceSpinner.setOnItemSelectedListener(this); + + plusImageView.setOnClickListener(this); + minusImageView.setOnClickListener(this); + + model = Model.valueOf(modelSpinner.getSelectedItem().toString().toUpperCase()); + device = Device.valueOf(deviceSpinner.getSelectedItem().toString()); + numThreads = Integer.parseInt(threadsTextView.getText().toString().trim()); + } + + protected int[] getRgbBytes() { + imageConverter.run(); + return rgbBytes; + } + + protected int getLuminanceStride() { + return yRowStride; + } + + protected byte[] getLuminance() { + return yuvBytes[0]; + } + + /** Callback for android.hardware.Camera API */ + @Override + public void onPreviewFrame(final byte[] bytes, final Camera camera) { + if (isProcessingFrame) { + LOGGER.w("Dropping frame!"); + return; + } + + try { + // Initialize the storage bitmaps once when the resolution is known. + if (rgbBytes == null) { + Camera.Size previewSize = camera.getParameters().getPreviewSize(); + previewHeight = previewSize.height; + previewWidth = previewSize.width; + rgbBytes = new int[previewWidth * previewHeight]; + onPreviewSizeChosen(new Size(previewSize.width, previewSize.height), 90); + } + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + return; + } + + isProcessingFrame = true; + yuvBytes[0] = bytes; + yRowStride = previewWidth; + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420SPToARGB8888(bytes, previewWidth, previewHeight, rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + camera.addCallbackBuffer(bytes); + isProcessingFrame = false; + } + }; + processImage(); + } + + /** Callback for Camera2 API */ + @Override + public void onImageAvailable(final ImageReader reader) { + // We need wait until we have some size from onPreviewSizeChosen + if (previewWidth == 0 || previewHeight == 0) { + return; + } + if (rgbBytes == null) { + rgbBytes = new int[previewWidth * previewHeight]; + } + try { + final Image image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + if (isProcessingFrame) { + image.close(); + return; + } + isProcessingFrame = true; + Trace.beginSection("imageAvailable"); + final Plane[] planes = image.getPlanes(); + fillBytes(planes, yuvBytes); + yRowStride = planes[0].getRowStride(); + final int uvRowStride = planes[1].getRowStride(); + final int uvPixelStride = planes[1].getPixelStride(); + + imageConverter = + new Runnable() { + @Override + public void run() { + ImageUtils.convertYUV420ToARGB8888( + yuvBytes[0], + yuvBytes[1], + yuvBytes[2], + previewWidth, + previewHeight, + yRowStride, + uvRowStride, + uvPixelStride, + rgbBytes); + } + }; + + postInferenceCallback = + new Runnable() { + @Override + public void run() { + image.close(); + isProcessingFrame = false; + } + }; + + processImage(); + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + Trace.endSection(); + return; + } + Trace.endSection(); + } + + @Override + public synchronized void onStart() { + LOGGER.d("onStart " + this); + super.onStart(); + } + + @Override + public synchronized void onResume() { + LOGGER.d("onResume " + this); + super.onResume(); + + handlerThread = new HandlerThread("inference"); + handlerThread.start(); + handler = new Handler(handlerThread.getLooper()); + } + + @Override + public synchronized void onPause() { + LOGGER.d("onPause " + this); + + handlerThread.quitSafely(); + try { + handlerThread.join(); + handlerThread = null; + handler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + + super.onPause(); + } + + @Override + public synchronized void onStop() { + LOGGER.d("onStop " + this); + super.onStop(); + } + + @Override + public synchronized void onDestroy() { + LOGGER.d("onDestroy " + this); + super.onDestroy(); + } + + protected synchronized void runInBackground(final Runnable r) { + if (handler != null) { + handler.post(r); + } + } + + @Override + public void onRequestPermissionsResult( + final int requestCode, final String[] permissions, final int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == PERMISSIONS_REQUEST) { + if (allPermissionsGranted(grantResults)) { + setFragment(); + } else { + requestPermission(); + } + } + } + + private static boolean allPermissionsGranted(final int[] grantResults) { + for (int result : grantResults) { + if (result != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + private boolean hasPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + return checkSelfPermission(PERMISSION_CAMERA) == PackageManager.PERMISSION_GRANTED; + } else { + return true; + } + } + + private void requestPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + if (shouldShowRequestPermissionRationale(PERMISSION_CAMERA)) { + Toast.makeText( + CameraActivity.this, + "Camera permission is required for this demo", + Toast.LENGTH_LONG) + .show(); + } + requestPermissions(new String[] {PERMISSION_CAMERA}, PERMISSIONS_REQUEST); + } + } + + // Returns true if the device supports the required hardware level, or better. + private boolean isHardwareLevelSupported( + CameraCharacteristics characteristics, int requiredLevel) { + int deviceLevel = characteristics.get(CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL); + if (deviceLevel == CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_LEGACY) { + return requiredLevel == deviceLevel; + } + // deviceLevel is not LEGACY, can use numerical sort + return requiredLevel <= deviceLevel; + } + + private String chooseCamera() { + final CameraManager manager = (CameraManager) getSystemService(Context.CAMERA_SERVICE); + try { + for (final String cameraId : manager.getCameraIdList()) { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + if (map == null) { + continue; + } + + // Fallback to camera1 API for internal cameras that don't have full support. + // This should help with legacy situations where using the camera2 API causes + // distorted or otherwise broken previews. + useCamera2API = + (facing == CameraCharacteristics.LENS_FACING_EXTERNAL) + || isHardwareLevelSupported( + characteristics, CameraCharacteristics.INFO_SUPPORTED_HARDWARE_LEVEL_FULL); + LOGGER.i("Camera API lv2?: %s", useCamera2API); + return cameraId; + } + } catch (CameraAccessException e) { + LOGGER.e(e, "Not allowed to access camera"); + } + + return null; + } + + protected void setFragment() { + String cameraId = chooseCamera(); + + Fragment fragment; + if (useCamera2API) { + CameraConnectionFragment camera2Fragment = + CameraConnectionFragment.newInstance( + new CameraConnectionFragment.ConnectionCallback() { + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + previewHeight = size.getHeight(); + previewWidth = size.getWidth(); + CameraActivity.this.onPreviewSizeChosen(size, rotation); + } + }, + this, + getLayoutId(), + getDesiredPreviewFrameSize()); + + camera2Fragment.setCamera(cameraId); + fragment = camera2Fragment; + } else { + fragment = + new LegacyCameraConnectionFragment(this, getLayoutId(), getDesiredPreviewFrameSize()); + } + + getFragmentManager().beginTransaction().replace(R.id.container, fragment).commit(); + } + + protected void fillBytes(final Plane[] planes, final byte[][] yuvBytes) { + // Because of the variable row stride it's not possible to know in + // advance the actual necessary dimensions of the yuv planes. + for (int i = 0; i < planes.length; ++i) { + final ByteBuffer buffer = planes[i].getBuffer(); + if (yuvBytes[i] == null) { + LOGGER.d("Initializing buffer %d at size %d", i, buffer.capacity()); + yuvBytes[i] = new byte[buffer.capacity()]; + } + buffer.get(yuvBytes[i]); + } + } + + protected void readyForNextImage() { + if (postInferenceCallback != null) { + postInferenceCallback.run(); + } + } + + protected int getScreenOrientation() { + switch (getWindowManager().getDefaultDisplay().getRotation()) { + case Surface.ROTATION_270: + return 270; + case Surface.ROTATION_180: + return 180; + case Surface.ROTATION_90: + return 90; + default: + return 0; + } + } + + @UiThread + protected void showResultsInTexture(float[] img_array, int imageSizeX, int imageSizeY) { + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + + } + + protected void showResultsInBottomSheet(List results) { + if (results != null && results.size() >= 3) { + Recognition recognition = results.get(0); + if (recognition != null) { + if (recognition.getTitle() != null) recognitionTextView.setText(recognition.getTitle()); + if (recognition.getConfidence() != null) + recognitionValueTextView.setText( + String.format("%.2f", (100 * recognition.getConfidence())) + "%"); + } + + Recognition recognition1 = results.get(1); + if (recognition1 != null) { + if (recognition1.getTitle() != null) recognition1TextView.setText(recognition1.getTitle()); + if (recognition1.getConfidence() != null) + recognition1ValueTextView.setText( + String.format("%.2f", (100 * recognition1.getConfidence())) + "%"); + } + + Recognition recognition2 = results.get(2); + if (recognition2 != null) { + if (recognition2.getTitle() != null) recognition2TextView.setText(recognition2.getTitle()); + if (recognition2.getConfidence() != null) + recognition2ValueTextView.setText( + String.format("%.2f", (100 * recognition2.getConfidence())) + "%"); + } + } + } + + protected void showFrameInfo(String frameInfo) { + frameValueTextView.setText(frameInfo); + } + + protected void showCropInfo(String cropInfo) { + cropValueTextView.setText(cropInfo); + } + + protected void showCameraResolution(String cameraInfo) { + cameraResolutionTextView.setText(cameraInfo); + } + + protected void showRotationInfo(String rotation) { + rotationTextView.setText(rotation); + } + + protected void showInference(String inferenceTime) { + inferenceTimeTextView.setText(inferenceTime); + } + + protected Model getModel() { + return model; + } + + private void setModel(Model model) { + if (this.model != model) { + LOGGER.d("Updating model: " + model); + this.model = model; + onInferenceConfigurationChanged(); + } + } + + protected Device getDevice() { + return device; + } + + private void setDevice(Device device) { + if (this.device != device) { + LOGGER.d("Updating device: " + device); + this.device = device; + final boolean threadsEnabled = device == Device.CPU; + plusImageView.setEnabled(threadsEnabled); + minusImageView.setEnabled(threadsEnabled); + threadsTextView.setText(threadsEnabled ? String.valueOf(numThreads) : "N/A"); + onInferenceConfigurationChanged(); + } + } + + protected int getNumThreads() { + return numThreads; + } + + private void setNumThreads(int numThreads) { + if (this.numThreads != numThreads) { + LOGGER.d("Updating numThreads: " + numThreads); + this.numThreads = numThreads; + onInferenceConfigurationChanged(); + } + } + + protected abstract void processImage(); + + protected abstract void onPreviewSizeChosen(final Size size, final int rotation); + + protected abstract int getLayoutId(); + + protected abstract Size getDesiredPreviewFrameSize(); + + protected abstract void onInferenceConfigurationChanged(); + + @Override + public void onClick(View v) { + if (v.getId() == R.id.plus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads >= 9) return; + setNumThreads(++numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } else if (v.getId() == R.id.minus) { + String threads = threadsTextView.getText().toString().trim(); + int numThreads = Integer.parseInt(threads); + if (numThreads == 1) { + return; + } + setNumThreads(--numThreads); + threadsTextView.setText(String.valueOf(numThreads)); + } + } + + @Override + public void onItemSelected(AdapterView parent, View view, int pos, long id) { + if (parent == modelSpinner) { + setModel(Model.valueOf(parent.getItemAtPosition(pos).toString().toUpperCase())); + } else if (parent == deviceSpinner) { + setDevice(Device.valueOf(parent.getItemAtPosition(pos).toString())); + } + } + + @Override + public void onNothingSelected(AdapterView parent) { + // Do nothing. + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..13e5c0dc341a86b1ddd66c4b562e0bf767641b42 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/CameraConnectionFragment.java @@ -0,0 +1,575 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.annotation.SuppressLint; +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.res.Configuration; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.text.TextUtils; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.Toast; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.Logger; + +/** + * Camera Connection Fragment that captures images from camera. + * + *

Instantiated by newInstance.

+ */ +@SuppressWarnings("FragmentNotInstantiable") +public class CameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + + /** + * The camera preview size will be chosen to be the smallest frame by pixel size capable of + * containing a DESIRED_SIZE x DESIRED_SIZE square. + */ + private static final int MINIMUM_PREVIEW_SIZE = 320; + + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + private static final String FRAGMENT_DIALOG = "dialog"; + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private final Semaphore cameraOpenCloseLock = new Semaphore(1); + /** A {@link OnImageAvailableListener} to receive frames as they are available. */ + private final OnImageAvailableListener imageListener; + /** The input size in pixels desired by TensorFlow (width and height of a square bitmap). */ + private final Size inputSize; + /** The layout identifier to inflate for this Fragment. */ + private final int layout; + + private final ConnectionCallback cameraConnectionCallback; + private final CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + @Override + public void onCaptureProgressed( + final CameraCaptureSession session, + final CaptureRequest request, + final CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + final CameraCaptureSession session, + final CaptureRequest request, + final TotalCaptureResult result) {} + }; + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + /** The rotation in degrees of the camera sensor from the display. */ + private Integer sensorOrientation; + /** The {@link Size} of camera preview. */ + private Size previewSize; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An {@link ImageReader} that handles preview frame capture. */ + private ImageReader previewReader; + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + @Override + public void onOpened(final CameraDevice cd) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = cd; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(final CameraDevice cd) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + } + + @Override + public void onError(final CameraDevice cd, final int error) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + final Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + @SuppressLint("ValidFragment") + private CameraConnectionFragment( + final ConnectionCallback connectionCallback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + this.cameraConnectionCallback = connectionCallback; + this.imageListener = imageListener; + this.layout = layout; + this.inputSize = inputSize; + } + + /** + * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose + * width and height are at least as large as the minimum of both, or an exact match if possible. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param width The minimum desired width + * @param height The minimum desired height + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + protected static Size chooseOptimalSize(final Size[] choices, final int width, final int height) { + final int minSize = Math.max(Math.min(width, height), MINIMUM_PREVIEW_SIZE); + final Size desiredSize = new Size(width, height); + + // Collect the supported resolutions that are at least as big as the preview Surface + boolean exactSizeFound = false; + final List bigEnough = new ArrayList(); + final List tooSmall = new ArrayList(); + for (final Size option : choices) { + if (option.equals(desiredSize)) { + // Set the size but don't return yet so that remaining sizes will still be logged. + exactSizeFound = true; + } + + if (option.getHeight() >= minSize && option.getWidth() >= minSize) { + bigEnough.add(option); + } else { + tooSmall.add(option); + } + } + + LOGGER.i("Desired size: " + desiredSize + ", min size: " + minSize + "x" + minSize); + LOGGER.i("Valid preview sizes: [" + TextUtils.join(", ", bigEnough) + "]"); + LOGGER.i("Rejected preview sizes: [" + TextUtils.join(", ", tooSmall) + "]"); + + if (exactSizeFound) { + LOGGER.i("Exact size match found."); + return desiredSize; + } + + // Pick the smallest of those, assuming we found any + if (bigEnough.size() > 0) { + final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea()); + LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight()); + return chosenSize; + } else { + LOGGER.e("Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static CameraConnectionFragment newInstance( + final ConnectionCallback callback, + final OnImageAvailableListener imageListener, + final int layout, + final Size inputSize) { + return new CameraConnectionFragment(callback, imageListener, layout, inputSize); + } + + /** + * Shows a {@link Toast} on the UI thread. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + Toast.makeText(activity, text, Toast.LENGTH_SHORT).show(); + } + }); + } + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + public void setCamera(String cameraId) { + this.cameraId = cameraId; + } + + /** Sets up member variables related to camera. */ + private void setUpCameraOutputs() { + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + + // Danger, W.R.! Attempting to use too large a preview size could exceed the camera + // bus' bandwidth limitation, resulting in gorgeous previews but the storage of + // garbage capture data. + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + inputSize.getWidth(), + inputSize.getHeight()); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + final int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.tfe_ic_camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + throw new IllegalStateException(getString(R.string.tfe_ic_camera_error)); + } + + cameraConnectionCallback.onPreviewSizeChosen(previewSize, sensorOrientation); + } + + /** Opens the camera specified by {@link CameraConnectionFragment#cameraId}. */ + private void openCamera(final int width, final int height) { + setUpCameraOutputs(); + configureTransform(width, height); + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != previewReader) { + previewReader.close(); + previewReader = null; + } + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("ImageListener"); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + final SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + final Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight()); + + // Create the reader for the preview frames. + previewReader = + ImageReader.newInstance( + previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2); + + previewReader.setOnImageAvailableListener(imageListener, backgroundHandler); + previewRequestBuilder.addTarget(previewReader.getSurface()); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface, previewReader.getSurface()), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(final CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + // Flash is automatically enabled when necessary. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + @Override + public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + /** + * Configures the necessary {@link Matrix} transformation to `mTextureView`. This method should be + * called after the camera preview size is determined in setUpCameraOutputs and also the size of + * `mTextureView` is fixed. + * + * @param viewWidth The width of `mTextureView` + * @param viewHeight The height of `mTextureView` + */ + private void configureTransform(final int viewWidth, final int viewHeight) { + final Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + final Matrix matrix = new Matrix(); + final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + final float centerX = viewRect.centerX(); + final float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + final float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** + * Callback for Activities to use to initialize their data once the selected preview size is + * known. + */ + public interface ConnectionCallback { + void onPreviewSizeChosen(Size size, int cameraRotation); + } + + /** Compares two {@code Size}s based on their areas. */ + static class CompareSizesByArea implements Comparator { + @Override + public int compare(final Size lhs, final Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(final String message) { + final ErrorDialog dialog = new ErrorDialog(); + final Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(final Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(final DialogInterface dialogInterface, final int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..24b5d72fdb42d47e5d2c87e3f944b71105748c1b --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/ClassifierActivity.java @@ -0,0 +1,238 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Typeface; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.SystemClock; +import android.util.Size; +import android.util.TypedValue; +import android.view.TextureView; +import android.view.ViewStub; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.List; +import java.util.ArrayList; + +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.BorderedText; +import org.tensorflow.lite.examples.classification.env.Logger; +import org.tensorflow.lite.examples.classification.tflite.Classifier; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; + +import android.widget.ImageView; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Paint; +import android.graphics.Rect; +import android.graphics.RectF; +import android.graphics.PixelFormat; +import java.nio.ByteBuffer; + +public class ClassifierActivity extends CameraActivity implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); + private static final float TEXT_SIZE_DIP = 10; + private Bitmap rgbFrameBitmap = null; + private long lastProcessingTimeMs; + private Integer sensorOrientation; + private Classifier classifier; + private BorderedText borderedText; + /** Input image size of the model along x axis. */ + private int imageSizeX; + /** Input image size of the model along y axis. */ + private int imageSizeY; + + @Override + protected int getLayoutId() { + return R.layout.tfe_ic_camera_connection_fragment; + } + + @Override + protected Size getDesiredPreviewFrameSize() { + return DESIRED_PREVIEW_SIZE; + } + + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + final float textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + borderedText = new BorderedText(textSizePx); + borderedText.setTypeface(Typeface.MONOSPACE); + + recreateClassifier(getModel(), getDevice(), getNumThreads()); + if (classifier == null) { + LOGGER.e("No classifier on preview!"); + return; + } + + previewWidth = size.getWidth(); + previewHeight = size.getHeight(); + + sensorOrientation = rotation - getScreenOrientation(); + LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation); + + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + } + + @Override + protected void processImage() { + rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight); + final int cropSize = Math.min(previewWidth, previewHeight); + + runInBackground( + new Runnable() { + @Override + public void run() { + if (classifier != null) { + final long startTime = SystemClock.uptimeMillis(); + //final List results = + // classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + final List results = new ArrayList<>(); + + float[] img_array = classifier.recognizeImage(rgbFrameBitmap, sensorOrientation); + + + /* + float maxval = Float.NEGATIVE_INFINITY; + float minval = Float.POSITIVE_INFINITY; + for (float cur : img_array) { + maxval = Math.max(maxval, cur); + minval = Math.min(minval, cur); + } + float multiplier = 0; + if ((maxval - minval) > 0) multiplier = 255 / (maxval - minval); + + int[] img_normalized = new int[img_array.length]; + for (int i = 0; i < img_array.length; ++i) { + float val = (float) (multiplier * (img_array[i] - minval)); + img_normalized[i] = (int) val; + } + + + + TextureView textureView = findViewById(R.id.textureView3); + //AutoFitTextureView textureView = (AutoFitTextureView) findViewById(R.id.texture); + + if(textureView.isAvailable()) { + int width = imageSizeX; + int height = imageSizeY; + + Canvas canvas = textureView.lockCanvas(); + canvas.drawColor(Color.BLUE); + Paint paint = new Paint(); + paint.setStyle(Paint.Style.FILL); + paint.setARGB(255, 150, 150, 150); + + int canvas_size = Math.min(canvas.getWidth(), canvas.getHeight()); + + Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.RGB_565); + + for (int ii = 0; ii < width; ii++) //pass the screen pixels in 2 directions + { + for (int jj = 0; jj < height; jj++) { + //int val = img_normalized[ii + jj * width]; + int index = (width - ii - 1) + (height - jj - 1) * width; + if(index < img_array.length) { + int val = img_normalized[index]; + bitmap.setPixel(ii, jj, Color.rgb(val, val, val)); + } + } + } + + canvas.drawBitmap(bitmap, null, new RectF(0, 0, canvas_size, canvas_size), null); + + textureView.unlockCanvasAndPost(canvas); + + } + */ + + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + LOGGER.v("Detect: %s", results); + + runOnUiThread( + new Runnable() { + @Override + public void run() { + //showResultsInBottomSheet(results); + showResultsInTexture(img_array, imageSizeX, imageSizeY); + showFrameInfo(previewWidth + "x" + previewHeight); + showCropInfo(imageSizeX + "x" + imageSizeY); + showCameraResolution(cropSize + "x" + cropSize); + showRotationInfo(String.valueOf(sensorOrientation)); + showInference(lastProcessingTimeMs + "ms"); + } + }); + } + readyForNextImage(); + } + }); + } + + @Override + protected void onInferenceConfigurationChanged() { + if (rgbFrameBitmap == null) { + // Defer creation until we're getting camera frames. + return; + } + final Device device = getDevice(); + final Model model = getModel(); + final int numThreads = getNumThreads(); + runInBackground(() -> recreateClassifier(model, device, numThreads)); + } + + private void recreateClassifier(Model model, Device device, int numThreads) { + if (classifier != null) { + LOGGER.d("Closing classifier."); + classifier.close(); + classifier = null; + } + if (device == Device.GPU + && (model == Model.QUANTIZED_MOBILENET || model == Model.QUANTIZED_EFFICIENTNET)) { + LOGGER.d("Not creating classifier: GPU doesn't support quantized models."); + runOnUiThread( + () -> { + Toast.makeText(this, R.string.tfe_ic_gpu_quant_error, Toast.LENGTH_LONG).show(); + }); + return; + } + try { + LOGGER.d( + "Creating classifier (model=%s, device=%s, numThreads=%d)", model, device, numThreads); + classifier = Classifier.create(this, model, device, numThreads); + } catch (IOException | IllegalArgumentException e) { + LOGGER.e(e, "Failed to create classifier."); + runOnUiThread( + () -> { + Toast.makeText(this, e.getMessage(), Toast.LENGTH_LONG).show(); + }); + return; + } + + // Updates the input image size. + imageSizeX = classifier.getImageSizeX(); + imageSizeY = classifier.getImageSizeY(); + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..760fe90375450c7b1356603c83fb37a68548ca13 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/LegacyCameraConnectionFragment.java @@ -0,0 +1,203 @@ +package org.tensorflow.lite.examples.classification; + +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import android.annotation.SuppressLint; +import android.app.Fragment; +import android.graphics.SurfaceTexture; +import android.hardware.Camera; +import android.hardware.Camera.CameraInfo; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import java.io.IOException; +import java.util.List; +import org.tensorflow.lite.examples.classification.customview.AutoFitTextureView; +import org.tensorflow.lite.examples.classification.env.ImageUtils; +import org.tensorflow.lite.examples.classification.env.Logger; + +public class LegacyCameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + /** Conversion from screen rotation to JPEG orientation. */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + private Camera camera; + private Camera.PreviewCallback imageListener; + private Size desiredSize; + /** The layout identifier to inflate for this Fragment. */ + private int layout; + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + + int index = getCameraId(); + camera = Camera.open(index); + + try { + Camera.Parameters parameters = camera.getParameters(); + List focusModes = parameters.getSupportedFocusModes(); + if (focusModes != null + && focusModes.contains(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE)) { + parameters.setFocusMode(Camera.Parameters.FOCUS_MODE_CONTINUOUS_PICTURE); + } + List cameraSizes = parameters.getSupportedPreviewSizes(); + Size[] sizes = new Size[cameraSizes.size()]; + int i = 0; + for (Camera.Size size : cameraSizes) { + sizes[i++] = new Size(size.width, size.height); + } + Size previewSize = + CameraConnectionFragment.chooseOptimalSize( + sizes, desiredSize.getWidth(), desiredSize.getHeight()); + parameters.setPreviewSize(previewSize.getWidth(), previewSize.getHeight()); + camera.setDisplayOrientation(90); + camera.setParameters(parameters); + camera.setPreviewTexture(texture); + } catch (IOException exception) { + camera.release(); + } + + camera.setPreviewCallbackWithBuffer(imageListener); + Camera.Size s = camera.getParameters().getPreviewSize(); + camera.addCallbackBuffer(new byte[ImageUtils.getYUVByteSize(s.height, s.width)]); + + textureView.setAspectRatio(s.height, s.width); + + camera.startPreview(); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) {} + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + @SuppressLint("ValidFragment") + public LegacyCameraConnectionFragment( + final Camera.PreviewCallback imageListener, final int layout, final Size desiredSize) { + this.imageListener = imageListener; + this.layout = layout; + this.desiredSize = desiredSize; + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(layout, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + + if (textureView.isAvailable()) { + if (camera != null) { + camera.startPreview(); + } + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + stopCamera(); + stopBackgroundThread(); + super.onPause(); + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("CameraBackground"); + backgroundThread.start(); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + protected void stopCamera() { + if (camera != null) { + camera.stopPreview(); + camera.setPreviewCallback(null); + camera.release(); + camera = null; + } + } + + private int getCameraId() { + CameraInfo ci = new CameraInfo(); + for (int i = 0; i < Camera.getNumberOfCameras(); i++) { + Camera.getCameraInfo(i, ci); + if (ci.facing == CameraInfo.CAMERA_FACING_BACK) return i; + } + return -1; // No camera found + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java new file mode 100644 index 0000000000000000000000000000000000000000..62e99ae70c2a7c4c60a776e7490742c5339e85f3 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + private int ratioWidth = 0; + private int ratioHeight = 0; + + public AutoFitTextureView(final Context context) { + this(context, null); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(final int width, final int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + ratioWidth = width; + ratioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + final int width = MeasureSpec.getSize(widthMeasureSpec); + final int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == ratioWidth || 0 == ratioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * ratioWidth / ratioHeight) { + setMeasuredDimension(width, width * ratioHeight / ratioWidth); + } else { + setMeasuredDimension(height * ratioWidth / ratioHeight, height); + } + } + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java new file mode 100644 index 0000000000000000000000000000000000000000..dc302ac04f145c9a1673a2d7e630a94a05ab1b1a --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.util.AttributeSet; +import android.view.View; +import java.util.LinkedList; +import java.util.List; + +/** A simple View providing a render callback to other classes. */ +public class OverlayView extends View { + private final List callbacks = new LinkedList(); + + public OverlayView(final Context context, final AttributeSet attrs) { + super(context, attrs); + } + + public void addCallback(final DrawCallback callback) { + callbacks.add(callback); + } + + @Override + public synchronized void draw(final Canvas canvas) { + for (final DrawCallback callback : callbacks) { + callback.drawCallback(canvas); + } + } + + /** Interface defining the callback for client classes. */ + public interface DrawCallback { + public void drawCallback(final Canvas canvas); + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java new file mode 100644 index 0000000000000000000000000000000000000000..2c57f603f12200079c888793cfa40d9b10dabde3 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import android.content.Context; +import android.graphics.Canvas; +import android.graphics.Paint; +import android.util.AttributeSet; +import android.util.TypedValue; +import android.view.View; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public class RecognitionScoreView extends View implements ResultsView { + private static final float TEXT_SIZE_DIP = 16; + private final float textSizePx; + private final Paint fgPaint; + private final Paint bgPaint; + private List results; + + public RecognitionScoreView(final Context context, final AttributeSet set) { + super(context, set); + + textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + fgPaint = new Paint(); + fgPaint.setTextSize(textSizePx); + + bgPaint = new Paint(); + bgPaint.setColor(0xcc4285f4); + } + + @Override + public void setResults(final List results) { + this.results = results; + postInvalidate(); + } + + @Override + public void onDraw(final Canvas canvas) { + final int x = 10; + int y = (int) (fgPaint.getTextSize() * 1.5f); + + canvas.drawPaint(bgPaint); + + if (results != null) { + for (final Recognition recog : results) { + canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); + y += (int) (fgPaint.getTextSize() * 1.5f); + } + } + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java new file mode 100644 index 0000000000000000000000000000000000000000..d055eb5f161a57fc439716efe6d49b7e45ef3fc7 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.customview; + +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; + +public interface ResultsView { + public void setResults(final List results); +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 0000000000000000000000000000000000000000..b1517edf496ef5800b97d046b92012a9f94a34d0 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml new file mode 100644 index 0000000000000000000000000000000000000000..70f4b24e35039e6bfc35989bcbe570a4bdc2ae07 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml @@ -0,0 +1,9 @@ + + + + + + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml new file mode 100644 index 0000000000000000000000000000000000000000..757f4503314fb9e5837f68ac515f4487d9b5fc2c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_add.xml @@ -0,0 +1,9 @@ + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml new file mode 100644 index 0000000000000000000000000000000000000000..a64b853e79137f0fd95f9d5fa6e0552cc255c7ae --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml @@ -0,0 +1,9 @@ + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000000000000000000000000000000000000..d5fccc538c179838bfdce779c26eebb4fa0b5ce9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml new file mode 100644 index 0000000000000000000000000000000000000000..b8f5d3559c4e83072d5d73a3241d240aa68daccf --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/drawable/rectangle.xml @@ -0,0 +1,13 @@ + + + + + + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..f0e1dae7afa15cf4a832de708f345482a6dfeff6 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml new file mode 100644 index 0000000000000000000000000000000000000000..97e5e7c6df25da48977f9064a888fd3735e4986f --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml @@ -0,0 +1,32 @@ + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml new file mode 100644 index 0000000000000000000000000000000000000000..77a348af90e2ed995ff106cd209cbf304c6b9153 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/layout/tfe_ic_layout_bottom_sheet.xml @@ -0,0 +1,321 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 0000000000000000000000000000000000000000..0c2a915e91af65a077d2e01db4ca21acd42906f3 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..ed82bafb536474c6a88c996b439a2781f31f3d3e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/colors.xml @@ -0,0 +1,8 @@ + + + #ffa800 + #ff6f00 + #425066 + + #66000000 + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml new file mode 100644 index 0000000000000000000000000000000000000000..5d3609029ca66b612c88b4f395e4e2e3cfc1f0e6 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/dimens.xml @@ -0,0 +1,5 @@ + + + 15dp + 8dp + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..7d763d85efc49879c8d3c0641484f5f472bfaca0 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/strings.xml @@ -0,0 +1,21 @@ + + Midas + This device doesn\'t support Camera2 API. + GPU does not yet supported quantized models. + Model: + + Float_EfficientNet + + + + Device: + + GPU + CPU + NNAPI + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..ad09a13ec6b2de8920a7441c9992f3cc0eedcfda --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/app/src/main/res/values/styles.xml @@ -0,0 +1,11 @@ + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..14492756847191ca3beff4c2e012d378c4e44be6 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/build.gradle @@ -0,0 +1,27 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:4.0.0' + classpath 'de.undercouch:gradle-download-task:4.0.2' + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + google() + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..9592636c07d9d5e6f61c0cfce1311d3e1ffcf34d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle.properties @@ -0,0 +1,15 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +android.useAndroidX=true +android.enableJetifier=true diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..f3d88b1c2faf2fc91d853cd5d4242b5547257070 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.jar differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..1b16c34a71cf212ed0cfb883d14d1b8511903eb2 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,5 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-6.1.1-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew new file mode 100644 index 0000000000000000000000000000000000000000..2fe81a7d95e4f9ad2c9b2a046707d36ceb3980b3 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew @@ -0,0 +1,183 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..9618d8d9607cd91a0efb866bcac4810064ba6fac --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/gradlew.bat @@ -0,0 +1,100 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..5d463975293264765a941795601cddb6cfc84f00 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite + implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true } + // Use local TensorFlow library + // implementation 'org.tensorflow:tensorflow-lite-local:0.0.0' +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..24ec573e7d184e7d64118a723d6645fd92d6e6d9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,376 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import android.view.TextureView; +import android.view.ViewStub; + +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.gpu.GpuDelegate; +import org.tensorflow.lite.nnapi.NnApiDelegate; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.TensorProcessor; +import org.tensorflow.lite.support.image.ImageProcessor; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.image.ops.ResizeOp; +import org.tensorflow.lite.support.image.ops.ResizeOp.ResizeMethod; +import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp; +import org.tensorflow.lite.support.image.ops.Rot90Op; +import org.tensorflow.lite.support.label.TensorLabel; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithSupport"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + QUANTIZED_EFFICIENTNET, + FLOAT_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** The loaded TensorFlow Lite model. */ + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + + /** Optional GPU delegate for accleration. */ + private GpuDelegate gpuDelegate = null; + + /** Optional NNAPI delegate for accleration. */ + private NnApiDelegate nnApiDelegate = null; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected Interpreter tflite; + + /** Options for configuring the Interpreter. */ + private final Interpreter.Options tfliteOptions = new Interpreter.Options(); + + /** Labels corresponding to the output of the vision model. */ + private final List labels; + + /** Input image TensorBuffer. */ + private TensorImage inputImageBuffer; + + /** Output probability TensorBuffer. */ + private final TensorBuffer outputProbabilityBuffer; + + /** Processer to apply post processing of the output probability. */ + private final TensorProcessor probabilityProcessor; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + switch (device) { + case NNAPI: + nnApiDelegate = new NnApiDelegate(); + tfliteOptions.addDelegate(nnApiDelegate); + break; + case GPU: + gpuDelegate = new GpuDelegate(); + tfliteOptions.addDelegate(gpuDelegate); + break; + case CPU: + break; + } + tfliteOptions.setNumThreads(numThreads); + tflite = new Interpreter(tfliteModel, tfliteOptions); + + // Loads labels out from the label file. + labels = FileUtil.loadLabels(activity, getLabelPath()); + + // Reads type and shape of input and output tensors, respectively. + int imageTensorIndex = 0; + int[] imageShape = tflite.getInputTensor(imageTensorIndex).shape(); // {1, height, width, 3} + if(imageShape[1] != imageShape[2]) { + imageSizeY = imageShape[2]; + imageSizeX = imageShape[3]; + } else { + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + DataType imageDataType = tflite.getInputTensor(imageTensorIndex).dataType(); + int probabilityTensorIndex = 0; + int[] probabilityShape = + tflite.getOutputTensor(probabilityTensorIndex).shape(); // {1, NUM_CLASSES} + DataType probabilityDataType = tflite.getOutputTensor(probabilityTensorIndex).dataType(); + + // Creates the input tensor. + inputImageBuffer = new TensorImage(imageDataType); + + // Creates the output tensor and its processor. + outputProbabilityBuffer = TensorBuffer.createFixedSize(probabilityShape, probabilityDataType); + + // Creates the post processor for the output probability. + probabilityProcessor = new TensorProcessor.Builder().add(getPostprocessNormalizeOp()).build(); + + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Runs inference and returns the classification results. */ + //public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + public float[] recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + Trace.beginSection("loadImage"); + long startTimeForLoadImage = SystemClock.uptimeMillis(); + inputImageBuffer = loadImage(bitmap, sensorOrientation); + long endTimeForLoadImage = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage)); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind()); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + float[] img_array = outputProbabilityBuffer.getFloatArray(); + + // Gets the map of label and probability. + //Map labeledProbability = + // new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer)) + // .getMapWithFloatValue(); + Trace.endSection(); + + // Gets top-k results. + return img_array;//getTopKProbability(labeledProbability); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (tflite != null) { + tflite.close(); + tflite = null; + } + if (gpuDelegate != null) { + gpuDelegate.close(); + gpuDelegate = null; + } + if (nnApiDelegate != null) { + nnApiDelegate.close(); + nnApiDelegate = null; + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** Loads input image, and applies preprocessing. */ + private TensorImage loadImage(final Bitmap bitmap, int sensorOrientation) { + // Loads bitmap into a TensorImage. + inputImageBuffer.load(bitmap); + + // Creates processor for the TensorImage. + int cropSize = min(bitmap.getWidth(), bitmap.getHeight()); + int numRotation = sensorOrientation / 90; + // TODO(b/143564309): Fuse ops inside ImageProcessor. + ImageProcessor imageProcessor = + new ImageProcessor.Builder() + .add(new ResizeWithCropOrPadOp(cropSize, cropSize)) + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // To get the same inference results as lib_task_api, which is built on top of the Task + // Library, use ResizeMethod.BILINEAR. + .add(new ResizeOp(imageSizeX, imageSizeY, ResizeMethod.NEAREST_NEIGHBOR)) + //.add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR)) + .add(new Rot90Op(numRotation)) + .add(getPreprocessNormalizeOp()) + .build(); + return imageProcessor.process(inputImageBuffer); + } + + /** Gets the top-k results. */ + private static List getTopKProbability(Map labelProb) { + // Find the best classifications. + PriorityQueue pq = + new PriorityQueue<>( + MAX_RESULTS, + new Comparator() { + @Override + public int compare(Recognition lhs, Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + for (Map.Entry entry : labelProb.entrySet()) { + pq.add(new Recognition("" + entry.getKey(), entry.getKey(), entry.getValue(), null)); + } + + final ArrayList recognitions = new ArrayList<>(); + int recognitionsSize = min(pq.size(), MAX_RESULTS); + for (int i = 0; i < recognitionsSize; ++i) { + recognitions.add(pq.poll()); + } + return recognitions; + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); + + /** Gets the name of the label file stored in Assets. */ + protected abstract String getLabelPath(); + + /** Gets the TensorOperator to nomalize the input image in preprocessing. */ + protected abstract TensorOperator getPreprocessNormalizeOp(); + + /** + * Gets the TensorOperator to dequantize the output probability in post processing. + * + *

For quantized model, we need de-quantize the prediction with NormalizeOp (as they are all + * essentially linear transformation). For float model, de-quantize is not required. But to + * uniform the API, de-quantize is added to float model too. Mean and std are set to 0.0f and + * 1.0f, respectively. + */ + protected abstract TensorOperator getPostprocessNormalizeOp(); +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..14dd027b26baefaedd979a8ac37f0bf984210ed4 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + private static final float IMAGE_MEAN = 115.0f; //127.0f; + private static final float IMAGE_STD = 58.0f; //128.0f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model_opt.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..40519de07cf5e887773250a4609a832b6060d684 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + + /** Float MobileNet requires additional normalization of the used input. */ + private static final float IMAGE_MEAN = 127.5f; + + private static final float IMAGE_STD = 127.5f; + + /** + * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f + * and 1.0f, repectively, to bypass the normalization. + */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 1.0f; + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param activity + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..d0d62f58d18333b6360ec30a4c85c9f1d38955ce --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_without_background.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..94b06e3df659005c287733a8a37672863fdadd71 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.common.ops.NormalizeOp; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to + * bypass the normalization. + */ + private static final float IMAGE_MEAN = 0.0f; + + private static final float IMAGE_STD = 1.0f; + + /** Quantized MobileNet requires additional dequantization to the output probability. */ + private static final float PROBABILITY_MEAN = 0.0f; + + private static final float PROBABILITY_STD = 255.0f; + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param activity + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "model_quant_0.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels.txt"; + } + + @Override + protected TensorOperator getPreprocessNormalizeOp() { + return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); + } + + @Override + protected TensorOperator getPostprocessNormalizeOp() { + return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b5983986e3d56a77a41676b9195b0d0882b5fb96 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/build.gradle @@ -0,0 +1,47 @@ +apply plugin: 'com.android.library' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + implementation project(":models") + implementation 'androidx.appcompat:appcompat:1.1.0' + + // Build off of nightly TensorFlow Lite Task Library + implementation('org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly') { changing = true } + implementation('org.tensorflow:tensorflow-lite-metadata:0.0.0-nightly') { changing = true } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ebe3c56c60a9b67eec218d969aecfdf5311d7b49 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java new file mode 100644 index 0000000000000000000000000000000000000000..45da52a0d0dfa203255e0f2d44901ee0618e739f --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java @@ -0,0 +1,278 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import static java.lang.Math.min; + +import android.app.Activity; +import android.graphics.Bitmap; +import android.graphics.Rect; +import android.graphics.RectF; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Log; +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.support.metadata.MetadataExtractor; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions.Orientation; +import org.tensorflow.lite.task.vision.classifier.Classifications; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier; +import org.tensorflow.lite.task.vision.classifier.ImageClassifier.ImageClassifierOptions; + +/** A classifier specialized to label images using TensorFlow Lite. */ +public abstract class Classifier { + public static final String TAG = "ClassifierWithTaskApi"; + + /** The model type used for classification. */ + public enum Model { + FLOAT_MOBILENET, + QUANTIZED_MOBILENET, + FLOAT_EFFICIENTNET, + QUANTIZED_EFFICIENTNET + } + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** Number of results to show in the UI. */ + private static final int MAX_RESULTS = 3; + + /** Image size along the x axis. */ + private final int imageSizeX; + + /** Image size along the y axis. */ + private final int imageSizeY; + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + protected final ImageClassifier imageClassifier; + + /** + * Creates a classifier with the provided configuration. + * + * @param activity The current Activity. + * @param model The model to use for classification. + * @param device The device to use for classification. + * @param numThreads The number of threads to use for classification. + * @return A classifier with the desired configuration. + */ + public static Classifier create(Activity activity, Model model, Device device, int numThreads) + throws IOException { + if (model == Model.QUANTIZED_MOBILENET) { + return new ClassifierQuantizedMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_MOBILENET) { + return new ClassifierFloatMobileNet(activity, device, numThreads); + } else if (model == Model.FLOAT_EFFICIENTNET) { + return new ClassifierFloatEfficientNet(activity, device, numThreads); + } else if (model == Model.QUANTIZED_EFFICIENTNET) { + return new ClassifierQuantizedEfficientNet(activity, device, numThreads); + } else { + throw new UnsupportedOperationException(); + } + } + + /** An immutable result returned by a Classifier describing what was recognized. */ + public static class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** Display name for the recognition. */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** Optional location within the source image for the location of the recognized object. */ + private RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + public void setLocation(RectF location) { + this.location = location; + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + /** Initializes a {@code Classifier}. */ + protected Classifier(Activity activity, Device device, int numThreads) throws IOException { + if (device != Device.CPU || numThreads != 1) { + throw new IllegalArgumentException( + "Manipulating the hardware accelerators and numbers of threads is not allowed in the Task" + + " library currently. Only CPU + single thread is allowed."); + } + + // Create the ImageClassifier instance. + ImageClassifierOptions options = + ImageClassifierOptions.builder().setMaxResults(MAX_RESULTS).build(); + imageClassifier = ImageClassifier.createFromFileAndOptions(activity, getModelPath(), options); + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + + // Get the input image size information of the underlying tflite model. + MappedByteBuffer tfliteModel = FileUtil.loadMappedFile(activity, getModelPath()); + MetadataExtractor metadataExtractor = new MetadataExtractor(tfliteModel); + // Image shape is in the format of {1, height, width, 3}. + int[] imageShape = metadataExtractor.getInputTensorShape(/*inputIndex=*/ 0); + imageSizeY = imageShape[1]; + imageSizeX = imageShape[2]; + } + + /** Runs inference and returns the classification results. */ + public List recognizeImage(final Bitmap bitmap, int sensorOrientation) { + // Logs this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + TensorImage inputImage = TensorImage.fromBitmap(bitmap); + int width = bitmap.getWidth(); + int height = bitmap.getHeight(); + int cropSize = min(width, height); + // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. + // Task Library resize the images using bilinear interpolation, which is slightly different from + // the nearest neighbor sampling algorithm used in lib_support. See + // https://github.com/tensorflow/examples/blob/0ef3d93e2af95d325c70ef3bcbbd6844d0631e07/lite/examples/image_classification/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/Classifier.java#L310. + ImageProcessingOptions imageOptions = + ImageProcessingOptions.builder() + .setOrientation(getOrientation(sensorOrientation)) + // Set the ROI to the center of the image. + .setRoi( + new Rect( + /*left=*/ (width - cropSize) / 2, + /*top=*/ (height - cropSize) / 2, + /*right=*/ (width + cropSize) / 2, + /*bottom=*/ (height + cropSize) / 2)) + .build(); + + // Runs the inference call. + Trace.beginSection("runInference"); + long startTimeForReference = SystemClock.uptimeMillis(); + List results = imageClassifier.classify(inputImage, imageOptions); + long endTimeForReference = SystemClock.uptimeMillis(); + Trace.endSection(); + Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference)); + + Trace.endSection(); + + return getRecognitions(results); + } + + /** Closes the interpreter and model to release resources. */ + public void close() { + if (imageClassifier != null) { + imageClassifier.close(); + } + } + + /** Get the image size along the x axis. */ + public int getImageSizeX() { + return imageSizeX; + } + + /** Get the image size along the y axis. */ + public int getImageSizeY() { + return imageSizeY; + } + + /** + * Converts a list of {@link Classifications} objects into a list of {@link Recognition} objects + * to match the interface of other inference method, such as using the TFLite + * Support Library.. + */ + private static List getRecognitions(List classifications) { + + final ArrayList recognitions = new ArrayList<>(); + // All the demo models are single head models. Get the first Classifications in the results. + for (Category category : classifications.get(0).getCategories()) { + recognitions.add( + new Recognition( + "" + category.getLabel(), category.getLabel(), category.getScore(), null)); + } + return recognitions; + } + + /* Convert the camera orientation in degree into {@link ImageProcessingOptions#Orientation}.*/ + private static Orientation getOrientation(int cameraOrientation) { + switch (cameraOrientation / 90) { + case 3: + return Orientation.BOTTOM_LEFT; + case 2: + return Orientation.BOTTOM_RIGHT; + case 1: + return Orientation.TOP_RIGHT; + default: + return Orientation.TOP_LEFT; + } + } + + /** Gets the name of the model file stored in Assets. */ + protected abstract String getModelPath(); +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..250794cc12d0e603aa47502322dc646d50689848 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float EfficientNet model. */ +public class ClassifierFloatEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + //return "efficientnet-lite0-fp32.tflite"; + return "model.tflite"; + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..0707de98de41395eaf3ddcfd74d6e36229a63760 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlowLite classifier works with the float MobileNet model. */ +public class ClassifierFloatMobileNet extends Classifier { + /** + * Initializes a {@code ClassifierFloatMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224.tflite"; + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java new file mode 100644 index 0000000000000000000000000000000000000000..05ca4fa6c409d0274a396c9b26c3c39ca8a8194e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; + +/** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ +public class ClassifierQuantizedEfficientNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "efficientnet-lite0-int8.tflite"; + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..978b08eeaf52a23eede437d61045db08d1dff163 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.examples.classification.tflite; + +import android.app.Activity; +import java.io.IOException; +import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; + +/** This TensorFlow Lite classifier works with the quantized MobileNet model. */ +public class ClassifierQuantizedMobileNet extends Classifier { + + /** + * Initializes a {@code ClassifierQuantizedMobileNet}. + * + * @param device a {@link Device} object to configure the hardware accelerator + * @param numThreads the number of threads during the inference + * @throws IOException if the model is not loaded correctly + */ + public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) + throws IOException { + super(activity, device, numThreads); + } + + @Override + protected String getModelPath() { + // you can download this file from + // see build.gradle for where to obtain this file. It should be auto + // downloaded into assets. + return "mobilenet_v1_1.0_224_quant.tflite"; + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8d825707af20cbbead6c4599f075599148e3511c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/build.gradle @@ -0,0 +1,40 @@ +apply plugin: 'com.android.library' +apply plugin: 'de.undercouch.download' + +android { + compileSdkVersion 28 + buildToolsVersion "28.0.0" + + defaultConfig { + minSdkVersion 21 + targetSdkVersion 28 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + + aaptOptions { + noCompress "tflite" + } + + lintOptions { + checkReleaseBuilds false + // Or, if you prefer, you can continue to check for errors in release builds, + // but continue the build even when errors are found: + abortOnError false + } +} + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +apply from:'download.gradle' diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle new file mode 100644 index 0000000000000000000000000000000000000000..ce76974a2c3bc6f8214461028e0dfa6ebc25d588 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/download.gradle @@ -0,0 +1,10 @@ +def modelFloatDownloadUrl = "https://github.com/isl-org/MiDaS/releases/download/v2_1/model_opt.tflite" +def modelFloatFile = "model_opt.tflite" + +task downloadModelFloat(type: Download) { + src "${modelFloatDownloadUrl}" + dest project.ext.ASSET_DIR + "/${modelFloatFile}" + overwrite false +} + +preBuild.dependsOn downloadModelFloat diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro new file mode 100644 index 0000000000000000000000000000000000000000..f1b424510da51fd82143bc74a0a801ae5a1e2fcd --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..42951a56497c5f947efe4aea6a07462019fb152c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/AndroidManifest.xml @@ -0,0 +1,3 @@ + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt new file mode 100644 index 0000000000000000000000000000000000000000..f40829ed0fc318c673860fae4be6c48529da116e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/labels_without_background.txt @@ -0,0 +1,1000 @@ +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8ebe235758d3d0f3d357c51ed54d78ac7eea8e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/models/src/main/assets/run_tflite.py @@ -0,0 +1,75 @@ +# Flex ops are included in the nightly build of the TensorFlow Python package. You can use TFLite models containing Flex ops by the same Python API as normal TFLite models. The nightly TensorFlow build can be installed with this command: +# Flex ops will be added to the TensorFlow Python package's and the tflite_runtime package from version 2.3 for Linux and 2.4 for other environments. +# https://www.tensorflow.org/lite/guide/ops_select#running_the_model + +# You must use: tf-nightly +# pip install tf-nightly + +import os +import glob +import cv2 +import numpy as np + +import tensorflow as tf + +width=256 +height=256 +model_name="model.tflite" +#model_name="model_quant.tflite" +image_name="dog.jpg" + +# input +img = cv2.imread(image_name) +img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + +mean=[0.485, 0.456, 0.406] +std=[0.229, 0.224, 0.225] +img = (img - mean) / std + +img_resized = tf.image.resize(img, [width,height], method='bicubic', preserve_aspect_ratio=False) +#img_resized = tf.transpose(img_resized, [2, 0, 1]) +img_input = img_resized.numpy() +reshape_img = img_input.reshape(1,width,height,3) +tensor = tf.convert_to_tensor(reshape_img, dtype=tf.float32) + +# load model +print("Load model...") +interpreter = tf.lite.Interpreter(model_path=model_name) +print("Allocate tensor...") +interpreter.allocate_tensors() +print("Get input/output details...") +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() +print("Get input shape...") +input_shape = input_details[0]['shape'] +print(input_shape) +print(input_details) +print(output_details) +#input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) +print("Set input tensor...") +interpreter.set_tensor(input_details[0]['index'], tensor) + +print("invoke()...") +interpreter.invoke() + +# The function `get_tensor()` returns a copy of the tensor data. +# Use `tensor()` in order to get a pointer to the tensor. +print("get output tensor...") +output = interpreter.get_tensor(output_details[0]['index']) +#output = np.squeeze(output) +output = output.reshape(width, height) +#print(output) +prediction = np.array(output) +print("reshape prediction...") +prediction = prediction.reshape(width, height) + +# output file +#prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) +print(" Write image to: output.png") +depth_min = prediction.min() +depth_max = prediction.max() +img_out = (255 * (prediction - depth_min) / (depth_max - depth_min)).astype("uint8") +print("save output image...") +cv2.imwrite("output.png", img_out) + +print("finished") \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e86d89d2483f92b7e778589011fad60fbba3a318 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/android/settings.gradle @@ -0,0 +1,2 @@ +rootProject.name = 'TFLite Image Classification Demo App' +include ':app', ':lib_support', ':lib_task_api', ':models' \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f1150e3379e4a38d31ca7bb46dc4f31d79f482c2 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/.gitignore @@ -0,0 +1,2 @@ +# ignore model file +#*.tflite diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..4917371aa33a65fdfc66c02d914f05489c446430 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.pbxproj @@ -0,0 +1,538 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = A1CE41C09920CCEC31985547 /* Pods_Midas.framework */; }; + 8402440123D9834600704ABD /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 8402440023D9834600704ABD /* README.md */; }; + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */ = {isa = PBXBuildFile; fileRef = 840ECB1F238BAA2300C7D88A /* InfoCell.swift */; }; + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */; }; + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 840EDD002341DE380017ED42 /* Main.storyboard */; }; + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 842DDB6D2372A82000F6BB94 /* OverlayView.swift */; }; + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */; }; + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 846BAF7523E7FE13006FC136 /* Constants.swift */; }; + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FEC82341D36E00377D34 /* PreviewView.swift */; }; + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8474FECA2341D39800377D34 /* CameraFeedManager.swift */; }; + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */; }; + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84952CB82361874A0052C104 /* TFLiteExtension.swift */; }; + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CEE2326338300A11A08 /* AppDelegate.swift */; }; + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84B67CF02326338300A11A08 /* ViewController.swift */; }; + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 84B67CF52326338400A11A08 /* Assets.xcassets */; }; + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */; }; + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 84F232D4254C831E0011862E /* model_opt.tflite */; }; + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */ = {isa = PBXBuildFile; fileRef = 84FCF5912387BD7900663812 /* tfl_logo.png */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 8402440023D9834600704ABD /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InfoCell.swift; sourceTree = ""; }; + 840EDCFC2341DDD30017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = "Base.lproj/Launch Screen.storyboard"; sourceTree = ""; }; + 840EDD012341DE380017ED42 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OverlayView.swift; sourceTree = ""; }; + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelDataHandler.swift; sourceTree = ""; }; + 846BAF7523E7FE13006FC136 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = ""; }; + 8474FEC82341D36E00377D34 /* PreviewView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PreviewView.swift; sourceTree = ""; }; + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CameraFeedManager.swift; sourceTree = ""; }; + 84884291236FF0A30043FC4C /* download_models.sh */ = {isa = PBXFileReference; lastKnownFileType = text.script.sh; path = download_models.sh; sourceTree = ""; }; + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CVPixelBufferExtension.swift; sourceTree = ""; }; + 84952CB82361874A0052C104 /* TFLiteExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TFLiteExtension.swift; sourceTree = ""; }; + 84B67CEB2326338300A11A08 /* Midas.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Midas.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 84B67CEE2326338300A11A08 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = ""; }; + 84B67CF02326338300A11A08 /* ViewController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ViewController.swift; sourceTree = ""; }; + 84B67CF52326338400A11A08 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 84B67CFA2326338400A11A08 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CGSizeExtension.swift; sourceTree = ""; }; + 84F232D4254C831E0011862E /* model_opt.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = model_opt.tflite; sourceTree = ""; }; + 84FCF5912387BD7900663812 /* tfl_logo.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; name = tfl_logo.png; path = Assets.xcassets/tfl_logo.png; sourceTree = ""; }; + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_Midas.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.release.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.release.xcconfig"; sourceTree = ""; }; + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-Midas.debug.xcconfig"; path = "Target Support Files/Pods-Midas/Pods-Midas.debug.xcconfig"; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 84B67CE82326338300A11A08 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 0CDA8C85042ADF65D0787629 /* Pods_Midas.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 840ECB1E238BAA0D00C7D88A /* Cells */ = { + isa = PBXGroup; + children = ( + 840ECB1F238BAA2300C7D88A /* InfoCell.swift */, + ); + path = Cells; + sourceTree = ""; + }; + 842DDB6C2372A80E00F6BB94 /* Views */ = { + isa = PBXGroup; + children = ( + 842DDB6D2372A82000F6BB94 /* OverlayView.swift */, + ); + path = Views; + sourceTree = ""; + }; + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */ = { + isa = PBXGroup; + children = ( + 846499C1235DAB0D009CBBC7 /* ModelDataHandler.swift */, + ); + path = ModelDataHandler; + sourceTree = ""; + }; + 8474FEC62341D2BE00377D34 /* ViewControllers */ = { + isa = PBXGroup; + children = ( + 84B67CF02326338300A11A08 /* ViewController.swift */, + ); + path = ViewControllers; + sourceTree = ""; + }; + 8474FEC72341D35800377D34 /* Camera Feed */ = { + isa = PBXGroup; + children = ( + 8474FEC82341D36E00377D34 /* PreviewView.swift */, + 8474FECA2341D39800377D34 /* CameraFeedManager.swift */, + ); + path = "Camera Feed"; + sourceTree = ""; + }; + 84884290236FF07F0043FC4C /* RunScripts */ = { + isa = PBXGroup; + children = ( + 84884291236FF0A30043FC4C /* download_models.sh */, + ); + path = RunScripts; + sourceTree = ""; + }; + 848842A22370180C0043FC4C /* Model */ = { + isa = PBXGroup; + children = ( + 84F232D4254C831E0011862E /* model_opt.tflite */, + ); + path = Model; + sourceTree = ""; + }; + 84952CB3236186A20052C104 /* Extensions */ = { + isa = PBXGroup; + children = ( + 84952CB4236186BE0052C104 /* CVPixelBufferExtension.swift */, + 84952CB82361874A0052C104 /* TFLiteExtension.swift */, + 84D6576C2387BB7E0048171E /* CGSizeExtension.swift */, + ); + path = Extensions; + sourceTree = ""; + }; + 84B67CE22326338300A11A08 = { + isa = PBXGroup; + children = ( + 8402440023D9834600704ABD /* README.md */, + 84884290236FF07F0043FC4C /* RunScripts */, + 84B67CED2326338300A11A08 /* Midas */, + 84B67CEC2326338300A11A08 /* Products */, + B4DFDCC28443B641BC36251D /* Pods */, + A3DA804B8D3F6891E3A02852 /* Frameworks */, + ); + sourceTree = ""; + }; + 84B67CEC2326338300A11A08 /* Products */ = { + isa = PBXGroup; + children = ( + 84B67CEB2326338300A11A08 /* Midas.app */, + ); + name = Products; + sourceTree = ""; + }; + 84B67CED2326338300A11A08 /* Midas */ = { + isa = PBXGroup; + children = ( + 840ECB1E238BAA0D00C7D88A /* Cells */, + 842DDB6C2372A80E00F6BB94 /* Views */, + 848842A22370180C0043FC4C /* Model */, + 84952CB3236186A20052C104 /* Extensions */, + 846499C0235DAAE7009CBBC7 /* ModelDataHandler */, + 8474FEC72341D35800377D34 /* Camera Feed */, + 8474FEC62341D2BE00377D34 /* ViewControllers */, + 84B67D002326339000A11A08 /* Storyboards */, + 84B67CEE2326338300A11A08 /* AppDelegate.swift */, + 846BAF7523E7FE13006FC136 /* Constants.swift */, + 84B67CF52326338400A11A08 /* Assets.xcassets */, + 84FCF5912387BD7900663812 /* tfl_logo.png */, + 84B67CFA2326338400A11A08 /* Info.plist */, + ); + path = Midas; + sourceTree = ""; + }; + 84B67D002326339000A11A08 /* Storyboards */ = { + isa = PBXGroup; + children = ( + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */, + 840EDD002341DE380017ED42 /* Main.storyboard */, + ); + path = Storyboards; + sourceTree = ""; + }; + A3DA804B8D3F6891E3A02852 /* Frameworks */ = { + isa = PBXGroup; + children = ( + A1CE41C09920CCEC31985547 /* Pods_Midas.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; + B4DFDCC28443B641BC36251D /* Pods */ = { + isa = PBXGroup; + children = ( + FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */, + D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */, + ); + path = Pods; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 84B67CEA2326338300A11A08 /* Midas */ = { + isa = PBXNativeTarget; + buildConfigurationList = 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */; + buildPhases = ( + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */, + 84884298237010B90043FC4C /* Download TensorFlow Lite model */, + 84B67CE72326338300A11A08 /* Sources */, + 84B67CE82326338300A11A08 /* Frameworks */, + 84B67CE92326338300A11A08 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = Midas; + productName = Midas; + productReference = 84B67CEB2326338300A11A08 /* Midas.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 84B67CE32326338300A11A08 /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 1030; + LastUpgradeCheck = 1030; + ORGANIZATIONNAME = tensorflow; + TargetAttributes = { + 84B67CEA2326338300A11A08 = { + CreatedOnToolsVersion = 10.3; + }; + }; + }; + buildConfigurationList = 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 84B67CE22326338300A11A08; + productRefGroup = 84B67CEC2326338300A11A08 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 84B67CEA2326338300A11A08 /* Midas */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 84B67CE92326338300A11A08 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 8402440123D9834600704ABD /* README.md in Resources */, + 84F232D5254C831E0011862E /* model_opt.tflite in Resources */, + 840EDD022341DE380017ED42 /* Main.storyboard in Resources */, + 840EDCFD2341DDD30017ED42 /* Launch Screen.storyboard in Resources */, + 84FCF5922387BD7900663812 /* tfl_logo.png in Resources */, + 84B67CF62326338400A11A08 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 14067F3CF309C9DB723C9F6F /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputFileListPaths = ( + ); + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-Midas-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; + 84884298237010B90043FC4C /* Download TensorFlow Lite model */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputFileListPaths = ( + ); + inputPaths = ( + ); + name = "Download TensorFlow Lite model"; + outputFileListPaths = ( + ); + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/bash; + shellScript = "\"$SRCROOT/RunScripts/download_models.sh\"\n"; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 84B67CE72326338300A11A08 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 842DDB6E2372A82000F6BB94 /* OverlayView.swift in Sources */, + 846BAF7623E7FE13006FC136 /* Constants.swift in Sources */, + 84952CB92361874A0052C104 /* TFLiteExtension.swift in Sources */, + 84D6576D2387BB7E0048171E /* CGSizeExtension.swift in Sources */, + 84B67CF12326338300A11A08 /* ViewController.swift in Sources */, + 84B67CEF2326338300A11A08 /* AppDelegate.swift in Sources */, + 8474FECB2341D39800377D34 /* CameraFeedManager.swift in Sources */, + 846499C2235DAB0D009CBBC7 /* ModelDataHandler.swift in Sources */, + 8474FEC92341D36E00377D34 /* PreviewView.swift in Sources */, + 84952CB5236186BE0052C104 /* CVPixelBufferExtension.swift in Sources */, + 840ECB20238BAA2300C7D88A /* InfoCell.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 840EDCFB2341DDD30017ED42 /* Launch Screen.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDCFC2341DDD30017ED42 /* Base */, + ); + name = "Launch Screen.storyboard"; + sourceTree = ""; + }; + 840EDD002341DE380017ED42 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 840EDD012341DE380017ED42 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 84B67CFB2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + 84B67CFC2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 84B67CFE2326338400A11A08 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = FCA88463911267B1001A596F /* Pods-Midas.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 84B67CFF2326338400A11A08 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = D2BFF06D0AE9137D332447F3 /* Pods-Midas.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_IDENTITY = "iPhone Developer"; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = BV6M48J3RX; + INFOPLIST_FILE = Midas/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = "org.midas.midas-tflite-npu"; + PRODUCT_NAME = Midas; + PROVISIONING_PROFILE_SPECIFIER = ""; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 84B67CE62326338300A11A08 /* Build configuration list for PBXProject "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFB2326338400A11A08 /* Debug */, + 84B67CFC2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 84B67CFD2326338400A11A08 /* Build configuration list for PBXNativeTarget "Midas" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 84B67CFE2326338400A11A08 /* Debug */, + 84B67CFF2326338400A11A08 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 84B67CE32326338300A11A08 /* Project object */; +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000000000000000000000000000000000000..919434a6254f0e9651f402737811be6634a03e9c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000000000000000000000000000000000000..18d981003d68d0546c4804ac2ff47dd97c6e7921 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate new file mode 100644 index 0000000000000000000000000000000000000000..1d20756ee57b79e9f9f886453bdb7997ca2ee2d4 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist new file mode 100644 index 0000000000000000000000000000000000000000..6093f6160eedfdfc20e96396247a7dbc9247cc55 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist @@ -0,0 +1,14 @@ + + + + + SchemeUserState + + PoseNet.xcscheme_^#shared#^_ + + orderHint + 3 + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift new file mode 100644 index 0000000000000000000000000000000000000000..233f0291ab4f379067543bdad3cc198a2dc3ab0f --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/AppDelegate.swift @@ -0,0 +1,41 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +@UIApplicationMain +class AppDelegate: UIResponder, UIApplicationDelegate { + + var window: UIWindow? + + func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?) -> Bool { + return true + } + + func applicationWillResignActive(_ application: UIApplication) { + } + + func applicationDidEnterBackground(_ application: UIApplication) { + } + + func applicationWillEnterForeground(_ application: UIApplication) { + } + + func applicationDidBecomeActive(_ application: UIApplication) { + } + + func applicationWillTerminate(_ application: UIApplication) { + } +} + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..65b74d7ef11fa59fafa829e681ac90906f3ac8b2 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1 @@ +{"images":[{"size":"60x60","expected-size":"180","filename":"180.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"40x40","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"60x60","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"57x57","expected-size":"57","filename":"57.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"87","filename":"87.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"57x57","expected-size":"114","filename":"114.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"60","filename":"60.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"1024x1024","filename":"1024.png","expected-size":"1024","idiom":"ios-marketing","folder":"Assets.xcassets/AppIcon.appiconset/","scale":"1x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"72x72","expected-size":"72","filename":"72.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"76x76","expected-size":"152","filename":"152.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"50x50","expected-size":"100","filename":"100.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"76x76","expected-size":"76","filename":"76.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"50x50","expected-size":"50","filename":"50.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"72x72","expected-size":"144","filename":"144.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"40x40","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"83.5x83.5","expected-size":"167","filename":"167.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"20x20","expected-size":"20","filename":"20.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"}]} \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift new file mode 100644 index 0000000000000000000000000000000000000000..48d65b88ee220e722fbad2570c8e879a431cd0f5 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/CameraFeedManager.swift @@ -0,0 +1,316 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + +// MARK: - CameraFeedManagerDelegate Declaration +@objc protocol CameraFeedManagerDelegate: class { + /// This method delivers the pixel buffer of the current frame seen by the device's camera. + @objc optional func cameraFeedManager( + _ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer + ) + + /// This method initimates that a session runtime error occured. + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) + + /// This method initimates that the session was interrupted. + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) + + /// This method initimates that the session interruption has ended. + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) + + /// This method initimates that there was an error in video configurtion. + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) + + /// This method initimates that the camera permissions have been denied. + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) +} + +/// This enum holds the state of the camera initialization. +// MARK: - Camera Initialization State Enum +enum CameraConfiguration { + case success + case failed + case permissionDenied +} + +/// This class manages all camera related functionalities. +// MARK: - Camera Related Functionalies Manager +class CameraFeedManager: NSObject { + // MARK: Camera Related Instance Variables + private let session: AVCaptureSession = AVCaptureSession() + + private let previewView: PreviewView + private let sessionQueue = DispatchQueue(label: "sessionQueue") + private var cameraConfiguration: CameraConfiguration = .failed + private lazy var videoDataOutput = AVCaptureVideoDataOutput() + private var isSessionRunning = false + + // MARK: CameraFeedManagerDelegate + weak var delegate: CameraFeedManagerDelegate? + + // MARK: Initializer + init(previewView: PreviewView) { + self.previewView = previewView + super.init() + + // Initializes the session + session.sessionPreset = .high + self.previewView.session = session + self.previewView.previewLayer.connection?.videoOrientation = .portrait + self.previewView.previewLayer.videoGravity = .resizeAspectFill + self.attemptToConfigureSession() + } + + // MARK: Session Start and End methods + + /// This method starts an AVCaptureSession based on whether the camera configuration was successful. + func checkCameraConfigurationAndStartSession() { + sessionQueue.async { + switch self.cameraConfiguration { + case .success: + self.addObservers() + self.startSession() + case .failed: + DispatchQueue.main.async { + self.delegate?.presentVideoConfigurationErrorAlert(self) + } + case .permissionDenied: + DispatchQueue.main.async { + self.delegate?.presentCameraPermissionsDeniedAlert(self) + } + } + } + } + + /// This method stops a running an AVCaptureSession. + func stopSession() { + self.removeObservers() + sessionQueue.async { + if self.session.isRunning { + self.session.stopRunning() + self.isSessionRunning = self.session.isRunning + } + } + + } + + /// This method resumes an interrupted AVCaptureSession. + func resumeInterruptedSession(withCompletion completion: @escaping (Bool) -> Void) { + sessionQueue.async { + self.startSession() + + DispatchQueue.main.async { + completion(self.isSessionRunning) + } + } + } + + /// This method starts the AVCaptureSession + private func startSession() { + self.session.startRunning() + self.isSessionRunning = self.session.isRunning + } + + // MARK: Session Configuration Methods. + /// This method requests for camera permissions and handles the configuration of the session and stores the result of configuration. + private func attemptToConfigureSession() { + switch AVCaptureDevice.authorizationStatus(for: .video) { + case .authorized: + self.cameraConfiguration = .success + case .notDetermined: + self.sessionQueue.suspend() + self.requestCameraAccess(completion: { granted in + self.sessionQueue.resume() + }) + case .denied: + self.cameraConfiguration = .permissionDenied + default: + break + } + + self.sessionQueue.async { + self.configureSession() + } + } + + /// This method requests for camera permissions. + private func requestCameraAccess(completion: @escaping (Bool) -> Void) { + AVCaptureDevice.requestAccess(for: .video) { (granted) in + if !granted { + self.cameraConfiguration = .permissionDenied + } else { + self.cameraConfiguration = .success + } + completion(granted) + } + } + + /// This method handles all the steps to configure an AVCaptureSession. + private func configureSession() { + guard cameraConfiguration == .success else { + return + } + session.beginConfiguration() + + // Tries to add an AVCaptureDeviceInput. + guard addVideoDeviceInput() == true else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + // Tries to add an AVCaptureVideoDataOutput. + guard addVideoDataOutput() else { + self.session.commitConfiguration() + self.cameraConfiguration = .failed + return + } + + session.commitConfiguration() + self.cameraConfiguration = .success + } + + /// This method tries to an AVCaptureDeviceInput to the current AVCaptureSession. + private func addVideoDeviceInput() -> Bool { + /// Tries to get the default back camera. + guard + let camera = AVCaptureDevice.default(.builtInWideAngleCamera, for: .video, position: .back) + else { + fatalError("Cannot find camera") + } + + do { + let videoDeviceInput = try AVCaptureDeviceInput(device: camera) + if session.canAddInput(videoDeviceInput) { + session.addInput(videoDeviceInput) + return true + } else { + return false + } + } catch { + fatalError("Cannot create video device input") + } + } + + /// This method tries to an AVCaptureVideoDataOutput to the current AVCaptureSession. + private func addVideoDataOutput() -> Bool { + let sampleBufferQueue = DispatchQueue(label: "sampleBufferQueue") + videoDataOutput.setSampleBufferDelegate(self, queue: sampleBufferQueue) + videoDataOutput.alwaysDiscardsLateVideoFrames = true + videoDataOutput.videoSettings = [ + String(kCVPixelBufferPixelFormatTypeKey): kCMPixelFormat_32BGRA + ] + + if session.canAddOutput(videoDataOutput) { + session.addOutput(videoDataOutput) + videoDataOutput.connection(with: .video)?.videoOrientation = .portrait + return true + } + return false + } + + // MARK: Notification Observer Handling + private func addObservers() { + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionRuntimeErrorOccured(notification:)), + name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionWasInterrupted(notification:)), + name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.addObserver( + self, selector: #selector(CameraFeedManager.sessionInterruptionEnded), + name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + private func removeObservers() { + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionRuntimeError, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionWasInterrupted, object: session) + NotificationCenter.default.removeObserver( + self, name: NSNotification.Name.AVCaptureSessionInterruptionEnded, object: session) + } + + // MARK: Notification Observers + @objc func sessionWasInterrupted(notification: Notification) { + if let userInfoValue = notification.userInfo?[AVCaptureSessionInterruptionReasonKey] + as AnyObject?, + let reasonIntegerValue = userInfoValue.integerValue, + let reason = AVCaptureSession.InterruptionReason(rawValue: reasonIntegerValue) + { + os_log("Capture session was interrupted with reason: %s", type: .error, reason.rawValue) + + var canResumeManually = false + if reason == .videoDeviceInUseByAnotherClient { + canResumeManually = true + } else if reason == .videoDeviceNotAvailableWithMultipleForegroundApps { + canResumeManually = false + } + + delegate?.cameraFeedManager(self, sessionWasInterrupted: canResumeManually) + + } + } + + @objc func sessionInterruptionEnded(notification: Notification) { + delegate?.cameraFeedManagerDidEndSessionInterruption(self) + } + + @objc func sessionRuntimeErrorOccured(notification: Notification) { + guard let error = notification.userInfo?[AVCaptureSessionErrorKey] as? AVError else { + return + } + + os_log("Capture session runtime error: %s", type: .error, error.localizedDescription) + + if error.code == .mediaServicesWereReset { + sessionQueue.async { + if self.isSessionRunning { + self.startSession() + } else { + DispatchQueue.main.async { + self.delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } + } + } else { + delegate?.cameraFeedManagerDidEncounterSessionRunTimeError(self) + } + } +} + +/// AVCaptureVideoDataOutputSampleBufferDelegate +extension CameraFeedManager: AVCaptureVideoDataOutputSampleBufferDelegate { + /// This method delegates the CVPixelBuffer of the frame seen by the camera currently. + func captureOutput( + _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer, + from connection: AVCaptureConnection + ) { + + // Converts the CMSampleBuffer to a CVPixelBuffer. + let pixelBuffer: CVPixelBuffer? = CMSampleBufferGetImageBuffer(sampleBuffer) + + guard let imagePixelBuffer = pixelBuffer else { + return + } + + // Delegates the pixel buffer to the ViewController. + delegate?.cameraFeedManager?(self, didOutput: imagePixelBuffer) + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift new file mode 100644 index 0000000000000000000000000000000000000000..308c7ec54308af5c152ff6038670b26501a8e82c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Camera Feed/PreviewView.swift @@ -0,0 +1,39 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit +import AVFoundation + + /// The camera frame is displayed on this view. +class PreviewView: UIView { + var previewLayer: AVCaptureVideoPreviewLayer { + guard let layer = layer as? AVCaptureVideoPreviewLayer else { + fatalError("Layer expected is of type VideoPreviewLayer") + } + return layer + } + + var session: AVCaptureSession? { + get { + return previewLayer.session + } + set { + previewLayer.session = newValue + } + } + + override class var layerClass: AnyClass { + return AVCaptureVideoPreviewLayer.self + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift new file mode 100644 index 0000000000000000000000000000000000000000..c6be64af5678541ec09fc367b03c80155876f0ba --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Cells/InfoCell.swift @@ -0,0 +1,21 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// Table cell for inference result in bottom view. +class InfoCell: UITableViewCell { + @IBOutlet weak var fieldNameLabel: UILabel! + @IBOutlet weak var infoLabel: UILabel! +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift new file mode 100644 index 0000000000000000000000000000000000000000..b0789ee58a1ea373d441f05333d8ce8914adadb7 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Constants.swift @@ -0,0 +1,25 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +enum Constants { + // MARK: - Constants related to the image processing + static let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2) + static let rgbPixelChannels = 3 + static let maxRGBValue: Float32 = 255.0 + + // MARK: - Constants related to the model interperter + static let defaultThreadCount = 2 + static let defaultDelegate: Delegates = .CPU +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..031550ea0081963d18b5b83712854babaf7c0a34 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CGSizeExtension.swift @@ -0,0 +1,45 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CGSize { + /// Returns `CGAfineTransform` to resize `self` to fit in destination size, keeping aspect ratio + /// of `self`. `self` image is resized to be inscribe to destination size and located in center of + /// destination. + /// + /// - Parameter toFitIn: destination size to be filled. + /// - Returns: `CGAffineTransform` to transform `self` image to `dest` image. + func transformKeepAspect(toFitIn dest: CGSize) -> CGAffineTransform { + let sourceRatio = self.height / self.width + let destRatio = dest.height / dest.width + + // Calculates ratio `self` to `dest`. + var ratio: CGFloat + var x: CGFloat = 0 + var y: CGFloat = 0 + if sourceRatio > destRatio { + // Source size is taller than destination. Resized to fit in destination height, and find + // horizontal starting point to be centered. + ratio = dest.height / self.height + x = (dest.width - self.width * ratio) / 2 + } else { + ratio = dest.width / self.width + y = (dest.height - self.height * ratio) / 2 + } + return CGAffineTransform(a: ratio, b: 0, c: 0, d: ratio, tx: x, ty: y) + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..4899c76562a546c513736fbf4556629b08d2c929 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/CVPixelBufferExtension.swift @@ -0,0 +1,172 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import Foundation + +extension CVPixelBuffer { + var size: CGSize { + return CGSize(width: CVPixelBufferGetWidth(self), height: CVPixelBufferGetHeight(self)) + } + + /// Returns a new `CVPixelBuffer` created by taking the self area and resizing it to the + /// specified target size. Aspect ratios of source image and destination image are expected to be + /// same. + /// + /// - Parameters: + /// - from: Source area of image to be cropped and resized. + /// - to: Size to scale the image to(i.e. image size used while training the model). + /// - Returns: The cropped and resized image of itself. + func resize(from source: CGRect, to size: CGSize) -> CVPixelBuffer? { + let rect = CGRect(origin: CGPoint(x: 0, y: 0), size: self.size) + guard rect.contains(source) else { + os_log("Resizing Error: source area is out of index", type: .error) + return nil + } + guard rect.size.width / rect.size.height - source.size.width / source.size.height < 1e-5 + else { + os_log( + "Resizing Error: source image ratio and destination image ratio is different", + type: .error) + return nil + } + + let inputImageRowBytes = CVPixelBufferGetBytesPerRow(self) + let imageChannels = 4 + + CVPixelBufferLockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) + defer { CVPixelBufferUnlockBaseAddress(self, CVPixelBufferLockFlags(rawValue: 0)) } + + // Finds the address of the upper leftmost pixel of the source area. + guard + let inputBaseAddress = CVPixelBufferGetBaseAddress(self)?.advanced( + by: Int(source.minY) * inputImageRowBytes + Int(source.minX) * imageChannels) + else { + return nil + } + + // Crops given area as vImage Buffer. + var croppedImage = vImage_Buffer( + data: inputBaseAddress, height: UInt(source.height), width: UInt(source.width), + rowBytes: inputImageRowBytes) + + let resultRowBytes = Int(size.width) * imageChannels + guard let resultAddress = malloc(Int(size.height) * resultRowBytes) else { + return nil + } + + // Allocates a vacant vImage buffer for resized image. + var resizedImage = vImage_Buffer( + data: resultAddress, + height: UInt(size.height), width: UInt(size.width), + rowBytes: resultRowBytes + ) + + // Performs the scale operation on cropped image and stores it in result image buffer. + guard vImageScale_ARGB8888(&croppedImage, &resizedImage, nil, vImage_Flags(0)) == kvImageNoError + else { + return nil + } + + let releaseCallBack: CVPixelBufferReleaseBytesCallback = { mutablePointer, pointer in + if let pointer = pointer { + free(UnsafeMutableRawPointer(mutating: pointer)) + } + } + + var result: CVPixelBuffer? + + // Converts the thumbnail vImage buffer to CVPixelBuffer + let conversionStatus = CVPixelBufferCreateWithBytes( + nil, + Int(size.width), Int(size.height), + CVPixelBufferGetPixelFormatType(self), + resultAddress, + resultRowBytes, + releaseCallBack, + nil, + nil, + &result + ) + + guard conversionStatus == kCVReturnSuccess else { + free(resultAddress) + return nil + } + + return result + } + + /// Returns the RGB `Data` representation of the given image buffer. + /// + /// - Parameters: + /// - isModelQuantized: Whether the model is quantized (i.e. fixed point values rather than + /// floating point values). + /// - Returns: The RGB data representation of the image buffer or `nil` if the buffer could not be + /// converted. + func rgbData( + isModelQuantized: Bool + ) -> Data? { + CVPixelBufferLockBaseAddress(self, .readOnly) + defer { CVPixelBufferUnlockBaseAddress(self, .readOnly) } + guard let sourceData = CVPixelBufferGetBaseAddress(self) else { + return nil + } + + let width = CVPixelBufferGetWidth(self) + let height = CVPixelBufferGetHeight(self) + let sourceBytesPerRow = CVPixelBufferGetBytesPerRow(self) + let destinationBytesPerRow = Constants.rgbPixelChannels * width + + // Assign input image to `sourceBuffer` to convert it. + var sourceBuffer = vImage_Buffer( + data: sourceData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: sourceBytesPerRow) + + // Make `destinationBuffer` and `destinationData` for its data to be assigned. + guard let destinationData = malloc(height * destinationBytesPerRow) else { + os_log("Error: out of memory", type: .error) + return nil + } + defer { free(destinationData) } + var destinationBuffer = vImage_Buffer( + data: destinationData, + height: vImagePixelCount(height), + width: vImagePixelCount(width), + rowBytes: destinationBytesPerRow) + + // Convert image type. + switch CVPixelBufferGetPixelFormatType(self) { + case kCVPixelFormatType_32BGRA: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + case kCVPixelFormatType_32ARGB: + vImageConvert_BGRA8888toRGB888(&sourceBuffer, &destinationBuffer, UInt32(kvImageNoFlags)) + default: + os_log("The type of this image is not supported.", type: .error) + return nil + } + + // Make `Data` with converted image. + let imageByteData = Data( + bytes: destinationBuffer.data, count: destinationBuffer.rowBytes * height) + + if isModelQuantized { return imageByteData } + + let imageBytes = [UInt8](imageByteData) + return Data(copyingBufferOf: imageBytes.map { Float($0) / Constants.maxRGBValue }) + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift new file mode 100644 index 0000000000000000000000000000000000000000..63f7ced786e2b550391c77af534d1d3c431522c6 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Extensions/TFLiteExtension.swift @@ -0,0 +1,75 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite + +// MARK: - Data +extension Data { + /// Creates a new buffer by copying the buffer pointer of the given array. + /// + /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit + /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting + /// data from the resulting buffer has undefined behavior. + /// - Parameter array: An array with elements of type `T`. + init(copyingBufferOf array: [T]) { + self = array.withUnsafeBufferPointer(Data.init) + } + + /// Convert a Data instance to Array representation. + func toArray(type: T.Type) -> [T] where T: AdditiveArithmetic { + var array = [T](repeating: T.zero, count: self.count / MemoryLayout.stride) + _ = array.withUnsafeMutableBytes { self.copyBytes(to: $0) } + return array + } +} + +// MARK: - Wrappers +/// Struct for handling multidimension `Data` in flat `Array`. +struct FlatArray { + private var array: [Element] + var dimensions: [Int] + + init(tensor: Tensor) { + dimensions = tensor.shape.dimensions + array = tensor.data.toArray(type: Element.self) + } + + private func flatIndex(_ index: [Int]) -> Int { + guard index.count == dimensions.count else { + fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).") + } + + var result = 0 + for i in 0.. index[i] else { + fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])") + } + result = dimensions[i] * result + index[i] + } + return result + } + + subscript(_ index: Int...) -> Element { + get { + return array[flatIndex(index)] + } + set(newValue) { + array[flatIndex(index)] = newValue + } + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..4330d9b33f31010549802febc6f6f2bc9fd9b950 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Info.plist @@ -0,0 +1,42 @@ + + + + + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + NSCameraUsageDescription + This app will use camera to continuously estimate the depth map. + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift new file mode 100644 index 0000000000000000000000000000000000000000..144cfe1fa3a65af5adcb572237f2bf9718e570ae --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ModelDataHandler/ModelDataHandler.swift @@ -0,0 +1,464 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Accelerate +import CoreImage +import Foundation +import TensorFlowLite +import UIKit + +/// This class handles all data preprocessing and makes calls to run inference on a given frame +/// by invoking the `Interpreter`. It then formats the inferences obtained. +class ModelDataHandler { + // MARK: - Private Properties + + /// TensorFlow Lite `Interpreter` object for performing inference on a given model. + private var interpreter: Interpreter + + /// TensorFlow lite `Tensor` of model input and output. + private var inputTensor: Tensor + + //private var heatsTensor: Tensor + //private var offsetsTensor: Tensor + private var outputTensor: Tensor + // MARK: - Initialization + + /// A failable initializer for `ModelDataHandler`. A new instance is created if the model is + /// successfully loaded from the app's main bundle. Default `threadCount` is 2. + init( + threadCount: Int = Constants.defaultThreadCount, + delegate: Delegates = Constants.defaultDelegate + ) throws { + // Construct the path to the model file. + guard + let modelPath = Bundle.main.path( + forResource: Model.file.name, + ofType: Model.file.extension + ) + else { + fatalError("Failed to load the model file with name: \(Model.file.name).") + } + + // Specify the options for the `Interpreter`. + var options = Interpreter.Options() + options.threadCount = threadCount + + // Specify the delegates for the `Interpreter`. + var delegates: [Delegate]? + switch delegate { + case .Metal: + delegates = [MetalDelegate()] + case .CoreML: + if let coreMLDelegate = CoreMLDelegate() { + delegates = [coreMLDelegate] + } else { + delegates = nil + } + default: + delegates = nil + } + + // Create the `Interpreter`. + interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: delegates) + + // Initialize input and output `Tensor`s. + // Allocate memory for the model's input `Tensor`s. + try interpreter.allocateTensors() + + // Get allocated input and output `Tensor`s. + inputTensor = try interpreter.input(at: 0) + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + /* + // Check if input and output `Tensor`s are in the expected formats. + guard (inputTensor.dataType == .uInt8) == Model.isQuantized else { + fatalError("Unexpected Model: quantization is \(!Model.isQuantized)") + } + + guard inputTensor.shape.dimensions[0] == Model.input.batchSize, + inputTensor.shape.dimensions[1] == Model.input.height, + inputTensor.shape.dimensions[2] == Model.input.width, + inputTensor.shape.dimensions[3] == Model.input.channelSize + else { + fatalError("Unexpected Model: input shape") + } + + + guard heatsTensor.shape.dimensions[0] == Model.output.batchSize, + heatsTensor.shape.dimensions[1] == Model.output.height, + heatsTensor.shape.dimensions[2] == Model.output.width, + heatsTensor.shape.dimensions[3] == Model.output.keypointSize + else { + fatalError("Unexpected Model: heat tensor") + } + + guard offsetsTensor.shape.dimensions[0] == Model.output.batchSize, + offsetsTensor.shape.dimensions[1] == Model.output.height, + offsetsTensor.shape.dimensions[2] == Model.output.width, + offsetsTensor.shape.dimensions[3] == Model.output.offsetSize + else { + fatalError("Unexpected Model: offset tensor") + } + */ + + } + + /// Runs Midas model with given image with given source area to destination area. + /// + /// - Parameters: + /// - on: Input image to run the model. + /// - from: Range of input image to run the model. + /// - to: Size of view to render the result. + /// - Returns: Result of the inference and the times consumed in every steps. + func runMidas(on pixelbuffer: CVPixelBuffer, from source: CGRect, to dest: CGSize) + //-> (Result, Times)? + //-> (FlatArray, Times)? + -> ([Float], Int, Int, Times)? + { + // Start times of each process. + let preprocessingStartTime: Date + let inferenceStartTime: Date + let postprocessingStartTime: Date + + // Processing times in miliseconds. + let preprocessingTime: TimeInterval + let inferenceTime: TimeInterval + let postprocessingTime: TimeInterval + + preprocessingStartTime = Date() + guard let data = preprocess(of: pixelbuffer, from: source) else { + os_log("Preprocessing failed", type: .error) + return nil + } + preprocessingTime = Date().timeIntervalSince(preprocessingStartTime) * 1000 + + inferenceStartTime = Date() + inference(from: data) + inferenceTime = Date().timeIntervalSince(inferenceStartTime) * 1000 + + postprocessingStartTime = Date() + //guard let result = postprocess(to: dest) else { + // os_log("Postprocessing failed", type: .error) + // return nil + //} + postprocessingTime = Date().timeIntervalSince(postprocessingStartTime) * 1000 + + + let results: [Float] + switch outputTensor.dataType { + case .uInt8: + guard let quantization = outputTensor.quantizationParameters else { + print("No results returned because the quantization values for the output tensor are nil.") + return nil + } + let quantizedResults = [UInt8](outputTensor.data) + results = quantizedResults.map { + quantization.scale * Float(Int($0) - quantization.zeroPoint) + } + case .float32: + results = [Float32](unsafeData: outputTensor.data) ?? [] + default: + print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.") + return nil + } + + + let times = Times( + preprocessing: preprocessingTime, + inference: inferenceTime, + postprocessing: postprocessingTime) + + return (results, Model.input.width, Model.input.height, times) + } + + // MARK: - Private functions to run model + /// Preprocesses given rectangle image to be `Data` of disired size by croping and resizing it. + /// + /// - Parameters: + /// - of: Input image to crop and resize. + /// - from: Target area to be cropped and resized. + /// - Returns: The cropped and resized image. `nil` if it can not be processed. + private func preprocess(of pixelBuffer: CVPixelBuffer, from targetSquare: CGRect) -> Data? { + let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) + assert(sourcePixelFormat == kCVPixelFormatType_32BGRA) + + // Resize `targetSquare` of input image to `modelSize`. + let modelSize = CGSize(width: Model.input.width, height: Model.input.height) + guard let thumbnail = pixelBuffer.resize(from: targetSquare, to: modelSize) + else { + return nil + } + + // Remove the alpha component from the image buffer to get the initialized `Data`. + let byteCount = + Model.input.batchSize + * Model.input.height * Model.input.width + * Model.input.channelSize + guard + let inputData = thumbnail.rgbData( + isModelQuantized: Model.isQuantized + ) + else { + os_log("Failed to convert the image buffer to RGB data.", type: .error) + return nil + } + + return inputData + } + + + + /* + /// Postprocesses output `Tensor`s to `Result` with size of view to render the result. + /// + /// - Parameters: + /// - to: Size of view to be displaied. + /// - Returns: Postprocessed `Result`. `nil` if it can not be processed. + private func postprocess(to viewSize: CGSize) -> Result? { + // MARK: Formats output tensors + // Convert `Tensor` to `FlatArray`. As Midas is not quantized, convert them to Float type + // `FlatArray`. + let heats = FlatArray(tensor: heatsTensor) + let offsets = FlatArray(tensor: offsetsTensor) + + // MARK: Find position of each key point + // Finds the (row, col) locations of where the keypoints are most likely to be. The highest + // `heats[0, row, col, keypoint]` value, the more likely `keypoint` being located in (`row`, + // `col`). + let keypointPositions = (0.. (Int, Int) in + var maxValue = heats[0, 0, 0, keypoint] + var maxRow = 0 + var maxCol = 0 + for row in 0.. maxValue { + maxValue = heats[0, row, col, keypoint] + maxRow = row + maxCol = col + } + } + } + return (maxRow, maxCol) + } + + // MARK: Calculates total confidence score + // Calculates total confidence score of each key position. + let totalScoreSum = keypointPositions.enumerated().reduce(0.0) { accumulator, elem -> Float32 in + accumulator + sigmoid(heats[0, elem.element.0, elem.element.1, elem.offset]) + } + let totalScore = totalScoreSum / Float32(Model.output.keypointSize) + + // MARK: Calculate key point position on model input + // Calculates `KeyPoint` coordination model input image with `offsets` adjustment. + let coords = keypointPositions.enumerated().map { index, elem -> (y: Float32, x: Float32) in + let (y, x) = elem + let yCoord = + Float32(y) / Float32(Model.output.height - 1) * Float32(Model.input.height) + + offsets[0, y, x, index] + let xCoord = + Float32(x) / Float32(Model.output.width - 1) * Float32(Model.input.width) + + offsets[0, y, x, index + Model.output.keypointSize] + return (y: yCoord, x: xCoord) + } + + // MARK: Transform key point position and make lines + // Make `Result` from `keypointPosition'. Each point is adjusted to `ViewSize` to be drawn. + var result = Result(dots: [], lines: [], score: totalScore) + var bodyPartToDotMap = [BodyPart: CGPoint]() + for (index, part) in BodyPart.allCases.enumerated() { + let position = CGPoint( + x: CGFloat(coords[index].x) * viewSize.width / CGFloat(Model.input.width), + y: CGFloat(coords[index].y) * viewSize.height / CGFloat(Model.input.height) + ) + bodyPartToDotMap[part] = position + result.dots.append(position) + } + + do { + try result.lines = BodyPart.lines.map { map throws -> Line in + guard let from = bodyPartToDotMap[map.from] else { + throw PostprocessError.missingBodyPart(of: map.from) + } + guard let to = bodyPartToDotMap[map.to] else { + throw PostprocessError.missingBodyPart(of: map.to) + } + return Line(from: from, to: to) + } + } catch PostprocessError.missingBodyPart(let missingPart) { + os_log("Postprocessing error: %s is missing.", type: .error, missingPart.rawValue) + return nil + } catch { + os_log("Postprocessing error: %s", type: .error, error.localizedDescription) + return nil + } + + return result + } +*/ + + + + /// Run inference with given `Data` + /// + /// Parameter `from`: `Data` of input image to run model. + private func inference(from data: Data) { + // Copy the initialized `Data` to the input `Tensor`. + do { + try interpreter.copy(data, toInputAt: 0) + + // Run inference by invoking the `Interpreter`. + try interpreter.invoke() + + // Get the output `Tensor` to process the inference results. + outputTensor = try interpreter.output(at: 0) + //heatsTensor = try interpreter.output(at: 0) + //offsetsTensor = try interpreter.output(at: 1) + + + } catch let error { + os_log( + "Failed to invoke the interpreter with error: %s", type: .error, + error.localizedDescription) + return + } + } + + /// Returns value within [0,1]. + private func sigmoid(_ x: Float32) -> Float32 { + return (1.0 / (1.0 + exp(-x))) + } +} + +// MARK: - Data types for inference result +struct KeyPoint { + var bodyPart: BodyPart = BodyPart.NOSE + var position: CGPoint = CGPoint() + var score: Float = 0.0 +} + +struct Line { + let from: CGPoint + let to: CGPoint +} + +struct Times { + var preprocessing: Double + var inference: Double + var postprocessing: Double +} + +struct Result { + var dots: [CGPoint] + var lines: [Line] + var score: Float +} + +enum BodyPart: String, CaseIterable { + case NOSE = "nose" + case LEFT_EYE = "left eye" + case RIGHT_EYE = "right eye" + case LEFT_EAR = "left ear" + case RIGHT_EAR = "right ear" + case LEFT_SHOULDER = "left shoulder" + case RIGHT_SHOULDER = "right shoulder" + case LEFT_ELBOW = "left elbow" + case RIGHT_ELBOW = "right elbow" + case LEFT_WRIST = "left wrist" + case RIGHT_WRIST = "right wrist" + case LEFT_HIP = "left hip" + case RIGHT_HIP = "right hip" + case LEFT_KNEE = "left knee" + case RIGHT_KNEE = "right knee" + case LEFT_ANKLE = "left ankle" + case RIGHT_ANKLE = "right ankle" + + /// List of lines connecting each part. + static let lines = [ + (from: BodyPart.LEFT_WRIST, to: BodyPart.LEFT_ELBOW), + (from: BodyPart.LEFT_ELBOW, to: BodyPart.LEFT_SHOULDER), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.RIGHT_SHOULDER, to: BodyPart.RIGHT_ELBOW), + (from: BodyPart.RIGHT_ELBOW, to: BodyPart.RIGHT_WRIST), + (from: BodyPart.LEFT_SHOULDER, to: BodyPart.LEFT_HIP), + (from: BodyPart.LEFT_HIP, to: BodyPart.RIGHT_HIP), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_SHOULDER), + (from: BodyPart.LEFT_HIP, to: BodyPart.LEFT_KNEE), + (from: BodyPart.LEFT_KNEE, to: BodyPart.LEFT_ANKLE), + (from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_KNEE), + (from: BodyPart.RIGHT_KNEE, to: BodyPart.RIGHT_ANKLE), + ] +} + +// MARK: - Delegates Enum +enum Delegates: Int, CaseIterable { + case CPU + case Metal + case CoreML + + var description: String { + switch self { + case .CPU: + return "CPU" + case .Metal: + return "GPU" + case .CoreML: + return "NPU" + } + } +} + +// MARK: - Custom Errors +enum PostprocessError: Error { + case missingBodyPart(of: BodyPart) +} + +// MARK: - Information about the model file. +typealias FileInfo = (name: String, extension: String) + +enum Model { + static let file: FileInfo = ( + name: "model_opt", extension: "tflite" + ) + + static let input = (batchSize: 1, height: 256, width: 256, channelSize: 3) + static let output = (batchSize: 1, height: 256, width: 256, channelSize: 1) + static let isQuantized = false +} + + +extension Array { + /// Creates a new array from the bytes of the given unsafe data. + /// + /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit + /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in + /// the `unsafeData`'s buffer to a new array returns an unsafe copy. + /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of + /// `MemoryLayout.stride`. + /// - Parameter unsafeData: The data containing the bytes to turn into an array. + init?(unsafeData: Data) { + guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } + #if swift(>=5.0) + self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } + #else + self = unsafeData.withUnsafeBytes { + .init(UnsafeBufferPointer( + start: $0, + count: unsafeData.count / MemoryLayout.stride + )) + } + #endif // swift(>=5.0) + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..a04c79f554777863bd0dc8287bfd60704ce28bf2 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..5f5623794bd35b9bb75efd7b7e249fd7357fdfbd --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Storyboards/Base.lproj/Main.storyboard @@ -0,0 +1,236 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift new file mode 100644 index 0000000000000000000000000000000000000000..fbb51b5a303412c0bbd158d76d025cf88fee6f8f --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/ViewControllers/ViewController.swift @@ -0,0 +1,489 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import AVFoundation +import UIKit +import os + + +public struct PixelData { + var a: UInt8 + var r: UInt8 + var g: UInt8 + var b: UInt8 +} + +extension UIImage { + convenience init?(pixels: [PixelData], width: Int, height: Int) { + guard width > 0 && height > 0, pixels.count == width * height else { return nil } + var data = pixels + guard let providerRef = CGDataProvider(data: Data(bytes: &data, count: data.count * MemoryLayout.size) as CFData) + else { return nil } + guard let cgim = CGImage( + width: width, + height: height, + bitsPerComponent: 8, + bitsPerPixel: 32, + bytesPerRow: width * MemoryLayout.size, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.premultipliedFirst.rawValue), + provider: providerRef, + decode: nil, + shouldInterpolate: false, + intent: .defaultIntent) + else { return nil } + self.init(cgImage: cgim) + } +} + + +class ViewController: UIViewController { + // MARK: Storyboards Connections + @IBOutlet weak var previewView: PreviewView! + + //@IBOutlet weak var overlayView: OverlayView! + @IBOutlet weak var overlayView: UIImageView! + + private var imageView : UIImageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)) + + private var imageViewInitialized: Bool = false + + @IBOutlet weak var resumeButton: UIButton! + @IBOutlet weak var cameraUnavailableLabel: UILabel! + + @IBOutlet weak var tableView: UITableView! + + @IBOutlet weak var threadCountLabel: UILabel! + @IBOutlet weak var threadCountStepper: UIStepper! + + @IBOutlet weak var delegatesControl: UISegmentedControl! + + // MARK: ModelDataHandler traits + var threadCount: Int = Constants.defaultThreadCount + var delegate: Delegates = Constants.defaultDelegate + + // MARK: Result Variables + // Inferenced data to render. + private var inferencedData: InferencedData? + + // Minimum score to render the result. + private let minimumScore: Float = 0.5 + + private var avg_latency: Double = 0.0 + + // Relative location of `overlayView` to `previewView`. + private var overlayViewFrame: CGRect? + + private var previewViewFrame: CGRect? + + // MARK: Controllers that manage functionality + // Handles all the camera related functionality + private lazy var cameraCapture = CameraFeedManager(previewView: previewView) + + // Handles all data preprocessing and makes calls to run inference. + private var modelDataHandler: ModelDataHandler? + + // MARK: View Handling Methods + override func viewDidLoad() { + super.viewDidLoad() + + do { + modelDataHandler = try ModelDataHandler() + } catch let error { + fatalError(error.localizedDescription) + } + + cameraCapture.delegate = self + tableView.delegate = self + tableView.dataSource = self + + // MARK: UI Initialization + // Setup thread count stepper with white color. + // https://forums.developer.apple.com/thread/121495 + threadCountStepper.setDecrementImage( + threadCountStepper.decrementImage(for: .normal), for: .normal) + threadCountStepper.setIncrementImage( + threadCountStepper.incrementImage(for: .normal), for: .normal) + // Setup initial stepper value and its label. + threadCountStepper.value = Double(Constants.defaultThreadCount) + threadCountLabel.text = Constants.defaultThreadCount.description + + // Setup segmented controller's color. + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.lightGray], + for: .normal) + delegatesControl.setTitleTextAttributes( + [NSAttributedString.Key.foregroundColor: UIColor.black], + for: .selected) + // Remove existing segments to initialize it with `Delegates` entries. + delegatesControl.removeAllSegments() + Delegates.allCases.forEach { delegate in + delegatesControl.insertSegment( + withTitle: delegate.description, + at: delegate.rawValue, + animated: false) + } + delegatesControl.selectedSegmentIndex = 0 + } + + override func viewWillAppear(_ animated: Bool) { + super.viewWillAppear(animated) + + cameraCapture.checkCameraConfigurationAndStartSession() + } + + override func viewWillDisappear(_ animated: Bool) { + cameraCapture.stopSession() + } + + override func viewDidLayoutSubviews() { + overlayViewFrame = overlayView.frame + previewViewFrame = previewView.frame + } + + // MARK: Button Actions + @IBAction func didChangeThreadCount(_ sender: UIStepper) { + let changedCount = Int(sender.value) + if threadCountLabel.text == changedCount.description { + return + } + + do { + modelDataHandler = try ModelDataHandler(threadCount: changedCount, delegate: delegate) + } catch let error { + fatalError(error.localizedDescription) + } + threadCount = changedCount + threadCountLabel.text = changedCount.description + os_log("Thread count is changed to: %d", threadCount) + } + + @IBAction func didChangeDelegate(_ sender: UISegmentedControl) { + guard let changedDelegate = Delegates(rawValue: delegatesControl.selectedSegmentIndex) else { + fatalError("Unexpected value from delegates segemented controller.") + } + do { + modelDataHandler = try ModelDataHandler(threadCount: threadCount, delegate: changedDelegate) + } catch let error { + fatalError(error.localizedDescription) + } + delegate = changedDelegate + os_log("Delegate is changed to: %s", delegate.description) + } + + @IBAction func didTapResumeButton(_ sender: Any) { + cameraCapture.resumeInterruptedSession { complete in + + if complete { + self.resumeButton.isHidden = true + self.cameraUnavailableLabel.isHidden = true + } else { + self.presentUnableToResumeSessionAlert() + } + } + } + + func presentUnableToResumeSessionAlert() { + let alert = UIAlertController( + title: "Unable to Resume Session", + message: "There was an error while attempting to resume session.", + preferredStyle: .alert + ) + alert.addAction(UIAlertAction(title: "OK", style: .default, handler: nil)) + + self.present(alert, animated: true) + } +} + +// MARK: - CameraFeedManagerDelegate Methods +extension ViewController: CameraFeedManagerDelegate { + func cameraFeedManager(_ manager: CameraFeedManager, didOutput pixelBuffer: CVPixelBuffer) { + runModel(on: pixelBuffer) + } + + // MARK: Session Handling Alerts + func cameraFeedManagerDidEncounterSessionRunTimeError(_ manager: CameraFeedManager) { + // Handles session run time error by updating the UI and providing a button if session can be + // manually resumed. + self.resumeButton.isHidden = false + } + + func cameraFeedManager( + _ manager: CameraFeedManager, sessionWasInterrupted canResumeManually: Bool + ) { + // Updates the UI when session is interupted. + if canResumeManually { + self.resumeButton.isHidden = false + } else { + self.cameraUnavailableLabel.isHidden = false + } + } + + func cameraFeedManagerDidEndSessionInterruption(_ manager: CameraFeedManager) { + // Updates UI once session interruption has ended. + self.cameraUnavailableLabel.isHidden = true + self.resumeButton.isHidden = true + } + + func presentVideoConfigurationErrorAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Confirguration Failed", message: "Configuration of camera has failed.", + preferredStyle: .alert) + let okAction = UIAlertAction(title: "OK", style: .cancel, handler: nil) + alertController.addAction(okAction) + + present(alertController, animated: true, completion: nil) + } + + func presentCameraPermissionsDeniedAlert(_ manager: CameraFeedManager) { + let alertController = UIAlertController( + title: "Camera Permissions Denied", + message: + "Camera permissions have been denied for this app. You can change this by going to Settings", + preferredStyle: .alert) + + let cancelAction = UIAlertAction(title: "Cancel", style: .cancel, handler: nil) + let settingsAction = UIAlertAction(title: "Settings", style: .default) { action in + if let url = URL.init(string: UIApplication.openSettingsURLString) { + UIApplication.shared.open(url, options: [:], completionHandler: nil) + } + } + + alertController.addAction(cancelAction) + alertController.addAction(settingsAction) + + present(alertController, animated: true, completion: nil) + } + + @objc func runModel(on pixelBuffer: CVPixelBuffer) { + guard let overlayViewFrame = overlayViewFrame, let previewViewFrame = previewViewFrame + else { + return + } + // To put `overlayView` area as model input, transform `overlayViewFrame` following transform + // from `previewView` to `pixelBuffer`. `previewView` area is transformed to fit in + // `pixelBuffer`, because `pixelBuffer` as a camera output is resized to fill `previewView`. + // https://developer.apple.com/documentation/avfoundation/avlayervideogravity/1385607-resizeaspectfill + let modelInputRange = overlayViewFrame.applying( + previewViewFrame.size.transformKeepAspect(toFitIn: pixelBuffer.size)) + + // Run Midas model. + guard + let (result, width, height, times) = self.modelDataHandler?.runMidas( + on: pixelBuffer, + from: modelInputRange, + to: overlayViewFrame.size) + else { + os_log("Cannot get inference result.", type: .error) + return + } + + if avg_latency == 0 { + avg_latency = times.inference + } else { + avg_latency = times.inference*0.1 + avg_latency*0.9 + } + + // Udpate `inferencedData` to render data in `tableView`. + inferencedData = InferencedData(score: Float(avg_latency), times: times) + + //let height = 256 + //let width = 256 + + let outputs = result + let outputs_size = width * height; + + var multiplier : Float = 1.0; + + let max_val : Float = outputs.max() ?? 0 + let min_val : Float = outputs.min() ?? 0 + + if((max_val - min_val) > 0) { + multiplier = 255 / (max_val - min_val); + } + + // Draw result. + DispatchQueue.main.async { + self.tableView.reloadData() + + var pixels: [PixelData] = .init(repeating: .init(a: 255, r: 0, g: 0, b: 0), count: width * height) + + for i in pixels.indices { + //if(i < 1000) + //{ + let val = UInt8((outputs[i] - min_val) * multiplier) + + pixels[i].r = val + pixels[i].g = val + pixels[i].b = val + //} + } + + + /* + pixels[i].a = 255 + pixels[i].r = .random(in: 0...255) + pixels[i].g = .random(in: 0...255) + pixels[i].b = .random(in: 0...255) + } + */ + + DispatchQueue.main.async { + let image = UIImage(pixels: pixels, width: width, height: height) + + self.imageView.image = image + + if (self.imageViewInitialized == false) { + self.imageViewInitialized = true + self.overlayView.addSubview(self.imageView) + self.overlayView.setNeedsDisplay() + } + } + + /* + let image = UIImage(pixels: pixels, width: width, height: height) + + var imageView : UIImageView + imageView = UIImageView(frame:CGRect(x:0, y:0, width:400, height:400)); + imageView.image = image + self.overlayView.addSubview(imageView) + self.overlayView.setNeedsDisplay() + */ + } + } +/* + func drawResult(of result: Result) { + self.overlayView.dots = result.dots + self.overlayView.lines = result.lines + self.overlayView.setNeedsDisplay() + } + + func clearResult() { + self.overlayView.clear() + self.overlayView.setNeedsDisplay() + } + */ + +} + + +// MARK: - TableViewDelegate, TableViewDataSource Methods +extension ViewController: UITableViewDelegate, UITableViewDataSource { + func numberOfSections(in tableView: UITableView) -> Int { + return InferenceSections.allCases.count + } + + func tableView(_ tableView: UITableView, numberOfRowsInSection section: Int) -> Int { + guard let section = InferenceSections(rawValue: section) else { + return 0 + } + + return section.subcaseCount + } + + func tableView(_ tableView: UITableView, cellForRowAt indexPath: IndexPath) -> UITableViewCell { + let cell = tableView.dequeueReusableCell(withIdentifier: "InfoCell") as! InfoCell + guard let section = InferenceSections(rawValue: indexPath.section) else { + return cell + } + guard let data = inferencedData else { return cell } + + var fieldName: String + var info: String + + switch section { + case .Score: + fieldName = section.description + info = String(format: "%.3f", data.score) + case .Time: + guard let row = ProcessingTimes(rawValue: indexPath.row) else { + return cell + } + var time: Double + switch row { + case .InferenceTime: + time = data.times.inference + } + fieldName = row.description + info = String(format: "%.2fms", time) + } + + cell.fieldNameLabel.text = fieldName + cell.infoLabel.text = info + + return cell + } + + func tableView(_ tableView: UITableView, heightForRowAt indexPath: IndexPath) -> CGFloat { + guard let section = InferenceSections(rawValue: indexPath.section) else { + return 0 + } + + var height = Traits.normalCellHeight + if indexPath.row == section.subcaseCount - 1 { + height = Traits.separatorCellHeight + Traits.bottomSpacing + } + return height + } + +} + +// MARK: - Private enums +/// UI coinstraint values +fileprivate enum Traits { + static let normalCellHeight: CGFloat = 35.0 + static let separatorCellHeight: CGFloat = 25.0 + static let bottomSpacing: CGFloat = 30.0 +} + +fileprivate struct InferencedData { + var score: Float + var times: Times +} + +/// Type of sections in Info Cell +fileprivate enum InferenceSections: Int, CaseIterable { + case Score + case Time + + var description: String { + switch self { + case .Score: + return "Average" + case .Time: + return "Processing Time" + } + } + + var subcaseCount: Int { + switch self { + case .Score: + return 1 + case .Time: + return ProcessingTimes.allCases.count + } + } +} + +/// Type of processing times in Time section in Info Cell +fileprivate enum ProcessingTimes: Int, CaseIterable { + case InferenceTime + + var description: String { + switch self { + case .InferenceTime: + return "Inference Time" + } + } +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift new file mode 100644 index 0000000000000000000000000000000000000000..3b53910b57563b6a195fd53321fa2a24ebaf3d3f --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Midas/Views/OverlayView.swift @@ -0,0 +1,63 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import UIKit + +/// UIView for rendering inference output. +class OverlayView: UIView { + + var dots = [CGPoint]() + var lines = [Line]() + + override func draw(_ rect: CGRect) { + for dot in dots { + drawDot(of: dot) + } + for line in lines { + drawLine(of: line) + } + } + + func drawDot(of dot: CGPoint) { + let dotRect = CGRect( + x: dot.x - Traits.dot.radius / 2, y: dot.y - Traits.dot.radius / 2, + width: Traits.dot.radius, height: Traits.dot.radius) + let dotPath = UIBezierPath(ovalIn: dotRect) + + Traits.dot.color.setFill() + dotPath.fill() + } + + func drawLine(of line: Line) { + let linePath = UIBezierPath() + linePath.move(to: CGPoint(x: line.from.x, y: line.from.y)) + linePath.addLine(to: CGPoint(x: line.to.x, y: line.to.y)) + linePath.close() + + linePath.lineWidth = Traits.line.width + Traits.line.color.setStroke() + + linePath.stroke() + } + + func clear() { + self.dots = [] + self.lines = [] + } +} + +private enum Traits { + static let dot = (radius: CGFloat(5), color: UIColor.orange) + static let line = (width: CGFloat(1.0), color: UIColor.orange) +} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..5e9461fc96dbbe3c22ca6bbf2bfd7df3981b9462 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/Podfile @@ -0,0 +1,12 @@ +# Uncomment the next line to define a global platform for your project + platform :ios, '12.0' + +target 'Midas' do + # Comment the next line if you're not using Swift and don't want to use dynamic frameworks + use_frameworks! + + # Pods for Midas + pod 'TensorFlowLiteSwift', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/CoreML', '~> 0.0.1-nightly' + pod 'TensorFlowLiteSwift/Metal', '~> 0.0.1-nightly' +end diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7b8eb29feaa21e67814b035dbd5c5fb2c62a4151 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/README.md @@ -0,0 +1,105 @@ +# Tensorflow Lite MiDaS iOS Example + +### Requirements + +- XCode 11.0 or above +- iOS 12.0 or above, [iOS 14 breaks the NPU Delegate](https://github.com/tensorflow/tensorflow/issues/43339) +- TensorFlow 2.4.0, TensorFlowLiteSwift -> 0.0.1-nightly + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. It uses TensorFlowLiteSwift / C++ libraries on iOS. The code is written in Swift. + +Paper: https://arxiv.org/abs/1907.01341 + +> Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +> René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +### Install TensorFlow + +Set default python version to python3: + +``` +echo 'export PATH=/usr/local/opt/python/libexec/bin:$PATH' >> ~/.zshenv +echo 'alias python=python3' >> ~/.zshenv +echo 'alias pip=pip3' >> ~/.zshenv +``` + +Install TensorFlow + +```shell +pip install tensorflow +``` + +### Install TensorFlowLiteSwift via Cocoapods + +Set required TensorFlowLiteSwift version in the file (`0.0.1-nightly` is recommended): https://github.com/isl-org/MiDaS/blob/master/mobile/ios/Podfile#L9 + +Install: brew, ruby, cocoapods + +``` +ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" +brew install mc rbenv ruby-build +sudo gem install cocoapods +``` + + +The TensorFlowLiteSwift library is available in [Cocoapods](https://cocoapods.org/), to integrate it to our project, we can run in the root directory of the project: + +```ruby +pod install +``` + +Now open the `Midas.xcworkspace` file in XCode, select your iPhone device (XCode->Product->Destination->iPhone) and launch it (cmd + R). If everything works well, you should see a real-time depth map from your camera. + +### Model + +The TensorFlow (TFlite) model `midas.tflite` is in the folder `/Midas/Model` + + +To use another model, you should convert it from TensorFlow saved-model to TFlite model (so that it can be deployed): + +```python +saved_model_export_dir = "./saved_model" +converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_export_dir) +tflite_model = converter.convert() +open(model_tflite_name, "wb").write("model.tflite") +``` + +### Setup XCode + +* Open directory `.xcworkspace` from the XCode + +* Press on your ProjectName (left-top corner) -> change Bundle Identifier to `com.midas.tflite-npu` or something like this (it should be unique) + +* select your Developer Team (your should be signed-in by using your AppleID) + +* Connect your iPhone (if you want to run it on real device instead of simulator), select your iPhone device (XCode->Product->Destination->iPhone) + +* Click in the XCode: Product -> Run + +* On your iPhone device go to the: Settings -> General -> Device Management (or Profiles) -> Apple Development -> Trust Apple Development + +---- + +Original repository: https://github.com/isl-org/MiDaS + + +### Examples: + +| ![photo_2020-09-27_17-43-20](https://user-images.githubusercontent.com/4096485/94367804-9610de80-00e9-11eb-8a23-8b32a6f52d41.jpg) | ![photo_2020-09-27_17-49-22](https://user-images.githubusercontent.com/4096485/94367974-7201cd00-00ea-11eb-8e0a-68eb9ea10f63.jpg) | ![photo_2020-09-27_17-52-30](https://user-images.githubusercontent.com/4096485/94367976-729a6380-00ea-11eb-8ce0-39d3e26dd550.jpg) | ![photo_2020-09-27_17-43-21](https://user-images.githubusercontent.com/4096485/94367807-97420b80-00e9-11eb-9dcd-848ad9e89e03.jpg) | +|---|---|---|---| + +## LICENSE + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh new file mode 100644 index 0000000000000000000000000000000000000000..d737b39d966278f5c6bc29802526ab86f8473de4 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/mobile/ios/RunScripts/download_models.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Download TF Lite model from the internet if it does not exist. + +TFLITE_MODEL="model_opt.tflite" +TFLITE_FILE="Midas/Model/${TFLITE_MODEL}" +MODEL_SRC="https://github.com/isl-org/MiDaS/releases/download/v2/${TFLITE_MODEL}" + +if test -f "${TFLITE_FILE}"; then + echo "INFO: TF Lite model already exists. Skip downloading and use the local model." +else + curl --create-dirs -o "${TFLITE_FILE}" -LJO "${MODEL_SRC}" + echo "INFO: Downloaded TensorFlow Lite model to ${TFLITE_FILE}." +fi + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6606ec028d1c629986e7019fe3564f5b4bfe425d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Alexey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1d43c2606767798ee46b34292e0483197424ec23 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/README.md @@ -0,0 +1,131 @@ +# MiDaS for ROS1 by using LibTorch in C++ + +### Requirements + +- Ubuntu 17.10 / 18.04 / 20.04, Debian Stretch +- ROS Melodic for Ubuntu (17.10 / 18.04) / Debian Stretch, ROS Noetic for Ubuntu 20.04 +- C++11 +- LibTorch >= 1.6 + +## Quick Start with a MiDaS Example + +MiDaS is a neural network to compute depth from a single image. + +* input from `image_topic`: `sensor_msgs/Image` - `RGB8` image with any shape +* output to `midas_topic`: `sensor_msgs/Image` - `TYPE_32FC1` inverse relative depth maps in range [0 - 255] with original size and channels=1 + +### Install Dependecies + +* install ROS Melodic for Ubuntu 17.10 / 18.04: +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_melodic_ubuntu_17_18.sh +./install_ros_melodic_ubuntu_17_18.sh +``` + +or Noetic for Ubuntu 20.04: + +```bash +wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_noetic_ubuntu_20.sh +./install_ros_noetic_ubuntu_20.sh +``` + + +* install LibTorch 1.7 with CUDA 11.0: + +On **Jetson (ARM)**: +```bash +wget https://nvidia.box.com/shared/static/wa34qwrwtk9njtyarwt5nvo6imenfy26.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl +sudo apt-get install python3-pip libopenblas-base libopenmpi-dev +pip3 install Cython +pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl +``` +Or compile LibTorch from source: https://github.com/pytorch/pytorch#from-source + +On **Linux (x86_64)**: +```bash +cd ~/ +wget https://download.pytorch.org/libtorch/cu110/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu110.zip +unzip libtorch-cxx11-abi-shared-with-deps-1.7.0+cu110.zip +``` + +* create symlink for OpenCV: + +```bash +sudo ln -s /usr/include/opencv4 /usr/include/opencv +``` + +* download and install MiDaS: + +```bash +source ~/.bashrc +cd ~/ +mkdir catkin_ws +cd catkin_ws +git clone https://github.com/isl-org/MiDaS +mkdir src +cp -r MiDaS/ros/* src + +chmod +x src/additions/*.sh +chmod +x src/*.sh +chmod +x src/midas_cpp/scripts/*.py +cp src/additions/do_catkin_make.sh ./do_catkin_make.sh +./do_catkin_make.sh +./src/additions/downloads.sh +``` + +### Usage + +* run only `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + +#### Test + +* Test - capture video and show result in the window: + * place any `test.mp4` video file to the directory `~/catkin_ws/src/` + * run `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` + * run test nodes in another terminal: `cd ~/catkin_ws/src && ./run_talker_listener_test.sh` and wait 30 seconds + + (to use Python 2, run command `sed -i 's/python3/python2/' ~/catkin_ws/src/midas_cpp/scripts/*.py` ) + +## Mobile version of MiDaS - Monocular Depth Estimation + +### Accuracy + +* MiDaS v2 small - ResNet50 default-decoder 384x384 +* MiDaS v2.1 small - EfficientNet-Lite3 small-decoder 256x256 + +**Zero-shot error** (the lower - the better): + +| Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | +|---|---|---|---|---|---|---| +| MiDaS v2 small 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | +| MiDaS v2.1 small 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | +| Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | + +None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. + +### Inference speed (FPS) on nVidia GPU + +Inference speed excluding pre and post processing, batch=1, **Frames Per Second** (the higher - the better): + +| Model | Jetson Nano, FPS | RTX 2080Ti, FPS | +|---|---|---| +| MiDaS v2 small 384x384 | 1.6 | 117 | +| MiDaS v2.1 small 256x256 | 8.1 | 232 | +| SpeedUp, X times | **5x** | **2x** | + +### Citation + +This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): + +>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer +René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2020, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d416fc00282aab146326bbba12a9274e1ba29b8 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/do_catkin_make.sh @@ -0,0 +1,5 @@ +mkdir src +catkin_make +source devel/setup.bash +echo $ROS_PACKAGE_PATH +chmod +x ./devel/setup.bash diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh new file mode 100644 index 0000000000000000000000000000000000000000..9c967d4e2dc7997da26399a063b5a54ecc314eb1 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/downloads.sh @@ -0,0 +1,5 @@ +mkdir ~/.ros +wget https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small-traced.pt +cp ./model-small-traced.pt ~/.ros/model-small-traced.pt + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh new file mode 100644 index 0000000000000000000000000000000000000000..b868112631e9d9bc7bccb601407dfc857b8a99d5 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_melodic_ubuntu_17_18.sh @@ -0,0 +1,34 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 +sudo apt-key adv --keyserver 'hkp://ha.pool.sks-keyservers.net:80' --recv-key 421C365BD9FF1F717815A3895523BAEEB01FA116 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-melodic-desktop-full + +printf "\nsource /opt/ros/melodic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python-rosinstall +sudo apt-get install python-catkin-tools +sudo apt-get install python-rospy +sudo apt-get install python-rosdep +sudo apt-get install python-roscd +sudo apt-get install python-pip \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh new file mode 100644 index 0000000000000000000000000000000000000000..d73ea1a3d92359819167d735a92d2a650b9bc245 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/install_ros_noetic_ubuntu_20.sh @@ -0,0 +1,33 @@ +#@title { display-mode: "code" } + +#from http://wiki.ros.org/indigo/Installation/Ubuntu + +#1.2 Setup sources.list +sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' + +# 1.3 Setup keys +sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 + +curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - + +# 1.4 Installation +sudo apt-get update +sudo apt-get upgrade + +# Desktop-Full Install: +sudo apt-get install ros-noetic-desktop-full + +printf "\nsource /opt/ros/noetic/setup.bash\n" >> ~/.bashrc + +# 1.5 Initialize rosdep +sudo rosdep init +rosdep update + + +# 1.7 Getting rosinstall (python) +sudo apt-get install python3-rosinstall +sudo apt-get install python3-catkin-tools +sudo apt-get install python3-rospy +sudo apt-get install python3-rosdep +sudo apt-get install python3-roscd +sudo apt-get install python3-pip \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..d0ef6073a9c9ce40744e1c81d557c1c68255b95e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/additions/make_package_cpp.sh @@ -0,0 +1,16 @@ +cd ~/catkin_ws/src +catkin_create_pkg midas_cpp std_msgs roscpp cv_bridge sensor_msgs image_transport +cd ~/catkin_ws +catkin_make + +chmod +x ~/catkin_ws/devel/setup.bash +printf "\nsource ~/catkin_ws/devel/setup.bash" >> ~/.bashrc +source ~/catkin_ws/devel/setup.bash + + +sudo rosdep init +rosdep update +#rospack depends1 midas_cpp +roscd midas_cpp +#cat package.xml +#rospack depends midas_cpp \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh new file mode 100644 index 0000000000000000000000000000000000000000..5a0d1583fffdc49216c625dfd07af2ae3b01a7a0 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/launch_midas_cpp.sh @@ -0,0 +1,2 @@ +source ~/catkin_ws/devel/setup.bash +roslaunch midas_cpp midas_cpp.launch model_name:="model-small-traced.pt" input_topic:="image_topic" output_topic:="midas_topic" out_orig_size:="true" \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..885341691d217f9c4c8fcb1e4ff568d87788c7b8 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/CMakeLists.txt @@ -0,0 +1,189 @@ +cmake_minimum_required(VERSION 3.0.2) +project(midas_cpp) + +## Compile as C++11, supported in ROS Kinetic and newer +# add_compile_options(-std=c++11) + +## Find catkin macros and libraries +## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) +## is used, also find other catkin packages +find_package(catkin REQUIRED COMPONENTS + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs +) + +## System dependencies are found with CMake's conventions +# find_package(Boost REQUIRED COMPONENTS system) + +list(APPEND CMAKE_PREFIX_PATH "~/libtorch") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python3.6/dist-packages/torch/lib") +list(APPEND CMAKE_PREFIX_PATH "/usr/local/lib/python2.7/dist-packages/torch/lib") + +if(NOT EXISTS "~/libtorch") + if (EXISTS "/usr/local/lib/python3.6/dist-packages/torch") + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python3.6/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python3.6/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python3.6/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python3.6/dist-packages/torch) + + elseif (EXISTS "/usr/local/lib/python2.7/dist-packages/torch") + + include_directories(/usr/local/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include/torch/csrc/api/include) + include_directories(/usr/local/lib/python2.7/dist-packages/torch/include) + + link_directories(/usr/local/lib) + link_directories(/usr/local/lib/python2.7/dist-packages/torch/lib) + + set(CMAKE_PREFIX_PATH /usr/local/lib/python2.7/dist-packages/torch) + set(Boost_USE_MULTITHREADED ON) + set(Torch_DIR /usr/local/lib/python2.7/dist-packages/torch) + endif() +endif() + + + +find_package(Torch REQUIRED) +find_package(OpenCV REQUIRED) +include_directories( ${OpenCV_INCLUDE_DIRS} ) + +add_executable(midas_cpp src/main.cpp) +target_link_libraries(midas_cpp "${TORCH_LIBRARIES}" "${OpenCV_LIBS} ${catkin_LIBRARIES}") +set_property(TARGET midas_cpp PROPERTY CXX_STANDARD 14) + + + +################################### +## catkin specific configuration ## +################################### +## The catkin_package macro generates cmake config files for your package +## Declare things to be passed to dependent projects +## INCLUDE_DIRS: uncomment this if your package contains header files +## LIBRARIES: libraries you create in this project that dependent projects also need +## CATKIN_DEPENDS: catkin_packages dependent projects also need +## DEPENDS: system dependencies of this project that dependent projects also need +catkin_package( +# INCLUDE_DIRS include +# LIBRARIES midas_cpp +# CATKIN_DEPENDS cv_bridge image_transport roscpp sensor_msgs std_msgs +# DEPENDS system_lib +) + +########### +## Build ## +########### + +## Specify additional locations of header files +## Your package locations should be listed before other locations +include_directories( +# include + ${catkin_INCLUDE_DIRS} +) + +## Declare a C++ library +# add_library(${PROJECT_NAME} +# src/${PROJECT_NAME}/midas_cpp.cpp +# ) + +## Add cmake target dependencies of the library +## as an example, code may need to be generated before libraries +## either from message generation or dynamic reconfigure +# add_dependencies(${PROJECT_NAME} ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Declare a C++ executable +## With catkin_make all packages are built within a single CMake context +## The recommended prefix ensures that target names across packages don't collide +# add_executable(${PROJECT_NAME}_node src/midas_cpp_node.cpp) + +## Rename C++ executable without prefix +## The above recommended prefix causes long target names, the following renames the +## target back to the shorter version for ease of user use +## e.g. "rosrun someones_pkg node" instead of "rosrun someones_pkg someones_pkg_node" +# set_target_properties(${PROJECT_NAME}_node PROPERTIES OUTPUT_NAME node PREFIX "") + +## Add cmake target dependencies of the executable +## same as for the library above +# add_dependencies(${PROJECT_NAME}_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) + +## Specify libraries to link a library or executable target against +# target_link_libraries(${PROJECT_NAME}_node +# ${catkin_LIBRARIES} +# ) + +############# +## Install ## +############# + +# all install targets should use catkin DESTINATION variables +# See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html + +## Mark executable scripts (Python etc.) for installation +## in contrast to setup.py, you can choose the destination +# catkin_install_python(PROGRAMS +# scripts/my_python_script +# DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark executables for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_executables.html +# install(TARGETS ${PROJECT_NAME}_node +# RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +# ) + +## Mark libraries for installation +## See http://docs.ros.org/melodic/api/catkin/html/howto/format1/building_libraries.html +# install(TARGETS ${PROJECT_NAME} +# ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} +# RUNTIME DESTINATION ${CATKIN_GLOBAL_BIN_DESTINATION} +# ) + +## Mark cpp header files for installation +# install(DIRECTORY include/${PROJECT_NAME}/ +# DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} +# FILES_MATCHING PATTERN "*.h" +# PATTERN ".svn" EXCLUDE +# ) + +## Mark other files for installation (e.g. launch and bag files, etc.) +# install(FILES +# # myfile1 +# # myfile2 +# DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} +# ) + +############# +## Testing ## +############# + +## Add gtest based cpp test target and link libraries +# catkin_add_gtest(${PROJECT_NAME}-test test/test_midas_cpp.cpp) +# if(TARGET ${PROJECT_NAME}-test) +# target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) +# endif() + +## Add folders to be run by python nosetests +# catkin_add_nosetests(test) + +install(TARGETS ${PROJECT_NAME} + ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} + RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) + +add_custom_command( + TARGET midas_cpp POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_CURRENT_BINARY_DIR}/midas_cpp + ${CMAKE_SOURCE_DIR}/midas_cpp +) \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch new file mode 100644 index 0000000000000000000000000000000000000000..88e86f42f668e76ad4976ec6794a8cb0f20cac65 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_cpp.launch @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch new file mode 100644 index 0000000000000000000000000000000000000000..8817a4f4933c56986fe0edc0886b2fded3d3406d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/launch/midas_talker_listener.launch @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml new file mode 100644 index 0000000000000000000000000000000000000000..9cac90eba75409bd170f73531c54c83c52ff047a --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/package.xml @@ -0,0 +1,77 @@ + + + midas_cpp + 0.1.0 + The midas_cpp package + + Alexey Bochkovskiy + MIT + https://github.com/isl-org/MiDaS/tree/master/ros + + + + + + + TODO + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + catkin + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + cv_bridge + image_transport + roscpp + rospy + sensor_msgs + std_msgs + + + + + + + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py new file mode 100644 index 0000000000000000000000000000000000000000..6927ea7a83ac9309e5f883ee974a5dcfa8a2aa3b --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("midas_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py new file mode 100644 index 0000000000000000000000000000000000000000..20e235f6958d644b89383752ab18e9e2275f55e5 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/listener_original.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +from __future__ import print_function + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +import numpy as np +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + +class video_show: + + def __init__(self): + self.show_output = rospy.get_param('~show_output', True) + self.save_output = rospy.get_param('~save_output', False) + self.output_video_file = rospy.get_param('~output_video_file','result.mp4') + # rospy.loginfo(f"Listener original - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") + + self.bridge = CvBridge() + self.image_sub = rospy.Subscriber("image_topic", Image, self.callback) + + def callback(self, data): + try: + cv_image = self.bridge.imgmsg_to_cv2(data) + except CvBridgeError as e: + print(e) + return + + if cv_image.size == 0: + return + + rospy.loginfo("Listener_original: Received new frame") + cv_image = cv_image.astype("uint8") + + if self.show_output==True: + cv2.imshow("video_show_orig", cv_image) + cv2.waitKey(10) + + if self.save_output==True: + if self.video_writer_init==False: + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) + + self.out.write(cv_image) + + + +def main(args): + rospy.init_node('listener_original', anonymous=True) + ic = video_show() + try: + rospy.spin() + except KeyboardInterrupt: + print("Shutting down") + cv2.destroyAllWindows() + +if __name__ == '__main__': + main(sys.argv) \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py new file mode 100644 index 0000000000000000000000000000000000000000..8219cc8632484a2efd02984347c615efad6b78b2 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/scripts/talker.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + + +import roslib +#roslib.load_manifest('my_package') +import sys +import rospy +import cv2 +from std_msgs.msg import String +from sensor_msgs.msg import Image +from cv_bridge import CvBridge, CvBridgeError + + +def talker(): + rospy.init_node('talker', anonymous=True) + + use_camera = rospy.get_param('~use_camera', False) + input_video_file = rospy.get_param('~input_video_file','test.mp4') + # rospy.loginfo(f"Talker - params: use_camera={use_camera}, input_video_file={input_video_file}") + + # rospy.loginfo("Talker: Trying to open a video stream") + if use_camera == True: + cap = cv2.VideoCapture(0) + else: + cap = cv2.VideoCapture(input_video_file) + + pub = rospy.Publisher('image_topic', Image, queue_size=1) + rate = rospy.Rate(30) # 30hz + bridge = CvBridge() + + while not rospy.is_shutdown(): + ret, cv_image = cap.read() + if ret==False: + print("Talker: Video is over") + rospy.loginfo("Video is over") + return + + try: + image = bridge.cv2_to_imgmsg(cv_image, "bgr8") + except CvBridgeError as e: + rospy.logerr("Talker: cv2image conversion failed: ", e) + print(e) + continue + + rospy.loginfo("Talker: Publishing frame") + pub.publish(image) + rate.sleep() + +if __name__ == '__main__': + try: + talker() + except rospy.ROSInterruptException: + pass diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4fc72c6955f66af71c9cb1fc7a7b1f643129685 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/midas_cpp/src/main.cpp @@ -0,0 +1,285 @@ +#include +#include +#include +#include + +#include + +#include // One-stop header. + +#include +#include +#include +#include + +#include +#include + +// includes for OpenCV >= 3.x +#ifndef CV_VERSION_EPOCH +#include +#include +#include +#endif + +// OpenCV includes for OpenCV 2.x +#ifdef CV_VERSION_EPOCH +#include +#include +#include +#include +#endif + +static const std::string OPENCV_WINDOW = "Image window"; + +class Midas +{ + ros::NodeHandle nh_; + image_transport::ImageTransport it_; + image_transport::Subscriber image_sub_; + image_transport::Publisher image_pub_; + + torch::jit::script::Module module; + torch::Device device; + + auto ToTensor(cv::Mat img, bool show_output = false, bool unsqueeze = false, int unsqueeze_dim = 0) + { + //std::cout << "image shape: " << img.size() << std::endl; + at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte); + + if (unsqueeze) + { + tensor_image.unsqueeze_(unsqueeze_dim); + //std::cout << "tensors new shape: " << tensor_image.sizes() << std::endl; + } + + if (show_output) + { + std::cout << tensor_image.slice(2, 0, 1) << std::endl; + } + //std::cout << "tenor shape: " << tensor_image.sizes() << std::endl; + return tensor_image; + } + + auto ToInput(at::Tensor tensor_image) + { + // Create a vector of inputs. + return std::vector{tensor_image}; + } + + auto ToCvImage(at::Tensor tensor, int cv_type = CV_8UC3) + { + int width = tensor.sizes()[0]; + int height = tensor.sizes()[1]; + try + { + cv::Mat output_mat; + if (cv_type == CV_8UC4 || cv_type == CV_8UC3 || cv_type == CV_8UC2 || cv_type == CV_8UC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_32FC4 || cv_type == CV_32FC3 || cv_type == CV_32FC2 || cv_type == CV_32FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + else if (cv_type == CV_64FC4 || cv_type == CV_64FC3 || cv_type == CV_64FC2 || cv_type == CV_64FC1) { + cv::Mat cv_image(cv::Size{ height, width }, cv_type, tensor.data_ptr()); + output_mat = cv_image; + } + + //show_image(output_mat, "converted image from tensor"); + return output_mat.clone(); + } + catch (const c10::Error& e) + { + std::cout << "an error has occured : " << e.msg() << std::endl; + } + return cv::Mat(height, width, CV_8UC3); + } + + std::string input_topic, output_topic, model_name; + bool out_orig_size; + int net_width, net_height; + torch::NoGradGuard guard; + at::Tensor mean, std; + at::Tensor output, tensor; + +public: + Midas() + : nh_(), it_(nh_), device(torch::Device(torch::kCPU)) + { + ros::param::param("~input_topic", input_topic, "image_topic"); + ros::param::param("~output_topic", output_topic, "midas_topic"); + ros::param::param("~model_name", model_name, "model-small-traced.pt"); + ros::param::param("~out_orig_size", out_orig_size, true); + ros::param::param("~net_width", net_width, 256); + ros::param::param("~net_height", net_height, 256); + + std::cout << ", input_topic = " << input_topic << + ", output_topic = " << output_topic << + ", model_name = " << model_name << + ", out_orig_size = " << out_orig_size << + ", net_width = " << net_width << + ", net_height = " << net_height << + std::endl; + + // Subscrive to input video feed and publish output video feed + image_sub_ = it_.subscribe(input_topic, 1, &Midas::imageCb, this); + image_pub_ = it_.advertise(output_topic, 1); + + std::cout << "Try to load torchscript model \n"; + + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + module = torch::jit::load(model_name); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + exit(0); + } + + std::cout << "ok\n"; + + try { + module.eval(); + torch::jit::getProfilingMode() = false; + torch::jit::setGraphExecutorOptimize(true); + + mean = torch::tensor({ 0.485, 0.456, 0.406 }); + std = torch::tensor({ 0.229, 0.224, 0.225 }); + + if (torch::hasCUDA()) { + std::cout << "cuda is available" << std::endl; + at::globalContext().setBenchmarkCuDNN(true); + device = torch::Device(torch::kCUDA); + module.to(device); + mean = mean.to(device); + std = std.to(device); + } + } + catch (const c10::Error& e) + { + std::cerr << " module initialization: " << e.msg() << std::endl; + } + } + + ~Midas() + { + } + + void imageCb(const sensor_msgs::ImageConstPtr& msg) + { + cv_bridge::CvImagePtr cv_ptr; + try + { + // sensor_msgs::Image to cv::Mat + cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::RGB8); + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // pre-processing + auto tensor_cpu = ToTensor(cv_ptr->image); // OpenCV-image -> Libtorch-tensor + + try { + tensor = tensor_cpu.to(device); // move to device (CPU or GPU) + + tensor = tensor.toType(c10::kFloat); + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor = tensor.unsqueeze(0); + tensor = at::upsample_bilinear2d(tensor, { net_height, net_width }, true); // resize + tensor = tensor.squeeze(0); + tensor = tensor.permute({ 1, 2, 0 }); // CHW -> HWC + + tensor = tensor.div(255).sub(mean).div(std); // normalization + tensor = tensor.permute({ 2, 0, 1 }); // HWC -> CHW + tensor.unsqueeze_(0); // CHW -> NCHW + } + catch (const c10::Error& e) + { + std::cerr << " pre-processing exception: " << e.msg() << std::endl; + return; + } + + auto input_to_net = ToInput(tensor); // input to the network + + // inference + output; + try { + output = module.forward(input_to_net).toTensor(); // run inference + } + catch (const c10::Error& e) + { + std::cerr << " module.forward() exception: " << e.msg() << std::endl; + return; + } + + output = output.detach().to(torch::kF32); + + // move to CPU temporary + at::Tensor output_tmp = output; + output_tmp = output_tmp.to(torch::kCPU); + + // normalization + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::min(); + + for (int i = 0; i < net_width * net_height; ++i) { + float val = output_tmp.data_ptr()[i]; + if (min_val > val) min_val = val; + if (max_val < val) max_val = val; + } + float range_val = max_val - min_val; + + output = output.sub(min_val).div(range_val).mul(255.0F).clamp(0, 255).to(torch::kF32); // .to(torch::kU8); + + // resize to the original size if required + if (out_orig_size) { + try { + output = at::upsample_bilinear2d(output.unsqueeze(0), { cv_ptr->image.size().height, cv_ptr->image.size().width }, true); + output = output.squeeze(0); + } + catch (const c10::Error& e) + { + std::cout << " upsample_bilinear2d() exception: " << e.msg() << std::endl; + return; + } + } + output = output.permute({ 1, 2, 0 }).to(torch::kCPU); + + int cv_type = CV_32FC1; // CV_8UC1; + auto cv_img = ToCvImage(output, cv_type); + + sensor_msgs::Image img_msg; + + try { + // cv::Mat -> sensor_msgs::Image + std_msgs::Header header; // empty header + header.seq = 0; // user defined counter + header.stamp = ros::Time::now();// time + //cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::MONO8, cv_img); + cv_bridge::CvImage img_bridge = cv_bridge::CvImage(header, sensor_msgs::image_encodings::TYPE_32FC1, cv_img); + + img_bridge.toImageMsg(img_msg); // cv_bridge -> sensor_msgs::Image + } + catch (cv_bridge::Exception& e) + { + ROS_ERROR("cv_bridge exception: %s", e.what()); + return; + } + + // Output modified video stream + image_pub_.publish(img_msg); + } +}; + +int main(int argc, char** argv) +{ + ros::init(argc, argv, "midas", ros::init_options::AnonymousName); + Midas ic; + ros::spin(); + return 0; +} \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..a997c4261072d0d627598fe06a723fcc7522d347 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/ros/run_talker_listener_test.sh @@ -0,0 +1,16 @@ +# place any test.mp4 file near with this file + +# roscore +# rosnode kill -a + +source ~/catkin_ws/devel/setup.bash + +roscore & +P1=$! +rosrun midas_cpp talker.py & +P2=$! +rosrun midas_cpp listener_original.py & +P3=$! +rosrun midas_cpp listener.py & +P4=$! +wait $P1 $P2 $P3 $P4 \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py new file mode 100644 index 0000000000000000000000000000000000000000..5696ef0547af093713ea416d18edd77d11879d0a --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/run.py @@ -0,0 +1,277 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import torch +import utils +import cv2 +import argparse +import time + +import numpy as np + +from imutils.video import VideoStream +from midas.model_loader import default_models, load_model + +first_execution = True +def process(device, model, model_type, image, input_size, target_size, optimize, use_camera): + """ + Run the inference and interpolate. + + Args: + device (torch.device): the torch device used + model: the model used for inference + model_type: the type of the model + image: the image fed into the neural network + input_size: the size (width, height) of the neural network input (for OpenVINO) + target_size: the size (width, height) the neural network output is interpolated to + optimize: optimize the model to half-floats on CUDA? + use_camera: is the camera used? + + Returns: + the prediction + """ + global first_execution + + if "openvino" in model_type: + if first_execution or not use_camera: + print(f" Input resized to {input_size[0]}x{input_size[1]} before entering the encoder") + first_execution = False + + sample = [np.reshape(image, (1, 3, *input_size))] + prediction = model(sample)[model.output(0)][0] + prediction = cv2.resize(prediction, dsize=target_size, + interpolation=cv2.INTER_CUBIC) + else: + sample = torch.from_numpy(image).to(device).unsqueeze(0) + + if optimize and device == torch.device("cuda"): + if first_execution: + print(" Optimization to half-floats activated. Use with caution, because models like Swin require\n" + " float precision to work properly and may yield non-finite depth values to some extent for\n" + " half-floats.") + sample = sample.to(memory_format=torch.channels_last) + sample = sample.half() + + if first_execution or not use_camera: + height, width = sample.shape[2:] + print(f" Input resized to {width}x{height} before entering the encoder") + first_execution = False + + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=target_size[::-1], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + return prediction + + +def create_side_by_side(image, depth, grayscale): + """ + Take an RGB image and depth map and place them side by side. This includes a proper normalization of the depth map + for better visibility. + + Args: + image: the RGB image + depth: the depth map + grayscale: use a grayscale colormap? + + Returns: + the image and depth map place side by side + """ + depth_min = depth.min() + depth_max = depth.max() + normalized_depth = 255 * (depth - depth_min) / (depth_max - depth_min) + normalized_depth *= 3 + + right_side = np.repeat(np.expand_dims(normalized_depth, 2), 3, axis=2) / 3 + if not grayscale: + right_side = cv2.applyColorMap(np.uint8(right_side), cv2.COLORMAP_INFERNO) + + if image is None: + return right_side + else: + return np.concatenate((image, right_side), axis=1) + + +def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", optimize=False, side=False, height=None, + square=False, grayscale=False): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + model_type (str): the model type + optimize (bool): optimize the model to half-floats on CUDA? + side (bool): RGB and depth side by side in output images? + height (int): inference encoder image height + square (bool): resize to a square resolution? + grayscale (bool): use a grayscale colormap? + """ + print("Initialize") + + # select device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device: %s" % device) + + model, transform, net_w, net_h = load_model(device, model_path, model_type, optimize, height, square) + + # get input + if input_path is not None: + image_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(image_names) + else: + print("No input path specified. Grabbing images from camera.") + + # create output folder + if output_path is not None: + os.makedirs(output_path, exist_ok=True) + + print("Start processing") + + if input_path is not None: + if output_path is None: + print("Warning: No output path specified. Images will be processed but not shown or stored anywhere.") + for index, image_name in enumerate(image_names): + + print(" Processing {} ({}/{})".format(image_name, index + 1, num_images)) + + # input + original_image_rgb = utils.read_image(image_name) # in [0, 1] + image = transform({"image": original_image_rgb})["image"] + + # compute + with torch.no_grad(): + prediction = process(device, model, model_type, image, (net_w, net_h), original_image_rgb.shape[1::-1], + optimize, False) + + # output + if output_path is not None: + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(image_name))[0] + '-' + model_type + ) + if not side: + utils.write_depth(filename, prediction, grayscale, bits=2) + else: + original_image_bgr = np.flip(original_image_rgb, 2) + content = create_side_by_side(original_image_bgr*255, prediction, grayscale) + cv2.imwrite(filename + ".png", content) + utils.write_pfm(filename + ".pfm", prediction.astype(np.float32)) + + else: + with torch.no_grad(): + fps = 1 + video = VideoStream(0).start() + time_start = time.time() + frame_index = 0 + while True: + frame = video.read() + if frame is not None: + original_image_rgb = np.flip(frame, 2) # in [0, 255] (flip required to get RGB) + image = transform({"image": original_image_rgb/255})["image"] + + prediction = process(device, model, model_type, image, (net_w, net_h), + original_image_rgb.shape[1::-1], optimize, True) + + original_image_bgr = np.flip(original_image_rgb, 2) if side else None + content = create_side_by_side(original_image_bgr, prediction, grayscale) + cv2.imshow('MiDaS Depth Estimation - Press Escape to close window ', content/255) + + if output_path is not None: + filename = os.path.join(output_path, 'Camera' + '-' + model_type + '_' + str(frame_index)) + cv2.imwrite(filename + ".png", content) + + alpha = 0.1 + if time.time()-time_start > 0: + fps = (1 - alpha) * fps + alpha * 1 / (time.time()-time_start) # exponential moving average + time_start = time.time() + print(f"\rFPS: {round(fps,2)}", end="") + + if cv2.waitKey(1) == 27: # Escape key + break + + frame_index += 1 + print() + + print("Finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default=None, + help='Folder with input images (if no input path is specified, images are tried to be grabbed ' + 'from camera)' + ) + + parser.add_argument('-o', '--output_path', + default=None, + help='Folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default=None, + help='Path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='dpt_beit_large_512', + help='Model type: ' + 'dpt_beit_large_512, dpt_beit_large_384, dpt_beit_base_384, dpt_swin2_large_384, ' + 'dpt_swin2_base_384, dpt_swin2_tiny_256, dpt_swin_large_384, dpt_next_vit_large_384, ' + 'dpt_levit_224, dpt_large_384, dpt_hybrid_384, midas_v21_384, midas_v21_small_256 or ' + 'openvino_midas_v21_small_256' + ) + + parser.add_argument('-s', '--side', + action='store_true', + help='Output images contain RGB and depth images side by side' + ) + + parser.add_argument('--optimize', dest='optimize', action='store_true', help='Use half-float optimization') + parser.set_defaults(optimize=False) + + parser.add_argument('--height', + type=int, default=None, + help='Preferred height of images feed into the encoder during inference. Note that the ' + 'preferred height may differ from the actual height, because an alignment to multiples of ' + '32 takes place. Many models support only the height chosen during training, which is ' + 'used automatically if this parameter is not set.' + ) + parser.add_argument('--square', + action='store_true', + help='Option to resize images to a square resolution by changing their widths when images are ' + 'fed into the encoder during inference. If this parameter is not set, the aspect ratio of ' + 'images is tried to be preserved if supported by the model.' + ) + parser.add_argument('--grayscale', + action='store_true', + help='Use a grayscale colormap instead of the inferno one. Although the inferno colormap, ' + 'which is used by default, is better for visibility, it does not allow storing 16-bit ' + 'depth values in PNGs but only 8-bit ones due to the precision limitation of this ' + 'colormap.' + ) + + args = parser.parse_args() + + + if args.model_weights is None: + args.model_weights = default_models[args.model_type] + + # set torch options + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, args.side, args.height, + args.square, args.grayscale) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5b5fe0e63668eab45a55b140826cb3762862b17c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/README.md @@ -0,0 +1,147 @@ +## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer + +### TensorFlow inference using `.pb` and `.onnx` models + +1. [Run inference on TensorFlow-model by using TensorFlow](#run-inference-on-tensorflow-model-by-using-tensorFlow) + +2. [Run inference on ONNX-model by using TensorFlow](#run-inference-on-onnx-model-by-using-tensorflow) + +3. [Make ONNX model from downloaded Pytorch model file](#make-onnx-model-from-downloaded-pytorch-model-file) + + +### Run inference on TensorFlow-model by using TensorFlow + +1) Download the model weights [model-f6b98070.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pb) +and [model-small.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.pb) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_pb.py + ``` + + Or run the small model: + + ```shell + python tf/run_pb.py --model_weights model-small.pb --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + +### Run inference on ONNX-model by using ONNX-Runtime + +1) Download the model weights [model-f6b98070.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.onnx) +and [model-small.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.onnx) and place the +file in the `/tf/` folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX Runtime +pip install onnxruntime==1.5.2 +``` + +#### Usage + +1) Place one or more input images in the folder `tf/input`. + +2) Run the model: + + ```shell + python tf/run_onnx.py + ``` + + Or run the small model: + + ```shell + python tf/run_onnx.py --model_weights model-small.onnx --model_type small + ``` + +3) The resulting inverse depth maps are written to the `tf/output` folder. + + + +### Make ONNX model from downloaded Pytorch model file + +1) Download the model weights [model-f6b98070.pt](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pt) and place the +file in the root folder. + +2) Set up dependencies: + +```shell +# install OpenCV +pip install --upgrade pip +pip install opencv-python + +# install PyTorch TorchVision +pip install -I torch==1.7.0 torchvision==0.8.0 + +# install TensorFlow +pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 + +# install ONNX +pip install onnx==1.7.0 + +# install ONNX-TensorFlow +git clone https://github.com/onnx/onnx-tensorflow.git +cd onnx-tensorflow +git checkout 095b51b88e35c4001d70f15f80f31014b592b81e +pip install -e . +``` + +#### Usage + +1) Run the converter: + + ```shell + python tf/make_onnx_model.py + ``` + +2) The resulting `model-f6b98070.onnx` file is written to the `/tf/` folder. + + +### Requirements + + The code was tested with Python 3.6.9, PyTorch 1.5.1, TensorFlow 2.2.0, TensorFlow-addons 0.8.3, ONNX 1.7.0, ONNX-TensorFlow (GitHub-master-17.07.2020) and OpenCV 4.3.0. + +### Citation + +Please cite our paper if you use this code or any of the models: +``` +@article{Ranftl2019, + author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, + title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, + journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, + year = {2020}, +} +``` + +### License + +MIT License + + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/input/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d14b0e4e1d2ea70fa315fd7ca7dfd72440a19376 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/make_onnx_model.py @@ -0,0 +1,112 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import ntpath +import glob +import torch +import utils +import cv2 +import numpy as np +from torchvision.transforms import Compose, Normalize +from torchvision import transforms + +from shutil import copyfile +import fileinput +import sys +sys.path.append(os.getcwd() + '/..') + +def modify_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename, modify_filename+'.bak') + + with open(modify_filename, 'r') as file : + filedata = file.read() + + filedata = filedata.replace('align_corners=True', 'align_corners=False') + filedata = filedata.replace('import torch.nn as nn', 'import torch.nn as nn\nimport torchvision.models as models') + filedata = filedata.replace('torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")', 'models.resnext101_32x8d()') + + with open(modify_filename, 'w') as file: + file.write(filedata) + +def restore_file(): + modify_filename = '../midas/blocks.py' + copyfile(modify_filename+'.bak', modify_filename) + +modify_file() + +from midas.midas_net import MidasNet +from midas.transforms import Resize, NormalizeImage, PrepareForNet + +restore_file() + + +class MidasNet_preprocessing(MidasNet): + """Network for monocular depth estimation. + """ + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + mean = torch.tensor([0.485, 0.456, 0.406]) + std = torch.tensor([0.229, 0.224, 0.225]) + x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + + return MidasNet.forward(self, x) + + +def run(model_path): + """Run MonoDepthNN to compute depth maps. + + Args: + model_path (str): path to saved model + """ + print("initialize") + + # select device + + # load network + #model = MidasNet(model_path, non_negative=True) + model = MidasNet_preprocessing(model_path, non_negative=True) + + model.eval() + + print("start processing") + + # input + img_input = np.zeros((3, 384, 384), np.float32) + + # compute + with torch.no_grad(): + sample = torch.from_numpy(img_input).unsqueeze(0) + prediction = model.forward(sample) + prediction = ( + torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=img_input.shape[:2], + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + torch.onnx.export(model, sample, ntpath.basename(model_path).rsplit('.', 1)[0]+'.onnx', opset_version=9) + + print("finished") + + +if __name__ == "__main__": + # set paths + # MODEL_PATH = "model.pt" + MODEL_PATH = "../model-f6b98070.pt" + + # compute depth maps + run(MODEL_PATH) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/output/.placeholder new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..7107b99969a127f951814f743d5c562a436b2430 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_onnx.py @@ -0,0 +1,119 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import sys +import numpy as np +import argparse + +import onnx +import onnxruntime as rt + +from transforms import Resize, NormalizeImage, PrepareForNet + + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # select device + device = "CUDA:0" + #device = "CPU" + print("device: %s" % device) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + print("loading model...") + model = rt.InferenceSession(model_path) + input_name = model.get_inputs()[0].name + output_name = model.get_outputs()[0].name + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + output = model.run([output_name], {input_name: img_input.reshape(1, 3, net_h, net_w).astype(np.float32)})[0] + prediction = np.array(output).reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.onnx', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py new file mode 100644 index 0000000000000000000000000000000000000000..e46254f7b37f72e7d87672d70fd4b2f393ad7658 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/run_pb.py @@ -0,0 +1,135 @@ +"""Compute depth maps for images in the input folder. +""" +import os +import glob +import utils +import cv2 +import argparse + +import tensorflow as tf + +from transforms import Resize, NormalizeImage, PrepareForNet + +def run(input_path, output_path, model_path, model_type="large"): + """Run MonoDepthNN to compute depth maps. + + Args: + input_path (str): path to input folder + output_path (str): path to output folder + model_path (str): path to saved model + """ + print("initialize") + + # the runtime initialization will not allocate all memory on the device to avoid out of GPU memory + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + try: + for gpu in gpus: + #tf.config.experimental.set_memory_growth(gpu, True) + tf.config.experimental.set_virtual_device_configuration(gpu, + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4000)]) + except RuntimeError as e: + print(e) + + # network resolution + if model_type == "large": + net_w, net_h = 384, 384 + elif model_type == "small": + net_w, net_h = 256, 256 + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + # load network + graph_def = tf.compat.v1.GraphDef() + with tf.io.gfile.GFile(model_path, 'rb') as f: + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + + model_operations = tf.compat.v1.get_default_graph().get_operations() + input_node = '0:0' + output_layer = model_operations[len(model_operations) - 1].name + ':0' + print("Last layer name: ", output_layer) + + resize_image = Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="upper_bound", + image_interpolation_method=cv2.INTER_CUBIC, + ) + + def compose2(f1, f2): + return lambda x: f2(f1(x)) + + transform = compose2(resize_image, PrepareForNet()) + + # get input + img_names = glob.glob(os.path.join(input_path, "*")) + num_images = len(img_names) + + # create output folder + os.makedirs(output_path, exist_ok=True) + + print("start processing") + + with tf.compat.v1.Session() as sess: + try: + # load images + for ind, img_name in enumerate(img_names): + + print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) + + # input + img = utils.read_image(img_name) + img_input = transform({"image": img})["image"] + + # compute + prob_tensor = sess.graph.get_tensor_by_name(output_layer) + prediction, = sess.run(prob_tensor, {input_node: [img_input] }) + prediction = prediction.reshape(net_h, net_w) + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + # output + filename = os.path.join( + output_path, os.path.splitext(os.path.basename(img_name))[0] + ) + utils.write_depth(filename, prediction, bits=2) + + except KeyError: + print ("Couldn't find input node: ' + input_node + ' or output layer: " + output_layer + ".") + exit(-1) + + print("finished") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', + default='input', + help='folder with input images' + ) + + parser.add_argument('-o', '--output_path', + default='output', + help='folder for output images' + ) + + parser.add_argument('-m', '--model_weights', + default='model-f6b98070.pb', + help='path to the trained weights of model' + ) + + parser.add_argument('-t', '--model_type', + default='large', + help='model type: large or small' + ) + + args = parser.parse_args() + + # compute depth maps + run(args.input_path, args.output_path, args.model_weights, args.model_type) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9a54bd55f5e31a90fad21242efbfda5a6cc1a7 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/tf/utils.py @@ -0,0 +1,82 @@ +import numpy as np +import sys +import cv2 + + +def write_pfm(path, image, scale=1): + """Write pfm file. + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + +def read_image(path): + """Read image and output RGB image (0-1). + Args: + path (str): path to file + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = 0 + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3976fd97dfe6a9dc7d4fa144be8fcb0b18b2db --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/base_models/midas_repo/utils.py @@ -0,0 +1,199 @@ +"""Utils for monoDepth. +""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, grayscale, bits=1): + """Write depth map to png file. + + Args: + path (str): filepath without extension + depth (array): depth + grayscale (bool): use a grayscale colormap? + """ + if not grayscale: + bits = 1 + + if not np.isfinite(depth).all(): + depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0) + print("WARNING: Non-finite depth values present") + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.dtype) + + if not grayscale: + out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/builder.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0818311b642561712a03a66655c638ce09a04cca --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/builder.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module +from .depth_model import DepthModel + +def build_model(config) -> DepthModel: + """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. + This function should be used to construct models for training and evaluation. + + Args: + config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. + + Returns: + torch.nn.Module: Model corresponding to name and version as specified in config + """ + module_name = f"zoedepth.models.{config.model}" + try: + module = import_module(module_name) + except ModuleNotFoundError as e: + # print the original error message + print(e) + raise ValueError( + f"Model {config.model} not found. Refer above error for details.") from e + try: + get_version = getattr(module, "get_version") + except AttributeError as e: + raise ValueError( + f"Model {config.model} has no get_version function.") from e + return get_version(config.version_name).build_from_config(config) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/depth_model.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/depth_model.py @@ -0,0 +1,152 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +import PIL.Image +from PIL import Image +from typing import Union + + +class DepthModel(nn.Module): + def __init__(self): + super().__init__() + self.device = 'cpu' + + def to(self, device) -> nn.Module: + self.device = device + return super().to(device) + + def forward(self, x, *args, **kwargs): + raise NotImplementedError + + def _infer(self, x: torch.Tensor): + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + return self(x)['metric_depth'] + + def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: + """ + Inference interface for the model with padding augmentation + Padding augmentation fixes the boundary artifacts in the output depth map. + Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. + This augmentation pads the input image and crops the prediction back to the original size / view. + + Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to pad the input or not. Defaults to True. + fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. + fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. + upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. + padding_mode (str, optional): padding mode. Defaults to "reflect". + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # assert x is nchw and c = 3 + assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) + assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) + + if pad_input: + assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" + pad_h = int(np.sqrt(x.shape[2]/2) * fh) + pad_w = int(np.sqrt(x.shape[3]/2) * fw) + padding = [pad_w, pad_w] + if pad_h > 0: + padding += [pad_h, pad_h] + + x = F.pad(x, padding, mode=padding_mode, **kwargs) + out = self._infer(x) + if out.shape[-2:] != x.shape[-2:]: + out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) + if pad_input: + # crop to the original size, handling the case where pad_h and pad_w is 0 + if pad_h > 0: + out = out[:, :, pad_h:-pad_h,:] + if pad_w > 0: + out = out[:, :, :, pad_w:-pad_w] + return out + + def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model with horizontal flip augmentation + Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # infer with horizontal flip and average + out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) + out = (out + torch.flip(out_flip, dims=[3])) / 2 + return out + + def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + if with_flip_aug: + return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) + else: + return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + + @torch.no_grad() + def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: + """ + Inference interface for the model for PIL image + Args: + pil_img (PIL.Image.Image): input PIL image + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". + """ + x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) + out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) + if output_type == "numpy": + return out_tensor.squeeze().cpu().numpy() + elif output_type == "pil": + # uint16 is required for depth pil image + out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) + return Image.fromarray(out_16bit_numpy) + elif output_type == "tensor": + return out_tensor.squeeze().cpu() + else: + raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/attractor.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/attractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ffcb227ad160a17bfa55750aef900d7d9a10968 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/attractor.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/dist_layers.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/dist_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5979e0b63e29d3b9df370a28f07ae739d89e43b3 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/dist_layers.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/localbins_layers.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/localbins_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17942f2f99651635c122a368c0382374caddfea8 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/__pycache__/localbins_layers.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/attractor.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/attractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/attractor.py @@ -0,0 +1,208 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +@torch.jit.script +def exp_attractor(dx, alpha: float = 300, gamma: int = 2): + """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) + + +@torch.jit.script +def inv_attractor(dx, alpha: float = 300, gamma: int = 2): + """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center + This is the default one according to the accompanying paper. + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return dx.div(1+alpha*dx.pow(gamma)) + + +class AttractorLayer(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm + nn.ReLU(inplace=True) + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + eps = 1e-3 + A = A + eps + n, c, h, w = A.shape + A = A.view(n, self.n_attractors, 2, h, w) + A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w + A_normed = A[:, :, 0, ...] # n, na, h, w + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func(dist(A_normed.unsqueeze( + 2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + # .shape N, nbins, h, w + delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = (self.max_depth - self.min_depth) * \ + b_new_centers + self.min_depth + B_centers, _ = torch.sort(B_centers, dim=1) + B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) + return b_new_centers, B_centers + + +class AttractorLayerUnnormed(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are unbounded + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + n, c, h, w = A.shape + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func( + dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + delta_c += dist(A[:, i, ...].unsqueeze(1) - + b_centers) # .shape N, nbins, h, w + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = b_new_centers + + return b_new_centers, B_centers diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/dist_layers.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/dist_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/dist_layers.py @@ -0,0 +1,121 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +def log_binom(n, k, eps=1e-7): + """ log(nCk) using stirling approximation """ + n = n + eps + k = k + eps + return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) + + +class LogBinomial(nn.Module): + def __init__(self, n_classes=256, act=torch.softmax): + """Compute log binomial distribution for n_classes + + Args: + n_classes (int, optional): number of output classes. Defaults to 256. + """ + super().__init__() + self.K = n_classes + self.act = act + self.register_buffer('k_idx', torch.arange( + 0, n_classes).view(1, -1, 1, 1)) + self.register_buffer('K_minus_1', torch.Tensor( + [self.K-1]).view(1, -1, 1, 1)) + + def forward(self, x, t=1., eps=1e-4): + """Compute log binomial distribution for x + + Args: + x (torch.Tensor - NCHW): probabilities + t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. + eps (float, optional): Small number for numerical stability. Defaults to 1e-4. + + Returns: + torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) + """ + if x.ndim == 3: + x = x.unsqueeze(1) # make it nchw + + one_minus_x = torch.clamp(1 - x, eps, 1) + x = torch.clamp(x, eps, 1) + y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ + torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) + return self.act(y/t, dim=1) + + +class ConditionalLogBinomial(nn.Module): + def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): + """Conditional Log Binomial distribution + + Args: + in_features (int): number of input channels in main feature + condition_dim (int): number of input channels in condition feature + n_classes (int, optional): Number of classes. Defaults to 256. + bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. + p_eps (float, optional): small eps value. Defaults to 1e-4. + max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. + min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. + """ + super().__init__() + self.p_eps = p_eps + self.max_temp = max_temp + self.min_temp = min_temp + self.log_binomial_transform = LogBinomial(n_classes, act=act) + bottleneck = (in_features + condition_dim) // bottleneck_factor + self.mlp = nn.Sequential( + nn.Conv2d(in_features + condition_dim, bottleneck, + kernel_size=1, stride=1, padding=0), + nn.GELU(), + # 2 for p linear norm, 2 for t linear norm + nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), + nn.Softplus() + ) + + def forward(self, x, cond): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Main feature + cond (torch.Tensor - NCHW): condition feature + + Returns: + torch.Tensor: Output log binomial distribution + """ + pt = self.mlp(torch.concat((x, cond), dim=1)) + p, t = pt[:, :2, ...], pt[:, 2:, ...] + + p = p + self.p_eps + p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) + + t = t + self.p_eps + t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) + t = t.unsqueeze(1) + t = (self.max_temp - self.min_temp) * t + self.min_temp + + return self.log_binomial_transform(p, t) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/localbins_layers.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/localbins_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/localbins_layers.py @@ -0,0 +1,169 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class SeedBinRegressor(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Min depth value. Defaults to 1e-3. + max_depth (float, optional): Max depth value. Defaults to 10. + """ + super().__init__() + self.version = "1_1" + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B = self._net(x) + eps = 1e-3 + B = B + eps + B_widths_normed = B / B.sum(dim=1, keepdim=True) + B_widths = (self.max_depth - self.min_depth) * \ + B_widths_normed # .shape NCHW + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad( + B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) + return B_widths_normed, B_centers + + +class SeedBinRegressorUnnormed(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are unbounded + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + """ + super().__init__() + self.version = "1_1" + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B_centers = self._net(x) + return B_centers, B_centers + + +class Projector(nn.Module): + def __init__(self, in_features, out_features, mlp_dim=128): + """Projector MLP + + Args: + in_features (int): input channels + out_features (int): output channels + mlp_dim (int, optional): hidden dimension. Defaults to 128. + """ + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, out_features, 1, 1, 0), + ) + + def forward(self, x): + return self._net(x) + + + +class LinearSplitter(nn.Module): + def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): + super().__init__() + + self.prev_nbins = prev_nbins + self.split_factor = split_factor + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.GELU(), + nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), + nn.ReLU() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + x : feature block; shape - n, c, h, w + b_prev : previous bin widths normed; shape - n, prev_nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + S = self._net(x) + eps = 1e-3 + S = S + eps + n, c, h, w = S.shape + S = S.view(n, self.prev_nbins, self.split_factor, h, w) + S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits + + b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) + + + b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees + # print(b_prev.shape, S_normed.shape) + # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat? + b = b_prev.unsqueeze(2) * S_normed + b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w + + # calculate bin centers for loss calculation + B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) + return b, B_centers \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/patch_transformer.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/patch_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/layers/patch_transformer.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): + """ViT-like transformer block + + Args: + in_channels (int): Input channels + patch_size (int, optional): patch size. Defaults to 10. + embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. + num_heads (int, optional): number of attention heads. Defaults to 4. + use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. + """ + super(PatchTransformerEncoder, self).__init__() + self.use_class_token = use_class_token + encoder_layers = nn.TransformerEncoderLayer( + embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): + """Generate positional encodings + + Args: + sequence_length (int): Sequence length + embedding_dim (int): Embedding dimension + + Returns: + torch.Tensor SBE: Positional encodings + """ + position = torch.arange( + 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) + index = torch.arange( + 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) + div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) + pos_encoding = position * div_term + pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) + pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) + return pos_encoding + + + def forward(self, x): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Input feature tensor + + Returns: + torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim + """ + embeddings = self.embedding_convPxP(x).flatten( + 2) # .shape = n,c,s = n, embedding_dim, s + if self.use_class_token: + # extra special token at start ? + embeddings = nn.functional.pad(embeddings, (1, 0)) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + S, N, E = embeddings.shape + embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/model_io.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..78b6579631dd847ac76651238cb5a948b5a66286 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/model_io.py @@ -0,0 +1,92 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch + +def load_state_dict(model, state_dict): + """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. + + DataParallel prefixes state_dict keys with 'module.' when saving. + If the model is not a DataParallel model but the state_dict is, then prefixes are removed. + If the model is a DataParallel model but the state_dict is not, then prefixes are added. + """ + state_dict = state_dict.get('model', state_dict) + # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' + + do_prefix = isinstance( + model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) + state = {} + for k, v in state_dict.items(): + if k.startswith('module.') and not do_prefix: + k = k[7:] + + if not k.startswith('module.') and do_prefix: + k = 'module.' + k + + state[k] = v + + model.load_state_dict(state) + print("Loaded successfully") + return model + + +def load_wts(model, checkpoint_path): + ckpt = torch.load(checkpoint_path, map_location='cpu') + return load_state_dict(model, ckpt) + + +def load_state_dict_from_url(model, url, **kwargs): + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) + return load_state_dict(model, state_dict) + + +def load_state_from_resource(model, resource: str): + """Loads weights to the model from a given resource. A resource can be of following types: + 1. URL. Prefixed with "url::" + e.g. url::http(s)://url.resource.com/ckpt.pt + + 2. Local path. Prefixed with "local::" + e.g. local::/path/to/ckpt.pt + + + Args: + model (torch.nn.Module): Model + resource (str): resource string + + Returns: + torch.nn.Module: Model with loaded weights + """ + print(f"Using pretrained resource {resource}") + + if resource.startswith('url::'): + url = resource.split('url::')[1] + return load_state_dict_from_url(model, url, progress=True) + + elif resource.startswith('local::'): + path = resource.split('local::')[1] + return load_wts(model, path) + + else: + raise ValueError("Invalid resource type, only url:: and local:: are supported") + \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__init__.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24ae3285d3a721d7825ed4390fdc6be9a35e3b5 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__pycache__/zoedepth_v1.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__pycache__/zoedepth_v1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8845513acac40cec23a691d80cb8177d012d9dab Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/__pycache__/zoedepth_v1.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json new file mode 100644 index 0000000000000000000000000000000000000000..3112ed78c89f00e1d13f5d6e5be87cd3216b6dc7 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth.json @@ -0,0 +1,58 @@ +{ + "model": { + "name": "ZoeDepth", + "version_name": "v1", + "n_bins": 64, + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "midas_model_type" : "DPT_BEiT_L_384", + "min_temp": 0.0212, + "max_temp": 50.0, + "output_distribution": "logbinomial", + "memory_efficient": true, + "inverse_midas": false, + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 0.2, + "w_reg": 0, + "w_grad": 0, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "midas_lr_factor": 1, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10, + "freeze_midas_bn": true + + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null, + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null + } +} \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json new file mode 100644 index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json @@ -0,0 +1,22 @@ +{ + "model": { + "bin_centers_type": "normed", + "img_size": [384, 768] + }, + + "train": { + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" + } +} \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..bc931b059d6165c84e8ff4f09d5c62d19930cee9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth/zoedepth_v1.py @@ -0,0 +1,250 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + b, c, h, w = x.shape + # print("input shape ", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_nk_v1 import ZoeDepthNK + +all_versions = { + "v1": ZoeDepthNK, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json new file mode 100644 index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json @@ -0,0 +1,67 @@ +{ + "model": { + "name": "ZoeDepthNK", + "version_name": "v1", + "bin_conf" : [ + { + "name": "nyu", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 10.0 + }, + { + "name": "kitti", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 80.0 + } + ], + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "min_temp": 0.0212, + "max_temp": 50.0, + "memory_efficient": true, + "midas_model_type" : "DPT_BEiT_L_384", + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth_nk", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 100, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "w_grad": 0, + "w_reg": 0, + "midas_lr_factor": 10, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10 + }, + + "infer": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false, + "force_keep_ar": true + }, + + "eval": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false + } +} \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..7368ae8031188a9f946d9d3f29633c96e791e68e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py @@ -0,0 +1,333 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn + +from zoedepth.models.depth_model import DepthModel +from zoedepth.models.base_models.midas import MidasCore +from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed +from zoedepth.models.layers.dist_layers import ConditionalLogBinomial +from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder +from zoedepth.models.model_io import load_state_from_resource + + +class ZoeDepthNK(DepthModel): + def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', + min_temp=5, max_temp=50, + memory_efficient=False, train_midas=True, + is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + + bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: + "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) + + The length of this list determines the number of metric heads. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + + memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. + + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + + """ + + super().__init__() + + self.core = core + self.bin_conf = bin_conf + self.min_temp = min_temp + self.max_temp = max_temp + self.memory_efficient = memory_efficient + self.train_midas = train_midas + self.is_midas_pretrained = is_midas_pretrained + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.inverse_midas = inverse_midas + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + # self.scales = [16, 8, 4, 2] # spatial scale factors + + self.conv2 = nn.Conv2d( + btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) + + # Transformer classifier on the bottleneck + self.patch_transformer = PatchTransformerEncoder( + btlnck_features, 1, 128, use_class_token=True) + self.mlp_classifier = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 2) + ) + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + self.bin_centers_type = bin_centers_type + # We have bins for each bin conf. + # Create a map (ModuleDict) of 'name' -> seed_bin_regressor + self.seed_bin_regressors = nn.ModuleDict( + {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for conf in bin_conf} + ) + + self.seed_projector = Projector( + btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + for num_out in num_out_features + ]) + + # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) + self.attractors = nn.ModuleDict( + {conf['name']: nn.ModuleList([ + Attractor(bin_embedding_dim, n_attractors[i], + mlp_dim=bin_embedding_dim, alpha=attractor_alpha, + gamma=attractor_gamma, kind=attractor_kind, + attractor_type=attractor_type, memory_efficient=memory_efficient, + min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for i in range(len(n_attractors)) + ]) + for conf in bin_conf} + ) + + last_in = N_MIDAS_OUT + # conditional log binomial for each bin conf + self.conditional_log_binomial = nn.ModuleDict( + {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) + for conf in bin_conf} + ) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. + return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. Defaults to False. + return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. + + Returns: + dict: Dictionary of outputs with keys: + - "rel_depth": Relative depth map of shape (B, 1, H, W) + - "metric_depth": Metric depth map of shape (B, 1, H, W) + - "domain_logits": Domain logits of shape (B, 2) + - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True + - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True + """ + b, c, h, w = x.shape + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + + # Predict which path to take + embedding = self.patch_transformer(x)[0] # N, E + domain_logits = self.mlp_classifier(embedding) # N, 2 + domain_vote = torch.softmax(domain_logits.sum( + dim=0, keepdim=True), dim=-1) # 1, 2 + + # Get the path + bin_conf_name = ["nyu", "kitti"][torch.argmax( + domain_vote, dim=-1).squeeze().item()] + + try: + conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] + except IndexError: + raise ValueError( + f"bin_conf_name {bin_conf_name} not found in bin_confs") + + min_depth = conf['min_depth'] + max_depth = conf['max_depth'] + + seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] + _, seed_b_centers = seed_bin_regressor(x) + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) + else: + b_prev = seed_b_centers + prev_b_embedding = self.seed_projector(x) + + attractors = self.attractors[bin_conf_name] + for projector, attractor, x in zip(self.projectors, attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b + prev_b_embedding = b_embedding + + last = outconv_activation + + b_centers = nn.functional.interpolate( + b_centers, last.shape[-2:], mode='bilinear', align_corners=True) + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + + clb = self.conditional_log_binomial[bin_conf_name] + x = clb(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + output = dict(domain_logits=domain_logits, metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + def get_rel_pos_params(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + encoder_params = get_enc_params_except_rel_pos() + rel_pos_params = get_rel_pos_params() + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 + param_conf.extend([ + {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, + {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, + {'params': midas_params, 'lr': lr / midas_lr_factor} + ]) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + param_conf.append({'params': remaining_params, 'lr': lr}) + return param_conf + + def get_conf_parameters(self, conf_name): + """ + Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + params = [] + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + if bin_conf_name == conf_name: + params += list(module.parameters()) + return params + + def freeze_conf(self, conf_name): + """ + Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = False + + def unfreeze_conf(self, conf_name): + """ + Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = True + + def freeze_all_confs(self): + """ + Freezes all the parameters of all the ModuleDicts children + """ + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + for p in module.parameters(): + p.requires_grad = False + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepthNK(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepthNK.build(**config) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/base_trainer.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..33fbbea3a7d49efe11b005adb5127f441eabfaf6 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/base_trainer.py @@ -0,0 +1,326 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import os +import uuid +import warnings +from datetime import datetime as dt +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +import wandb +from tqdm import tqdm + +from zoedepth.utils.config import flatten +from zoedepth.utils.misc import RunningAverageDict, colorize, colors + + +def is_rank_zero(args): + return args.rank == 0 + + +class BaseTrainer: + def __init__(self, config, model, train_loader, test_loader=None, device=None): + """ Base Trainer class for training a model.""" + + self.config = config + self.metric_criterion = "abs_rel" + if device is None: + device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('cpu') + self.device = device + self.model = model + self.train_loader = train_loader + self.test_loader = test_loader + self.optimizer = self.init_optimizer() + self.scheduler = self.init_scheduler() + + def resize_to_target(self, prediction, target): + if prediction.shape[2:] != target.shape[-2:]: + prediction = nn.functional.interpolate( + prediction, size=target.shape[-2:], mode="bilinear", align_corners=True + ) + return prediction + + def load_ckpt(self, checkpoint_dir="./checkpoints", ckpt_type="best"): + import glob + import os + + from zoedepth.models.model_io import load_wts + + if hasattr(self.config, "checkpoint"): + checkpoint = self.config.checkpoint + elif hasattr(self.config, "ckpt_pattern"): + pattern = self.config.ckpt_pattern + matches = glob.glob(os.path.join( + checkpoint_dir, f"*{pattern}*{ckpt_type}*")) + if not (len(matches) > 0): + raise ValueError(f"No matches found for the pattern {pattern}") + checkpoint = matches[0] + else: + return + model = load_wts(self.model, checkpoint) + # TODO : Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it. + print("Loaded weights from {0}".format(checkpoint)) + warnings.warn( + "Resuming training is not properly supported in this repo. Implement loading / saving of optimizer and scheduler to support it.") + self.model = model + + def init_optimizer(self): + m = self.model.module if self.config.multigpu else self.model + + if self.config.same_lr: + print("Using same LR") + if hasattr(m, 'core'): + m.core.unfreeze() + params = self.model.parameters() + else: + print("Using diff LR") + if not hasattr(m, 'get_lr_params'): + raise NotImplementedError( + f"Model {m.__class__.__name__} does not implement get_lr_params. Please implement it or use the same LR for all parameters.") + + params = m.get_lr_params(self.config.lr) + + return optim.AdamW(params, lr=self.config.lr, weight_decay=self.config.wd) + + def init_scheduler(self): + lrs = [l['lr'] for l in self.optimizer.param_groups] + return optim.lr_scheduler.OneCycleLR(self.optimizer, lrs, epochs=self.config.epochs, steps_per_epoch=len(self.train_loader), + cycle_momentum=self.config.cycle_momentum, + base_momentum=0.85, max_momentum=0.95, div_factor=self.config.div_factor, final_div_factor=self.config.final_div_factor, pct_start=self.config.pct_start, three_phase=self.config.three_phase) + + def train_on_batch(self, batch, train_step): + raise NotImplementedError + + def validate_on_batch(self, batch, val_step): + raise NotImplementedError + + def raise_if_nan(self, losses): + for key, value in losses.items(): + if torch.isnan(value): + raise ValueError(f"{key} is NaN, Stopping training") + + @property + def iters_per_epoch(self): + return len(self.train_loader) + + @property + def total_iters(self): + return self.config.epochs * self.iters_per_epoch + + def should_early_stop(self): + if self.config.get('early_stop', False) and self.step > self.config.early_stop: + return True + + def train(self): + print(f"Training {self.config.name}") + if self.config.uid is None: + self.config.uid = str(uuid.uuid4()).split('-')[-1] + run_id = f"{dt.now().strftime('%d-%h_%H-%M')}-{self.config.uid}" + self.config.run_id = run_id + self.config.experiment_id = f"{self.config.name}{self.config.version_name}_{run_id}" + self.should_write = ((not self.config.distributed) + or self.config.rank == 0) + self.should_log = self.should_write # and logging + if self.should_log: + tags = self.config.tags.split( + ',') if self.config.tags != '' else None + wandb.init(project=self.config.project, name=self.config.experiment_id, config=flatten(self.config), dir=self.config.root, + tags=tags, notes=self.config.notes, settings=wandb.Settings(start_method="fork")) + + self.model.train() + self.step = 0 + best_loss = np.inf + validate_every = int(self.config.validate_every * self.iters_per_epoch) + + + if self.config.prefetch: + + for i, batch in tqdm(enumerate(self.train_loader), desc=f"Prefetching...", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader): + pass + + losses = {} + def stringify_losses(L): return "; ".join(map( + lambda kv: f"{colors.fg.purple}{kv[0]}{colors.reset}: {round(kv[1].item(),3):.4e}", L.items())) + for epoch in range(self.config.epochs): + if self.should_early_stop(): + break + + self.epoch = epoch + ################################# Train loop ########################################################## + if self.should_log: + wandb.log({"Epoch": epoch}, step=self.step) + pbar = tqdm(enumerate(self.train_loader), desc=f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train", + total=self.iters_per_epoch) if is_rank_zero(self.config) else enumerate(self.train_loader) + for i, batch in pbar: + if self.should_early_stop(): + print("Early stopping") + break + # print(f"Batch {self.step+1} on rank {self.config.rank}") + losses = self.train_on_batch(batch, i) + # print(f"trained batch {self.step+1} on rank {self.config.rank}") + + self.raise_if_nan(losses) + if is_rank_zero(self.config) and self.config.print_losses: + pbar.set_description( + f"Epoch: {epoch + 1}/{self.config.epochs}. Loop: Train. Losses: {stringify_losses(losses)}") + self.scheduler.step() + + if self.should_log and self.step % 50 == 0: + wandb.log({f"Train/{name}": loss.item() + for name, loss in losses.items()}, step=self.step) + + self.step += 1 + + ######################################################################################################## + + if self.test_loader: + if (self.step % validate_every) == 0: + self.model.eval() + if self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_latest.pt") + + ################################# Validation loop ################################################## + # validate on the entire validation set in every process but save only from rank 0, I know, inefficient, but avoids divergence of processes + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log( + {f"Test/{name}": tloss for name, tloss in test_losses.items()}, step=self.step) + + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + if self.config.distributed: + dist.barrier() + # print(f"Validated: {metrics} on device {self.config.rank}") + + # print(f"Finished step {self.step} on device {self.config.rank}") + ################################################################################################# + + # Save / validate at the end + self.step += 1 # log as final point + self.model.eval() + self.save_checkpoint(f"{self.config.experiment_id}_latest.pt") + if self.test_loader: + + ################################# Validation loop ################################################## + metrics, test_losses = self.validate() + # print("Validated: {}".format(metrics)) + if self.should_log: + wandb.log({f"Test/{name}": tloss for name, + tloss in test_losses.items()}, step=self.step) + wandb.log({f"Metrics/{k}": v for k, + v in metrics.items()}, step=self.step) + + if (metrics[self.metric_criterion] < best_loss) and self.should_write: + self.save_checkpoint( + f"{self.config.experiment_id}_best.pt") + best_loss = metrics[self.metric_criterion] + + self.model.train() + + def validate(self): + with torch.no_grad(): + losses_avg = RunningAverageDict() + metrics_avg = RunningAverageDict() + for i, batch in tqdm(enumerate(self.test_loader), desc=f"Epoch: {self.epoch + 1}/{self.config.epochs}. Loop: Validation", total=len(self.test_loader), disable=not is_rank_zero(self.config)): + metrics, losses = self.validate_on_batch(batch, val_step=i) + + if losses: + losses_avg.update(losses) + if metrics: + metrics_avg.update(metrics) + + return metrics_avg.get_value(), losses_avg.get_value() + + def save_checkpoint(self, filename): + if not self.should_write: + return + root = self.config.save_dir + if not os.path.isdir(root): + os.makedirs(root) + + fpath = os.path.join(root, filename) + m = self.model.module if self.config.multigpu else self.model + torch.save( + { + "model": m.state_dict(), + "optimizer": None, # TODO : Change to self.optimizer.state_dict() if resume support is needed, currently None to reduce file size + "epoch": self.epoch + }, fpath) + + def log_images(self, rgb: Dict[str, list] = {}, depth: Dict[str, list] = {}, scalar_field: Dict[str, list] = {}, prefix="", scalar_cmap="jet", min_depth=None, max_depth=None): + if not self.should_log: + return + + if min_depth is None: + try: + min_depth = self.config.min_depth + max_depth = self.config.max_depth + except AttributeError: + min_depth = None + max_depth = None + + depth = {k: colorize(v, vmin=min_depth, vmax=max_depth) + for k, v in depth.items()} + scalar_field = {k: colorize( + v, vmin=None, vmax=None, cmap=scalar_cmap) for k, v in scalar_field.items()} + images = {**rgb, **depth, **scalar_field} + wimages = { + prefix+"Predictions": [wandb.Image(v, caption=k) for k, v in images.items()]} + wandb.log(wimages, step=self.step) + + def log_line_plot(self, data): + if not self.should_log: + return + + plt.plot(data) + plt.ylabel("Scale factors") + wandb.log({"Scale factors": wandb.Image(plt)}, step=self.step) + plt.close() + + def log_bar_plot(self, title, labels, values): + if not self.should_log: + return + + data = [[label, val] for (label, val) in zip(labels, values)] + table = wandb.Table(data=data, columns=["label", "value"]) + wandb.log({title: wandb.plot.bar(table, "label", + "value", title=title)}, step=self.step) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/builder.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a663541b08912ebedce21a68c7599ce4c06e85d0 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/builder.py @@ -0,0 +1,48 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module + + +def get_trainer(config): + """Builds and returns a trainer based on the config. + + Args: + config (dict): the config dict (typically constructed using utils.config.get_config) + config.trainer (str): the name of the trainer to use. The module named "{config.trainer}_trainer" must exist in trainers root module + + Raises: + ValueError: If the specified trainer does not exist under trainers/ folder + + Returns: + Trainer (inherited from zoedepth.trainers.BaseTrainer): The Trainer object + """ + assert "trainer" in config and config.trainer is not None and config.trainer != '', "Trainer not specified. Config: {0}".format( + config) + try: + Trainer = getattr(import_module( + f"zoedepth.trainers.{config.trainer}_trainer"), 'Trainer') + except ModuleNotFoundError as e: + raise ValueError(f"Trainer {config.trainer}_trainer not found.") from e + return Trainer diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/loss.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5a1c15cdf5628c1474c566fdc6e58159d7f5ab --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/loss.py @@ -0,0 +1,316 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.amp as amp +import numpy as np + + +KEY_OUTPUT = 'metric_depth' + + +def extract_key(prediction, key): + if isinstance(prediction, dict): + return prediction[key] + return prediction + + +# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) +class SILogLoss(nn.Module): + """SILog loss (pixel-wise)""" + def __init__(self, beta=0.15): + super(SILogLoss, self).__init__() + self.name = 'SILog' + self.beta = beta + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + if target.ndim == 3: + target = target.unsqueeze(1) + + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + input = input[mask] + target = target[mask] + + with amp.autocast(enabled=False): # amp causes NaNs in this loss function + alpha = 1e-7 + g = torch.log(input + alpha) - torch.log(target + alpha) + + # n, c, h, w = g.shape + # norm = 1/(h*w) + # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 + + Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) + + loss = 10 * torch.sqrt(Dg) + + if torch.isnan(loss): + print("Nan SILog loss") + print("input:", input.shape) + print("target:", target.shape) + print("G", torch.sum(torch.isnan(g))) + print("Input min max", torch.min(input), torch.max(input)) + print("Target min max", torch.min(target), torch.max(target)) + print("Dg", torch.isnan(Dg)) + print("loss", torch.isnan(loss)) + + if not return_interpolated: + return loss + + return loss, intr_input + + +def grad(x): + # x.shape : n, c, h, w + diff_x = x[..., 1:, 1:] - x[..., 1:, :-1] + diff_y = x[..., 1:, 1:] - x[..., :-1, 1:] + mag = diff_x**2 + diff_y**2 + # angle_ratio + angle = torch.atan(diff_y / (diff_x + 1e-10)) + return mag, angle + + +def grad_mask(mask): + return mask[..., 1:, 1:] & mask[..., 1:, :-1] & mask[..., :-1, 1:] + + +class GradL1Loss(nn.Module): + """Gradient loss""" + def __init__(self): + super(GradL1Loss, self).__init__() + self.name = 'GradL1' + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + grad_gt = grad(target) + grad_pred = grad(input) + mask_g = grad_mask(mask) + + loss = nn.functional.l1_loss(grad_pred[0][mask_g], grad_gt[0][mask_g]) + loss = loss + \ + nn.functional.l1_loss(grad_pred[1][mask_g], grad_gt[1][mask_g]) + if not return_interpolated: + return loss + return loss, intr_input + + +class OrdinalRegressionLoss(object): + + def __init__(self, ord_num, beta, discretization="SID"): + self.ord_num = ord_num + self.beta = beta + self.discretization = discretization + + def _create_ord_label(self, gt): + N,one, H, W = gt.shape + # print("gt shape:", gt.shape) + + ord_c0 = torch.ones(N, self.ord_num, H, W).to(gt.device) + if self.discretization == "SID": + label = self.ord_num * torch.log(gt) / np.log(self.beta) + else: + label = self.ord_num * (gt - 1.0) / (self.beta - 1.0) + label = label.long() + mask = torch.linspace(0, self.ord_num - 1, self.ord_num, requires_grad=False) \ + .view(1, self.ord_num, 1, 1).to(gt.device) + mask = mask.repeat(N, 1, H, W).contiguous().long() + mask = (mask > label) + ord_c0[mask] = 0 + ord_c1 = 1 - ord_c0 + # implementation according to the paper. + # ord_label = torch.ones(N, self.ord_num * 2, H, W).to(gt.device) + # ord_label[:, 0::2, :, :] = ord_c0 + # ord_label[:, 1::2, :, :] = ord_c1 + # reimplementation for fast speed. + ord_label = torch.cat((ord_c0, ord_c1), dim=1) + return ord_label, mask + + def __call__(self, prob, gt): + """ + :param prob: ordinal regression probability, N x 2*Ord Num x H x W, torch.Tensor + :param gt: depth ground truth, NXHxW, torch.Tensor + :return: loss: loss value, torch.float + """ + # N, C, H, W = prob.shape + valid_mask = gt > 0. + ord_label, mask = self._create_ord_label(gt) + # print("prob shape: {}, ord label shape: {}".format(prob.shape, ord_label.shape)) + entropy = -prob * ord_label + loss = torch.sum(entropy, dim=1)[valid_mask.squeeze(1)] + return loss.mean() + + +class DiscreteNLLLoss(nn.Module): + """Cross entropy loss""" + def __init__(self, min_depth=1e-3, max_depth=10, depth_bins=64): + super(DiscreteNLLLoss, self).__init__() + self.name = 'CrossEntropy' + self.ignore_index = -(depth_bins + 1) + # self._loss_func = nn.NLLLoss(ignore_index=self.ignore_index) + self._loss_func = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + self.min_depth = min_depth + self.max_depth = max_depth + self.depth_bins = depth_bins + self.alpha = 1 + self.zeta = 1 - min_depth + self.beta = max_depth + self.zeta + + def quantize_depth(self, depth): + # depth : N1HW + # output : NCHW + + # Quantize depth log-uniformly on [1, self.beta] into self.depth_bins bins + depth = torch.log(depth / self.alpha) / np.log(self.beta / self.alpha) + depth = depth * (self.depth_bins - 1) + depth = torch.round(depth) + depth = depth.long() + return depth + + + + def _dequantize_depth(self, depth): + """ + Inverse of quantization + depth : NCHW -> N1HW + """ + # Get the center of the bin + + + + + def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): + input = extract_key(input, KEY_OUTPUT) + # assert torch.all(input <= 0), "Input should be negative" + + if input.shape[-1] != target.shape[-1] and interpolate: + input = nn.functional.interpolate( + input, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = input + else: + intr_input = input + + # assert torch.all(input)<=1) + if target.ndim == 3: + target = target.unsqueeze(1) + + target = self.quantize_depth(target) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + # Set the mask to ignore_index + mask = mask.long() + input = input * mask + (1 - mask) * self.ignore_index + target = target * mask + (1 - mask) * self.ignore_index + + + + input = input.flatten(2) # N, nbins, H*W + target = target.flatten(1) # N, H*W + loss = self._loss_func(input, target) + + if not return_interpolated: + return loss + return loss, intr_input + + + + +def compute_scale_and_shift(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + # A needs to be a positive definite matrix. + valid = det > 0 + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + + return x_0, x_1 +class ScaleAndShiftInvariantLoss(nn.Module): + def __init__(self): + super().__init__() + self.name = "SSILoss" + + def forward(self, prediction, target, mask, interpolate=True, return_interpolated=False): + + if prediction.shape[-1] != target.shape[-1] and interpolate: + prediction = nn.functional.interpolate(prediction, target.shape[-2:], mode='bilinear', align_corners=True) + intr_input = prediction + else: + intr_input = prediction + + + prediction, target, mask = prediction.squeeze(), target.squeeze(), mask.squeeze() + assert prediction.shape == target.shape, f"Shape mismatch: Expected same shape but got {prediction.shape} and {target.shape}." + + scale, shift = compute_scale_and_shift(prediction, target, mask) + + scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) + + loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask]) + if not return_interpolated: + return loss + return loss, intr_input + + + + +if __name__ == '__main__': + # Tests for DiscreteNLLLoss + celoss = DiscreteNLLLoss() + print(celoss(torch.rand(4, 64, 26, 32)*10, torch.rand(4, 1, 26, 32)*10, )) + + d = torch.Tensor([6.59, 3.8, 10.0]) + print(celoss.dequantize_depth(celoss.quantize_depth(d))) diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d528ae126f1c51b2f25fd31f94a39591ceb2f43a --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/zoedepth_nk_trainer.py @@ -0,0 +1,143 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics + +from .base_trainer import BaseTrainer + + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.domain_classifier_loss = nn.CrossEntropyLoss() + + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + + Assumes all images in a batch are from the same dataset + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + # batch['dataset'] is a tensor strings all valued either 'nyu' or 'kitti'. labels nyu -> 0, kitti -> 1 + dataset = batch['dataset'][0] + # Convert to 0s or 1s + domain_labels = torch.Tensor([dataset == 'kitti' for _ in range( + images.size(0))]).to(torch.long).to(self.device) + + # m = self.model.module if self.config.multigpu else self.model + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + output = self.model(images) + pred_depths = output['metric_depth'] + domain_logits = output['domain_logits'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + if self.config.w_domain > 0: + l_domain = self.domain_classifier_loss( + domain_logits, domain_labels) + loss = loss + self.config.w_domain * l_domain + losses["DomainLoss"] = l_domain + else: + l_domain = torch.Tensor([0.]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and self.step > 1 and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + self.scaler.update() + self.optimizer.zero_grad(set_to_none=True) + + return losses + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(images)["metric_depth"] + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + mask = torch.logical_and( + depths_gt > self.config.min_depth, depths_gt < self.config.max_depth) + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac1c24c0512c1c1b191670a7c24abb4fca47ba1 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/trainers/zoedepth_trainer.py @@ -0,0 +1,177 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +from zoedepth.trainers.loss import GradL1Loss, SILogLoss +from zoedepth.utils.config import DATASETS_CONFIG +from zoedepth.utils.misc import compute_metrics +from zoedepth.data.preprocess import get_black_border + +from .base_trainer import BaseTrainer +from torchvision import transforms +from PIL import Image +import numpy as np + +class Trainer(BaseTrainer): + def __init__(self, config, model, train_loader, test_loader=None, device=None): + super().__init__(config, model, train_loader, + test_loader=test_loader, device=device) + self.device = device + self.silog_loss = SILogLoss() + self.grad_loss = GradL1Loss() + self.scaler = amp.GradScaler(enabled=self.config.use_amp) + + def train_on_batch(self, batch, train_step): + """ + Expects a batch of images and depth as input + batch["image"].shape : batch_size, c, h, w + batch["depth"].shape : batch_size, 1, h, w + """ + + images, depths_gt = batch['image'].to( + self.device), batch['depth'].to(self.device) + dataset = batch['dataset'][0] + + b, c, h, w = images.size() + mask = batch["mask"].to(self.device).to(torch.bool) + + losses = {} + + with amp.autocast(enabled=self.config.use_amp): + + output = self.model(images) + pred_depths = output['metric_depth'] + + l_si, pred = self.silog_loss( + pred_depths, depths_gt, mask=mask, interpolate=True, return_interpolated=True) + loss = self.config.w_si * l_si + losses[self.silog_loss.name] = l_si + + if self.config.w_grad > 0: + l_grad = self.grad_loss(pred, depths_gt, mask=mask) + loss = loss + self.config.w_grad * l_grad + losses[self.grad_loss.name] = l_grad + else: + l_grad = torch.Tensor([0]) + + self.scaler.scale(loss).backward() + + if self.config.clip_grad > 0: + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_( + self.model.parameters(), self.config.clip_grad) + + self.scaler.step(self.optimizer) + + if self.should_log and (self.step % int(self.config.log_images_every * self.iters_per_epoch)) == 0: + # -99 is treated as invalid depth in the log_images function and is colored grey. + depths_gt[torch.logical_not(mask)] = -99 + + self.log_images(rgb={"Input": images[0, ...]}, depth={"GT": depths_gt[0], "PredictedMono": pred[0]}, prefix="Train", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + if self.config.get("log_rel", False): + self.log_images( + scalar_field={"RelPred": output["relative_depth"][0]}, prefix="TrainRel") + + self.scaler.update() + self.optimizer.zero_grad() + + return losses + + @torch.no_grad() + def eval_infer(self, x): + with amp.autocast(enabled=self.config.use_amp): + m = self.model.module if self.config.multigpu else self.model + pred_depths = m(x)['metric_depth'] + return pred_depths + + @torch.no_grad() + def crop_aware_infer(self, x): + # if we are not avoiding the black border, we can just use the normal inference + if not self.config.get("avoid_boundary", False): + return self.eval_infer(x) + + # otherwise, we need to crop the image to avoid the black border + # For now, this may be a bit slow due to converting to numpy and back + # We assume no normalization is done on the input image + + # get the black border + assert x.shape[0] == 1, "Only batch size 1 is supported for now" + x_pil = transforms.ToPILImage()(x[0].cpu()) + x_np = np.array(x_pil, dtype=np.uint8) + black_border_params = get_black_border(x_np) + top, bottom, left, right = black_border_params.top, black_border_params.bottom, black_border_params.left, black_border_params.right + x_np_cropped = x_np[top:bottom, left:right, :] + x_cropped = transforms.ToTensor()(Image.fromarray(x_np_cropped)) + + # run inference on the cropped image + pred_depths_cropped = self.eval_infer(x_cropped.unsqueeze(0).to(self.device)) + + # resize the prediction to x_np_cropped's size + pred_depths_cropped = nn.functional.interpolate( + pred_depths_cropped, size=(x_np_cropped.shape[0], x_np_cropped.shape[1]), mode="bilinear", align_corners=False) + + + # pad the prediction back to the original size + pred_depths = torch.zeros((1, 1, x_np.shape[0], x_np.shape[1]), device=pred_depths_cropped.device, dtype=pred_depths_cropped.dtype) + pred_depths[:, :, top:bottom, left:right] = pred_depths_cropped + + return pred_depths + + + + def validate_on_batch(self, batch, val_step): + images = batch['image'].to(self.device) + depths_gt = batch['depth'].to(self.device) + dataset = batch['dataset'][0] + mask = batch["mask"].to(self.device) + if 'has_valid_depth' in batch: + if not batch['has_valid_depth']: + return None, None + + depths_gt = depths_gt.squeeze().unsqueeze(0).unsqueeze(0) + mask = mask.squeeze().unsqueeze(0).unsqueeze(0) + if dataset == 'nyu': + pred_depths = self.crop_aware_infer(images) + else: + pred_depths = self.eval_infer(images) + pred_depths = pred_depths.squeeze().unsqueeze(0).unsqueeze(0) + + with amp.autocast(enabled=self.config.use_amp): + l_depth = self.silog_loss( + pred_depths, depths_gt, mask=mask.to(torch.bool), interpolate=True) + + metrics = compute_metrics(depths_gt, pred_depths, **self.config) + losses = {f"{self.silog_loss.name}": l_depth.item()} + + if val_step == 1 and self.should_log: + depths_gt[torch.logical_not(mask)] = -99 + self.log_images(rgb={"Input": images[0]}, depth={"GT": depths_gt[0], "PredictedMono": pred_depths[0]}, prefix="Test", + min_depth=DATASETS_CONFIG[dataset]['min_depth'], max_depth=DATASETS_CONFIG[dataset]['max_depth']) + + return metrics, losses diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__init__.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0e2aed9fc748b7c75eeca46e897d8be3b3bf638 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e95fb3e1d2c71ea343d457ea27d6b2a24b94222 Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/arg_utils.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cec642eaa198b55340cb9b85e89b2063608576ed Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/__pycache__/config.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/arg_utils.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/config.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/easydict/__init__.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e491cacd57951f7977f134b49c64b61fb6daa32b Binary files /dev/null and b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/easydict/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/geometry.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..e3da8c75b5a8e39b4b58a4dcd827b84d79b9115c --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/geometry.py @@ -0,0 +1,98 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np + +def get_intrinsics(H,W): + """ + Intrinsics for a pinhole camera model. + Assume fov of 55 degrees and central principal point. + """ + f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0) + cx = 0.5 * W + cy = 0.5 * H + return np.array([[f, 0, cx], + [0, f, cy], + [0, 0, 1]]) + +def depth_to_points(depth, R=None, t=None): + + K = get_intrinsics(depth.shape[1], depth.shape[2]) + Kinv = np.linalg.inv(K) + if R is None: + R = np.eye(3) + if t is None: + t = np.zeros(3) + + # M converts from your coordinate to PyTorch3D's coordinate system + M = np.eye(3) + M[0, 0] = -1.0 + M[1, 1] = -1.0 + + height, width = depth.shape[1:3] + + x = np.arange(width) + y = np.arange(height) + coord = np.stack(np.meshgrid(x, y), -1) + coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1 + coord = coord.astype(np.float32) + # coord = torch.as_tensor(coord, dtype=torch.float32, device=device) + coord = coord[None] # bs, h, w, 3 + + D = depth[:, :, :, None, None] + # print(D.shape, Kinv[None, None, None, ...].shape, coord[:, :, :, :, None].shape ) + pts3D_1 = D * Kinv[None, None, None, ...] @ coord[:, :, :, :, None] + # pts3D_1 live in your coordinate system. Convert them to Py3D's + pts3D_1 = M[None, None, None, ...] @ pts3D_1 + # from reference to targe tviewpoint + pts3D_2 = R[None, None, None, ...] @ pts3D_1 + t[None, None, None, :, None] + # pts3D_2 = pts3D_1 + # depth_2 = pts3D_2[:, :, :, 2, :] # b,1,h,w + return pts3D_2[:, :, :, :3, 0][0] + + +def create_triangles(h, w, mask=None): + """ + Reference: https://github.com/google-research/google-research/blob/e96197de06613f1b027d20328e06d69829fa5a89/infinite_nature/render_utils.py#L68 + Creates mesh triangle indices from a given pixel grid size. + This function is not and need not be differentiable as triangle indices are + fixed. + Args: + h: (int) denoting the height of the image. + w: (int) denoting the width of the image. + Returns: + triangles: 2D numpy array of indices (int) with shape (2(W-1)(H-1) x 3) + """ + x, y = np.meshgrid(range(w - 1), range(h - 1)) + tl = y * w + x + tr = y * w + x + 1 + bl = (y + 1) * w + x + br = (y + 1) * w + x + 1 + triangles = np.array([tl, bl, tr, br, tr, bl]) + triangles = np.transpose(triangles, (1, 2, 0)).reshape( + ((w - 1) * (h - 1) * 2, 3)) + if mask is not None: + mask = mask.reshape(-1) + triangles = triangles[mask[triangles].all(1)] + return triangles diff --git a/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/misc.py b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbe403d3669829eecdf658458c76aa5e87e2b33 --- /dev/null +++ b/third_party/src/flux_ch/annotator/zoe/zoedepth/utils/misc.py @@ -0,0 +1,368 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +"""Miscellaneous utility functions.""" + +from scipy import ndimage + +import base64 +import math +import re +from io import BytesIO + +import matplotlib +import matplotlib.cm +import numpy as np +import requests +import torch +import torch.distributed as dist +import torch.nn +import torch.nn as nn +import torch.utils.data.distributed +from PIL import Image +from torchvision.transforms import ToTensor + + +class RunningAverage: + def __init__(self): + self.avg = 0 + self.count = 0 + + def append(self, value): + self.avg = (value + self.count * self.avg) / (self.count + 1) + self.count += 1 + + def get_value(self): + return self.avg + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + + +class RunningAverageDict: + """A dictionary of running averages.""" + def __init__(self): + self._dict = None + + def update(self, new_dict): + if new_dict is None: + return + + if self._dict is None: + self._dict = dict() + for key, value in new_dict.items(): + self._dict[key] = RunningAverage() + + for key, value in new_dict.items(): + self._dict[key].append(value) + + def get_value(self): + if self._dict is None: + return None + return {key: value.get_value() for key, value in self._dict.items()} + + +def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None): + """Converts a depth map to a color image. + + Args: + value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed + vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. + vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. + cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. + invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. + invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. + background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). + gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. + value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. + + Returns: + numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) + """ + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + + value = value.squeeze() + if invalid_mask is None: + invalid_mask = value == invalid_val + mask = np.logical_not(invalid_mask) + + # normalize + vmin = np.percentile(value[mask],2) if vmin is None else vmin + vmax = np.percentile(value[mask],85) if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + + # squeeze last dim if it exists + # grey out the invalid values + + value[invalid_mask] = np.nan + cmapper = matplotlib.cm.get_cmap(cmap) + if value_transform: + value = value_transform(value) + # value = value / value.max() + value = cmapper(value, bytes=True) # (nxmx4) + + # img = value[:, :, :] + img = value[...] + img[invalid_mask] = background_color + + # return img.transpose((2, 0, 1)) + if gamma_corrected: + # gamma correction + img = img / 255 + img = np.power(img, 2.2) + img = img * 255 + img = img.astype(np.uint8) + return img + + +def count_parameters(model, include_all=False): + return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all) + + +def compute_errors(gt, pred): + """Compute metrics for 'pred' compared to 'gt' + + Args: + gt (numpy.ndarray): Ground truth values + pred (numpy.ndarray): Predicted values + + gt.shape should be equal to pred.shape + + Returns: + dict: Dictionary containing the following metrics: + 'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25 + 'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2 + 'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3 + 'abs_rel': Absolute relative error + 'rmse': Root mean squared error + 'log_10': Absolute log10 error + 'sq_rel': Squared relative error + 'rmse_log': Root mean squared error on the log scale + 'silog': Scale invariant log error + """ + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + abs_rel = np.mean(np.abs(gt - pred) / gt) + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + err = np.log(pred) - np.log(gt) + silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 + + log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean() + return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log, + silog=silog, sq_rel=sq_rel) + + +def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs): + """Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics. + """ + if 'config' in kwargs: + config = kwargs['config'] + garg_crop = config.garg_crop + eigen_crop = config.eigen_crop + min_depth_eval = config.min_depth_eval + max_depth_eval = config.max_depth_eval + + if gt.shape[-2:] != pred.shape[-2:] and interpolate: + pred = nn.functional.interpolate( + pred, gt.shape[-2:], mode='bilinear', align_corners=True) + + pred = pred.squeeze().cpu().numpy() + pred[pred < min_depth_eval] = min_depth_eval + pred[pred > max_depth_eval] = max_depth_eval + pred[np.isinf(pred)] = max_depth_eval + pred[np.isnan(pred)] = min_depth_eval + + gt_depth = gt.squeeze().cpu().numpy() + valid_mask = np.logical_and( + gt_depth > min_depth_eval, gt_depth < max_depth_eval) + + if garg_crop or eigen_crop: + gt_height, gt_width = gt_depth.shape + eval_mask = np.zeros(valid_mask.shape) + + if garg_crop: + eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), + int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1 + + elif eigen_crop: + # print("-"*10, " EIGEN CROP ", "-"*10) + if dataset == 'kitti': + eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), + int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1 + else: + # assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images" + eval_mask[45:471, 41:601] = 1 + else: + eval_mask = np.ones(valid_mask.shape) + valid_mask = np.logical_and(valid_mask, eval_mask) + return compute_errors(gt_depth[valid_mask], pred[valid_mask]) + + +#################################### Model uilts ################################################ + + +def parallelize(config, model, find_unused_parameters=True): + + if config.gpu is not None: + torch.cuda.set_device(config.gpu) + model = model.cuda(config.gpu) + + config.multigpu = False + if config.distributed: + # Use DDP + config.multigpu = True + config.rank = config.rank * config.ngpus_per_node + config.gpu + dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, + world_size=config.world_size, rank=config.rank) + config.batch_size = int(config.batch_size / config.ngpus_per_node) + # config.batch_size = 8 + config.workers = int( + (config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node) + print("Device", config.gpu, "Rank", config.rank, "batch size", + config.batch_size, "Workers", config.workers) + torch.cuda.set_device(config.gpu) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = model.cuda(config.gpu) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu, + find_unused_parameters=find_unused_parameters) + + elif config.gpu is None: + # Use DP + config.multigpu = True + model = model.cuda() + model = torch.nn.DataParallel(model) + + return model + + +################################################################################################# + + +##################################################################################################### + + +class colors: + '''Colors class: + Reset all colors with colors.reset + Two subclasses fg for foreground and bg for background. + Use as colors.subclass.colorname. + i.e. colors.fg.red or colors.bg.green + Also, the generic bold, disable, underline, reverse, strikethrough, + and invisible work with the main class + i.e. colors.bold + ''' + reset = '\033[0m' + bold = '\033[01m' + disable = '\033[02m' + underline = '\033[04m' + reverse = '\033[07m' + strikethrough = '\033[09m' + invisible = '\033[08m' + + class fg: + black = '\033[30m' + red = '\033[31m' + green = '\033[32m' + orange = '\033[33m' + blue = '\033[34m' + purple = '\033[35m' + cyan = '\033[36m' + lightgrey = '\033[37m' + darkgrey = '\033[90m' + lightred = '\033[91m' + lightgreen = '\033[92m' + yellow = '\033[93m' + lightblue = '\033[94m' + pink = '\033[95m' + lightcyan = '\033[96m' + + class bg: + black = '\033[40m' + red = '\033[41m' + green = '\033[42m' + orange = '\033[43m' + blue = '\033[44m' + purple = '\033[45m' + cyan = '\033[46m' + lightgrey = '\033[47m' + + +def printc(text, color): + print(f"{color}{text}{colors.reset}") + +############################################ + +def get_image_from_url(url): + response = requests.get(url) + img = Image.open(BytesIO(response.content)).convert("RGB") + return img + +def url_to_torch(url, size=(384, 384)): + img = get_image_from_url(url) + img = img.resize(size, Image.ANTIALIAS) + img = torch.from_numpy(np.asarray(img)).float() + img = img.permute(2, 0, 1) + img.div_(255) + return img + +def pil_to_batched_tensor(img): + return ToTensor()(img).unsqueeze(0) + +def save_raw_16bit(depth, fpath="raw.png"): + if isinstance(depth, torch.Tensor): + depth = depth.squeeze().cpu().numpy() + + assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array" + assert depth.ndim == 2, "Depth must be 2D" + depth = depth * 256 # scale for 16-bit png + depth = depth.astype(np.uint16) + depth = Image.fromarray(depth) + depth.save(fpath) + print("Saved raw depth to", fpath) \ No newline at end of file diff --git a/third_party/src/flux_ch/api.py b/third_party/src/flux_ch/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b08202adb35d2ffae320bb9b47f567e538837836 --- /dev/null +++ b/third_party/src/flux_ch/api.py @@ -0,0 +1,194 @@ +import io +import os +import time +from pathlib import Path + +import requests +from PIL import Image + +API_ENDPOINT = "https://api.bfl.ml" + + +class ApiException(Exception): + def __init__(self, status_code: int, detail: str | list[dict] | None = None): + super().__init__() + self.detail = detail + self.status_code = status_code + + def __str__(self) -> str: + return self.__repr__() + + def __repr__(self) -> str: + if self.detail is None: + message = None + elif isinstance(self.detail, str): + message = self.detail + else: + message = "[" + ",".join(d["msg"] for d in self.detail) + "]" + return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" + + +class ImageRequest: + def __init__( + self, + prompt: str, + width: int = 1024, + height: int = 1024, + name: str = "flux.1-pro", + num_steps: int = 50, + prompt_upsampling: bool = False, + seed: int | None = None, + validate: bool = True, + launch: bool = True, + api_key: str | None = None, + ): + """ + Manages an image generation request to the API. + + Args: + prompt: Prompt to sample + width: Width of the image in pixel + height: Height of the image in pixel + name: Name of the model + num_steps: Number of network evaluations + prompt_upsampling: Use prompt upsampling + seed: Fix the generation seed + validate: Run input validation + launch: Directly launches request + api_key: Your API key if not provided by the environment + + Raises: + ValueError: For invalid input + ApiException: For errors raised from the API + """ + if validate: + if name not in ["flux.1-pro"]: + raise ValueError(f"Invalid model {name}") + elif width % 32 != 0: + raise ValueError(f"width must be divisible by 32, got {width}") + elif not (256 <= width <= 1440): + raise ValueError(f"width must be between 256 and 1440, got {width}") + elif height % 32 != 0: + raise ValueError(f"height must be divisible by 32, got {height}") + elif not (256 <= height <= 1440): + raise ValueError(f"height must be between 256 and 1440, got {height}") + elif not (1 <= num_steps <= 50): + raise ValueError(f"steps must be between 1 and 50, got {num_steps}") + + self.request_json = { + "prompt": prompt, + "width": width, + "height": height, + "variant": name, + "steps": num_steps, + "prompt_upsampling": prompt_upsampling, + } + if seed is not None: + self.request_json["seed"] = seed + + self.request_id: str | None = None + self.result: dict | None = None + self._image_bytes: bytes | None = None + self._url: str | None = None + if api_key is None: + self.api_key = os.environ.get("BFL_API_KEY") + else: + self.api_key = api_key + + if launch: + self.request() + + def request(self): + """ + Request to generate the image. + """ + if self.request_id is not None: + return + response = requests.post( + f"{API_ENDPOINT}/v1/image", + headers={ + "accept": "application/json", + "x-key": self.api_key, + "Content-Type": "application/json", + }, + json=self.request_json, + ) + result = response.json() + if response.status_code != 200: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + self.request_id = response.json()["id"] + + def retrieve(self) -> dict: + """ + Wait for the generation to finish and retrieve response. + """ + if self.request_id is None: + self.request() + while self.result is None: + response = requests.get( + f"{API_ENDPOINT}/v1/get_result", + headers={ + "accept": "application/json", + "x-key": self.api_key, + }, + params={ + "id": self.request_id, + }, + ) + result = response.json() + if "status" not in result: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + elif result["status"] == "Ready": + self.result = result["result"] + elif result["status"] == "Pending": + time.sleep(0.5) + else: + raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") + return self.result + + @property + def bytes(self) -> bytes: + """ + Generated image as bytes. + """ + if self._image_bytes is None: + response = requests.get(self.url) + if response.status_code == 200: + self._image_bytes = response.content + else: + raise ApiException(status_code=response.status_code) + return self._image_bytes + + @property + def url(self) -> str: + """ + Public url to retrieve the image from + """ + if self._url is None: + result = self.retrieve() + self._url = result["sample"] + return self._url + + @property + def image(self) -> Image.Image: + """ + Load the image as a PIL Image + """ + return Image.open(io.BytesIO(self.bytes)) + + def save(self, path: str): + """ + Save the generated image to a local path + """ + suffix = Path(self.url).suffix + if not path.endswith(suffix): + path = path + suffix + Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as file: + file.write(self.bytes) + + +if __name__ == "__main__": + from fire import Fire + + Fire(ImageRequest) diff --git a/third_party/src/flux_ch/cli.py b/third_party/src/flux_ch/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..f3624bc6c387f359162e68f46995b12ce341970a --- /dev/null +++ b/third_party/src/flux_ch/cli.py @@ -0,0 +1,254 @@ +import os +import re +import time +from dataclasses import dataclass +from glob import iglob + +import torch +from einops import rearrange +from fire import Fire +from PIL import ExifTags, Image + +from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack +from flux.util import (configs, embed_watermark, load_ae, load_clip, + load_flow_model, load_t5) +from transformers import pipeline + +NSFW_THRESHOLD = 0.85 + +@dataclass +class SamplingOptions: + prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + + +def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: + user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" + usage = ( + "Usage: Either write your prompt directly, leave this field empty " + "to repeat the prompt or write a command starting with a slash:\n" + "- '/w ' will set the width of the generated image\n" + "- '/h ' will set the height of the generated image\n" + "- '/s ' sets the next seed\n" + "- '/g ' sets the guidance (flux-dev only)\n" + "- '/n ' sets the number of steps\n" + "- '/q' to quit" + ) + + while (prompt := input(user_question)).startswith("/"): + if prompt.startswith("/w"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, width = prompt.split() + options.width = 16 * (int(width) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/h"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, height = prompt.split() + options.height = 16 * (int(height) // 16) + print( + f"Setting resolution to {options.width} x {options.height} " + f"({options.height *options.width/1e6:.2f}MP)" + ) + elif prompt.startswith("/g"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, guidance = prompt.split() + options.guidance = float(guidance) + print(f"Setting guidance to {options.guidance}") + elif prompt.startswith("/s"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, seed = prompt.split() + options.seed = int(seed) + print(f"Setting seed to {options.seed}") + elif prompt.startswith("/n"): + if prompt.count(" ") != 1: + print(f"Got invalid command '{prompt}'\n{usage}") + continue + _, steps = prompt.split() + options.num_steps = int(steps) + print(f"Setting seed to {options.num_steps}") + elif prompt.startswith("/q"): + print("Quitting") + return None + else: + if not prompt.startswith("/h"): + print(f"Got invalid command '{prompt}'\n{usage}") + print(usage) + if prompt != "": + options.prompt = prompt + return options + + +@torch.inference_mode() +def main( + name: str = "flux-schnell", + width: int = 1360, + height: int = 768, + seed: int | None = None, + prompt: str = ( + "a photo of a forest with mist swirling around the tree trunks. The word " + '"FLUX" is painted over it in big, red brush strokes with visible texture' + ), + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + guidance: float = 3.5, + offload: bool = False, + output_dir: str = "output", + add_sampling_metadata: bool = True, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the model to load + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + """ + nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection") + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 50 + + # allow for packing and conversion to latent space + height = 16 * (height // 16) + width = 16 * (width // 16) + + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + # init all components + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + prompt=prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if loop: + opts = parse_prompt(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.prompt}") + t0 = time.perf_counter() + + # prepare input + x = get_noise( + 1, + opts.height, + opts.width, + device=torch_device, + dtype=torch.bfloat16, + seed=opts.seed, + ) + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + inp = prepare(t5, clip, x, prompt=opts.prompt) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # denoise initial noise + x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + x = unpack(x.float(), opts.height, opts.width) + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + t1 = time.perf_counter() + + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s. Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + + if nsfw_score < NSFW_THRESHOLD: + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = name + if add_sampling_metadata: + exif_data[ExifTags.Base.ImageDescription] = prompt + img.save(fn, exif=exif_data, quality=95, subsampling=0) + idx += 1 + else: + print("Your generated image may contain NSFW content.") + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + else: + opts = None + + +def app(): + Fire(main) + + +if __name__ == "__main__": + app() diff --git a/third_party/src/flux_ch/controlnet.py b/third_party/src/flux_ch/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a04cc0234b2b726a550cbe62d027943f6bbcbb --- /dev/null +++ b/third_party/src/flux_ch/controlnet.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return controlnet_block_res_samples diff --git a/third_party/src/flux_ch/math.py b/third_party/src/flux_ch/math.py new file mode 100644 index 0000000000000000000000000000000000000000..0156bb6a205dec340e029f0c87cf70ae8709ae12 --- /dev/null +++ b/third_party/src/flux_ch/math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/third_party/src/flux_ch/model.py b/third_party/src/flux_ch/model.py new file mode 100644 index 0000000000000000000000000000000000000000..51531c114babcea3b7a365ca44ee458bfce9a673 --- /dev/null +++ b/third_party/src/flux_ch/model.py @@ -0,0 +1,228 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + image_proj: Tensor | None = None, + ip_scale: Tensor | float = 1.0, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + image_proj, + ip_scale, + ) + else: + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + image_proj=image_proj, + ip_scale=ip_scale, + ) + # controlnet residual + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[index_block % 2] + + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + vec, + pe, + ) + else: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/third_party/src/flux_ch/modules/__pycache__/autoencoder.cpython-310.pyc b/third_party/src/flux_ch/modules/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..370fbac5fdea1dd8b795baaeba6601e237dd11bf Binary files /dev/null and b/third_party/src/flux_ch/modules/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/modules/__pycache__/conditioner.cpython-310.pyc b/third_party/src/flux_ch/modules/__pycache__/conditioner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad24761b7006b426a2877376de387420bf3386fd Binary files /dev/null and b/third_party/src/flux_ch/modules/__pycache__/conditioner.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/modules/__pycache__/layers.cpython-310.pyc b/third_party/src/flux_ch/modules/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b29a1e637bc09c92f1d3207020cc66ba13d7cfb Binary files /dev/null and b/third_party/src/flux_ch/modules/__pycache__/layers.cpython-310.pyc differ diff --git a/third_party/src/flux_ch/modules/autoencoder.py b/third_party/src/flux_ch/modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..75159f711f65f064107a1a1b9be6f09fc9872028 --- /dev/null +++ b/third_party/src/flux_ch/modules/autoencoder.py @@ -0,0 +1,312 @@ +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/third_party/src/flux_ch/modules/conditioner.py b/third_party/src/flux_ch/modules/conditioner.py new file mode 100644 index 0000000000000000000000000000000000000000..7cdd881878ace848745da7d723c60f03392916ab --- /dev/null +++ b/third_party/src/flux_ch/modules/conditioner.py @@ -0,0 +1,38 @@ +from torch import Tensor, nn +from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, + T5Tokenizer) + + +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) + else: + self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, text: list[str]) -> Tensor: + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + outputs = self.hf_module( + input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] diff --git a/third_party/src/flux_ch/modules/layers.py b/third_party/src/flux_ch/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c5489698671c6ed32dcb790a2f83d682d898b872 --- /dev/null +++ b/third_party/src/flux_ch/modules/layers.py @@ -0,0 +1,595 @@ +import math +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + +from ..math import attention, rope +import torch.nn.functional as F + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + t.device + ) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + +class FLuxSelfAttnProcessor: + def __call__(self, attn, x, pe, **attention_kwargs): + print('2' * 30) + + qkv = attn.qkv(x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + return x + +class LoraFluxAttnProcessor(nn.Module): + + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + + def __call__(self, attn, x, pe, **attention_kwargs): + qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k = attn.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = attn.proj(x) + self.proj_lora(x) * self.lora_weight + print('1' * 30) + print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm') + return x + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + def forward(): + pass + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + +class DoubleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1): + super().__init__() + self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class IPDoubleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with double stream block.""" + + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPDoubleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_double_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=True) + self.ip_adapter_double_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=True) + + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_k_proj.bias) + + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.weight) + nn.init.zeros_(self.ip_adapter_double_stream_v_proj.bias) + + def __call__(self, attn, img, txt, vec, pe, image_proj, ip_scale=1.0, **attention_kwargs): + + # Prepare image for attention + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, :txt.shape[1]], attn1[:, txt.shape[1]:] + + # print(f"txt_attn shape: {txt_attn.size()}") + # print(f"img_attn shape: {img_attn.size()}") + + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + + + # IP-adapter processing + ip_query = img_q # latent sample query + ip_key = self.ip_adapter_double_stream_k_proj(image_proj) + ip_value = self.ip_adapter_double_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value, + dropout_p=0.0, + is_causal=False + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)", H=attn.num_heads, D=attn.head_dim) + + img = img + ip_scale * ip_attention + + return img, txt + +class DoubleStreamBlockProcessor: + def __call__(self, attn, img, txt, vec, pe, **attention_kwargs): + img_mod1, img_mod2 = attn.img_mod(vec) + txt_mod1, txt_mod2 = attn.txt_mod(vec) + + # prepare image for attention + img_modulated = attn.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = attn.img_attn.qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = attn.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = attn.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn1 = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift) + return img, txt + +class DoubleStreamBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_heads + + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + processor = DoubleStreamBlockProcessor() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor = None, + ip_scale: float =1.0, + ) -> tuple[Tensor, Tensor]: + if image_proj is None: + return self.processor(self, img, txt, vec, pe) + else: + return self.processor(self, img, txt, vec, pe, image_proj, ip_scale) + +class IPSingleStreamBlockProcessor(nn.Module): + """Attention processor for handling IP-adapter with single stream block.""" + def __init__(self, context_dim, hidden_dim): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "IPSingleStreamBlockProcessor requires PyTorch 2.0 or higher. Please upgrade PyTorch." + ) + + # Ensure context_dim matches the dimension of image_proj + self.context_dim = context_dim + self.hidden_dim = hidden_dim + + # Initialize projections for IP-adapter + self.ip_adapter_single_stream_k_proj = nn.Linear(context_dim, hidden_dim, bias=False) + self.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, hidden_dim, bias=False) + + nn.init.zeros_(self.ip_adapter_single_stream_k_proj.weight) + nn.init.zeros_(self.ip_adapter_single_stream_v_proj.weight) + + def __call__( + self, + attn: nn.Module, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # IP-adapter processing + ip_query = q + ip_key = self.ip_adapter_single_stream_k_proj(image_proj) + ip_value = self.ip_adapter_single_stream_v_proj(image_proj) + + # Reshape projections for multi-head attention + ip_key = rearrange(ip_key, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + ip_value = rearrange(ip_value, 'B L (H D) -> B H L D', H=attn.num_heads, D=attn.head_dim) + + + # Compute attention between IP projections and the latent query + ip_attention = F.scaled_dot_product_attention( + ip_query, + ip_key, + ip_value + ) + ip_attention = rearrange(ip_attention, "B H L D -> B L (H D)") + + attn_out = attn_1 + ip_scale * ip_attention + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_out, attn.mlp_act(mlp)), 2)) + out = x + mod.gate * output + + return out + + +class SingleStreamBlockLoraProcessor(nn.Module): + def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1): + super().__init__() + self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha) + self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha) + self.lora_weight = lora_weight + + def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight + output = x + mod.gate * output + return output + + +class SingleStreamBlockProcessor: + def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + + mod, _ = attn.modulation(vec) + x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift + qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1) + + q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads) + q, k = attn.norm(q, k, v) + + # compute attention + attn_1 = attention(q, k, v, pe=pe) + + # compute activation in mlp stream, cat again and run second linear layer + output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) + output = x + mod.gate * output + return output + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(self.head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + processor = SingleStreamBlockProcessor() + self.set_processor(processor) + + + def set_processor(self, processor) -> None: + self.processor = processor + + def get_processor(self): + return self.processor + + def forward( + self, + x: Tensor, + vec: Tensor, + pe: Tensor, + image_proj: Tensor | None = None, + ip_scale: float = 1.0 + ) -> Tensor: + if image_proj is None: + return self.processor(self, x, vec, pe) + else: + return self.processor(self, x, vec, pe, image_proj, ip_scale) + + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + +class ImageProjModel(torch.nn.Module): + """Projection Model + https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L28 + """ + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + diff --git a/third_party/src/flux_ch/sampling.py b/third_party/src/flux_ch/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..216a544b1e2bf9316a91b1fabee67dff4c88d89a --- /dev/null +++ b/third_party/src/flux_ch/sampling.py @@ -0,0 +1,247 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from .model import Flux +from .modules.conditioner import HFEmbedder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + timesteps: list[float], + neg_txt: Tensor = None, + neg_txt_ids: Tensor = None, + neg_vec: Tensor = None, + # sampling parameters + + guidance: float = 4.0, + true_gs = 1, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1.0, + neg_ip_scale: Tensor | float = 1.0, + +): + + i = 0 + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=image_proj, + ip_scale=ip_scale, + ) + if i >= timestep_to_start_cfg: + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + img = img + (t_prev - t_curr) * pred + i += 1 + return img + + + +def denoise_controlnet( + model: Flux, + controlnet:None, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + neg_txt: Tensor, + neg_txt_ids: Tensor, + neg_vec: Tensor, + controlnet_cond, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + true_gs = 1, + controlnet_gs=0.7, + timestep_to_start_cfg=0, + # ip-adapter parameters + image_proj: Tensor=None, + neg_image_proj: Tensor=None, + ip_scale: Tensor | float = 1, + neg_ip_scale: Tensor | float = 1, +): + # this is ignored for schnell + i = 0 + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples], + image_proj=image_proj, + ip_scale=ip_scale, + ) + if i >= timestep_to_start_cfg: + neg_block_res_samples = controlnet( + img=img, + img_ids=img_ids, + controlnet_cond=controlnet_cond, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + neg_pred = model( + img=img, + img_ids=img_ids, + txt=neg_txt, + txt_ids=neg_txt_ids, + y=neg_vec, + timesteps=t_vec, + guidance=guidance_vec, + block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples], + image_proj=neg_image_proj, + ip_scale=neg_ip_scale, + ) + pred = neg_pred + true_gs * (pred - neg_pred) + + img = img + (t_prev - t_curr) * pred + + i += 1 + return img + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/third_party/src/flux_ch/util.py b/third_party/src/flux_ch/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9dd496213f0acc9cd05ddc621d504dd9e87373e9 --- /dev/null +++ b/third_party/src/flux_ch/util.py @@ -0,0 +1,433 @@ +import os +from dataclasses import dataclass + +import torch +import json +import cv2 +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file as load_sft + +from optimum.quanto import requantize + +from .model import Flux, FluxParams +from .controlnet import ControlNetFlux +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder +from .annotator.dwpose import DWposeDetector +from .annotator.mlsd import MLSDdetector +from .annotator.canny import CannyDetector +from .annotator.midas import MidasDetector +from .annotator.hed import HEDdetector +from .annotator.tile import TileDetector +from .annotator.zoe import ZoeDetector + + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + +def get_lora_rank(checkpoint): + for k in checkpoint.keys(): + if k.endswith(".down.weight"): + return checkpoint[k].shape[0] + +def load_checkpoint(local_path, repo_id, name): + if local_path is not None: + if '.safetensors' in local_path: + print(f"Loading .safetensors checkpoint from {local_path}") + checkpoint = load_safetensors(local_path) + else: + print(f"Loading checkpoint from {local_path}") + checkpoint = torch.load(local_path, map_location='cpu') + elif repo_id is not None and name is not None: + print(f"Loading checkpoint {name} from repo id {repo_id}") + checkpoint = load_from_repo_id(repo_id, name) + else: + raise ValueError( + "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" + ) + return checkpoint + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + +#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 +#Added upscale_method, mode params +def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode='edge'): + if skip_hwc3: + img = input_image + else: + img = HWC3(input_image) + H_raw, W_raw, _ = img.shape + if resolution == 0: + return img, lambda x: x + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(np.round(float(H_raw) * k)) + W_target = int(np.round(float(W_raw) * k)) + img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA) + H_pad, W_pad = pad64(H_target), pad64(W_target) + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) + + def remove_pad(x): + return safer_memory(x[:H_target, :W_target, ...]) + + return safer_memory(img_padded), remove_pad + +class Annotator: + def __init__(self, name: str, device: str): + if name == "canny": + processor = CannyDetector() + elif name == "openpose": + processor = DWposeDetector(device) + elif name == "depth": + processor = MidasDetector() + elif name == "hed": + processor = HEDdetector() + elif name == "hough": + processor = MLSDdetector() + elif name == "tile": + processor = TileDetector() + elif name == "zoe": + processor = ZoeDetector() + self.name = name + self.processor = processor + + def __call__(self, image: Image, width: int, height: int): + image = np.array(image) + detect_resolution = max(width, height) + image, remove_pad = resize_image_with_pad(image, detect_resolution) + + image = np.array(image) + if self.name == "canny": + result = self.processor(image, low_threshold=100, high_threshold=200) + elif self.name == "hough": + result = self.processor(image, thr_v=0.05, thr_d=5) + elif self.name == "depth": + result = self.processor(image) + result, _ = result + else: + result = self.processor(image) + + result = HWC3(remove_pad(result)) + result = cv2.resize(result, (width, height)) + return result + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + repo_id_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fp8": ModelSpec( + repo_id="XLabs-AI/flux-dev-fp8", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux-dev-fp8.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FP8"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + +def load_from_repo_id(repo_id, checkpoint_name): + ckpt_path = hf_hub_download(repo_id, checkpoint_name) + sd = load_sft(ckpt_path, device='cpu') + return sd + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') + + + model = Flux(configs[name].params).to(torch.bfloat16) + + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device='cpu') + with open(json_path, "r") as f: + quantization_map = json.load(f) + print("Start a quantization process...") + requantize(model, sd, quantization_map, device=device) + print("Model is quantized!") + return model + +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = ControlNetFlux(configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was choosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] diff --git a/third_party/src/flux_ch/xflux_pipeline.py b/third_party/src/flux_ch/xflux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..579e46a890bfbc7b9c62d6effb04189f3f26fb0c --- /dev/null +++ b/third_party/src/flux_ch/xflux_pipeline.py @@ -0,0 +1,421 @@ +from PIL import Image, ExifTags +import numpy as np +import torch +from torch import Tensor + +from einops import rearrange +import uuid +import os + +from .modules.layers import ( + SingleStreamBlockProcessor, + DoubleStreamBlockProcessor, + SingleStreamBlockLoraProcessor, + DoubleStreamBlockLoraProcessor, + IPDoubleStreamBlockProcessor, + ImageProjModel, +) +from .sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack +from .util import ( + load_ae, + load_clip, + load_flow_model, + load_t5, + load_controlnet, + load_flow_model_quintized, + Annotator, + get_lora_rank, + load_checkpoint +) + +from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor + +class XFluxPipeline: + def __init__(self, model_type, device, offload: bool = False): + self.device = torch.device(device) + self.offload = offload + self.model_type = model_type + + self.clip = load_clip(self.device) + self.t5 = load_t5(self.device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device) + if "fp8" in model_type: + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + else: + self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + + self.image_encoder_path = "openai/clip-vit-large-patch14" + self.hf_lora_collection = "XLabs-AI/flux-lora-collection" + self.lora_types_to_names = { + "realism": "lora.safetensors", + } + self.controlnet_loaded = False + self.ip_loaded = False + + def set_ip(self, local_path: str = None, repo_id = None, name: str = None): + self.model.to(self.device) + + # unpack checkpoint + checkpoint = load_checkpoint(local_path, repo_id, name) + prefix = "double_blocks." + blocks = {} + proj = {} + + for key, value in checkpoint.items(): + if key.startswith(prefix): + blocks[key[len(prefix):].replace('.processor.', '.')] = value + if key.startswith("ip_adapter_proj_model"): + proj[key[len("ip_adapter_proj_model."):]] = value + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + + # setup image embedding projection model + self.improj = ImageProjModel(4096, 768, 4) + self.improj.load_state_dict(proj) + self.improj = self.improj.to(self.device, dtype=torch.bfloat16) + + ip_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + ip_state_dict = {} + for k in checkpoint.keys(): + if name in k: + ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] + if ip_state_dict: + ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) + ip_attn_procs[name].load_state_dict(ip_state_dict) + ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) + else: + ip_attn_procs[name] = self.model.attn_processors[name] + + self.model.set_attn_processor(ip_attn_procs) + self.ip_loaded = True + + def set_lora(self, local_path: str = None, repo_id: str = None, + name: str = None, lora_weight: int = 0.7): + checkpoint = load_checkpoint(local_path, repo_id, name) + self.update_model_with_lora(checkpoint, lora_weight) + + def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): + checkpoint = load_checkpoint( + None, self.hf_lora_collection, self.lora_types_to_names[lora_type] + ) + self.update_model_with_lora(checkpoint, lora_weight) + + def update_model_with_lora(self, checkpoint, lora_weight): + rank = get_lora_rank(checkpoint) + lora_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight + + if len(lora_state_dict): + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) + else: + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(self.device) + else: + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockProcessor() + else: + lora_attn_procs[name] = DoubleStreamBlockProcessor() + + self.model.set_attn_processor(lora_attn_procs) + + def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): + self.model.to(self.device) + self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + + checkpoint = load_checkpoint(local_path, repo_id, name) + self.controlnet.load_state_dict(checkpoint, strict=False) + self.annotator = Annotator(control_type, self.device) + self.controlnet_loaded = True + self.control_type = control_type + + def get_image_proj( + self, + image_prompt: Tensor, + ): + # encode image-prompt embeds + image_prompt = self.clip_image_processor( + images=image_prompt, + return_tensors="pt" + ).pixel_values + image_prompt = image_prompt.to(self.image_encoder.device) + image_prompt_embeds = self.image_encoder( + image_prompt + ).image_embeds.to( + device=self.device, dtype=torch.bfloat16, + ) + # encode image + image_proj = self.improj(image_prompt_embeds) + return image_proj + + def __call__(self, + prompt: str, + image_prompt: Image = None, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + true_gs: float = 3, + control_weight: float = 0.9, + ip_scale: float = 1.0, + neg_ip_scale: float = 1.0, + neg_prompt: str = '', + neg_image_prompt: Image = None, + timestep_to_start_cfg: int = 0, + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + image_proj = None + neg_image_proj = None + if not (image_prompt is None and neg_image_prompt is None) : + assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' + + if image_prompt is None: + image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + if neg_image_prompt is None: + neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + + image_proj = self.get_image_proj(image_prompt) + neg_image_proj = self.get_image_proj(neg_image_prompt) + + if self.controlnet_loaded: + controlnet_image = self.annotator(controlnet_image, width, height) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute( + 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward( + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + true_gs=true_gs, + control_weight=control_weight, + neg_prompt=neg_prompt, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + def generate_from_feat(self, + prompt: str, + image_proj: Tensor = None, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + true_gs: float = 3, + control_weight: float = 0.9, + ip_scale: float = 1.0, + neg_ip_scale: float = 1.0, + neg_prompt: str = '', + neg_image_proj: Tensor = None, + timestep_to_start_cfg: int = 0, + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + # image_proj = None + # neg_image_proj = None + # if not (image_prompt is None and neg_image_prompt is None) : + # assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' + + # if image_prompt is None: + # image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + # if neg_image_prompt is None: + # neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + + # image_proj = self.get_image_proj(image_prompt) + # neg_image_proj = self.get_image_proj(neg_image_prompt) + + # if self.controlnet_loaded: + # controlnet_image = self.annotator(controlnet_image, width, height) + # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + # controlnet_image = controlnet_image.permute( + # 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward( + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + true_gs=true_gs, + control_weight=control_weight, + neg_prompt=neg_prompt, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + + @torch.inference_mode() + def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, + lora_weight, local_path, lora_local_path, ip_local_path): + if controlnet_image is not None: + controlnet_image = Image.fromarray(controlnet_image) + if ((self.controlnet_loaded and control_type != self.control_type) + or not self.controlnet_loaded): + if local_path is not None: + self.set_controlnet(control_type, local_path=local_path) + else: + self.set_controlnet(control_type, local_path=None, + repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", + name=f"flux-{control_type}-controlnet-v3.safetensors") + if lora_local_path is not None: + self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) + if image_prompt is not None: + image_prompt = Image.fromarray(image_prompt) + if neg_image_prompt is not None: + neg_image_prompt = Image.fromarray(neg_image_prompt) + if not self.ip_loaded: + if ip_local_path is not None: + self.set_ip(local_path=ip_local_path) + else: + self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", + name="flux-ip-adapter.safetensors") + seed = int(seed) + if seed == -1: + seed = torch.Generator(device="cpu").seed() + + img = self(prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg) + + filename = f"output/gradio/{uuid.uuid4()}.jpg" + os.makedirs(os.path.dirname(filename), exist_ok=True) + exif_data = Image.Exif() + exif_data[ExifTags.Base.Make] = "XLabs AI" + exif_data[ExifTags.Base.Model] = self.model_type + img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) + return img, filename + + def forward( + self, + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image = None, + timestep_to_start_cfg = 0, + true_gs = 3.5, + control_weight = 0.9, + neg_prompt="", + image_proj=None, + neg_image_proj=None, + ip_scale=1.0, + neg_ip_scale=1.0, + ): + x = get_noise( + 1, height, width, device=self.device, + dtype=torch.bfloat16, seed=seed + ) + timesteps = get_schedule( + num_steps, + (width // 8) * (height // 8) // (16 * 16), + shift=True, + ) + torch.manual_seed(seed) + with torch.no_grad(): + if self.offload: + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) + neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) + + if self.offload: + self.offload_model_to_cpu(self.t5, self.clip) + self.model = self.model.to(self.device) + if self.controlnet_loaded: + x = denoise_controlnet( + self.model, + **inp_cond, + controlnet=self.controlnet, + timesteps=timesteps, + guidance=guidance, + controlnet_cond=controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + controlnet_gs=control_weight, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + else: + x = denoise( + self.model, + **inp_cond, + timesteps=timesteps, + guidance=guidance, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + if self.offload: + self.offload_model_to_cpu(self.model) + self.ae.decoder.to(x.device) + x = unpack(x.float(), height, width) + x = self.ae.decode(x) + self.offload_model_to_cpu(self.ae.decoder) + + x1 = x.clamp(-1, 1) + x1 = rearrange(x1[-1], "c h w -> h w c") + output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) + return output_img + + def offload_model_to_cpu(self, *models): + if not self.offload: return + for model in models: + model.cpu() + torch.cuda.empty_cache() + + +class XFluxSampler(XFluxPipeline): + def __init__(self, clip, t5, ae, model, device): + self.clip = clip + self.t5 = t5 + self.ae = ae + self.model = model + self.model.eval() + self.device = device + self.controlnet_loaded = False + self.ip_loaded = False + self.offload = False diff --git a/third_party/src/flux_edit/__init__.py b/third_party/src/flux_edit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43c365a49d6980e88acba10ef3069f110a59644a --- /dev/null +++ b/third_party/src/flux_edit/__init__.py @@ -0,0 +1,11 @@ +try: + from ._version import version as __version__ # type: ignore + from ._version import version_tuple +except ImportError: + __version__ = "unknown (no version information available)" + version_tuple = (0, 0, "unknown", "noinfo") + +from pathlib import Path + +PACKAGE = __package__.replace("_", "-") +PACKAGE_ROOT = Path(__file__).parent diff --git a/third_party/src/flux_edit/__main__.py b/third_party/src/flux_edit/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5cf0fd2444d4cda4053fa74dad3371556b886e5 --- /dev/null +++ b/third_party/src/flux_edit/__main__.py @@ -0,0 +1,4 @@ +from .cli import app + +if __name__ == "__main__": + app() diff --git a/third_party/src/flux_edit/__pycache__/_math.cpython-310.pyc b/third_party/src/flux_edit/__pycache__/_math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9020b1c267952f1f66cbf1486d5c7819d89503a8 Binary files /dev/null and b/third_party/src/flux_edit/__pycache__/_math.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/__pycache__/controlnet.cpython-310.pyc b/third_party/src/flux_edit/__pycache__/controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f29fdf5c8805963b2c047b8c5f9c3ca68f863761 Binary files /dev/null and b/third_party/src/flux_edit/__pycache__/controlnet.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/__pycache__/math.cpython-310.pyc b/third_party/src/flux_edit/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d460fb1dc13fd0bd2ac7a6decf499d9385870f0 Binary files /dev/null and b/third_party/src/flux_edit/__pycache__/math.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/__pycache__/model.cpython-310.pyc b/third_party/src/flux_edit/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0835d8e1bcd7cb3d3913b8fdb903d05a6f814e9 Binary files /dev/null and b/third_party/src/flux_edit/__pycache__/model.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/__pycache__/sampling.cpython-310.pyc b/third_party/src/flux_edit/__pycache__/sampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed95aa9693978bece22e2ae41d73186a46caec5f Binary files /dev/null and b/third_party/src/flux_edit/__pycache__/sampling.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/__pycache__/util.cpython-310.pyc b/third_party/src/flux_edit/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b38ebaa98813d7809574ffc0a0d2ed8b0b82033 Binary files /dev/null and b/third_party/src/flux_edit/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/_math.py b/third_party/src/flux_edit/_math.py new file mode 100644 index 0000000000000000000000000000000000000000..0156bb6a205dec340e029f0c87cf70ae8709ae12 --- /dev/null +++ b/third_party/src/flux_edit/_math.py @@ -0,0 +1,30 @@ +import torch +from einops import rearrange +from torch import Tensor + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/third_party/src/flux_edit/annotator/__pycache__/util.cpython-310.pyc b/third_party/src/flux_edit/annotator/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e2b2b582d69cfdd33d86534fb71b8a190ca3251 Binary files /dev/null and b/third_party/src/flux_edit/annotator/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/canny/__init__.py b/third_party/src/flux_edit/annotator/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b --- /dev/null +++ b/third_party/src/flux_edit/annotator/canny/__init__.py @@ -0,0 +1,6 @@ +import cv2 + + +class CannyDetector: + def __call__(self, img, low_threshold, high_threshold): + return cv2.Canny(img, low_threshold, high_threshold) diff --git a/third_party/src/flux_edit/annotator/canny/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_edit/annotator/canny/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1aff6963463beb747d673cb78bd7729e734c6914 Binary files /dev/null and b/third_party/src/flux_edit/annotator/canny/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/ckpts/ckpts.txt b/third_party/src/flux_edit/annotator/ckpts/ckpts.txt new file mode 100644 index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b --- /dev/null +++ b/third_party/src/flux_edit/annotator/ckpts/ckpts.txt @@ -0,0 +1 @@ +Weights here. \ No newline at end of file diff --git a/third_party/src/flux_edit/annotator/dwpose/__init__.py b/third_party/src/flux_edit/annotator/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6e172d05c9de3f1cdd61e330ad8d6dde46dfdd --- /dev/null +++ b/third_party/src/flux_edit/annotator/dwpose/__init__.py @@ -0,0 +1,68 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + +import torch +import numpy as np +from . import util +from .wholebody import Wholebody + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + canvas = util.draw_bodypose(canvas, candidate, subset) + + canvas = util.draw_handpose(canvas, hands) + + canvas = util.draw_facepose(canvas, faces) + + return canvas + + +class DWposeDetector: + def __init__(self, device): + + self.pose_estimation = Wholebody(device) + + def __call__(self, oriImg): + oriImg = oriImg.copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.pose_estimation(oriImg) + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18] + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + bodies = dict(candidate=body, subset=score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + return draw_pose(pose, H, W) diff --git a/third_party/src/flux_edit/annotator/dwpose/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_edit/annotator/dwpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be30c63bab6d7862d703ac96a3849ba7d7bcb96f Binary files /dev/null and b/third_party/src/flux_edit/annotator/dwpose/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc b/third_party/src/flux_edit/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aa572792a22603d17c7bda459ac65d2d58065e2 Binary files /dev/null and b/third_party/src/flux_edit/annotator/dwpose/__pycache__/onnxdet.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc b/third_party/src/flux_edit/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e72be42702d881ce28fefa155cf41019db7109c Binary files /dev/null and b/third_party/src/flux_edit/annotator/dwpose/__pycache__/onnxpose.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/dwpose/__pycache__/util.cpython-310.pyc b/third_party/src/flux_edit/annotator/dwpose/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19741f0465547d1fc4860c243053b35e3c94b144 Binary files /dev/null and b/third_party/src/flux_edit/annotator/dwpose/__pycache__/util.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc b/third_party/src/flux_edit/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..883fc926e39d18e40ce6262c1dd6532861017ebd Binary files /dev/null and b/third_party/src/flux_edit/annotator/dwpose/__pycache__/wholebody.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/dwpose/onnxdet.py b/third_party/src/flux_edit/annotator/dwpose/onnxdet.py new file mode 100644 index 0000000000000000000000000000000000000000..e0411c96a5eef41e981bde5481ef7786b242f1fa --- /dev/null +++ b/third_party/src/flux_edit/annotator/dwpose/onnxdet.py @@ -0,0 +1,125 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + output = session.run(None, ort_inputs) + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/third_party/src/flux_edit/annotator/dwpose/onnxpose.py b/third_party/src/flux_edit/annotator/dwpose/onnxpose.py new file mode 100644 index 0000000000000000000000000000000000000000..79cd4a06241123af81ea22446a4ca8816716443f --- /dev/null +++ b/third_party/src/flux_edit/annotator/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/third_party/src/flux_edit/annotator/dwpose/util.py b/third_party/src/flux_edit/annotator/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..73d7d0153b38d143eb8090e07a9784a274b619ed --- /dev/null +++ b/third_party/src/flux_edit/annotator/dwpose/util.py @@ -0,0 +1,297 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/third_party/src/flux_edit/annotator/dwpose/wholebody.py b/third_party/src/flux_edit/annotator/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..d73f19d61c238c47cf7de98d01385b2150a5361f --- /dev/null +++ b/third_party/src/flux_edit/annotator/dwpose/wholebody.py @@ -0,0 +1,48 @@ +import cv2 +import numpy as np + +import onnxruntime as ort +from huggingface_hub import hf_hub_download +from .onnxdet import inference_detector +from .onnxpose import inference_pose + + +class Wholebody: + def __init__(self, device="cuda:0"): + providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'] + onnx_det = hf_hub_download("yzd-v/DWPose", "yolox_l.onnx") + onnx_pose = hf_hub_download("yzd-v/DWPose", "dw-ll_ucoco_384.onnx") + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + diff --git a/third_party/src/flux_edit/annotator/hed/__init__.py b/third_party/src/flux_edit/annotator/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d24a89aac52566caad511055b08137cd2a03d60a --- /dev/null +++ b/third_party/src/flux_edit/annotator/hed/__init__.py @@ -0,0 +1,95 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import cv2 +import torch +import numpy as np + +from huggingface_hub import hf_hub_download +from einops import rearrange +from annotator.util import annotator_ckpts_path + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + + +class HEDdetector: + def __init__(self): + modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth") + if not os.path.exists(modelpath): + modelpath = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth") + self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() + self.netNetwork.load_state_dict(torch.load(modelpath)) + + def __call__(self, input_image): + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image.copy()).float().cuda() + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + return edge + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z diff --git a/third_party/src/flux_edit/annotator/hed/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_edit/annotator/hed/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4773e3fc122ef86cc94834f5e80174091f0a882d Binary files /dev/null and b/third_party/src/flux_edit/annotator/hed/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/midas/LICENSE b/third_party/src/flux_edit/annotator/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/src/flux_edit/annotator/midas/__init__.py b/third_party/src/flux_edit/annotator/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36789767f35bcc169c2cbf096e2747539df4f14d --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/__init__.py @@ -0,0 +1,42 @@ +# Midas Depth Estimation +# From https://github.com/isl-org/MiDaS +# MIT LICENSE + +import cv2 +import numpy as np +import torch + +from einops import rearrange +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self): + self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): + assert input_image.ndim == 3 + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8) + + return depth_image, normal_image diff --git a/third_party/src/flux_edit/annotator/midas/__pycache__/__init__.cpython-310.pyc b/third_party/src/flux_edit/annotator/midas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a428326f0ce37c2246c6f293d316bbbfd581dcb6 Binary files /dev/null and b/third_party/src/flux_edit/annotator/midas/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/midas/__pycache__/api.cpython-310.pyc b/third_party/src/flux_edit/annotator/midas/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ea28c08ac43739dbb07b03d84531755b1968bbf Binary files /dev/null and b/third_party/src/flux_edit/annotator/midas/__pycache__/api.cpython-310.pyc differ diff --git a/third_party/src/flux_edit/annotator/midas/api.py b/third_party/src/flux_edit/annotator/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..5010d294f550905e241b696e9a031f69c9ef910a --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/api.py @@ -0,0 +1,168 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from huggingface_hub import hf_hub_download + +from .midas.dpt_depth import DPTDepthModel +from .midas.midas_net import MidasNet +from .midas.midas_net_custom import MidasNet_small +from .midas.transforms import Resize, NormalizeImage, PrepareForNet +from annotator.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + model_path = hf_hub_download("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt") + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/third_party/src/flux_edit/annotator/midas/midas/__init__.py b/third_party/src/flux_edit/annotator/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/src/flux_edit/annotator/midas/midas/base_model.py b/third_party/src/flux_edit/annotator/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/third_party/src/flux_edit/annotator/midas/midas/blocks.py b/third_party/src/flux_edit/annotator/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/third_party/src/flux_edit/annotator/midas/midas/dpt_depth.py b/third_party/src/flux_edit/annotator/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/third_party/src/flux_edit/annotator/midas/midas/midas_net.py b/third_party/src/flux_edit/annotator/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/third_party/src/flux_edit/annotator/midas/midas/midas_net_custom.py b/third_party/src/flux_edit/annotator/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/third_party/src/flux_edit/annotator/midas/midas/transforms.py b/third_party/src/flux_edit/annotator/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/third_party/src/flux_edit/annotator/midas/midas/vit.py b/third_party/src/flux_edit/annotator/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/third_party/src/flux_edit/annotator/midas/utils.py b/third_party/src/flux_edit/annotator/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/third_party/src/flux_edit/annotator/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/third_party/src/flux_edit/annotator/util.py b/third_party/src/flux_edit/annotator/util.py new file mode 100644 index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05 --- /dev/null +++ b/third_party/src/flux_edit/annotator/util.py @@ -0,0 +1,38 @@ +import numpy as np +import cv2 +import os + + +annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') + + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img diff --git a/third_party/src/flux_edit/api.py b/third_party/src/flux_edit/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b08202adb35d2ffae320bb9b47f567e538837836 --- /dev/null +++ b/third_party/src/flux_edit/api.py @@ -0,0 +1,194 @@ +import io +import os +import time +from pathlib import Path + +import requests +from PIL import Image + +API_ENDPOINT = "https://api.bfl.ml" + + +class ApiException(Exception): + def __init__(self, status_code: int, detail: str | list[dict] | None = None): + super().__init__() + self.detail = detail + self.status_code = status_code + + def __str__(self) -> str: + return self.__repr__() + + def __repr__(self) -> str: + if self.detail is None: + message = None + elif isinstance(self.detail, str): + message = self.detail + else: + message = "[" + ",".join(d["msg"] for d in self.detail) + "]" + return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" + + +class ImageRequest: + def __init__( + self, + prompt: str, + width: int = 1024, + height: int = 1024, + name: str = "flux.1-pro", + num_steps: int = 50, + prompt_upsampling: bool = False, + seed: int | None = None, + validate: bool = True, + launch: bool = True, + api_key: str | None = None, + ): + """ + Manages an image generation request to the API. + + Args: + prompt: Prompt to sample + width: Width of the image in pixel + height: Height of the image in pixel + name: Name of the model + num_steps: Number of network evaluations + prompt_upsampling: Use prompt upsampling + seed: Fix the generation seed + validate: Run input validation + launch: Directly launches request + api_key: Your API key if not provided by the environment + + Raises: + ValueError: For invalid input + ApiException: For errors raised from the API + """ + if validate: + if name not in ["flux.1-pro"]: + raise ValueError(f"Invalid model {name}") + elif width % 32 != 0: + raise ValueError(f"width must be divisible by 32, got {width}") + elif not (256 <= width <= 1440): + raise ValueError(f"width must be between 256 and 1440, got {width}") + elif height % 32 != 0: + raise ValueError(f"height must be divisible by 32, got {height}") + elif not (256 <= height <= 1440): + raise ValueError(f"height must be between 256 and 1440, got {height}") + elif not (1 <= num_steps <= 50): + raise ValueError(f"steps must be between 1 and 50, got {num_steps}") + + self.request_json = { + "prompt": prompt, + "width": width, + "height": height, + "variant": name, + "steps": num_steps, + "prompt_upsampling": prompt_upsampling, + } + if seed is not None: + self.request_json["seed"] = seed + + self.request_id: str | None = None + self.result: dict | None = None + self._image_bytes: bytes | None = None + self._url: str | None = None + if api_key is None: + self.api_key = os.environ.get("BFL_API_KEY") + else: + self.api_key = api_key + + if launch: + self.request() + + def request(self): + """ + Request to generate the image. + """ + if self.request_id is not None: + return + response = requests.post( + f"{API_ENDPOINT}/v1/image", + headers={ + "accept": "application/json", + "x-key": self.api_key, + "Content-Type": "application/json", + }, + json=self.request_json, + ) + result = response.json() + if response.status_code != 200: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + self.request_id = response.json()["id"] + + def retrieve(self) -> dict: + """ + Wait for the generation to finish and retrieve response. + """ + if self.request_id is None: + self.request() + while self.result is None: + response = requests.get( + f"{API_ENDPOINT}/v1/get_result", + headers={ + "accept": "application/json", + "x-key": self.api_key, + }, + params={ + "id": self.request_id, + }, + ) + result = response.json() + if "status" not in result: + raise ApiException(status_code=response.status_code, detail=result.get("detail")) + elif result["status"] == "Ready": + self.result = result["result"] + elif result["status"] == "Pending": + time.sleep(0.5) + else: + raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") + return self.result + + @property + def bytes(self) -> bytes: + """ + Generated image as bytes. + """ + if self._image_bytes is None: + response = requests.get(self.url) + if response.status_code == 200: + self._image_bytes = response.content + else: + raise ApiException(status_code=response.status_code) + return self._image_bytes + + @property + def url(self) -> str: + """ + Public url to retrieve the image from + """ + if self._url is None: + result = self.retrieve() + self._url = result["sample"] + return self._url + + @property + def image(self) -> Image.Image: + """ + Load the image as a PIL Image + """ + return Image.open(io.BytesIO(self.bytes)) + + def save(self, path: str): + """ + Save the generated image to a local path + """ + suffix = Path(self.url).suffix + if not path.endswith(suffix): + path = path + suffix + Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as file: + file.write(self.bytes) + + +if __name__ == "__main__": + from fire import Fire + + Fire(ImageRequest) diff --git a/third_party/src/flux_edit/cli.py b/third_party/src/flux_edit/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..25e18cfaa3a8041069782994e68867e3423e4482 --- /dev/null +++ b/third_party/src/flux_edit/cli.py @@ -0,0 +1,289 @@ + +import os +import re +import time +from dataclasses import dataclass +from glob import iglob +import argparse +import torch +from einops import rearrange +# from fire import Fire +from PIL import ExifTags, Image + +from sampling import denoise, get_noise, get_schedule, prepare, unpack +from util import (configs, load_ae, load_clip, + load_flow_model, load_t5) +from transformers import pipeline +from PIL import Image +import numpy as np + +import os +os.environ["FLUX_DEV"] = "/group/40034/hilljswang/flux/ckpt/flux1-dev.safetensors" +os.environ["FLUX_SCHNELL"] = "/group/40034/leizizhang/pretrained/FLUX.1-schnell/flux1-schnell.safetensors" +os.environ["AE"] = "/group/40034/hilljswang/flux/ckpt/ae.safetensors" +NSFW_THRESHOLD = 0.85 + +@dataclass +class SamplingOptions: + source_prompt: str + target_prompt: str + # prompt: str + width: int + height: int + num_steps: int + guidance: float + seed: int | None + +@torch.inference_mode() +def encode(init_image, torch_device, ae): + init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1 + init_image = init_image.unsqueeze(0) + init_image = init_image.to(torch_device) + init_image = ae.encode(init_image.to()).to(torch.bfloat16) + return init_image + +@torch.inference_mode() +def main( + args, + seed: int | None = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + num_steps: int | None = None, + loop: bool = False, + offload: bool = False, + add_sampling_metadata: bool = True, +): + """ + Sample the flux model. Either interactively (set `--loop`) or run for a + single image. + + Args: + name: Name of the model to load + height: height of the sample in pixels (should be a multiple of 16) + width: width of the sample in pixels (should be a multiple of 16) + seed: Set a seed for sampling + output_name: where to save the output image, `{idx}` will be replaced + by the index of the sample + prompt: Prompt used for sampling + device: Pytorch device + num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) + loop: start an interactive session and sample multiple times + guidance: guidance value used for guidance distillation + add_sampling_metadata: Add the prompt to the image Exif metadata + """ + torch.set_grad_enabled(False) + name = args.name + source_prompt = args.source_prompt + target_prompt = args.target_prompt + guidance = args.guidance + output_dir = args.output_dir + num_steps = args.num_steps + # import pdb;pdb.set_trace() + # use_solver = args.use_solver + offload = args.offload + + # nsfw_classifier = pipeline("image-classification", model="/group/40034/hilljswang/flux/nsfw_image_detection", device=device) + + if name not in configs: + available = ", ".join(configs.keys()) + raise ValueError(f"Got unknown model name: {name}, chose from {available}") + + torch_device = torch.device(device) + if num_steps is None: + num_steps = 4 if name == "flux-schnell" else 25 + + # init all components + t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) + clip = load_clip(torch_device) + model = load_flow_model(name, device="cpu" if offload else torch_device) + ae = load_ae(name, device="cpu" if offload else torch_device) + + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.encoder.to(torch_device) + + init_image = None + if os.path.isdir(args.source_img_dir): + for file_name in sorted(os.listdir(args.source_img_dir)): + path= os.path.join(args.source_img_dir, file_name) + if init_image is None: + init_image = np.array(Image.open(path)) + width, height = init_image.shape[0], init_image.shape[1] + init_image = encode(init_image, torch_device, ae) + else: + init_image = torch.cat((init_image, encode(np.array(Image.open(path)), torch_device, ae)), dim=0) + else: + init_image = np.array(Image.open(args.source_img_dir)) + shape = init_image.shape + # import pdb;pdb.set_trace() + + new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16 + new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16 + + init_image = init_image[:new_h, :new_w, :] + + width, height = init_image.shape[0], init_image.shape[1] + init_image = encode(init_image, torch_device, ae) + # import pdb;pdb.set_trace() + + rng = torch.Generator(device="cpu") + opts = SamplingOptions( + source_prompt=source_prompt, + target_prompt=target_prompt, + width=width, + height=height, + num_steps=num_steps, + guidance=guidance, + seed=seed, + ) + + if loop: + opts = parse_prompt(opts) + + while opts is not None: + if opts.seed is None: + opts.seed = rng.seed() + print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}") + t0 = time.perf_counter() + + # prepare input + # x = get_noise( + # 1, + # opts.height, + # opts.width, + # device=torch_device, + # dtype=torch.bfloat16, + # seed=opts.seed, + # ) + + opts.seed = None + if offload: + ae = ae.cpu() + torch.cuda.empty_cache() + t5, clip = t5.to(torch_device), clip.to(torch_device) + + #############inverse####################### + info = {} + info['feature_path'] = args.feature_path + info['inject_type'] = args.inject_type + info['inject_step'] = args.inject + info['partial'] = args.partial + if not os.path.exists(args.feature_path): + os.mkdir(args.feature_path) + + inp = prepare(t5, clip, init_image, prompt=opts.source_prompt) + inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt) + timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) + + # offload TEs to CPU, load model to gpu + if offload: + t5, clip = t5.cpu(), clip.cpu() + torch.cuda.empty_cache() + model = model.to(torch_device) + + # inversion initial noise + # import pdb;pdb.set_trace() + z = denoise(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info) + + # import pdb;pdb.set_trace() + inp_target["img"] = z + + timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(name != "flux-schnell")) + + # denoise initial noise + x = denoise(model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info) + + # offload model, load autoencoder to gpu + if offload: + model.cpu() + torch.cuda.empty_cache() + ae.decoder.to(x.device) + + # decode latents to pixel space + batch_x = unpack(x.float(), opts.width, opts.height) + + for x in batch_x: + x = x.unsqueeze(0) + output_name = os.path.join(output_dir, "img_{idx}.jpg") + if not os.path.exists(output_dir): + os.makedirs(output_dir) + idx = 0 + else: + fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] + if len(fns) > 0: + idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 + else: + idx = 0 + + with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): + x = ae.decode(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t1 = time.perf_counter() + + fn = output_name.format(idx=idx) + print(f"Done in {t1 - t0:.1f}s. Saving {fn}") + # bring into PIL format and save + x = x.clamp(-1, 1) + # x = embed_watermark(x.float()) + x = rearrange(x[0], "c h w -> h w c") + + img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + # nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] + img.save(fn) + # if nsfw_score < NSFW_THRESHOLD: + # exif_data = Image.Exif() + # exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + # exif_data[ExifTags.Base.Make] = "Black Forest Labs" + # exif_data[ExifTags.Base.Model] = name + # if add_sampling_metadata: + # exif_data[ExifTags.Base.ImageDescription] = source_prompt + # img.save(fn, exif=exif_data, quality=95, subsampling=0) + # idx += 1 + # else: + # print("Your generated image may contain NSFW content.") + + if loop: + print("-" * 80) + opts = parse_prompt(opts) + else: + opts = None + + +# def app(): +# Fire(main) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='FLUX inference') + + parser.add_argument('--name', default='flux-dev', type=str, + help='flux model') + parser.add_argument('--source_img_dir', default='', type=str, + help='flux model') + parser.add_argument('--source_prompt', type=str, + help='source prompt') + parser.add_argument('--target_prompt', type=str, + help='source prompt') + parser.add_argument('--feature_path', type=str, + help='feature_path') + parser.add_argument('--guidance', type=int, default=5, + help='guidance scale') + parser.add_argument('--num_steps', type=int, default=25, + help='num_steps') + parser.add_argument('--inject', type=int, default=20, + help='inject') + parser.add_argument('--partial', type=int, default=None, + help='partial inject') + parser.add_argument('--output_dir', default='output', type=str, + help='output dir') + parser.add_argument('--inject_type', type=str, + help='source prompt') + # parser.add_argument('--use_solver', action='store_true', help='Use solver if flag is present') + parser.add_argument('--offload', action='store_true', help='Use solver if flag is present') + + args = parser.parse_args() + + main(args) diff --git a/third_party/src/flux_edit/controlnet.py b/third_party/src/flux_edit/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..abadc32cc518a752dcc14aaf49bc407537072f99 --- /dev/null +++ b/third_party/src/flux_edit/controlnet.py @@ -0,0 +1,222 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ControlNetFlux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams, controlnet_depth=2): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(controlnet_depth) + ] + ) + + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(controlnet_depth): + controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.gradient_checkpointing = False + self.input_hint_block = nn.Sequential( + nn.Conv2d(3, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1), + nn.SiLU(), + nn.Conv2d(16, 16, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(nn.Conv2d(16, 16, 3, padding=1)) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) + img = img + controlnet_cond + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + + for block in self.double_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + ) + else: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + block_res_samples = block_res_samples + (img,) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + return controlnet_block_res_samples diff --git a/third_party/src/flux_edit/model.py b/third_party/src/flux_edit/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e22115619b1225a53212d91bfb4ce3f837e54b4c --- /dev/null +++ b/third_party/src/flux_edit/model.py @@ -0,0 +1,228 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor, nn +from einops import rearrange + +from modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, + MLPEmbedder, SingleStreamBlock, + timestep_embedding) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + """ + Transformer model for flow matching on sequences. + """ + _supports_gradient_checkpointing = True + + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + for _ in range(params.depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @property + def attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + block_controlnet_hidden_states=None, + guidance: Tensor | None = None, + image_proj: Tensor | None = None, + ip_scale: Tensor | float = 1.0, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError("Didn't get guidance strength for guidance distilled model.") + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + if block_controlnet_hidden_states is not None: + controlnet_depth = len(block_controlnet_hidden_states) + for index_block, block in enumerate(self.double_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + txt, + vec, + pe, + image_proj, + ip_scale, + ) + else: + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + image_proj=image_proj, + ip_scale=ip_scale, + ) + # controlnet residual + if block_controlnet_hidden_states is not None: + img = img + block_controlnet_hidden_states[index_block % 2] + + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + img, + vec, + pe, + ) + else: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img diff --git a/third_party/src/flux_edit/sampling.py b/third_party/src/flux_edit/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..97e87e0996505ca0ff94a4993a8b2c21503b9532 --- /dev/null +++ b/third_party/src/flux_edit/sampling.py @@ -0,0 +1,188 @@ +import math +from typing import Callable + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from model import Flux +from modules.conditioner import HFEmbedder + + +def get_noise( + num_samples: int, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + seed: int, +): + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=device, + dtype=dtype, + generator=torch.Generator(device=device).manual_seed(seed), + ) + + +def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: + bs, c, h, w = img.shape + if bs == 1 and not isinstance(prompt, str): + bs = len(prompt) + + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if img.shape[0] == 1 and bs > 1: + img = repeat(img, "1 ... -> bs ...", bs=bs) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + if isinstance(prompt, str): + prompt = [prompt] + txt = t5(prompt) + if txt.shape[0] == 1 and bs > 1: + txt = repeat(txt, "1 ... -> bs ...", bs=bs) + txt_ids = torch.zeros(bs, txt.shape[1], 3) + + vec = clip(prompt) + if vec.shape[0] == 1 and bs > 1: + vec = repeat(vec, "1 ... -> bs ...", bs=bs) + + # img = rearrange(img, "b c d -> 1 (b c) d") + # img_ids = rearrange(img_ids, "b c d -> 1 (b c) d") + # txt = txt[0].unsqueeze(0) + # txt_ids = txt_ids[0].unsqueeze(0) + # vec = vec[0].unsqueeze(0) + return { + "img": img, + "img_ids": img_ids.to(img.device), + "txt": txt.to(img.device), + "txt_ids": txt_ids.to(img.device), + "vec": vec.to(img.device), + } + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +def denoise( + model: Flux, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + inverse, + info, + guidance: float = 4.0, + use_solver = True, +): + # this is ignored for schnell + inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step']) + + if info['partial'] is not None: + timesteps = timesteps[:info['partial']] + + if inverse: + timesteps = timesteps[::-1] + inject_list = inject_list[::-1] + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + info['t'] = t_prev if inverse else t_curr + info['inverse'] = inverse + info['second_order'] = False + info['inject'] = inject_list[i] + # import pdb;pdb.set_trace() + if use_solver: + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + info=info + ) + + img_mid = img + (t_prev - t_curr) / 2 * pred + + t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device) + info['second_order'] = True + pred_mid = model( + img=img_mid, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec_mid, + guidance=guidance_vec, + info=info + ) + # import pdb;pdb.set_trace() + first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2) + + img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order + else: + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + info=info + ) + img = img + (t_prev - t_curr) * pred + return img + + +def unpack(x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) diff --git a/third_party/src/flux_edit/test.sh b/third_party/src/flux_edit/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b2d92850547fd41c82cd9d71de6240328197e89 --- /dev/null +++ b/third_party/src/flux_edit/test.sh @@ -0,0 +1,10 @@ +CUDA_VISIBLE_DEVICES=1 python cli.py --source_prompt "A dog standing." \ + --target_prompt "A dog sitting." \ + --guidance 3 \ + --source_img_dir '/home/jiayi_guo/OnlinePlay/img2img/output/init_image-1.png' \ + --num_steps 30 \ + --inject 5 \ + --name 'flux-dev' \ + --feature_path '/home/jiayi_guo/OnlinePlay/img2img/output/' \ + --output_dir '/home/jiayi_guo/OnlinePlay/third_party/src/flux_edit/results/' \ + --inject_type 'v' \ No newline at end of file diff --git a/third_party/src/flux_edit/util.py b/third_party/src/flux_edit/util.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0cbe10f3d14e8135eed1430925f3ce8b47d241 --- /dev/null +++ b/third_party/src/flux_edit/util.py @@ -0,0 +1,433 @@ +import os +from dataclasses import dataclass + +import torch +import json +import cv2 +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_file as load_sft + +from optimum.quanto import requantize + +from model import Flux, FluxParams +from controlnet import ControlNetFlux +from modules.autoencoder import AutoEncoder, AutoEncoderParams +from modules.conditioner import HFEmbedder +from annotator.dwpose import DWposeDetector +from annotator.mlsd import MLSDdetector +from annotator.canny import CannyDetector +from annotator.midas import MidasDetector +from annotator.hed import HEDdetector +from annotator.tile import TileDetector +from annotator.zoe import ZoeDetector + + +def load_safetensors(path): + tensors = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + +def get_lora_rank(checkpoint): + for k in checkpoint.keys(): + if k.endswith(".down.weight"): + return checkpoint[k].shape[0] + +def load_checkpoint(local_path, repo_id, name): + if local_path is not None: + if '.safetensors' in local_path: + print(f"Loading .safetensors checkpoint from {local_path}") + checkpoint = load_safetensors(local_path) + else: + print(f"Loading checkpoint from {local_path}") + checkpoint = torch.load(local_path, map_location='cpu') + elif repo_id is not None and name is not None: + print(f"Loading checkpoint {name} from repo id {repo_id}") + checkpoint = load_from_repo_id(repo_id, name) + else: + raise ValueError( + "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" + ) + return checkpoint + + +def c_crop(image): + width, height = image.size + new_size = min(width, height) + left = (width - new_size) / 2 + top = (height - new_size) / 2 + right = (width + new_size) / 2 + bottom = (height + new_size) / 2 + return image.crop((left, top, right, bottom)) + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + +#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 +#Added upscale_method, mode params +def resize_image_with_pad(input_image, resolution, skip_hwc3=False, mode='edge'): + if skip_hwc3: + img = input_image + else: + img = HWC3(input_image) + H_raw, W_raw, _ = img.shape + if resolution == 0: + return img, lambda x: x + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(np.round(float(H_raw) * k)) + W_target = int(np.round(float(W_raw) * k)) + img = cv2.resize(img, (W_target, H_target), interpolation=cv2.INTER_AREA) + H_pad, W_pad = pad64(H_target), pad64(W_target) + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) + + def remove_pad(x): + return safer_memory(x[:H_target, :W_target, ...]) + + return safer_memory(img_padded), remove_pad + +class Annotator: + def __init__(self, name: str, device: str): + if name == "canny": + processor = CannyDetector() + elif name == "openpose": + processor = DWposeDetector(device) + elif name == "depth": + processor = MidasDetector() + elif name == "hed": + processor = HEDdetector() + elif name == "hough": + processor = MLSDdetector() + elif name == "tile": + processor = TileDetector() + elif name == "zoe": + processor = ZoeDetector() + self.name = name + self.processor = processor + + def __call__(self, image: Image, width: int, height: int): + image = np.array(image) + detect_resolution = max(width, height) + image, remove_pad = resize_image_with_pad(image, detect_resolution) + + image = np.array(image) + if self.name == "canny": + result = self.processor(image, low_threshold=100, high_threshold=200) + elif self.name == "hough": + result = self.processor(image, thr_v=0.05, thr_d=5) + elif self.name == "depth": + result = self.processor(image) + result, _ = result + else: + result = self.processor(image) + + result = HWC3(remove_pad(result)) + result = cv2.resize(result, (width, height)) + return result + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: str | None + ae_path: str | None + repo_id: str | None + repo_flow: str | None + repo_ae: str | None + repo_id_ae: str | None + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-dev-fp8": ModelSpec( + repo_id="XLabs-AI/flux-dev-fp8", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux-dev-fp8.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV_FP8"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_id_ae="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def print_load_warning(missing: list[str], unexpected: list[str]) -> None: + if len(missing) > 0 and len(unexpected) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + print("\n" + "-" * 79 + "\n") + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + elif len(missing) > 0: + print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) + elif len(unexpected) > 0: + print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) + +def load_from_repo_id(repo_id, checkpoint_name): + ckpt_path = hf_hub_download(repo_id, checkpoint_name) + sd = load_sft(ckpt_path, device='cpu') + return sd + +def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params).to(torch.bfloat16) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) + + with torch.device("meta" if ckpt_path is not None else device): + model = Flux(configs[name].params) + + if ckpt_path is not None: + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return model + +def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): + # Loading Flux + print("Init model") + ckpt_path = configs[name].ckpt_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') + + + model = Flux(configs[name].params).to(torch.bfloat16) + + print("Loading checkpoint") + # load_sft doesn't support torch.device + sd = load_sft(ckpt_path, device='cpu') + with open(json_path, "r") as f: + quantization_map = json.load(f) + print("Start a quantization process...") + requantize(model, sd, quantization_map, device=device) + print("Model is quantized!") + return model + +def load_controlnet(name, device, transformer=None): + with torch.device(device): + controlnet = ControlNetFlux(configs[name].params) + if transformer is not None: + controlnet.load_state_dict(transformer.state_dict(), strict=False) + return controlnet + +def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: + # max length 64, 128, 256 and 512 should work (if your sequence is short enough) + return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) + +def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: + return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) + + +def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: + ckpt_path = configs[name].ae_path + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) + + # Loading the autoencoder + print("Init AE") + with torch.device("meta" if ckpt_path is not None else device): + ae = AutoEncoder(configs[name].ae_params) + + if ckpt_path is not None: + sd = load_sft(ckpt_path, device=str(device)) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + print_load_warning(missing, unexpected) + return ae + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [-1, 1] + + Returns: + same as input but watermarked + """ + image = 0.5 * image + 0.5 + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( + image.device + ) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + image = 2 * image - 1 + return image + + +# A fixed 48-bit message that was choosen at random +WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] diff --git a/third_party/src/flux_edit/xflux_pipeline.py b/third_party/src/flux_edit/xflux_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..579e46a890bfbc7b9c62d6effb04189f3f26fb0c --- /dev/null +++ b/third_party/src/flux_edit/xflux_pipeline.py @@ -0,0 +1,421 @@ +from PIL import Image, ExifTags +import numpy as np +import torch +from torch import Tensor + +from einops import rearrange +import uuid +import os + +from .modules.layers import ( + SingleStreamBlockProcessor, + DoubleStreamBlockProcessor, + SingleStreamBlockLoraProcessor, + DoubleStreamBlockLoraProcessor, + IPDoubleStreamBlockProcessor, + ImageProjModel, +) +from .sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack +from .util import ( + load_ae, + load_clip, + load_flow_model, + load_t5, + load_controlnet, + load_flow_model_quintized, + Annotator, + get_lora_rank, + load_checkpoint +) + +from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor + +class XFluxPipeline: + def __init__(self, model_type, device, offload: bool = False): + self.device = torch.device(device) + self.offload = offload + self.model_type = model_type + + self.clip = load_clip(self.device) + self.t5 = load_t5(self.device, max_length=512) + self.ae = load_ae(model_type, device="cpu" if offload else self.device) + if "fp8" in model_type: + self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device) + else: + self.model = load_flow_model(model_type, device="cpu" if offload else self.device) + + self.image_encoder_path = "openai/clip-vit-large-patch14" + self.hf_lora_collection = "XLabs-AI/flux-lora-collection" + self.lora_types_to_names = { + "realism": "lora.safetensors", + } + self.controlnet_loaded = False + self.ip_loaded = False + + def set_ip(self, local_path: str = None, repo_id = None, name: str = None): + self.model.to(self.device) + + # unpack checkpoint + checkpoint = load_checkpoint(local_path, repo_id, name) + prefix = "double_blocks." + blocks = {} + proj = {} + + for key, value in checkpoint.items(): + if key.startswith(prefix): + blocks[key[len(prefix):].replace('.processor.', '.')] = value + if key.startswith("ip_adapter_proj_model"): + proj[key[len("ip_adapter_proj_model."):]] = value + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float16 + ) + self.clip_image_processor = CLIPImageProcessor() + + # setup image embedding projection model + self.improj = ImageProjModel(4096, 768, 4) + self.improj.load_state_dict(proj) + self.improj = self.improj.to(self.device, dtype=torch.bfloat16) + + ip_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + ip_state_dict = {} + for k in checkpoint.keys(): + if name in k: + ip_state_dict[k.replace(f'{name}.', '')] = checkpoint[k] + if ip_state_dict: + ip_attn_procs[name] = IPDoubleStreamBlockProcessor(4096, 3072) + ip_attn_procs[name].load_state_dict(ip_state_dict) + ip_attn_procs[name].to(self.device, dtype=torch.bfloat16) + else: + ip_attn_procs[name] = self.model.attn_processors[name] + + self.model.set_attn_processor(ip_attn_procs) + self.ip_loaded = True + + def set_lora(self, local_path: str = None, repo_id: str = None, + name: str = None, lora_weight: int = 0.7): + checkpoint = load_checkpoint(local_path, repo_id, name) + self.update_model_with_lora(checkpoint, lora_weight) + + def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7): + checkpoint = load_checkpoint( + None, self.hf_lora_collection, self.lora_types_to_names[lora_type] + ) + self.update_model_with_lora(checkpoint, lora_weight) + + def update_model_with_lora(self, checkpoint, lora_weight): + rank = get_lora_rank(checkpoint) + lora_attn_procs = {} + + for name, _ in self.model.attn_processors.items(): + lora_state_dict = {} + for k in checkpoint.keys(): + if name in k: + lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight + + if len(lora_state_dict): + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockLoraProcessor(dim=3072, rank=rank) + else: + lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank) + lora_attn_procs[name].load_state_dict(lora_state_dict) + lora_attn_procs[name].to(self.device) + else: + if name.startswith("single_blocks"): + lora_attn_procs[name] = SingleStreamBlockProcessor() + else: + lora_attn_procs[name] = DoubleStreamBlockProcessor() + + self.model.set_attn_processor(lora_attn_procs) + + def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None): + self.model.to(self.device) + self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16) + + checkpoint = load_checkpoint(local_path, repo_id, name) + self.controlnet.load_state_dict(checkpoint, strict=False) + self.annotator = Annotator(control_type, self.device) + self.controlnet_loaded = True + self.control_type = control_type + + def get_image_proj( + self, + image_prompt: Tensor, + ): + # encode image-prompt embeds + image_prompt = self.clip_image_processor( + images=image_prompt, + return_tensors="pt" + ).pixel_values + image_prompt = image_prompt.to(self.image_encoder.device) + image_prompt_embeds = self.image_encoder( + image_prompt + ).image_embeds.to( + device=self.device, dtype=torch.bfloat16, + ) + # encode image + image_proj = self.improj(image_prompt_embeds) + return image_proj + + def __call__(self, + prompt: str, + image_prompt: Image = None, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + true_gs: float = 3, + control_weight: float = 0.9, + ip_scale: float = 1.0, + neg_ip_scale: float = 1.0, + neg_prompt: str = '', + neg_image_prompt: Image = None, + timestep_to_start_cfg: int = 0, + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + image_proj = None + neg_image_proj = None + if not (image_prompt is None and neg_image_prompt is None) : + assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' + + if image_prompt is None: + image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + if neg_image_prompt is None: + neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + + image_proj = self.get_image_proj(image_prompt) + neg_image_proj = self.get_image_proj(neg_image_prompt) + + if self.controlnet_loaded: + controlnet_image = self.annotator(controlnet_image, width, height) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute( + 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward( + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + true_gs=true_gs, + control_weight=control_weight, + neg_prompt=neg_prompt, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + def generate_from_feat(self, + prompt: str, + image_proj: Tensor = None, + controlnet_image: Image = None, + width: int = 512, + height: int = 512, + guidance: float = 4, + num_steps: int = 50, + seed: int = 123456789, + true_gs: float = 3, + control_weight: float = 0.9, + ip_scale: float = 1.0, + neg_ip_scale: float = 1.0, + neg_prompt: str = '', + neg_image_proj: Tensor = None, + timestep_to_start_cfg: int = 0, + ): + width = 16 * (width // 16) + height = 16 * (height // 16) + # image_proj = None + # neg_image_proj = None + # if not (image_prompt is None and neg_image_prompt is None) : + # assert self.ip_loaded, 'You must setup IP-Adapter to add image prompt as input' + + # if image_prompt is None: + # image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + # if neg_image_prompt is None: + # neg_image_prompt = np.zeros((width, height, 3), dtype=np.uint8) + + # image_proj = self.get_image_proj(image_prompt) + # neg_image_proj = self.get_image_proj(neg_image_prompt) + + # if self.controlnet_loaded: + # controlnet_image = self.annotator(controlnet_image, width, height) + # controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + # controlnet_image = controlnet_image.permute( + # 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device) + + return self.forward( + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + true_gs=true_gs, + control_weight=control_weight, + neg_prompt=neg_prompt, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + + @torch.inference_mode() + def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg, control_type, control_weight, + lora_weight, local_path, lora_local_path, ip_local_path): + if controlnet_image is not None: + controlnet_image = Image.fromarray(controlnet_image) + if ((self.controlnet_loaded and control_type != self.control_type) + or not self.controlnet_loaded): + if local_path is not None: + self.set_controlnet(control_type, local_path=local_path) + else: + self.set_controlnet(control_type, local_path=None, + repo_id=f"xlabs-ai/flux-controlnet-{control_type}-v3", + name=f"flux-{control_type}-controlnet-v3.safetensors") + if lora_local_path is not None: + self.set_lora(local_path=lora_local_path, lora_weight=lora_weight) + if image_prompt is not None: + image_prompt = Image.fromarray(image_prompt) + if neg_image_prompt is not None: + neg_image_prompt = Image.fromarray(neg_image_prompt) + if not self.ip_loaded: + if ip_local_path is not None: + self.set_ip(local_path=ip_local_path) + else: + self.set_ip(repo_id="xlabs-ai/flux-ip-adapter", + name="flux-ip-adapter.safetensors") + seed = int(seed) + if seed == -1: + seed = torch.Generator(device="cpu").seed() + + img = self(prompt, image_prompt, controlnet_image, width, height, guidance, + num_steps, seed, true_gs, control_weight, ip_scale, neg_ip_scale, neg_prompt, + neg_image_prompt, timestep_to_start_cfg) + + filename = f"output/gradio/{uuid.uuid4()}.jpg" + os.makedirs(os.path.dirname(filename), exist_ok=True) + exif_data = Image.Exif() + exif_data[ExifTags.Base.Make] = "XLabs AI" + exif_data[ExifTags.Base.Model] = self.model_type + img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) + return img, filename + + def forward( + self, + prompt, + width, + height, + guidance, + num_steps, + seed, + controlnet_image = None, + timestep_to_start_cfg = 0, + true_gs = 3.5, + control_weight = 0.9, + neg_prompt="", + image_proj=None, + neg_image_proj=None, + ip_scale=1.0, + neg_ip_scale=1.0, + ): + x = get_noise( + 1, height, width, device=self.device, + dtype=torch.bfloat16, seed=seed + ) + timesteps = get_schedule( + num_steps, + (width // 8) * (height // 8) // (16 * 16), + shift=True, + ) + torch.manual_seed(seed) + with torch.no_grad(): + if self.offload: + self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) + inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt) + neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) + + if self.offload: + self.offload_model_to_cpu(self.t5, self.clip) + self.model = self.model.to(self.device) + if self.controlnet_loaded: + x = denoise_controlnet( + self.model, + **inp_cond, + controlnet=self.controlnet, + timesteps=timesteps, + guidance=guidance, + controlnet_cond=controlnet_image, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + controlnet_gs=control_weight, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + else: + x = denoise( + self.model, + **inp_cond, + timesteps=timesteps, + guidance=guidance, + timestep_to_start_cfg=timestep_to_start_cfg, + neg_txt=neg_inp_cond['txt'], + neg_txt_ids=neg_inp_cond['txt_ids'], + neg_vec=neg_inp_cond['vec'], + true_gs=true_gs, + image_proj=image_proj, + neg_image_proj=neg_image_proj, + ip_scale=ip_scale, + neg_ip_scale=neg_ip_scale, + ) + + if self.offload: + self.offload_model_to_cpu(self.model) + self.ae.decoder.to(x.device) + x = unpack(x.float(), height, width) + x = self.ae.decode(x) + self.offload_model_to_cpu(self.ae.decoder) + + x1 = x.clamp(-1, 1) + x1 = rearrange(x1[-1], "c h w -> h w c") + output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy()) + return output_img + + def offload_model_to_cpu(self, *models): + if not self.offload: return + for model in models: + model.cpu() + torch.cuda.empty_cache() + + +class XFluxSampler(XFluxPipeline): + def __init__(self, clip, t5, ae, model, device): + self.clip = clip + self.t5 = t5 + self.ae = ae + self.model = model + self.model.eval() + self.device = device + self.controlnet_loaded = False + self.ip_loaded = False + self.offload = False