cangcz commited on
Commit
34ee308
·
verified ·
1 Parent(s): 51e2bb1
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +53 -0
  2. README.md +5 -6
  3. __init__.py +0 -0
  4. anchorcrafter/__init__.py +0 -0
  5. anchorcrafter/dwpose/__init__.py +0 -0
  6. anchorcrafter/dwpose/dwpose_detector.py +71 -0
  7. anchorcrafter/dwpose/onnxdet.py +145 -0
  8. anchorcrafter/dwpose/onnxpose.py +375 -0
  9. anchorcrafter/dwpose/preprocess.py +85 -0
  10. anchorcrafter/dwpose/util.py +133 -0
  11. anchorcrafter/dwpose/wholebody.py +60 -0
  12. anchorcrafter/modules/__init__.py +0 -0
  13. anchorcrafter/modules/attention_processor.py +466 -0
  14. anchorcrafter/modules/obj_attn_net.py +47 -0
  15. anchorcrafter/modules/obj_proj_net.py +33 -0
  16. anchorcrafter/modules/pose_net.py +88 -0
  17. anchorcrafter/modules/track_net.py +76 -0
  18. anchorcrafter/modules/unet.py +509 -0
  19. anchorcrafter/pipelines/pipeline.py +739 -0
  20. anchorcrafter/utils/__init__.py +0 -0
  21. anchorcrafter/utils/geglu_patch.py +10 -0
  22. anchorcrafter/utils/loader.py +45 -0
  23. anchorcrafter/utils/utils.py +51 -0
  24. app.py +332 -0
  25. config/test.yaml +17 -0
  26. constants.py +4 -0
  27. data/anchor/1.jpg +0 -0
  28. data/anchor/2.jpg +0 -0
  29. data/anchor/3.jpg +3 -0
  30. data/anchor/4.jpg +3 -0
  31. data/anchor/5.jpg +3 -0
  32. data/depth_cut/cheese_1.mp4 +3 -0
  33. data/depth_cut/cheese_2.mp4 +3 -0
  34. data/depth_cut/cup_1.mp4 +3 -0
  35. data/depth_cut/cup_2.mp4 +3 -0
  36. data/depth_cut/earphone_1.mp4 +3 -0
  37. data/depth_cut/earphone_2.mp4 +3 -0
  38. data/depth_cut/hmbb_1.mp4 +3 -0
  39. data/depth_cut/hmbb_2.mp4 +3 -0
  40. data/depth_cut/mouse_1.mp4 +3 -0
  41. data/depth_cut/mouse_2.mp4 +3 -0
  42. data/hand_cut/cheese_1.mp4 +3 -0
  43. data/hand_cut/cheese_2.mp4 +3 -0
  44. data/hand_cut/cup_1.mp4 +3 -0
  45. data/hand_cut/cup_2.mp4 +3 -0
  46. data/hand_cut/earphone_1.mp4 +3 -0
  47. data/hand_cut/earphone_2.mp4 +3 -0
  48. data/hand_cut/hmbb_1.mp4 +3 -0
  49. data/hand_cut/hmbb_2.mp4 +3 -0
  50. data/hand_cut/mouse_1.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,56 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/anchor/3.jpg filter=lfs diff=lfs merge=lfs -text
37
+ data/anchor/4.jpg filter=lfs diff=lfs merge=lfs -text
38
+ data/anchor/5.jpg filter=lfs diff=lfs merge=lfs -text
39
+ data/depth_cut/cheese_1.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ data/depth_cut/cheese_2.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ data/depth_cut/cup_1.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ data/depth_cut/cup_2.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ data/depth_cut/earphone_1.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ data/depth_cut/earphone_2.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ data/depth_cut/hmbb_1.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ data/depth_cut/hmbb_2.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ data/depth_cut/mouse_1.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ data/depth_cut/mouse_2.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ data/hand_cut/cheese_1.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ data/hand_cut/cheese_2.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ data/hand_cut/cup_1.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ data/hand_cut/cup_2.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ data/hand_cut/earphone_1.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ data/hand_cut/earphone_2.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ data/hand_cut/hmbb_1.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ data/hand_cut/hmbb_2.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ data/hand_cut/mouse_1.mp4 filter=lfs diff=lfs merge=lfs -text
58
+ data/hand_cut/mouse_2.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ data/object/cheese_0.jpg filter=lfs diff=lfs merge=lfs -text
60
+ data/object/cheese_1.jpg filter=lfs diff=lfs merge=lfs -text
61
+ data/object/cheese_2.jpg filter=lfs diff=lfs merge=lfs -text
62
+ data/object/cup_0.jpg filter=lfs diff=lfs merge=lfs -text
63
+ data/object/cup_1.jpg filter=lfs diff=lfs merge=lfs -text
64
+ data/object/cup_2.jpg filter=lfs diff=lfs merge=lfs -text
65
+ data/object/earphone_0.jpg filter=lfs diff=lfs merge=lfs -text
66
+ data/object/earphone_1.jpg filter=lfs diff=lfs merge=lfs -text
67
+ data/object/earphone_2.jpg filter=lfs diff=lfs merge=lfs -text
68
+ data/object/hmbb_0.jpg filter=lfs diff=lfs merge=lfs -text
69
+ data/object/hmbb_1.jpg filter=lfs diff=lfs merge=lfs -text
70
+ data/object/hmbb_2.jpg filter=lfs diff=lfs merge=lfs -text
71
+ data/object/mouse_0.jpg filter=lfs diff=lfs merge=lfs -text
72
+ data/object/mouse_1.jpg filter=lfs diff=lfs merge=lfs -text
73
+ data/object/mouse_2.jpg filter=lfs diff=lfs merge=lfs -text
74
+ data/out/cheese.mp4 filter=lfs diff=lfs merge=lfs -text
75
+ data/out/cup.mp4 filter=lfs diff=lfs merge=lfs -text
76
+ data/out/ear.mp4 filter=lfs diff=lfs merge=lfs -text
77
+ data/out/hmbb.mp4 filter=lfs diff=lfs merge=lfs -text
78
+ data/out/mouse.mp4 filter=lfs diff=lfs merge=lfs -text
79
+ data/video/cheese_1.mp4 filter=lfs diff=lfs merge=lfs -text
80
+ data/video/cheese_2.mp4 filter=lfs diff=lfs merge=lfs -text
81
+ data/video/cup_1.mp4 filter=lfs diff=lfs merge=lfs -text
82
+ data/video/cup_2.mp4 filter=lfs diff=lfs merge=lfs -text
83
+ data/video/earphone_1.mp4 filter=lfs diff=lfs merge=lfs -text
84
+ data/video/earphone_2.mp4 filter=lfs diff=lfs merge=lfs -text
85
+ data/video/hmbb_1.mp4 filter=lfs diff=lfs merge=lfs -text
86
+ data/video/hmbb_2.mp4 filter=lfs diff=lfs merge=lfs -text
87
+ data/video/mouse_1.mp4 filter=lfs diff=lfs merge=lfs -text
88
+ data/video/mouse_2.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: Test
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.25.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: New Test
3
+ emoji: 📚
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.24.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__init__.py ADDED
File without changes
anchorcrafter/__init__.py ADDED
File without changes
anchorcrafter/dwpose/__init__.py ADDED
File without changes
anchorcrafter/dwpose/dwpose_detector.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .wholebody import Wholebody
7
+ from huggingface_hub import hf_hub_download
8
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ class DWposeDetector:
12
+ """
13
+ A pose detect method for image-like data.
14
+
15
+ Parameters:
16
+ model_det: (str) serialized ONNX format model path,
17
+ such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx
18
+ model_pose: (str) serialized ONNX format model path,
19
+ such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx
20
+ device: (str) 'cpu' or 'cuda:{device_id}'
21
+ """
22
+ def __init__(self, model_det, model_pose, device='cpu'):
23
+ self.args = model_det, model_pose, device
24
+ pose_estimation = Wholebody(*self.args)
25
+ self.pose_estimation = pose_estimation
26
+
27
+ def release_memory(self):
28
+ if hasattr(self, 'pose_estimation'):
29
+ del self.pose_estimation
30
+ import gc; gc.collect()
31
+
32
+ def __call__(self, oriImg):
33
+ oriImg = oriImg.copy()
34
+ H, W, C = oriImg.shape
35
+ with torch.no_grad():
36
+ candidate, score = self.pose_estimation(oriImg)
37
+ nums, _, locs = candidate.shape
38
+ candidate[..., 0] /= float(W)
39
+ candidate[..., 1] /= float(H)
40
+ body = candidate[:, :18].copy()
41
+ body = body.reshape(nums * 18, locs)
42
+ subset = score[:, :18].copy()
43
+ for i in range(len(subset)):
44
+ for j in range(len(subset[i])):
45
+ if subset[i][j] > 0.3:
46
+ subset[i][j] = int(18 * i + j)
47
+ else:
48
+ subset[i][j] = -1
49
+
50
+ faces = candidate[:, 24:92]
51
+
52
+ hands = candidate[:, 92:113]
53
+ hands = np.vstack([hands, candidate[:, 113:]])
54
+
55
+ faces_score = score[:, 24:92]
56
+ hands_score = np.vstack([score[:, 92:113], score[:, 113:]])
57
+
58
+ bodies = dict(candidate=body, subset=subset, score=score[:, :18])
59
+ pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)
60
+
61
+ return pose
62
+
63
+
64
+
65
+ model_det_path = hf_hub_download(repo_id="yzd-v/DWPose", filename="yolox_l.onnx")
66
+ model_pose_path = hf_hub_download(repo_id="yzd-v/DWPose", filename="dw-ll_ucoco_384.onnx")
67
+
68
+ dwpose_detector = DWposeDetector(
69
+ model_det=model_det_path,
70
+ model_pose=model_pose_path,
71
+ device=device)
anchorcrafter/dwpose/onnxdet.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def nms(boxes, scores, nms_thr):
6
+ """Single class NMS implemented in Numpy.
7
+
8
+ Args:
9
+ boxes (np.ndarray): shape=(N,4); N is number of boxes
10
+ scores (np.ndarray): the score of bboxes
11
+ nms_thr (float): the threshold in NMS
12
+
13
+ Returns:
14
+ List[int]: output bbox ids
15
+ """
16
+ x1 = boxes[:, 0]
17
+ y1 = boxes[:, 1]
18
+ x2 = boxes[:, 2]
19
+ y2 = boxes[:, 3]
20
+
21
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
22
+ order = scores.argsort()[::-1]
23
+
24
+ keep = []
25
+ while order.size > 0:
26
+ i = order[0]
27
+ keep.append(i)
28
+ xx1 = np.maximum(x1[i], x1[order[1:]])
29
+ yy1 = np.maximum(y1[i], y1[order[1:]])
30
+ xx2 = np.minimum(x2[i], x2[order[1:]])
31
+ yy2 = np.minimum(y2[i], y2[order[1:]])
32
+
33
+ w = np.maximum(0.0, xx2 - xx1 + 1)
34
+ h = np.maximum(0.0, yy2 - yy1 + 1)
35
+ inter = w * h
36
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
37
+
38
+ inds = np.where(ovr <= nms_thr)[0]
39
+ order = order[inds + 1]
40
+
41
+ return keep
42
+
43
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
44
+ """Multiclass NMS implemented in Numpy. Class-aware version.
45
+
46
+ Args:
47
+ boxes (np.ndarray): shape=(N,4); N is number of boxes
48
+ scores (np.ndarray): the score of bboxes
49
+ nms_thr (float): the threshold in NMS
50
+ score_thr (float): the threshold of cls score
51
+
52
+ Returns:
53
+ np.ndarray: outputs bboxes coordinate
54
+ """
55
+ final_dets = []
56
+ num_classes = scores.shape[1]
57
+ for cls_ind in range(num_classes):
58
+ cls_scores = scores[:, cls_ind]
59
+ valid_score_mask = cls_scores > score_thr
60
+ if valid_score_mask.sum() == 0:
61
+ continue
62
+ else:
63
+ valid_scores = cls_scores[valid_score_mask]
64
+ valid_boxes = boxes[valid_score_mask]
65
+ keep = nms(valid_boxes, valid_scores, nms_thr)
66
+ if len(keep) > 0:
67
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
68
+ dets = np.concatenate(
69
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
70
+ )
71
+ final_dets.append(dets)
72
+ if len(final_dets) == 0:
73
+ return None
74
+ return np.concatenate(final_dets, 0)
75
+
76
+ def demo_postprocess(outputs, img_size, p6=False):
77
+ grids = []
78
+ expanded_strides = []
79
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
80
+
81
+ hsizes = [img_size[0] // stride for stride in strides]
82
+ wsizes = [img_size[1] // stride for stride in strides]
83
+
84
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
85
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
86
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
87
+ grids.append(grid)
88
+ shape = grid.shape[:2]
89
+ expanded_strides.append(np.full((*shape, 1), stride))
90
+
91
+ grids = np.concatenate(grids, 1)
92
+ expanded_strides = np.concatenate(expanded_strides, 1)
93
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
94
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
95
+
96
+ return outputs
97
+
98
+ def preprocess(img, input_size, swap=(2, 0, 1)):
99
+ if len(img.shape) == 3:
100
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
101
+ else:
102
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
103
+
104
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
105
+ resized_img = cv2.resize(
106
+ img,
107
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
108
+ interpolation=cv2.INTER_LINEAR,
109
+ ).astype(np.uint8)
110
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
111
+
112
+ padded_img = padded_img.transpose(swap)
113
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
114
+ return padded_img, r
115
+
116
+ def inference_detector(session, oriImg):
117
+ """run anchor detect
118
+ """
119
+ input_shape = (640,640)
120
+ img, ratio = preprocess(oriImg, input_shape)
121
+
122
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
123
+ output = session.run(None, ort_inputs)
124
+ predictions = demo_postprocess(output[0], input_shape)[0]
125
+
126
+ boxes = predictions[:, :4]
127
+ scores = predictions[:, 4:5] * predictions[:, 5:]
128
+
129
+ boxes_xyxy = np.ones_like(boxes)
130
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
131
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
132
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
133
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
134
+ boxes_xyxy /= ratio
135
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
136
+ if dets is not None:
137
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
138
+ isscore = final_scores>0.3
139
+ iscat = final_cls_inds == 0
140
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
141
+ final_boxes = final_boxes[isbbox]
142
+ else:
143
+ final_boxes = np.array([])
144
+
145
+ return final_boxes
anchorcrafter/dwpose/onnxpose.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+
7
+ def preprocess(
8
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10
+ """Do preprocessing for RTMPose model inference.
11
+
12
+ Args:
13
+ img (np.ndarray): Input image in shape.
14
+ input_size (tuple): Input image size in shape (w, h).
15
+
16
+ Returns:
17
+ tuple:
18
+ - resized_img (np.ndarray): Preprocessed image.
19
+ - center (np.ndarray): Center of image.
20
+ - scale (np.ndarray): Scale of image.
21
+ """
22
+ # get shape of image
23
+ img_shape = img.shape[:2]
24
+ out_img, out_center, out_scale = [], [], []
25
+ if len(out_bbox) == 0:
26
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27
+ for i in range(len(out_bbox)):
28
+ x0 = out_bbox[i][0]
29
+ y0 = out_bbox[i][1]
30
+ x1 = out_bbox[i][2]
31
+ y1 = out_bbox[i][3]
32
+ bbox = np.array([x0, y0, x1, y1])
33
+
34
+ # get center and scale
35
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36
+
37
+ # do affine transformation
38
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
39
+
40
+ # normalize image
41
+ mean = np.array([123.675, 116.28, 103.53])
42
+ std = np.array([58.395, 57.12, 57.375])
43
+ resized_img = (resized_img - mean) / std
44
+
45
+ out_img.append(resized_img)
46
+ out_center.append(center)
47
+ out_scale.append(scale)
48
+
49
+ return out_img, out_center, out_scale
50
+
51
+
52
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
53
+ """Inference RTMPose model.
54
+
55
+ Args:
56
+ sess (ort.InferenceSession): ONNXRuntime session.
57
+ img (np.ndarray): Input image in shape.
58
+
59
+ Returns:
60
+ outputs (np.ndarray): Output of RTMPose model.
61
+ """
62
+ all_out = []
63
+ # build input
64
+ for i in range(len(img)):
65
+ input = [img[i].transpose(2, 0, 1)]
66
+
67
+ # build output
68
+ sess_input = {sess.get_inputs()[0].name: input}
69
+ sess_output = []
70
+ for out in sess.get_outputs():
71
+ sess_output.append(out.name)
72
+
73
+ # run model
74
+ outputs = sess.run(sess_output, sess_input)
75
+ all_out.append(outputs)
76
+
77
+ return all_out
78
+
79
+
80
+ def postprocess(outputs: List[np.ndarray],
81
+ model_input_size: Tuple[int, int],
82
+ center: Tuple[int, int],
83
+ scale: Tuple[int, int],
84
+ simcc_split_ratio: float = 2.0
85
+ ) -> Tuple[np.ndarray, np.ndarray]:
86
+ """Postprocess for RTMPose model output.
87
+
88
+ Args:
89
+ outputs (np.ndarray): Output of RTMPose model.
90
+ model_input_size (tuple): RTMPose model Input image size.
91
+ center (tuple): Center of bbox in shape (x, y).
92
+ scale (tuple): Scale of bbox in shape (w, h).
93
+ simcc_split_ratio (float): Split ratio of simcc.
94
+
95
+ Returns:
96
+ tuple:
97
+ - keypoints (np.ndarray): Rescaled keypoints.
98
+ - scores (np.ndarray): Model predict scores.
99
+ """
100
+ all_key = []
101
+ all_score = []
102
+ for i in range(len(outputs)):
103
+ # use simcc to decode
104
+ simcc_x, simcc_y = outputs[i]
105
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
106
+
107
+ # rescale keypoints
108
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
109
+ all_key.append(keypoints[0])
110
+ all_score.append(scores[0])
111
+
112
+ return np.array(all_key), np.array(all_score)
113
+
114
+
115
+ def bbox_xyxy2cs(bbox: np.ndarray,
116
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
117
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
118
+
119
+ Args:
120
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
121
+ as (left, top, right, bottom)
122
+ padding (float): BBox padding factor that will be multilied to scale.
123
+ Default: 1.0
124
+
125
+ Returns:
126
+ tuple: A tuple containing center and scale.
127
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
128
+ (n, 2)
129
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
130
+ (n, 2)
131
+ """
132
+ # convert single bbox from (4, ) to (1, 4)
133
+ dim = bbox.ndim
134
+ if dim == 1:
135
+ bbox = bbox[None, :]
136
+
137
+ # get bbox center and scale
138
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
139
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
140
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
141
+
142
+ if dim == 1:
143
+ center = center[0]
144
+ scale = scale[0]
145
+
146
+ return center, scale
147
+
148
+
149
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
150
+ aspect_ratio: float) -> np.ndarray:
151
+ """Extend the scale to match the given aspect ratio.
152
+
153
+ Args:
154
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
155
+ aspect_ratio (float): The ratio of ``w/h``
156
+
157
+ Returns:
158
+ np.ndarray: The reshaped image scale in (2, )
159
+ """
160
+ w, h = np.hsplit(bbox_scale, [1])
161
+ bbox_scale = np.where(w > h * aspect_ratio,
162
+ np.hstack([w, w / aspect_ratio]),
163
+ np.hstack([h * aspect_ratio, h]))
164
+ return bbox_scale
165
+
166
+
167
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
168
+ """Rotate a point by an angle.
169
+
170
+ Args:
171
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
172
+ angle_rad (float): rotation angle in radian
173
+
174
+ Returns:
175
+ np.ndarray: Rotated point in shape (2, )
176
+ """
177
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
178
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
179
+ return rot_mat @ pt
180
+
181
+
182
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
183
+ """To calculate the affine matrix, three pairs of points are required. This
184
+ function is used to get the 3rd point, given 2D points a & b.
185
+
186
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
187
+ anticlockwise, using b as the rotation center.
188
+
189
+ Args:
190
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
191
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
192
+
193
+ Returns:
194
+ np.ndarray: The 3rd point.
195
+ """
196
+ direction = a - b
197
+ c = b + np.r_[-direction[1], direction[0]]
198
+ return c
199
+
200
+
201
+ def get_warp_matrix(center: np.ndarray,
202
+ scale: np.ndarray,
203
+ rot: float,
204
+ output_size: Tuple[int, int],
205
+ shift: Tuple[float, float] = (0., 0.),
206
+ inv: bool = False) -> np.ndarray:
207
+ """Calculate the affine transformation matrix that can warp the bbox area
208
+ in the input image to the output size.
209
+
210
+ Args:
211
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
212
+ scale (np.ndarray[2, ]): Scale of the bounding box
213
+ wrt [width, height].
214
+ rot (float): Rotation angle (degree).
215
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
216
+ destination heatmaps.
217
+ shift (0-100%): Shift translation ratio wrt the width/height.
218
+ Default (0., 0.).
219
+ inv (bool): Option to inverse the affine transform direction.
220
+ (inv=False: src->dst or inv=True: dst->src)
221
+
222
+ Returns:
223
+ np.ndarray: A 2x3 transformation matrix
224
+ """
225
+ shift = np.array(shift)
226
+ src_w = scale[0]
227
+ dst_w = output_size[0]
228
+ dst_h = output_size[1]
229
+
230
+ # compute transformation matrix
231
+ rot_rad = np.deg2rad(rot)
232
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
233
+ dst_dir = np.array([0., dst_w * -0.5])
234
+
235
+ # get four corners of the src rectangle in the original image
236
+ src = np.zeros((3, 2), dtype=np.float32)
237
+ src[0, :] = center + scale * shift
238
+ src[1, :] = center + src_dir + scale * shift
239
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
240
+
241
+ # get four corners of the dst rectangle in the input image
242
+ dst = np.zeros((3, 2), dtype=np.float32)
243
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
244
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
245
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
246
+
247
+ if inv:
248
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
249
+ else:
250
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
251
+
252
+ return warp_mat
253
+
254
+
255
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
256
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257
+ """Get the bbox image as the model input by affine transform.
258
+
259
+ Args:
260
+ input_size (dict): The input size of the model.
261
+ bbox_scale (dict): The bbox scale of the img.
262
+ bbox_center (dict): The bbox center of the img.
263
+ img (np.ndarray): The original image.
264
+
265
+ Returns:
266
+ tuple: A tuple containing center and scale.
267
+ - np.ndarray[float32]: img after affine transform.
268
+ - np.ndarray[float32]: bbox scale after affine transform.
269
+ """
270
+ w, h = input_size
271
+ warp_size = (int(w), int(h))
272
+
273
+ # reshape bbox to fixed aspect ratio
274
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
275
+
276
+ # get the affine matrix
277
+ center = bbox_center
278
+ scale = bbox_scale
279
+ rot = 0
280
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
281
+
282
+ # do affine transform
283
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
284
+
285
+ return img, bbox_scale
286
+
287
+
288
+ def get_simcc_maximum(simcc_x: np.ndarray,
289
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
290
+ """Get maximum response location and value from simcc representations.
291
+
292
+ Note:
293
+ instance number: N
294
+ num_keypoints: K
295
+ heatmap height: H
296
+ heatmap width: W
297
+
298
+ Args:
299
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
300
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
301
+
302
+ Returns:
303
+ tuple:
304
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
305
+ (K, 2) or (N, K, 2)
306
+ - vals (np.ndarray): values of maximum heatmap responses in shape
307
+ (K,) or (N, K)
308
+ """
309
+ N, K, Wx = simcc_x.shape
310
+ simcc_x = simcc_x.reshape(N * K, -1)
311
+ simcc_y = simcc_y.reshape(N * K, -1)
312
+
313
+ # get maximum value locations
314
+ x_locs = np.argmax(simcc_x, axis=1)
315
+ y_locs = np.argmax(simcc_y, axis=1)
316
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
317
+ max_val_x = np.amax(simcc_x, axis=1)
318
+ max_val_y = np.amax(simcc_y, axis=1)
319
+
320
+ # get maximum value across x and y axis
321
+ mask = max_val_x > max_val_y
322
+ max_val_x[mask] = max_val_y[mask]
323
+ vals = max_val_x
324
+ locs[vals <= 0.] = -1
325
+
326
+ # reshape
327
+ locs = locs.reshape(N, K, 2)
328
+ vals = vals.reshape(N, K)
329
+
330
+ return locs, vals
331
+
332
+
333
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
334
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
335
+ """Modulate simcc distribution with Gaussian.
336
+
337
+ Args:
338
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
339
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
340
+ simcc_split_ratio (int): The split ratio of simcc.
341
+
342
+ Returns:
343
+ tuple: A tuple containing center and scale.
344
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
345
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
346
+ """
347
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
348
+ keypoints /= simcc_split_ratio
349
+
350
+ return keypoints, scores
351
+
352
+
353
+ def inference_pose(session, out_bbox, oriImg):
354
+ """run pose detect
355
+
356
+ Args:
357
+ session (ort.InferenceSession): ONNXRuntime session.
358
+ out_bbox (np.ndarray): bbox list
359
+ oriImg (np.ndarray): Input image in shape.
360
+
361
+ Returns:
362
+ tuple:
363
+ - keypoints (np.ndarray): Rescaled keypoints.
364
+ - scores (np.ndarray): Model predict scores.
365
+ """
366
+ h, w = session.get_inputs()[0].shape[2:]
367
+ model_input_size = (w, h)
368
+ # preprocess for rtm-pose model inference.
369
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
370
+ # run pose estimation for processed img
371
+ outputs = inference(session, resized_img)
372
+ # postprocess for rtm-pose model output.
373
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
374
+
375
+ return keypoints, scores
anchorcrafter/dwpose/preprocess.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import decord
3
+ import numpy as np
4
+
5
+ from .util import draw_pose
6
+ from .dwpose_detector import dwpose_detector as dwprocessor
7
+ import pickle
8
+ import os
9
+
10
+ def get_video_pose(
11
+ video_path: str,
12
+ ref_image: np.ndarray,
13
+ sample_stride: int=1,
14
+ total_frames: int=28,
15
+ ):
16
+ """preprocess ref image pose and video pose
17
+
18
+ Args:
19
+ video_path (str): video pose path
20
+ ref_image (np.ndarray): reference image
21
+ sample_stride (int, optional): Defaults to 1.
22
+ total_frames(int): Defaults to 28.
23
+ Returns:
24
+ np.ndarray: sequence of video pose
25
+ """
26
+ # select ref-keypoint from reference pose for pose rescale
27
+ ref_pose = dwprocessor(ref_image)
28
+ ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
29
+ ref_keypoint_id = [i for i in ref_keypoint_id \
30
+ if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0]
31
+ ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id]
32
+
33
+ height, width, _ = ref_image.shape
34
+ print(f'h,w: {height}, {width}')
35
+
36
+ # read input video
37
+ vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
38
+ pkl_path = "data/pose_pkl/" + video_path.split("/")[-1].split(".")[0] + ".pkl"
39
+ print("total frames:", total_frames)
40
+ if os.path.exists(pkl_path): # read pose from file
41
+ with open(pkl_path, "rb") as f:
42
+ poses_frames = pickle.load(f)
43
+ detected_poses = [poses_frames[frm] for frm in range(0, len(poses_frames), sample_stride)]
44
+ detected_poses = detected_poses[:total_frames]
45
+ else: # calculate pose
46
+ frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy()
47
+ frames = frames[:total_frames]
48
+ detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")]
49
+
50
+ detected_bodies = np.stack(
51
+ [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:,
52
+ ref_keypoint_id]
53
+ # compute linear-rescale params
54
+ ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1)
55
+ fh, fw, _ = vr[0].shape
56
+ ax = ay / (fh / fw / height * width)
57
+ bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax)
58
+ a = np.array([ax, ay])
59
+ b = np.array([bx, by])
60
+ output_pose = []
61
+ # pose rescale
62
+ for detected_pose in detected_poses:
63
+ detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b
64
+ detected_pose['faces'] = detected_pose['faces'] * a + b
65
+ detected_pose['hands'] = detected_pose['hands'] * a + b
66
+ im = draw_pose(detected_pose, height, width)
67
+ output_pose.append(np.array(im))
68
+
69
+ return np.stack(output_pose), a, b
70
+
71
+
72
+
73
+ def get_image_pose(ref_image):
74
+ """process image pose
75
+
76
+ Args:
77
+ ref_image (np.ndarray): reference image pixel value
78
+
79
+ Returns:
80
+ np.ndarray: pose visual image in RGB-mode
81
+ """
82
+ height, width, _ = ref_image.shape
83
+ ref_pose = dwprocessor(ref_image)
84
+ pose_img = draw_pose(ref_pose, height, width)
85
+ return np.array(pose_img)
anchorcrafter/dwpose/util.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import cv2
5
+
6
+
7
+ eps = 0.01
8
+
9
+ def alpha_blend_color(color, alpha):
10
+ """blend color according to point conf
11
+ """
12
+ return [int(c * alpha) for c in color]
13
+
14
+ def draw_bodypose(canvas, candidate, subset, score):
15
+ H, W, C = canvas.shape
16
+ candidate = np.array(candidate)
17
+ subset = np.array(subset)
18
+
19
+ stickwidth = 4
20
+
21
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
22
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
23
+ [1, 16], [16, 18], [3, 17], [6, 18]]
24
+
25
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
26
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
27
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
28
+
29
+ for i in range(17):
30
+ for n in range(len(subset)):
31
+ index = subset[n][np.array(limbSeq[i]) - 1]
32
+ conf = score[n][np.array(limbSeq[i]) - 1]
33
+ if conf[0] < 0.3 or conf[1] < 0.3:
34
+ continue
35
+ Y = candidate[index.astype(int), 0] * float(W)
36
+ X = candidate[index.astype(int), 1] * float(H)
37
+ mX = np.mean(X)
38
+ mY = np.mean(Y)
39
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
40
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
41
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
42
+ cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))
43
+
44
+ canvas = (canvas * 0.6).astype(np.uint8)
45
+
46
+ for i in range(18):
47
+ for n in range(len(subset)):
48
+ index = int(subset[n][i])
49
+ if index == -1:
50
+ continue
51
+ x, y = candidate[index][0:2]
52
+ conf = score[n][i]
53
+ x = int(x * W)
54
+ y = int(y * H)
55
+ cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)
56
+
57
+ return canvas
58
+
59
+ def draw_handpose(canvas, all_hand_peaks, all_hand_scores):
60
+ H, W, C = canvas.shape
61
+
62
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
63
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
64
+
65
+ for peaks, scores in zip(all_hand_peaks, all_hand_scores):
66
+
67
+ for ie, e in enumerate(edges):
68
+ x1, y1 = peaks[e[0]]
69
+ x2, y2 = peaks[e[1]]
70
+ x1 = int(x1 * W)
71
+ y1 = int(y1 * H)
72
+ x2 = int(x2 * W)
73
+ y2 = int(y2 * H)
74
+ score = int(scores[e[0]] * scores[e[1]] * 255)
75
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
76
+ cv2.line(canvas, (x1, y1), (x2, y2),
77
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2)
78
+
79
+ for i, keyponit in enumerate(peaks):
80
+ x, y = keyponit
81
+ x = int(x * W)
82
+ y = int(y * H)
83
+ score = int(scores[i] * 255)
84
+ if x > eps and y > eps:
85
+ cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1)
86
+ return canvas
87
+
88
+ def draw_facepose(canvas, all_lmks, all_scores):
89
+ H, W, C = canvas.shape
90
+ for lmks, scores in zip(all_lmks, all_scores):
91
+ for lmk, score in zip(lmks, scores):
92
+ x, y = lmk
93
+ x = int(x * W)
94
+ y = int(y * H)
95
+ conf = int(score * 255)
96
+ if x > eps and y > eps:
97
+ cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)
98
+ return canvas
99
+
100
+ def draw_pose(pose, H, W, ref_w=2160):
101
+ """vis dwpose outputs
102
+
103
+ Args:
104
+ pose (List): DWposeDetector outputs in dwpose_detector.py
105
+ H (int): height
106
+ W (int): width
107
+ ref_w (int, optional) Defaults to 2160.
108
+
109
+ Returns:
110
+ np.ndarray: image pixel value in RGB mode
111
+ """
112
+ bodies = pose['bodies']
113
+ faces = pose['faces']
114
+ hands = pose['hands']
115
+ candidate = bodies['candidate']
116
+ subset = bodies['subset']
117
+
118
+ sz = min(H, W)
119
+ sr = (ref_w / sz) if sz != ref_w else 1
120
+
121
+ ########################################## create zero canvas ##################################################
122
+ canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)
123
+
124
+ ########################################### draw body pose #####################################################
125
+ canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score'])
126
+
127
+ ########################################### draw hand pose #####################################################
128
+ canvas = draw_handpose(canvas, hands, pose['hands_score'])
129
+
130
+ ########################################### draw face pose #####################################################
131
+ canvas = draw_facepose(canvas, faces, pose['faces_score'])
132
+
133
+ return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
anchorcrafter/dwpose/wholebody.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+
4
+ from .onnxdet import inference_detector
5
+ from .onnxpose import inference_pose
6
+
7
+ import os
8
+
9
+ class Wholebody:
10
+ """detect anchor pose by dwpose
11
+ """
12
+ def __init__(self, model_det, model_pose, device="cpu"):
13
+ #print('wholebody init')
14
+ providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
15
+ provider_options = None if device == 'cpu' else [{'device_id': 3}]
16
+ #print('session create')
17
+ self.session_det = ort.InferenceSession(
18
+ path_or_bytes=model_det, providers=providers, provider_options=provider_options
19
+ )
20
+ #print('session_pose create')
21
+ self.session_pose = ort.InferenceSession(
22
+ path_or_bytes=model_pose, providers=providers, provider_options=provider_options
23
+ )
24
+
25
+ def __call__(self, oriImg):
26
+ """call to process dwpose-detect
27
+
28
+ Args:
29
+ oriImg (np.ndarray): detected image
30
+
31
+ """
32
+ det_result = inference_detector(self.session_det, oriImg)
33
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
34
+
35
+ keypoints_info = np.concatenate(
36
+ (keypoints, scores[..., None]), axis=-1)
37
+ # compute neck joint
38
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
39
+ # neck score when visualizing pred
40
+ neck[:, 2:4] = np.logical_and(
41
+ keypoints_info[:, 5, 2:4] > 0.3,
42
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
43
+ new_keypoints_info = np.insert(
44
+ keypoints_info, 17, neck, axis=1)
45
+ mmpose_idx = [
46
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
47
+ ]
48
+ openpose_idx = [
49
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
50
+ ]
51
+ new_keypoints_info[:, openpose_idx] = \
52
+ new_keypoints_info[:, mmpose_idx]
53
+ keypoints_info = new_keypoints_info
54
+
55
+ keypoints, scores = keypoints_info[
56
+ ..., :2], keypoints_info[..., 2]
57
+
58
+ return keypoints, scores
59
+
60
+
anchorcrafter/modules/__init__.py ADDED
File without changes
anchorcrafter/modules/attention_processor.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import matplotlib.pyplot as plt
6
+
7
+ try:
8
+ import xformers
9
+ import xformers.ops
10
+ xformers_available = True
11
+ except Exception as e:
12
+ xformers_available = False
13
+
14
+ class RegionControler(object):
15
+ def __init__(self) -> None:
16
+ self.prompt_image_conditioning = []
17
+ region_control = RegionControler()
18
+
19
+ class AttnProcessor(nn.Module):
20
+ r"""
21
+ Default processor for performing attention-related computations.
22
+ """
23
+ def __init__(
24
+ self,
25
+ hidden_size=None,
26
+ cross_attention_dim=None,
27
+ ):
28
+ super().__init__()
29
+
30
+ def forward(
31
+ self,
32
+ attn,
33
+ hidden_states,
34
+ encoder_hidden_states=None,
35
+ attention_mask=None,
36
+ temb=None,
37
+ ):
38
+ residual = hidden_states
39
+
40
+ if attn.spatial_norm is not None:
41
+ hidden_states = attn.spatial_norm(hidden_states, temb)
42
+
43
+ input_ndim = hidden_states.ndim
44
+
45
+ if input_ndim == 4:
46
+ batch_size, channel, height, width = hidden_states.shape
47
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
48
+
49
+ batch_size, sequence_length, _ = (
50
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
51
+ )
52
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
53
+
54
+ if attn.group_norm is not None:
55
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
56
+
57
+ query = attn.to_q(hidden_states)
58
+
59
+ if encoder_hidden_states is None:
60
+ encoder_hidden_states = hidden_states
61
+ elif attn.norm_cross:
62
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
63
+
64
+ key = attn.to_k(encoder_hidden_states)
65
+ value = attn.to_v(encoder_hidden_states)
66
+
67
+ query = attn.head_to_batch_dim(query)
68
+ key = attn.head_to_batch_dim(key)
69
+ value = attn.head_to_batch_dim(value)
70
+
71
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
72
+ hidden_states = torch.bmm(attention_probs, value)
73
+ hidden_states = attn.batch_to_head_dim(hidden_states)
74
+
75
+ # linear proj
76
+ hidden_states = attn.to_out[0](hidden_states)
77
+ # dropout
78
+ hidden_states = attn.to_out[1](hidden_states)
79
+
80
+ if input_ndim == 4:
81
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
82
+
83
+ if attn.residual_connection:
84
+ hidden_states = hidden_states + residual
85
+
86
+ hidden_states = hidden_states / attn.rescale_output_factor
87
+
88
+ return hidden_states
89
+
90
+
91
+ class IPAttnProcessor(nn.Module):
92
+ r"""
93
+ Attention processor for IP-Adapater.
94
+ Args:
95
+ hidden_size (`int`):
96
+ The hidden size of the attention layer.
97
+ cross_attention_dim (`int`):
98
+ The number of channels in the `encoder_hidden_states`.
99
+ scale (`float`, defaults to 1.0):
100
+ the weight scale of image prompt.
101
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
102
+ The context length of the image features.
103
+ """
104
+
105
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
106
+ super().__init__()
107
+
108
+ self.hidden_size = hidden_size
109
+ self.cross_attention_dim = cross_attention_dim
110
+ self.scale = scale
111
+ self.num_tokens = num_tokens
112
+
113
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
114
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
115
+
116
+ def forward(
117
+ self,
118
+ attn,
119
+ hidden_states,
120
+ encoder_hidden_states=None,
121
+ attention_mask=None,
122
+ temb=None,
123
+ attn_bias=None,
124
+ ):
125
+ hidden_states=hidden_states.to(torch.float16)
126
+ encoder_hidden_states=encoder_hidden_states.to(torch.float16)
127
+ residual = hidden_states
128
+
129
+ if attn.spatial_norm is not None:
130
+ hidden_states = attn.spatial_norm(hidden_states, temb)
131
+
132
+ input_ndim = hidden_states.ndim
133
+
134
+ if input_ndim == 4:
135
+ batch_size, channel, height, width = hidden_states.shape
136
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
137
+
138
+ batch_size, sequence_length, _ = (
139
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
140
+ )
141
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
142
+
143
+ if attn.group_norm is not None:
144
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
145
+
146
+ query = attn.to_q(hidden_states)
147
+
148
+ if encoder_hidden_states is None:
149
+ encoder_hidden_states = hidden_states
150
+ else:
151
+ # get encoder_hidden_states, ip_hidden_states
152
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
153
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
154
+ if attn.norm_cross:
155
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
156
+
157
+ key = attn.to_k(encoder_hidden_states)
158
+ value = attn.to_v(encoder_hidden_states)
159
+
160
+ query = attn.head_to_batch_dim(query)
161
+ key = attn.head_to_batch_dim(key)
162
+ value = attn.head_to_batch_dim(value)
163
+
164
+ if xformers_available:
165
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
166
+ else:
167
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
168
+ hidden_states = torch.bmm(attention_probs, value)
169
+
170
+ if attn_bias is not None:
171
+ # print(f'ipadapter attn_bias, shape: {attn_bias.shape} sum: {attn_bias.sum()}')
172
+ # 目标区域注意力结果系数为1,其余地区系数为0
173
+ mask = attn_bias.repeat(1, 1, hidden_states.shape[2]).to(hidden_states.dtype)
174
+ hidden_states = hidden_states * (1 - mask)
175
+
176
+ hidden_states = attn.batch_to_head_dim(hidden_states)
177
+
178
+ # for ip-adapter
179
+ ip_hidden_states=ip_hidden_states.to(torch.float16)
180
+ ip_key = self.to_k_ip(ip_hidden_states)
181
+ ip_value = self.to_v_ip(ip_hidden_states)
182
+
183
+ ip_key = attn.head_to_batch_dim(ip_key)
184
+ ip_value = attn.head_to_batch_dim(ip_value)
185
+
186
+ if xformers_available:
187
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, attention_mask=None)
188
+ else:
189
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, attention_mask=None)
190
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
191
+
192
+ if attn_bias is not None:
193
+ # print(f'ipadapter attn_bias, shape: {attn_bias.shape} sum: {attn_bias.sum()}')
194
+ # 目标区域注意力结果系数为1,其余地区系数为0
195
+ mask = attn_bias.repeat(1, 1, ip_hidden_states.shape[2]).to(ip_hidden_states.dtype)
196
+ ip_hidden_states = ip_hidden_states * mask
197
+
198
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
199
+
200
+ # region control
201
+ if len(region_control.prompt_image_conditioning) == 1:
202
+ region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
203
+ if region_mask is not None:
204
+ h, w = region_mask.shape[:2]
205
+ ratio = (h * w / query.shape[1]) ** 0.5
206
+ mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
207
+ else:
208
+ mask = torch.ones_like(ip_hidden_states)
209
+ ip_hidden_states = ip_hidden_states * mask
210
+
211
+ hidden_states = hidden_states + self.scale * ip_hidden_states
212
+
213
+ # linear proj
214
+ hidden_states = attn.to_out[0](hidden_states)
215
+ # dropout
216
+ hidden_states = attn.to_out[1](hidden_states)
217
+
218
+ if input_ndim == 4:
219
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
220
+
221
+ if attn.residual_connection:
222
+ hidden_states = hidden_states + residual
223
+
224
+ hidden_states = hidden_states / attn.rescale_output_factor
225
+
226
+ return hidden_states
227
+
228
+
229
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
230
+ # TODO attention_mask
231
+ query = query.contiguous()
232
+ key = key.contiguous()
233
+ value = value.contiguous()
234
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
235
+ # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
236
+ return hidden_states
237
+
238
+
239
+ class AttnProcessor2_0(torch.nn.Module):
240
+ r"""
241
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
242
+ """
243
+ def __init__(
244
+ self,
245
+ hidden_size=None,
246
+ cross_attention_dim=None,
247
+ ):
248
+ super().__init__()
249
+ if not hasattr(F, "scaled_dot_product_attention"):
250
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
251
+
252
+ def forward(
253
+ self,
254
+ attn,
255
+ hidden_states,
256
+ encoder_hidden_states=None,
257
+ attention_mask=None,
258
+ temb=None,
259
+ ):
260
+ residual = hidden_states
261
+
262
+ if attn.spatial_norm is not None:
263
+ hidden_states = attn.spatial_norm(hidden_states, temb)
264
+
265
+ input_ndim = hidden_states.ndim
266
+
267
+ if input_ndim == 4:
268
+ batch_size, channel, height, width = hidden_states.shape
269
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
270
+
271
+ batch_size, sequence_length, _ = (
272
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
273
+ )
274
+
275
+ if attention_mask is not None:
276
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
277
+ # scaled_dot_product_attention expects attention_mask shape to be
278
+ # (batch, heads, source_length, target_length)
279
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
280
+
281
+ if attn.group_norm is not None:
282
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
283
+
284
+ query = attn.to_q(hidden_states)
285
+
286
+ if encoder_hidden_states is None:
287
+ encoder_hidden_states = hidden_states
288
+ elif attn.norm_cross:
289
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
290
+
291
+ key = attn.to_k(encoder_hidden_states)
292
+ value = attn.to_v(encoder_hidden_states)
293
+
294
+ inner_dim = key.shape[-1]
295
+ head_dim = inner_dim // attn.heads
296
+
297
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
298
+
299
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
300
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
301
+
302
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
303
+ # TODO: add support for attn.scale when we move to Torch 2.1
304
+ hidden_states = F.scaled_dot_product_attention(
305
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
306
+ )
307
+
308
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
309
+ hidden_states = hidden_states.to(query.dtype)
310
+
311
+ # linear proj
312
+ hidden_states = attn.to_out[0](hidden_states)
313
+ # dropout
314
+ hidden_states = attn.to_out[1](hidden_states)
315
+
316
+ if input_ndim == 4:
317
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
318
+
319
+ if attn.residual_connection:
320
+ hidden_states = hidden_states + residual
321
+
322
+ hidden_states = hidden_states / attn.rescale_output_factor
323
+
324
+ return hidden_states
325
+
326
+ class IPAttnProcessor2_0(torch.nn.Module):
327
+ r"""
328
+ Attention processor for IP-Adapater for PyTorch 2.0.
329
+ Args:
330
+ hidden_size (`int`):
331
+ The hidden size of the attention layer.
332
+ cross_attention_dim (`int`):
333
+ The number of channels in the `encoder_hidden_states`.
334
+ scale (`float`, defaults to 1.0):
335
+ the weight scale of image prompt.
336
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
337
+ The context length of the image features.
338
+ """
339
+
340
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
341
+ super().__init__()
342
+
343
+ if not hasattr(F, "scaled_dot_product_attention"):
344
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
345
+
346
+ self.hidden_size = hidden_size
347
+ self.cross_attention_dim = cross_attention_dim
348
+ self.scale = scale
349
+ self.num_tokens = num_tokens
350
+
351
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
352
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
353
+
354
+ def forward(
355
+ self,
356
+ attn,
357
+ hidden_states,
358
+ encoder_hidden_states=None,
359
+ attention_mask=None,
360
+ temb=None,
361
+ ):
362
+ residual = hidden_states
363
+
364
+ if attn.spatial_norm is not None:
365
+ hidden_states = attn.spatial_norm(hidden_states, temb)
366
+
367
+ input_ndim = hidden_states.ndim
368
+
369
+ if input_ndim == 4:
370
+ batch_size, channel, height, width = hidden_states.shape
371
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
372
+
373
+ batch_size, sequence_length, _ = (
374
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
375
+ )
376
+
377
+ if attention_mask is not None:
378
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
379
+ # scaled_dot_product_attention expects attention_mask shape to be
380
+ # (batch, heads, source_length, target_length)
381
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
382
+
383
+ if attn.group_norm is not None:
384
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
385
+
386
+ query = attn.to_q(hidden_states)
387
+
388
+ if encoder_hidden_states is None:
389
+ encoder_hidden_states = hidden_states
390
+ else:
391
+ # get encoder_hidden_states, ip_hidden_states
392
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
393
+ encoder_hidden_states, ip_hidden_states = (
394
+ encoder_hidden_states[:, :end_pos, :],
395
+ encoder_hidden_states[:, end_pos:, :],
396
+ )
397
+ if attn.norm_cross:
398
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
399
+
400
+ key = attn.to_k(encoder_hidden_states)
401
+ value = attn.to_v(encoder_hidden_states)
402
+
403
+ inner_dim = key.shape[-1]
404
+ head_dim = inner_dim // attn.heads
405
+
406
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
407
+
408
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
409
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
410
+
411
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
412
+ # TODO: add support for attn.scale when we move to Torch 2.1
413
+ hidden_states = F.scaled_dot_product_attention(
414
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
415
+ )
416
+
417
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
418
+ hidden_states = hidden_states.to(query.dtype)
419
+
420
+ # for ip-adapter
421
+ ip_key = self.to_k_ip(ip_hidden_states)
422
+ ip_value = self.to_v_ip(ip_hidden_states)
423
+
424
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
425
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
426
+
427
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
428
+ # TODO: add support for attn.scale when we move to Torch 2.1
429
+ ip_hidden_states = F.scaled_dot_product_attention(
430
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
431
+ )
432
+ with torch.no_grad():
433
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
434
+ #print(self.attn_map.shape)
435
+
436
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
437
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
438
+
439
+ # region control
440
+ if len(region_control.prompt_image_conditioning) == 1:
441
+ region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
442
+ if region_mask is not None:
443
+ query = query.reshape([-1, query.shape[-2], query.shape[-1]])
444
+ h, w = region_mask.shape[:2]
445
+ ratio = (h * w / query.shape[1]) ** 0.5
446
+ mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
447
+ else:
448
+ mask = torch.ones_like(ip_hidden_states)
449
+ ip_hidden_states = ip_hidden_states * mask
450
+
451
+ hidden_states = hidden_states + self.scale * ip_hidden_states
452
+
453
+ # linear proj
454
+ hidden_states = attn.to_out[0](hidden_states)
455
+ # dropout
456
+ hidden_states = attn.to_out[1](hidden_states)
457
+
458
+ if input_ndim == 4:
459
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
460
+
461
+ if attn.residual_connection:
462
+ hidden_states = hidden_states + residual
463
+
464
+ hidden_states = hidden_states / attn.rescale_output_factor
465
+
466
+ return hidden_states
anchorcrafter/modules/obj_attn_net.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.models.attention import BasicTransformerBlock
7
+
8
+
9
+ class ObjAttnNet(nn.Module):
10
+ """Object-centric attention network with dual transformer blocks
11
+
12
+ Args:
13
+ inner_dim (int): Dimension of internal representations (default: 1024)
14
+ num_heads (int): Number of attention heads (default: 32)
15
+ out_dim (int): Output dimension (default: 1024)
16
+ embedding_size (int): Base embedding size (default: 1370)
17
+ """
18
+ def __init__(self, inner_dim=1024, num_heads=32, out_dim=1024, embedding_size=1370):
19
+ super().__init__()
20
+ self.embedding_size = embedding_size
21
+ # Transformer blocks configuration
22
+ transformer_config = {
23
+ "dim": inner_dim,
24
+ "num_attention_heads": num_heads,
25
+ "attention_head_dim": inner_dim // num_heads
26
+ }
27
+ # Network components
28
+ self.space_transformer_1 = BasicTransformerBlock(**transformer_config)
29
+ self.space_transformer_2 = BasicTransformerBlock(**transformer_config)
30
+ self.proj_out = nn.Linear(inner_dim, out_dim)
31
+ self.norm = nn.LayerNorm(out_dim)
32
+
33
+ def forward(self, embeddings): # [b, n, c]
34
+ # First transformer processing
35
+ x = self.space_transformer_1(embeddings)
36
+
37
+ # Select middle embeddings segment
38
+ x = x[:, self.embedding_size: self.embedding_size * 2, :]
39
+
40
+ # Second transformer processing
41
+ x = self.space_transformer_2(x)
42
+
43
+ # Select final output tokens
44
+ x = x[:, :12, :]
45
+
46
+ # Project and normalize
47
+ return self.norm(self.proj_out(x))
anchorcrafter/modules/obj_proj_net.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class ObjProjNet(nn.Module):
9
+ """Projection network for CLIP embeddings to cross-attention space
10
+
11
+ Args:
12
+ cross_attention_dim (int): Dimension of cross-attention features (default: 1024)
13
+ clip_embeddings_dim (int): Dimension of input CLIP embeddings (default: 3072)
14
+ context_tokens (int): Number of additional context tokens (default: 4)
15
+ inner_dim (int): Intermediate projection dimension (default: 1024)
16
+ """
17
+
18
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=3072,
19
+ context_tokens=4, inner_dim=1024):
20
+ super().__init__()
21
+
22
+ self.cross_attention_dim = cross_attention_dim
23
+ self.context_tokens = context_tokens
24
+ self.proj_in = nn.Linear(clip_embeddings_dim, inner_dim)
25
+ self.proj_out = nn.Linear(inner_dim, self.context_tokens * cross_attention_dim)
26
+ self.norm = nn.LayerNorm(cross_attention_dim)
27
+
28
+ def forward(self, image_embeds):
29
+ x = self.proj_in(image_embeds)
30
+ x = self.proj_out(x).reshape(
31
+ -1, self.context_tokens, self.cross_attention_dim
32
+ )
33
+ return self.norm(x)
anchorcrafter/modules/pose_net.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import einops
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.init as init
9
+
10
+ from diffusers.utils.constants import SAFETENSORS_WEIGHTS_NAME,WEIGHTS_NAME
11
+
12
+ from typing import Union, Optional
13
+
14
+
15
+ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
16
+ if variant is not None:
17
+ splits = weights_name.split(".")
18
+ splits = splits[:-1] + [variant] + splits[-1:]
19
+ weights_name = ".".join(splits)
20
+
21
+ return weights_name
22
+
23
+
24
+ class PoseNet(nn.Module):
25
+ """Convolutional network for processing pose sequence conditioning
26
+
27
+ Args:
28
+ latent_channels (int): Number of output latent channels (default: 320)
29
+ input_channels (int): Number of input pose channels (default: 6)
30
+ scale_factor (float): Initial output scaling factor (default: 2.0)
31
+ """
32
+ def __init__(
33
+ self,
34
+ latent_channels: int = 320,
35
+ input_channels: int = 6,
36
+ scale_factor: float = 2.0
37
+ ):
38
+ super().__init__()
39
+ # multiple convolution layers
40
+ self.conv_layers = nn.Sequential(
41
+ nn.Conv2d(input_channels, 6, kernel_size=3, padding=1),
42
+ nn.SiLU(),
43
+ nn.Conv2d(6, 16, kernel_size=4, stride=2, padding=1),
44
+ nn.SiLU(),
45
+
46
+ nn.Conv2d(16, 16, kernel_size=3, padding=1),
47
+ nn.SiLU(),
48
+ nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),
49
+ nn.SiLU(),
50
+
51
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
52
+ nn.SiLU(),
53
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
54
+ nn.SiLU(),
55
+
56
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
57
+ nn.SiLU(),
58
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
59
+ nn.SiLU()
60
+ )
61
+
62
+ # Final projection layer
63
+ self.final_proj = nn.Conv2d(128, latent_channels, kernel_size=1)
64
+
65
+ # Initialize layers
66
+ self._initialize_weights()
67
+
68
+ self.scale = nn.Parameter(torch.tensor(scale_factor, dtype=torch.float16))
69
+
70
+ def _initialize_weights(self):
71
+ """Initialize weights with He. initialization and zero out the biases
72
+ """
73
+ for m in self.conv_layers:
74
+ if isinstance(m, nn.Conv2d):
75
+ n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
76
+ init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
77
+ if m.bias is not None:
78
+ init.zeros_(m.bias)
79
+ init.zeros_(self.final_proj.weight)
80
+ if self.final_proj.bias is not None:
81
+ init.zeros_(self.final_proj.bias)
82
+
83
+ def forward(self, x):
84
+ if x.ndim == 5:
85
+ x = einops.rearrange(x, "b f c h w -> (b f) c h w")
86
+ x = self.conv_layers(x)
87
+
88
+ return self.final_proj(x) * self.scale
anchorcrafter/modules/track_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import einops
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.init as init
9
+
10
+ from typing import Optional
11
+
12
+
13
+ class TrackNet(nn.Module):
14
+ """Convolutional network for processing tracking sequence conditioning
15
+
16
+ Args:
17
+ latent_channels (int): Number of output latent channels (default: 320)
18
+ input_channels (int): Number of input tracking channels (default: 3)
19
+ scale_factor (float): Initial output scaling factor (default: 2.0)
20
+ """
21
+ def __init__(
22
+ self,
23
+ latent_channels=320,
24
+ input_channels: int = 3,
25
+ scale_factor: float = 2.0
26
+ ):
27
+ super().__init__()
28
+ # multiple convolution layers
29
+ self.conv_layers = nn.Sequential(
30
+ nn.Conv2d(input_channels, 3, kernel_size=3, padding=1),
31
+ nn.SiLU(),
32
+ nn.Conv2d(3, 16, kernel_size=4, stride=2, padding=1),
33
+ nn.SiLU(),
34
+
35
+ nn.Conv2d(16, 16, kernel_size=3, padding=1),
36
+ nn.SiLU(),
37
+ nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),
38
+ nn.SiLU(),
39
+
40
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
41
+ nn.SiLU(),
42
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
43
+ nn.SiLU(),
44
+
45
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
46
+ nn.SiLU(),
47
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
48
+ nn.SiLU()
49
+ )
50
+
51
+ # Final projection layer
52
+ self.final_proj = nn.Conv2d(in_channels=128, out_channels=latent_channels, kernel_size=1)
53
+
54
+ # Initialize layers
55
+ self._initialize_weights()
56
+
57
+ self.scale = nn.Parameter(torch.tensor(scale_factor, dtype=torch.float16))
58
+
59
+ def _initialize_weights(self):
60
+ """Initialize weights with He. initialization and zero out the biases
61
+ """
62
+ for m in self.conv_layers:
63
+ if isinstance(m, nn.Conv2d):
64
+ n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
65
+ init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
66
+ if m.bias is not None:
67
+ init.zeros_(m.bias)
68
+ init.zeros_(self.final_proj.weight)
69
+ if self.final_proj.bias is not None:
70
+ init.zeros_(self.final_proj.bias)
71
+
72
+ def forward(self, x):
73
+ if x.ndim == 5:
74
+ x = einops.rearrange(x, "b f c h w -> (b f) c h w")
75
+ x = self.conv_layers(x)
76
+ return self.final_proj(x) * self.scale
anchorcrafter/modules/unet.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders import UNet2DConditionLoadersMixin
8
+ from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
9
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.utils import BaseOutput, logging
12
+
13
+ from diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ @dataclass
19
+ class UNetSpatioTemporalConditionOutput(BaseOutput):
20
+ """
21
+ The output of [`UNetSpatioTemporalConditionModel`].
22
+
23
+ Args:
24
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
25
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
26
+ """
27
+
28
+ sample: torch.FloatTensor = None
29
+
30
+
31
+ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
32
+ r"""
33
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state,
34
+ and a timestep and returns a sample shaped output.
35
+
36
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
37
+ for all models (such as downloading or saving).
38
+
39
+ Parameters:
40
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
41
+ Height and width of input/output sample.
42
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
43
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
44
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal",
45
+ "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
46
+ The tuple of downsample blocks to use.
47
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal",
48
+ "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
49
+ The tuple of upsample blocks to use.
50
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
51
+ The tuple of output channels for each block.
52
+ addition_time_embed_dim: (`int`, defaults to 256):
53
+ Dimension to to encode the additional time ids.
54
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
55
+ The dimension of the projection of encoded `added_time_ids`.
56
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
57
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
58
+ The dimension of the cross attention features.
59
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
60
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
61
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
62
+ [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
63
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
64
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
65
+ The number of attention heads.
66
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
67
+ """
68
+
69
+ _supports_gradient_checkpointing = True
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ sample_size: Optional[int] = None,
75
+ in_channels: int = 8,
76
+ out_channels: int = 4,
77
+ down_block_types: Tuple[str] = (
78
+ "CrossAttnDownBlockSpatioTemporal",
79
+ "CrossAttnDownBlockSpatioTemporal",
80
+ "CrossAttnDownBlockSpatioTemporal",
81
+ "DownBlockSpatioTemporal",
82
+ ),
83
+ up_block_types: Tuple[str] = (
84
+ "UpBlockSpatioTemporal",
85
+ "CrossAttnUpBlockSpatioTemporal",
86
+ "CrossAttnUpBlockSpatioTemporal",
87
+ "CrossAttnUpBlockSpatioTemporal",
88
+ ),
89
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
90
+ addition_time_embed_dim: int = 256,
91
+ projection_class_embeddings_input_dim: int = 768,
92
+ layers_per_block: Union[int, Tuple[int]] = 2,
93
+ cross_attention_dim: Union[int, Tuple[int]] = 2048,
94
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
95
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
96
+ num_frames: int = 25,
97
+ ):
98
+ super().__init__()
99
+
100
+ self.sample_size = sample_size
101
+
102
+ # Check inputs
103
+ if len(down_block_types) != len(up_block_types):
104
+ raise ValueError(
105
+ f"Must provide the same number of `down_block_types` as `up_block_types`. " \
106
+ f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
107
+ )
108
+
109
+ if len(block_out_channels) != len(down_block_types):
110
+ raise ValueError(
111
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. " \
112
+ f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
113
+ )
114
+
115
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
116
+ raise ValueError(
117
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. " \
118
+ f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
119
+ )
120
+
121
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
122
+ raise ValueError(
123
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. " \
124
+ f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
125
+ )
126
+
127
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
128
+ raise ValueError(
129
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. " \
130
+ f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
131
+ )
132
+
133
+ # input
134
+ self.conv_in = nn.Conv2d(
135
+ in_channels,
136
+ block_out_channels[0],
137
+ kernel_size=3,
138
+ padding=1,
139
+ )
140
+
141
+ # time
142
+ time_embed_dim = block_out_channels[0] * 4
143
+
144
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
145
+ timestep_input_dim = block_out_channels[0]
146
+
147
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
148
+
149
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
150
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
151
+
152
+ self.down_blocks = nn.ModuleList([])
153
+ self.up_blocks = nn.ModuleList([])
154
+
155
+ if isinstance(num_attention_heads, int):
156
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
157
+
158
+ if isinstance(cross_attention_dim, int):
159
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
160
+
161
+ if isinstance(layers_per_block, int):
162
+ layers_per_block = [layers_per_block] * len(down_block_types)
163
+
164
+ if isinstance(transformer_layers_per_block, int):
165
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
166
+
167
+ blocks_time_embed_dim = time_embed_dim
168
+
169
+ # down
170
+ output_channel = block_out_channels[0]
171
+ for i, down_block_type in enumerate(down_block_types):
172
+ input_channel = output_channel
173
+ output_channel = block_out_channels[i]
174
+ is_final_block = i == len(block_out_channels) - 1
175
+
176
+ down_block = get_down_block(
177
+ down_block_type,
178
+ num_layers=layers_per_block[i],
179
+ transformer_layers_per_block=transformer_layers_per_block[i],
180
+ in_channels=input_channel,
181
+ out_channels=output_channel,
182
+ temb_channels=blocks_time_embed_dim,
183
+ add_downsample=not is_final_block,
184
+ resnet_eps=1e-5,
185
+ cross_attention_dim=cross_attention_dim[i],
186
+ num_attention_heads=num_attention_heads[i],
187
+ resnet_act_fn="silu",
188
+ )
189
+ self.down_blocks.append(down_block)
190
+
191
+ # mid
192
+ self.mid_block = UNetMidBlockSpatioTemporal(
193
+ block_out_channels[-1],
194
+ temb_channels=blocks_time_embed_dim,
195
+ transformer_layers_per_block=transformer_layers_per_block[-1],
196
+ cross_attention_dim=cross_attention_dim[-1],
197
+ num_attention_heads=num_attention_heads[-1],
198
+ )
199
+
200
+ # count how many layers upsample the images
201
+ self.num_upsamplers = 0
202
+
203
+ # up
204
+ reversed_block_out_channels = list(reversed(block_out_channels))
205
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
206
+ reversed_layers_per_block = list(reversed(layers_per_block))
207
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
208
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
209
+
210
+ output_channel = reversed_block_out_channels[0]
211
+ for i, up_block_type in enumerate(up_block_types):
212
+ is_final_block = i == len(block_out_channels) - 1
213
+
214
+ prev_output_channel = output_channel
215
+ output_channel = reversed_block_out_channels[i]
216
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
217
+
218
+ # add upsample block for all BUT final layer
219
+ if not is_final_block:
220
+ add_upsample = True
221
+ self.num_upsamplers += 1
222
+ else:
223
+ add_upsample = False
224
+
225
+ up_block = get_up_block(
226
+ up_block_type,
227
+ num_layers=reversed_layers_per_block[i] + 1,
228
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
229
+ in_channels=input_channel,
230
+ out_channels=output_channel,
231
+ prev_output_channel=prev_output_channel,
232
+ temb_channels=blocks_time_embed_dim,
233
+ add_upsample=add_upsample,
234
+ resnet_eps=1e-5,
235
+ resolution_idx=i,
236
+ cross_attention_dim=reversed_cross_attention_dim[i],
237
+ num_attention_heads=reversed_num_attention_heads[i],
238
+ resnet_act_fn="silu",
239
+ )
240
+ self.up_blocks.append(up_block)
241
+ prev_output_channel = output_channel
242
+
243
+ # out
244
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
245
+ self.conv_act = nn.SiLU()
246
+
247
+ self.conv_out = nn.Conv2d(
248
+ block_out_channels[0],
249
+ out_channels,
250
+ kernel_size=3,
251
+ padding=1,
252
+ )
253
+
254
+ @property
255
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
256
+ r"""
257
+ Returns:
258
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
259
+ indexed by its weight name.
260
+ """
261
+ # set recursively
262
+ processors = {}
263
+
264
+ def fn_recursive_add_processors(
265
+ name: str,
266
+ module: torch.nn.Module,
267
+ processors: Dict[str, AttentionProcessor],
268
+ ):
269
+ if hasattr(module, "get_processor"):
270
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
271
+
272
+ for sub_name, child in module.named_children():
273
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
274
+
275
+ return processors
276
+
277
+ for name, module in self.named_children():
278
+ fn_recursive_add_processors(name, module, processors)
279
+
280
+ return processors
281
+
282
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
283
+ r"""
284
+ Sets the attention processor to use to compute attention.
285
+
286
+ Parameters:
287
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
288
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
289
+ for **all** `Attention` layers.
290
+
291
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
292
+ processor. This is strongly recommended when setting trainable attention processors.
293
+
294
+ """
295
+ count = len(self.attn_processors.keys())
296
+
297
+ if isinstance(processor, dict) and len(processor) != count:
298
+ raise ValueError(
299
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
300
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
301
+ )
302
+
303
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
304
+ if hasattr(module, "set_processor"):
305
+ if not isinstance(processor, dict):
306
+ module.set_processor(processor)
307
+ else:
308
+ module.set_processor(processor.pop(f"{name}.processor"))
309
+
310
+ for sub_name, child in module.named_children():
311
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
312
+
313
+ for name, module in self.named_children():
314
+ fn_recursive_attn_processor(name, module, processor)
315
+
316
+ def set_default_attn_processor(self):
317
+ """
318
+ Disables custom attention processors and sets the default attention implementation.
319
+ """
320
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
321
+ processor = AttnProcessor()
322
+ else:
323
+ raise ValueError(
324
+ f"Cannot call `set_default_attn_processor` " \
325
+ f"when attention processors are of type {next(iter(self.attn_processors.values()))}"
326
+ )
327
+
328
+ self.set_attn_processor(processor)
329
+
330
+ def _set_gradient_checkpointing(self, module, value=False):
331
+ if hasattr(module, "gradient_checkpointing"):
332
+ module.gradient_checkpointing = value
333
+
334
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
335
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
336
+ """
337
+ Sets the attention processor to use [feed forward
338
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
339
+
340
+ Parameters:
341
+ chunk_size (`int`, *optional*):
342
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
343
+ over each tensor of dim=`dim`.
344
+ dim (`int`, *optional*, defaults to `0`):
345
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
346
+ or dim=1 (sequence length).
347
+ """
348
+ if dim not in [0, 1]:
349
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
350
+
351
+ # By default chunk size is 1
352
+ chunk_size = chunk_size or 1
353
+
354
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
355
+ if hasattr(module, "set_chunk_feed_forward"):
356
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
357
+
358
+ for child in module.children():
359
+ fn_recursive_feed_forward(child, chunk_size, dim)
360
+
361
+ for module in self.children():
362
+ fn_recursive_feed_forward(module, chunk_size, dim)
363
+
364
+ def forward(
365
+ self,
366
+ sample: torch.FloatTensor,
367
+ timestep: Union[torch.Tensor, float, int],
368
+ encoder_hidden_states: torch.Tensor,
369
+ added_time_ids: torch.Tensor,
370
+ pose_latents: torch.Tensor = None,
371
+ image_only_indicator: bool = False,
372
+ return_dict: bool = True,
373
+ obj_track_latents: torch.Tensor = None,
374
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
375
+ r"""
376
+ The [`UNetSpatioTemporalConditionModel`] forward method.
377
+
378
+ Args:
379
+ sample (`torch.FloatTensor`):
380
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
381
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
382
+ encoder_hidden_states (`torch.FloatTensor`):
383
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
384
+ added_time_ids: (`torch.FloatTensor`):
385
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
386
+ embeddings and added to the time embeddings.
387
+ pose_latents: (`torch.FloatTensor`):
388
+ The additional latents for pose sequences.
389
+ image_only_indicator (`bool`, *optional*, defaults to `False`):
390
+ Whether or not training with all images.
391
+ return_dict (`bool`, *optional*, defaults to `True`):
392
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`]
393
+ instead of a plain tuple.
394
+ Returns:
395
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
396
+ If `return_dict` is True,
397
+ an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned,
398
+ otherwise a `tuple` is returned where the first element is the sample tensor.
399
+ """
400
+ # 1. time
401
+ timesteps = timestep
402
+ if not torch.is_tensor(timesteps):
403
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
404
+ # This would be a good case for the `match` statement (Python 3.10+)
405
+ is_mps = sample.device.type == "mps"
406
+ if isinstance(timestep, float):
407
+ dtype = torch.float32 if is_mps else torch.float64
408
+ else:
409
+ dtype = torch.int32 if is_mps else torch.int64
410
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
411
+ elif len(timesteps.shape) == 0:
412
+ timesteps = timesteps[None].to(sample.device)
413
+
414
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
415
+ batch_size, num_frames = sample.shape[:2]
416
+ timesteps = timesteps.expand(batch_size)
417
+
418
+ t_emb = self.time_proj(timesteps)
419
+
420
+ # `Timesteps` does not contain any weights and will always return f32 tensors
421
+ # but time_embedding might actually be running in fp16. so we need to cast here.
422
+ # there might be better ways to encapsulate this.
423
+ t_emb = t_emb.to(dtype=torch.float16)
424
+
425
+ emb = self.time_embedding(t_emb)
426
+
427
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
428
+ time_embeds = time_embeds.reshape((batch_size, -1))
429
+ time_embeds = time_embeds.to(emb.dtype)
430
+ aug_emb = self.add_embedding(time_embeds)
431
+ emb = emb + aug_emb
432
+
433
+ # Flatten the batch and frames dimensions
434
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
435
+ sample = sample.flatten(0, 1)
436
+ # Repeat the embeddings num_video_frames times
437
+ # emb: [batch, channels] -> [batch * frames, channels]
438
+ emb = emb.repeat_interleave(num_frames, dim=0)
439
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
440
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
441
+
442
+ # 2. pre-process
443
+ sample = self.conv_in(sample)
444
+ if pose_latents is not None:
445
+ sample = sample + pose_latents
446
+ if obj_track_latents is not None:
447
+ sample = sample + obj_track_latents
448
+ image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \
449
+ if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
450
+
451
+ down_block_res_samples = (sample,)
452
+ for downsample_block in self.down_blocks:
453
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
454
+ sample, res_samples = downsample_block(
455
+ hidden_states=sample,
456
+ temb=emb,
457
+ encoder_hidden_states=encoder_hidden_states,
458
+ image_only_indicator=image_only_indicator,
459
+ )
460
+ else:
461
+ sample, res_samples = downsample_block(
462
+ hidden_states=sample,
463
+ temb=emb,
464
+ image_only_indicator=image_only_indicator,
465
+ )
466
+
467
+ down_block_res_samples += res_samples
468
+
469
+ # 4. mid
470
+ sample = self.mid_block(
471
+ hidden_states=sample,
472
+ temb=emb,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ image_only_indicator=image_only_indicator,
475
+ )
476
+
477
+ # 5. up
478
+ for i, upsample_block in enumerate(self.up_blocks):
479
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
480
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
481
+
482
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
483
+ sample = upsample_block(
484
+ hidden_states=sample,
485
+ temb=emb,
486
+ res_hidden_states_tuple=res_samples,
487
+ encoder_hidden_states=encoder_hidden_states,
488
+ image_only_indicator=image_only_indicator,
489
+ )
490
+ else:
491
+ sample = upsample_block(
492
+ hidden_states=sample,
493
+ temb=emb,
494
+ res_hidden_states_tuple=res_samples,
495
+ image_only_indicator=image_only_indicator,
496
+ )
497
+
498
+ # 6. post-process
499
+ sample = self.conv_norm_out(sample)
500
+ sample = self.conv_act(sample)
501
+ sample = self.conv_out(sample)
502
+
503
+ # 7. Reshape back to original shape
504
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
505
+
506
+ if not return_dict:
507
+ return (sample,)
508
+
509
+ return UNetSpatioTemporalConditionOutput(sample=sample)
anchorcrafter/pipelines/pipeline.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import os.path
4
+ from dataclasses import dataclass
5
+ from typing import Callable, Dict, List, Optional, Union
6
+
7
+ import PIL.Image
8
+ import einops
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
12
+ from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
13
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
15
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \
16
+ import _resize_with_antialiasing, _append_dims
17
+ from diffusers.schedulers import EulerDiscreteScheduler
18
+ from diffusers.utils import BaseOutput, logging
19
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
20
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
21
+ from anchorcrafter.modules.track_net import TrackNet
22
+ import torch.nn as nn
23
+ from transformers import AutoImageProcessor, AutoModel
24
+ import torch.nn.functional as F
25
+ from torchvision.transforms.functional import pil_to_tensor
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+ import constants
28
+
29
+ from ..modules.obj_proj_net import ObjProjNet
30
+ from ..modules.obj_attn_net import ObjAttnNet
31
+ from ..modules.pose_net import PoseNet
32
+
33
+
34
+ def _append_dims(x, target_dims):
35
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
36
+ dims_to_append = target_dims - x.ndim
37
+ if dims_to_append < 0:
38
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
39
+ return x[(...,) + (None,) * dims_to_append]
40
+
41
+
42
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
43
+ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
44
+ batch_size, channels, num_frames, height, width = video.shape
45
+ outputs = []
46
+ for batch_idx in range(batch_size):
47
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
48
+ batch_output = processor.postprocess(batch_vid, output_type)
49
+
50
+ outputs.append(batch_output)
51
+
52
+ if output_type == "np":
53
+ outputs = np.stack(outputs)
54
+
55
+ elif output_type == "pt":
56
+ outputs = torch.stack(outputs)
57
+
58
+ elif not output_type == "pil":
59
+ raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
60
+
61
+ return outputs
62
+
63
+
64
+ @dataclass
65
+ class AnchorCrafterPipelineOutput(BaseOutput):
66
+ r"""
67
+ Output class for anchorcrafter pipeline.
68
+
69
+ Args:
70
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
71
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
72
+ num_frames, height, width, num_channels)`.
73
+ """
74
+
75
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
76
+
77
+
78
+ class AnchorCrafterPipeline(DiffusionPipeline):
79
+ r"""
80
+ Pipeline to generate video from an input image using Stable Video Diffusion.
81
+
82
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
83
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
84
+
85
+ Args:
86
+ vae ([`AutoencoderKLTemporalDecoder`]):
87
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
88
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
89
+ Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K]
90
+ (https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
91
+ unet ([`UNetSpatioTemporalConditionModel`]):
92
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
93
+ scheduler ([`EulerDiscreteScheduler`]):
94
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
95
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
96
+ A `CLIPImageProcessor` to extract features from generated images.
97
+ dino_feature_extractor(['AutoImageProcessor']):
98
+ A `AutoImageProcessor` to extract features from images.
99
+ pose_net ([`PoseNet`]):
100
+ A net to inject pose signals into unet.
101
+ track_net (['TrackNet']):
102
+ A net to inject object pose signals into unet.
103
+ obj_proj_net (['ObjProjNet']):
104
+ A network with linearnet to extract object features.
105
+ obj_attn_net (['ObjAttnNet']):
106
+ A network with self attention to extract object features.
107
+
108
+ """
109
+
110
+ model_cpu_offload_seq = "image_encoder->unet->vae"
111
+ _callback_tensor_inputs = ["latents"]
112
+
113
+ def __init__(
114
+ self,
115
+ vae: AutoencoderKLTemporalDecoder,
116
+ image_encoder: CLIPVisionModelWithProjection,
117
+ obj_image_encoder: AutoModel,
118
+ unet: UNetSpatioTemporalConditionModel,
119
+ scheduler: EulerDiscreteScheduler,
120
+ feature_extractor: CLIPImageProcessor,
121
+ dino_feature_extractor: AutoImageProcessor,
122
+ pose_net: PoseNet,
123
+ track_net: TrackNet,
124
+ obj_proj_net: ObjProjNet,
125
+ obj_attn_net: ObjAttnNet
126
+ ):
127
+ super().__init__()
128
+
129
+ self.register_modules(
130
+ vae=vae,
131
+ image_encoder=image_encoder,
132
+ obj_image_encoder=obj_image_encoder,
133
+ unet=unet,
134
+ scheduler=scheduler,
135
+ feature_extractor=feature_extractor,
136
+ dino_feature_extractor=dino_feature_extractor,
137
+ pose_net=pose_net,
138
+ track_net=track_net,
139
+ obj_proj_net=obj_proj_net,
140
+ obj_attn_net=obj_attn_net
141
+ )
142
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
143
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
144
+
145
+ def _encode_image(
146
+ self,
147
+ image: PipelineImageInput,
148
+ obj_pixels: PipelineImageInput,
149
+ device: Union[str, torch.device],
150
+ num_videos_per_prompt: int,
151
+ do_classifier_free_guidance: bool):
152
+ dtype = next(self.image_encoder.parameters()).dtype
153
+ # print(image)
154
+ # print(obj_pixels)
155
+ if not isinstance(image, torch.Tensor):
156
+ image = self.image_processor.pil_to_numpy(image)
157
+ image = self.image_processor.numpy_to_pt(image)
158
+
159
+ # We normalize the image before resizing to match with the original implementation.
160
+ # Then we unnormalize it after resizing.
161
+ image = image * 2.0 - 1.0
162
+ image = _resize_with_antialiasing(image, (224, 224))
163
+ image = (image + 1.0) / 2.0
164
+
165
+ # Normalize the image with for CLIP input
166
+ image = self.feature_extractor(
167
+ images=image,
168
+ do_normalize=True,
169
+ do_center_crop=False,
170
+ do_resize=False,
171
+ do_rescale=False,
172
+ return_tensors="pt",
173
+ ).pixel_values
174
+
175
+ image = image.to(device=device, dtype=dtype)
176
+ image=image.to(dtype=torch.float16)
177
+
178
+ image_embeddings = self.image_encoder(image).image_embeds
179
+
180
+ obj_all_embeddings = None
181
+ for obj in obj_pixels:
182
+ if not isinstance(obj, torch.Tensor):
183
+ obj = self.image_processor.pil_to_numpy(obj)
184
+ obj = self.image_processor.numpy_to_pt(obj)
185
+
186
+ # We normalize the image before resizing to match with the original implementation.
187
+ # Then we unnormalize it after resizing.
188
+ obj = obj * 2.0 - 1.0
189
+ obj = _resize_with_antialiasing(obj, (518, 518))
190
+ obj = (obj + 1.0) / 2.0
191
+
192
+ # Normalize the image with for CLIP input
193
+ obj = self.dino_feature_extractor(
194
+ images=obj,
195
+ do_normalize=True,
196
+ do_center_crop=False,
197
+ do_resize=False,
198
+ do_rescale=False,
199
+ return_tensors="pt",
200
+ ).pixel_values
201
+
202
+ obj = obj.to(device=device, dtype=self.obj_image_encoder.dtype)
203
+ print("[dino feature extractor] output obj image:", obj.shape) # torch.Size([1, 3, 518, 518])
204
+
205
+ obj_pixels_embeddings = self.obj_image_encoder(obj).last_hidden_state # torch.Size([1, 257, 768])
206
+ #obj_pixels_embeddings = obj_pixels_embeddings[:, 0, :] # 1,768
207
+
208
+ if obj_all_embeddings is None:
209
+ obj_all_embeddings = obj_pixels_embeddings
210
+ else:
211
+ obj_all_embeddings = torch.concat((obj_all_embeddings, obj_pixels_embeddings), dim=1)
212
+ image_embeddings = image_embeddings.unsqueeze(1)
213
+
214
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
215
+ bs_embed, seq_len, _ = image_embeddings.shape
216
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
217
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
218
+ print("obj_all_embeddings", obj_all_embeddings)
219
+ return image_embeddings, obj_all_embeddings
220
+
221
+ def _encode_vae_image(
222
+ self,
223
+ image: torch.Tensor,
224
+ device: Union[str, torch.device],
225
+ num_videos_per_prompt: int,
226
+ do_classifier_free_guidance: bool,
227
+ ):
228
+ image = image.to(device=device, dtype=self.vae.dtype)
229
+
230
+ # image_latents = torch.zeros((image.shape[0], 4, 96, 64)).to(device=device, dtype=self.vae.dtype)
231
+ image_latents = torch.zeros((image.shape[0], 4, 128, 72)).to(device=device, dtype=self.vae.dtype)
232
+ for i in range(0, image.shape[0], 16):
233
+ if i + 16 > image.shape[0]:
234
+ image_latents[i:] = self.vae.encode(image[i:]).latent_dist.mode()
235
+ else:
236
+ image_latents[i:i + 16] = self.vae.encode(image[i:i + 16]).latent_dist.mode()
237
+
238
+ if do_classifier_free_guidance:
239
+ negative_image_latents = torch.zeros_like(image_latents)
240
+
241
+ # For classifier free guidance, we need to do two forward passes.
242
+ # Here we concatenate the unconditional and text embeddings into a single batch
243
+ # to avoid doing two forward passes
244
+ image_latents = torch.cat([negative_image_latents, image_latents])
245
+
246
+ # duplicate image_latents for each generation per prompt, using mps friendly method
247
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
248
+
249
+ return image_latents
250
+
251
+ def _get_add_time_ids(
252
+ self,
253
+ fps: int,
254
+ motion_bucket_id: int,
255
+ noise_aug_strength: float,
256
+ dtype: torch.dtype,
257
+ batch_size: int,
258
+ num_videos_per_prompt: int,
259
+ do_classifier_free_guidance: bool,
260
+ ):
261
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
262
+
263
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
264
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
265
+
266
+ if expected_add_embed_dim != passed_add_embed_dim:
267
+ raise ValueError(
268
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
269
+ f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. " \
270
+ f"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
271
+ )
272
+
273
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
274
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
275
+
276
+ if do_classifier_free_guidance:
277
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
278
+
279
+ return add_time_ids
280
+
281
+ def decode_latents(
282
+ self,
283
+ latents: torch.Tensor,
284
+ num_frames: int,
285
+ decode_chunk_size: int = 8):
286
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
287
+ latents = latents.flatten(0, 1)
288
+
289
+ latents = 1 / self.vae.config.scaling_factor * latents
290
+
291
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
292
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
293
+
294
+ # decode decode_chunk_size frames at a time to avoid OOM
295
+ frames = []
296
+ for i in range(0, latents.shape[0], decode_chunk_size):
297
+ num_frames_in = latents[i: i + decode_chunk_size].shape[0]
298
+ decode_kwargs = {}
299
+ if accepts_num_frames:
300
+ # we only pass num_frames_in if it's expected
301
+ decode_kwargs["num_frames"] = num_frames_in
302
+
303
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
304
+ frames.append(frame.cpu())
305
+ frames = torch.cat(frames, dim=0)
306
+
307
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
308
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
309
+
310
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
311
+ frames = frames.float()
312
+ return frames
313
+
314
+ def check_inputs(self, image, height, width):
315
+ if (
316
+ not isinstance(image, torch.Tensor)
317
+ and not isinstance(image, PIL.Image.Image)
318
+ and not isinstance(image, list)
319
+ ):
320
+ raise ValueError(
321
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
322
+ f" {type(image)}"
323
+ )
324
+
325
+ if height % 8 != 0 or width % 8 != 0:
326
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
327
+
328
+ def prepare_latents(
329
+ self,
330
+ batch_size: int,
331
+ num_frames: int,
332
+ num_channels_noise_latents: int,
333
+ height: int,
334
+ width: int,
335
+ dtype: torch.dtype,
336
+ device: Union[str, torch.device],
337
+ generator: torch.Generator,
338
+ latents: Optional[torch.Tensor] = None,
339
+ ):
340
+ shape = (
341
+ batch_size,
342
+ num_frames,
343
+ num_channels_noise_latents,
344
+ height // self.vae_scale_factor,
345
+ width // self.vae_scale_factor,
346
+ )
347
+ if isinstance(generator, list) and len(generator) != batch_size:
348
+ raise ValueError(
349
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
350
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
351
+ )
352
+
353
+ if latents is None:
354
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
355
+ else:
356
+ latents = latents.to(device)
357
+
358
+ # scale the initial noise by the standard deviation required by the scheduler
359
+ latents = latents * self.scheduler.init_noise_sigma
360
+ return latents
361
+
362
+ @property
363
+ def guidance_scale(self):
364
+ return self._guidance_scale
365
+
366
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
367
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
368
+ # corresponds to doing no classifier free guidance.
369
+ @property
370
+ def do_classifier_free_guidance(self):
371
+ if isinstance(self.guidance_scale, (int, float)):
372
+ return self.guidance_scale > 1
373
+ return self.guidance_scale.max() > 1
374
+
375
+ @property
376
+ def num_timesteps(self):
377
+ return self._num_timesteps
378
+
379
+ def prepare_extra_step_kwargs(self, generator, eta):
380
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
381
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
382
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
383
+ # and should be between [0, 1]
384
+
385
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
386
+ extra_step_kwargs = {}
387
+ if accepts_eta:
388
+ extra_step_kwargs["eta"] = eta
389
+
390
+ # check if the scheduler accepts generator
391
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
392
+ if accepts_generator:
393
+ extra_step_kwargs["generator"] = generator
394
+ return extra_step_kwargs
395
+
396
+ @torch.no_grad()
397
+ def __call__(
398
+ self,
399
+ image_pixels: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
400
+ pose_pixels: Union[torch.FloatTensor],
401
+ obj_pixels: Union[torch.FloatTensor],
402
+ obj_track_pixels: Union[torch.FloatTensor],
403
+ hand_pixels: Union[torch.FloatTensor],
404
+ height: int = 576,
405
+ width: int = 1024,
406
+ num_frames: Optional[int] = None,
407
+ tile_size: Optional[int] = 16,
408
+ tile_overlap: Optional[int] = 4,
409
+ num_inference_steps: int = 25,
410
+ min_guidance_scale: float = 1.0,
411
+ max_guidance_scale: float = 3.0,
412
+ fps: int = 7,
413
+ motion_bucket_id: int = 127,
414
+ noise_aug_strength: float = 0.02,
415
+ image_only_indicator: bool = False,
416
+ decode_chunk_size: Optional[int] = None,
417
+ num_videos_per_prompt: Optional[int] = 1,
418
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
419
+ latents: Optional[torch.FloatTensor] = None,
420
+ output_type: Optional[str] = "pil",
421
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
422
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
423
+ return_dict: bool = True,
424
+ device: Union[str, torch.device] = None,
425
+ visual_output: bool = False,
426
+ ):
427
+ r"""
428
+ Pipeline execution function for video generation.
429
+
430
+ Args:
431
+ image_pixels: Input image(s) for guidance
432
+ pose_pixels: Pose data tensor
433
+ obj_pixels: Object reference tensor
434
+ obj_track_pixels: Object tracking data tensor
435
+ hand_pixels: Hand tracking data tensor
436
+ height: Output video height
437
+ width: Output video width
438
+ num_frames: Number of frames to generate
439
+ tile_size: Processing tile size
440
+ tile_overlap: Tile overlap size
441
+ num_inference_steps: Number of denoising steps
442
+ min_guidance_scale: Minimum CFG scale
443
+ max_guidance_scale: Maximum CFG scale
444
+ fps: Frames per second
445
+ motion_bucket_id: Motion control parameter
446
+ noise_aug_strength: Noise augmentation strength
447
+ image_only_indicator: Image-only processing flag
448
+ decode_chunk_size: Frame decoding chunk size
449
+ num_videos_per_prompt: Videos per prompt
450
+ generator: Random number generator
451
+ latents: Initial latent vectors
452
+ output_type: Output format
453
+ callback_on_step_end: Callback function
454
+ callback_on_step_end_tensor_inputs: Callback inputs
455
+ return_dict: Return type flag
456
+ device: Computation device
457
+ visual_output: Visualization flag
458
+
459
+ Returns:
460
+ Generated video output
461
+ """
462
+ pose_pixels = torch.cat([pose_pixels, hand_pixels], dim=1)
463
+
464
+ # 0. Default height and width to unet
465
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
466
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
467
+
468
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
469
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
470
+
471
+ # 1. Check inputs. Raise error if not correct
472
+ self.check_inputs(image_pixels, height, width)
473
+
474
+ # 2. Define call parameters
475
+ if isinstance(image_pixels, PIL.Image.Image):
476
+ batch_size = 1
477
+ elif isinstance(image_pixels, list):
478
+ batch_size = len(image_pixels)
479
+ else:
480
+ batch_size = image_pixels.shape[0]
481
+ device = device if device is not None else self._execution_device
482
+
483
+
484
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
485
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
486
+ # corresponds to doing no classifier free guidance.
487
+ self._guidance_scale = max_guidance_scale
488
+
489
+ # 3. Encode input image
490
+ self.image_encoder.to(device)
491
+ self.obj_image_encoder.to(device)
492
+
493
+ encoder_hidden_states, obj_embeddings = self._encode_image(image_pixels, obj_pixels, device, num_videos_per_prompt,
494
+ self.do_classifier_free_guidance)
495
+ obj_embeddings = obj_embeddings.to(encoder_hidden_states.dtype)
496
+ # self.image_encoder.cpu()
497
+
498
+ self.image_encoder.cpu()
499
+ self.obj_image_encoder.cpu()
500
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
501
+ # is why it is reduced here.
502
+ fps = fps - 1
503
+
504
+ # 4. Encode input image using VAE
505
+ image_pixels = self.image_processor.preprocess(image_pixels, height=height, width=width).to(device)
506
+ obj_image = pil_to_tensor(obj_pixels[1])
507
+ h_pad = (image_pixels.shape[-2] - obj_image.shape[-2]) // 2
508
+ w_pad = (image_pixels.shape[-1] - obj_image.shape[-1]) // 2
509
+ obj_image = F.pad(obj_image, (w_pad, w_pad, h_pad, h_pad), mode='constant', value=0)
510
+ print(f'obj_image before process: {obj_image.shape}')
511
+ obj_image = self.image_processor.preprocess(obj_image, height=height, width=width).to(device)
512
+ print(f'obj_image after process: {obj_image.shape}')
513
+
514
+ noise = randn_tensor(image_pixels.shape, generator=generator, device=device, dtype=image_pixels.dtype)
515
+ image_pixels = image_pixels + noise_aug_strength * noise
516
+ obj_image = obj_image + noise_aug_strength * noise
517
+
518
+ self.vae.to(device)
519
+ image_latents = self._encode_vae_image(
520
+ image_pixels,
521
+ device=device,
522
+ num_videos_per_prompt=num_videos_per_prompt,
523
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
524
+ )
525
+ image_latents = image_latents.to(encoder_hidden_states.dtype)
526
+ obj_image_latents = self._encode_vae_image(
527
+ obj_image,
528
+ device=device,
529
+ num_videos_per_prompt=num_videos_per_prompt,
530
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
531
+ )
532
+ obj_image_latents = obj_image_latents.to(encoder_hidden_states.dtype)
533
+ #print(f'image_latents: {image_latents}')
534
+ self.vae.cpu()
535
+
536
+ # Repeat the image latents for each frame so we can concatenate them with the noise
537
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
538
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
539
+ obj_image_latents = obj_image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
540
+ # 5. Get Added Time IDs
541
+ added_time_ids = self._get_add_time_ids(
542
+ fps,
543
+ motion_bucket_id,
544
+ noise_aug_strength,
545
+ encoder_hidden_states.dtype,
546
+ batch_size,
547
+ num_videos_per_prompt,
548
+ self.do_classifier_free_guidance,
549
+ )
550
+ added_time_ids = added_time_ids.to(device)
551
+
552
+ # 4. Prepare timesteps
553
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)
554
+
555
+ # 5. Prepare latent variables
556
+ # num_channels_latents = self.unet.config.in_channels
557
+ # print("latents",latents)
558
+ latents = self.prepare_latents(
559
+ batch_size * num_videos_per_prompt,
560
+ tile_size,
561
+ 4,
562
+ height,
563
+ width,
564
+ encoder_hidden_states.dtype,
565
+ device,
566
+ generator,
567
+ latents,
568
+ )
569
+ latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames]
570
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
571
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
572
+
573
+ # 7. Prepare guidance scale
574
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
575
+ guidance_scale = guidance_scale.to(device, latents.dtype)
576
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
577
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
578
+
579
+ self._guidance_scale = guidance_scale
580
+
581
+ # 8. Denoising loop
582
+ self._num_timesteps = len(timesteps)
583
+
584
+ self.pose_net.to(device)
585
+ self.track_net.to(device)
586
+ self.unet.to(device)
587
+ self.obj_proj_net.to(device)
588
+ self.obj_attn_net.to(device)
589
+
590
+ with torch.cuda.device(device):
591
+ torch.cuda.empty_cache()
592
+
593
+ obj_cls_emb = torch.cat([
594
+ obj_embeddings[:, 0, :], obj_embeddings[:, 1370, :], obj_embeddings[:, 1370*2, :]
595
+ ], dim=1).to(torch.float16)
596
+ obj_cls_embeddings = self.obj_proj_net(obj_cls_emb)
597
+ obj_embeddings = obj_embeddings.to(torch.device('cuda'))
598
+ obj_attn_embeddings = self.obj_attn_net(obj_embeddings)
599
+ encoder_hidden_states = torch.concat([
600
+ encoder_hidden_states, obj_cls_embeddings, obj_attn_embeddings
601
+ ], dim=1)
602
+
603
+ if self.do_classifier_free_guidance:
604
+ negative_image_embeddings = torch.zeros_like(encoder_hidden_states)
605
+
606
+ # For classifier free guidance, we need to do two forward passes.
607
+ # Here we concatenate the unconditional and text embeddings into a single batch
608
+ # to avoid doing two forward passes
609
+ encoder_hidden_states = torch.cat([negative_image_embeddings, encoder_hidden_states])
610
+
611
+ def hook_function(module, inputdata, output):
612
+ if isinstance(output, tuple):
613
+ print(f"Module name: {module.__class__.__name__} Output shape: {output}")
614
+ else:
615
+ print(f"Module name: {module.__class__.__name__} Output shape: {output.shape}")
616
+ print("Output stats - mean: {}, std: {}, min: {}, max: {}".format(output.mean().item(), output.std().item(),
617
+ output.min().item(), output.max().item()))
618
+ if torch.isnan(output).any():
619
+ print(f"!!!!!!!!!!!!!!!!!!!!NaN detected after layer: {module.__class__.__name__}!!!!!!!!!!!!!!!!!!!!")
620
+ hooks = []
621
+ def register_hooks():
622
+ for name, module in self.unet.named_modules():
623
+ if isinstance(module, nn.Module):
624
+ hooks.append(module.register_forward_hook(hook_function))
625
+
626
+ bias_start = 1
627
+ bias_step = 4
628
+ with (self.progress_bar(total=len(timesteps) * math.ceil((num_frames-1)/(tile_size-1))) as progress_bar):
629
+ for i, t in enumerate(timesteps):
630
+ # expand the latents if we are doing classifier free guidance
631
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
632
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
633
+ # Concatenate image_latents over channels dimension
634
+ print(f"{latent_model_input.shape} {image_latents.shape} {obj_image_latents.shape}")
635
+ latent_model_input = torch.cat([latent_model_input, image_latents, obj_image_latents], dim=2)
636
+
637
+ # predict the noise residual
638
+ noise_pred = torch.zeros_like(image_latents)
639
+ noise_pred_cnt = torch.zeros_like(image_latents)
640
+ weight = torch.ones_like(image_latents)
641
+
642
+ bias_start = (bias_start - 1) % (num_frames - 1) + 1
643
+ start_cur = bias_start
644
+ finished_len = 1
645
+ print(f'start_cur {start_cur}')
646
+ while finished_len < num_frames:
647
+ start_cur = (start_cur - 1) % (num_frames - 1) + 1
648
+ end_cur = start_cur + tile_size - 1
649
+
650
+ idx = [0, ]
651
+ idx.extend([(ii - 1) % (num_frames - 1) + 1 for ii in range(start_cur, end_cur)])
652
+ print(idx)
653
+ # classification-free inference
654
+ pose_latents = self.pose_net(pose_pixels[idx].to(dtype=torch.float16).to(device))
655
+
656
+ track_latents = self.track_net(obj_track_pixels[idx].to(dtype=torch.float16).to(device))
657
+
658
+ if visual_output:
659
+ os.makedirs('./visual_spatio_attn', exist_ok=True)
660
+ for name, module in self.unet.named_modules():
661
+ if '.transformer_blocks.' in name and name.endswith('.attn2'):
662
+ module.visual_path = None
663
+
664
+ latent_model_input=latent_model_input.to(dtype=torch.float16)
665
+ encoder_hidden_states=encoder_hidden_states.to(dtype=torch.float16)
666
+ t=t.to(dtype=torch.float16)
667
+ _noise_pred = self.unet(
668
+ latent_model_input[:1, idx],
669
+ t,
670
+ encoder_hidden_states=encoder_hidden_states[:1],
671
+ added_time_ids=added_time_ids[:1],
672
+ pose_latents=None,
673
+ image_only_indicator=image_only_indicator,
674
+ return_dict=False,
675
+ obj_track_latents=None,
676
+ )[0]
677
+ noise_pred[:1, idx] += _noise_pred
678
+
679
+ # normal inference
680
+
681
+ if visual_output:
682
+ os.makedirs('./visual_spatio_attn', exist_ok=True)
683
+ for name, module in self.unet.named_modules():
684
+ if '.transformer_blocks.' in name and name.endswith('.attn2'):
685
+ module.visual_path = os.path.join('./visual_spatio_attn', name[:-6] + '.png')
686
+
687
+ _noise_pred = self.unet(
688
+ latent_model_input[1:, idx],
689
+ t,
690
+ encoder_hidden_states=encoder_hidden_states[1:],
691
+ added_time_ids=added_time_ids[1:],
692
+ pose_latents=pose_latents,
693
+ image_only_indicator=image_only_indicator,
694
+ return_dict=False,
695
+ obj_track_latents= track_latents,
696
+ )[0]
697
+ noise_pred[1:, idx] += _noise_pred
698
+
699
+ noise_pred_cnt[:, idx] += weight[:, idx]
700
+ finished_len += tile_size - 1
701
+ start_cur += tile_size - 1
702
+ progress_bar.update()
703
+
704
+ bias_start += bias_step
705
+ noise_pred = noise_pred.div_(noise_pred_cnt)
706
+
707
+ # perform guidance
708
+ if self.do_classifier_free_guidance:
709
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
710
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
711
+
712
+ # compute the previous noisy sample x_t -> x_t-1
713
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
714
+
715
+ if callback_on_step_end is not None:
716
+ callback_kwargs = {}
717
+ for k in callback_on_step_end_tensor_inputs:
718
+ callback_kwargs[k] = locals()[k]
719
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
720
+
721
+ latents = callback_outputs.pop("latents", latents)
722
+
723
+ self.pose_net.cpu()
724
+ self.unet.cpu()
725
+ self.track_net.cpu()
726
+ self.obj_proj_net.cpu()
727
+ if not output_type == "latent":
728
+ self.vae.decoder.to(device)
729
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
730
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
731
+ else:
732
+ frames = latents
733
+
734
+ self.maybe_free_model_hooks()
735
+
736
+ if not return_dict:
737
+ return frames
738
+
739
+ return AnchorCrafterPipelineOutput(frames=frames)
anchorcrafter/utils/__init__.py ADDED
File without changes
anchorcrafter/utils/geglu_patch.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers.models.activations
2
+
3
+
4
+ def patch_geglu_inplace():
5
+ """Patch GEGLU with inplace multiplication to save GPU memory."""
6
+ def forward(self, hidden_states):
7
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
8
+ hidden_states = hidden_states.clone()
9
+ return hidden_states.mul_(self.gelu(gate))
10
+ diffusers.models.activations.GEGLU.forward = forward
anchorcrafter/utils/loader.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from diffusers.models import AutoencoderKLTemporalDecoder
6
+ from diffusers.schedulers import EulerDiscreteScheduler
7
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
8
+ from transformers import AutoImageProcessor, AutoModel
9
+
10
+ from ..modules.unet import UNetSpatioTemporalConditionModel
11
+ from ..modules.track_net import TrackNet
12
+ from ..modules.obj_proj_net import ObjProjNet
13
+ from ..modules.obj_attn_net import ObjAttnNet
14
+ from ..modules.pose_net import PoseNet
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class AnchorCrafter(torch.nn.Module):
20
+ def __init__(self, base_model_path, dino_path):
21
+ """construnct base model components and load pretrained svd model except pose-net
22
+ Args:
23
+ base_model_path (str): pretrained svd model path
24
+ """
25
+ super().__init__()
26
+ unet_config = UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet")
27
+ unet_config["in_channels"] = 12
28
+ self.unet = UNetSpatioTemporalConditionModel.from_config(unet_config).to(torch.float16)
29
+ self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
30
+ base_model_path, subfolder="vae", torch_dtype=torch.float16, variant="fp16")
31
+ self.obj_image_encoder = AutoModel.from_pretrained(dino_path).to(torch.float16)
32
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
33
+ base_model_path, subfolder="image_encoder", torch_dtype=torch.float16, variant="fp16")
34
+ self.noise_scheduler = EulerDiscreteScheduler.from_pretrained(
35
+ base_model_path, subfolder="scheduler")
36
+ self.feature_extractor = CLIPImageProcessor.from_pretrained(
37
+ base_model_path, subfolder="feature_extractor")
38
+ self.dino_feature_extractor = AutoImageProcessor.from_pretrained(dino_path)
39
+
40
+ # pose_net
41
+ self.pose_net = PoseNet(latent_channels=self.unet.config.block_out_channels[0]).to(dtype=torch.float16)
42
+ # track_net
43
+ self.track_net = TrackNet(latent_channels=self.unet.config.block_out_channels[0]).to(dtype=torch.float16)
44
+ self.obj_proj_net = ObjProjNet(context_tokens=3).to(dtype=torch.float16)
45
+ self.obj_attn_net = ObjAttnNet().to(dtype=torch.float16)
anchorcrafter/utils/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import tempfile
5
+ import torch
6
+ from typing import List, Union
7
+
8
+
9
+ def save_video_with_cv2(frames: Union[torch.Tensor, List[np.ndarray]], output_path: str, fps: int = 24):
10
+ """Save video using OpenCV (supports PyTorch tensors or numpy arrays input)"""
11
+ if isinstance(frames, torch.Tensor):
12
+ frames = frames.detach().cpu().numpy()
13
+
14
+ # Ensure data is uint8 type in 0-255 range
15
+ processed_frames = []
16
+ for frame in frames:
17
+ # Convert float types (assuming 0-1 range) to 0-255
18
+ if frame.dtype == np.float32 or frame.dtype == np.float64:
19
+ frame = (frame * 255).clip(0, 255).astype(np.uint8)
20
+ elif frame.dtype != np.uint8:
21
+ frame = frame.astype(np.uint8)
22
+
23
+ # Convert color channel order to BGR (OpenCV requirement)
24
+ if frame.ndim == 3 and frame.shape[2] == 3: # If RGB format
25
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
26
+
27
+ processed_frames.append(frame)
28
+
29
+ if not processed_frames:
30
+ raise ValueError("No valid video frames to save")
31
+
32
+ height, width = processed_frames[0].shape[:2]
33
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
34
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
35
+
36
+ try:
37
+ for frame in processed_frames:
38
+ writer.write(frame)
39
+ finally:
40
+ writer.release()
41
+
42
+
43
+ def save_to_mp4(frames: Union[torch.Tensor, List[np.ndarray]], fps: int = 7) -> str:
44
+ """Save to MP4 and return temporary file path"""
45
+ # Adjust dimensions if input is PyTorch tensor (f, c, h, w) -> (f, h, w, c)
46
+ if isinstance(frames, torch.Tensor):
47
+ frames = frames.permute(0, 2, 3, 1)
48
+
49
+ temp_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
50
+ save_video_with_cv2(frames, temp_path, fps)
51
+ return temp_path
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import spaces
3
+ import os
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ from inference import process_inputs,run_pipeline
8
+
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from anchorcrafter.utils.utils import save_to_mp4
12
+ from threading import Thread
13
+ from anchorcrafter.utils.loader import AnchorCrafter
14
+ from huggingface_hub import hf_hub_download
15
+ from diffusers.utils.import_utils import is_xformers_available
16
+ from diffusers.models.attention_processor import XFormersAttnProcessor
17
+ from anchorcrafter.modules.attention_processor import IPAttnProcessor
18
+ from anchorcrafter.pipelines.pipeline import AnchorCrafterPipeline
19
+ from packaging import version
20
+ import logging
21
+ logger = logging.getLogger(__name__)
22
+ css='''
23
+ .text-container {
24
+ background-color: #f0faff;
25
+ border: 1px solid #b3d8ff;
26
+ border-radius: 6px;
27
+ padding: 5px;
28
+ margin: 5px auto;
29
+ width: fit-content;
30
+ box-shadow: 2px 2px 6px rgba(0, 0, 0, 0.1);
31
+ }
32
+
33
+ .text-container h2 {
34
+ font-family: Arial, sans-serif;
35
+ color: #000000;
36
+ font-size: 18px;
37
+ font-weight: bold;
38
+ margin-bottom: 5px;
39
+ margin-top: 5px;
40
+ }
41
+
42
+ .text-container p {
43
+ font-family: Arial, sans-serif;
44
+ color: #444444;
45
+ font-size: 18px;
46
+ line-height: 1.5;
47
+ margin-top: 5px;
48
+ }
49
+ '''
50
+ global pipeline, infer_config, model_path, anchorcrafter_models
51
+ # Path mappings
52
+ IMAGE_VIDEO_MAP = {
53
+ 0: ["data/video/hmbb_1.mp4", "data/video/hmbb_2.mp4"],
54
+ 1: ["data/video/cheese_1.mp4", "data/video/cheese_2.mp4"],
55
+ 2: ["data/video/earphone_1.mp4", "data/video/earphone_2.mp4"],
56
+ 3: ["data/video/mouse_1.mp4", "data/video/mouse_2.mp4"],
57
+ 4: ["data/video/cup_1.mp4", "data/video/cup_2.mp4"],}
58
+
59
+ OBJECT_INDEX_MAP ={
60
+ "hmbb":0,"cheese":1,"earphone":2,"mouse":3,"cup":4
61
+ }
62
+ OUTPUT_PATH_MAP={
63
+ "hmbb":"data/out/hmbb.mp4","earphone":"data/out/earphone.mp4","cup":"data/out/cup.mp4","mouse":"data/out/mouse.mp4","cheese":"data/out/cheese.mp4"
64
+ }
65
+ POSE_TRACK_MAP = {
66
+ 0: [["data/depth_cut/hmbb_1.mp4", "data/hand_cut/hmbb_1.mp4"],
67
+ ["data/depth_cut/hmbb_2.mp4", "data/hand_cut/hmbb_2.mp4"]],
68
+ 1: [["data/depth_cut/cheese_1.mp4", "data/hand_cut/cheese_1.mp4"],
69
+ ["data/depth_cut/cheese_2.mp4", "data/hand_cut/cheese_2.mp4"]],
70
+ 2: [["data/depth_cut/earphone_1.mp4", "data/hand_cut/earphone_1.mp4"],
71
+ ["data/depth_cut/earphone_2.mp4", "data/hand_cut/earphone_2.mp4"]],
72
+ 3: [["data/depth_cut/mouse_1.mp4", "data/hand_cut/mouse_1.mp4"],
73
+ ["data/depth_cut/mouse_2.mp4", "data/hand_cut/mouse_2.mp4"]],
74
+ 4: [["data/depth_cut/cup_1.mp4", "data/hand_cut/cup_1.mp4"],
75
+ ["data/depth_cut/cup_2.mp4", "data/hand_cut/cup_2.mp4"]]}
76
+
77
+ EXAMPLE_IMAGES = [
78
+ "data/object/hmbb_1.jpg",
79
+ "data/object/cheese_1.jpg",
80
+ "data/object/earphone_1.jpg",
81
+ "data/object/mouse_1.jpg",
82
+ "data/object/cup_1.jpg",
83
+ ]
84
+
85
+ def update_video_choices(evt: gr.SelectData, selected_state):
86
+ """Update video choices based on gallery selection"""
87
+ selected_state = evt.index
88
+ video1, video2 = IMAGE_VIDEO_MAP[selected_state]
89
+ return (
90
+ gr.update(value=video1, visible=True),
91
+ gr.update(value=video2, visible=True),
92
+ selected_state
93
+ )
94
+
95
+ def clear_anchor():
96
+ """Clear anchor image input"""
97
+ return gr.update(value=None)
98
+
99
+ def select_button1(selected_state, video_state):
100
+ """Handle first video selection"""
101
+ return (
102
+ gr.update(variant="primary"),
103
+ gr.update(variant="secondary"),
104
+ 0,
105
+ selected_state
106
+ )
107
+
108
+ def select_button2(selected_state, video_state):
109
+ """Handle second video selection"""
110
+ return (
111
+ gr.update(variant="secondary"),
112
+ gr.update(variant="primary"),
113
+ 1,
114
+ selected_state
115
+ )
116
+
117
+ def load_model():
118
+ """Initialize model components"""
119
+ global pipeline, infer_config, model_path, anchorcrafter_models
120
+ infer_config = OmegaConf.load("config/test.yaml")
121
+ anchorcrafter_models = AnchorCrafter(infer_config.base_model_path, infer_config.dino_path)
122
+ # Download model weights
123
+ model_path = hf_hub_download(
124
+ repo_id=infer_config.anchorcrafter_path,
125
+ filename="pytorch_model.bin"
126
+ )
127
+
128
+ @spaces.GPU
129
+ def run(infer_config,image_pixels, pose_pixels, obj_pixels, obj_track_pixels,hand_pixels):
130
+ """Execute the generation pipeline"""
131
+ global anchorcrafter_models
132
+ device=torch.device('cuda')
133
+ if is_xformers_available():
134
+ import xformers
135
+ xformers_version = version.parse(xformers.__version__)
136
+ if xformers_version == version.parse("0.0.16"):
137
+ logger.warn(
138
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
139
+ )
140
+ anchorcrafter_models.unet.enable_xformers_memory_efficient_attention()
141
+ else:
142
+ raise ValueError(
143
+ "xformers is not available. Make sure it is installed correctly")
144
+
145
+ # Configure attention processors
146
+ attn_procs = {}
147
+ for name in anchorcrafter_models.unet.attn_processors.keys():
148
+ cross_attention_dim = None if name.endswith(
149
+ "attn1.processor") else anchorcrafter_models.unet.config.cross_attention_dim
150
+ hidden_size = None
151
+ if name.startswith("mid_block"):
152
+ hidden_size = anchorcrafter_models.unet.config.block_out_channels[-1]
153
+ elif name.startswith("up_blocks"):
154
+ block_id = int(name[len("up_blocks.")])
155
+ hidden_size = list(reversed(anchorcrafter_models.unet.config.block_out_channels))[block_id]
156
+ elif name.startswith("down_blocks"):
157
+ block_id = int(name[len("down_blocks.")])
158
+ hidden_size = anchorcrafter_models.unet.config.block_out_channels[block_id]
159
+ if cross_attention_dim is None:
160
+ attn_procs[name] = XFormersAttnProcessor()
161
+ else:
162
+ attn_procs[name] = IPAttnProcessor(
163
+ hidden_size=hidden_size,
164
+ cross_attention_dim=cross_attention_dim,
165
+ scale=1.0,
166
+ num_tokens=15
167
+ )
168
+ anchorcrafter_models.unet.set_attn_processor(attn_procs)
169
+ anchorcrafter_models=anchorcrafter_models.to(torch.float16)
170
+ # Load model weights
171
+ model_weights = torch.load(model_path)
172
+ missing, unexpected = anchorcrafter_models.load_state_dict(model_weights, strict=False)
173
+ logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
174
+ # Initialize pipeline
175
+ pipeline = AnchorCrafterPipeline(
176
+ vae=anchorcrafter_models.vae,
177
+ image_encoder=anchorcrafter_models.image_encoder,
178
+ obj_image_encoder=anchorcrafter_models.obj_image_encoder,
179
+ unet=anchorcrafter_models.unet,
180
+ scheduler=anchorcrafter_models.noise_scheduler,
181
+ feature_extractor=anchorcrafter_models.feature_extractor,
182
+ dino_feature_extractor=anchorcrafter_models.dino_feature_extractor,
183
+ pose_net=anchorcrafter_models.pose_net,
184
+ track_net=anchorcrafter_models.track_net,
185
+ obj_proj_net=anchorcrafter_models.obj_proj_net,
186
+ obj_attn_net=anchorcrafter_models.obj_attn_net
187
+ )
188
+
189
+ for task in infer_config.test_case:
190
+ _video_frames = run_pipeline(
191
+ pipeline,
192
+ image_pixels, pose_pixels, obj_pixels, obj_track_pixels,
193
+ hand_pixels=hand_pixels, total_frames=infer_config.total_frames,
194
+ device=device, task_config=task)
195
+
196
+ return _video_frames
197
+
198
+ def pre(selected_state, video_state, anchor_image):
199
+ """Process user inputs and generate video"""
200
+ if anchor_image is None:
201
+ raise gr.Error("Please upload an anchor image first!")
202
+ # Convert PIL Image to numpy array
203
+ if isinstance(anchor_image, Image.Image):
204
+ anchor_image = np.array(anchor_image)
205
+ logger.debug(f"Converted image shape: {anchor_image.shape}")
206
+
207
+ # Get resource paths
208
+ video_path = IMAGE_VIDEO_MAP[selected_state][video_state]
209
+ obj_path = EXAMPLE_IMAGES[selected_state]
210
+ obj_track_path = POSE_TRACK_MAP[selected_state][video_state][0]
211
+ hand_path = POSE_TRACK_MAP[selected_state][video_state][1]
212
+
213
+ # Preprocess inputs
214
+ pose_pixels, image_pixels, obj_pixels, obj_track_pixels, hand_pixels = process_inputs(
215
+ video_path=video_path,
216
+ image_pixels=anchor_image,
217
+ obj_path=obj_path,
218
+ obj_track_path=obj_track_path,
219
+ hand_path=hand_path,
220
+ total_frames=infer_config.total_frames,
221
+ )
222
+ # Generate video
223
+ _video_frames = run(infer_config, image_pixels, pose_pixels, obj_pixels, obj_track_pixels, hand_pixels)
224
+ temp_path = save_to_mp4(_video_frames, fps=infer_config.fps)
225
+ return temp_path
226
+
227
+
228
+ def find_file_index(target_file):
229
+ for category_id, group in IMAGE_VIDEO_MAP.items():
230
+ for file_idx, file_path in enumerate(group):
231
+ if target_file in file_path:
232
+ return category_id, file_idx
233
+ return None
234
+
235
+
236
+ def exam_result(anchor, object_exam, video_exam):
237
+ logging.info("Function entered")
238
+
239
+ filename = os.path.splitext(os.path.basename(video_exam))[0] # "hmbb_2"
240
+ prefix = filename.rsplit("_", 1)[0] # "hmbb"
241
+
242
+ file_idx = int(filename.split("_")[-1])-1 # 2
243
+ selected_state=OBJECT_INDEX_MAP[prefix]
244
+
245
+ video1 = IMAGE_VIDEO_MAP[selected_state][0]
246
+ video2 = IMAGE_VIDEO_MAP[selected_state][1]
247
+ out = OUTPUT_PATH_MAP[prefix]
248
+
249
+ return (
250
+ gr.update(value=video1), # video_preview1
251
+ gr.update(value=video2), # video_preview2
252
+ gr.update(variant="primary" if file_idx == 0 else "secondary"), # btn1
253
+ gr.update(variant="secondary" if file_idx == 0 else "primary"), # btn2
254
+ selected_state,
255
+ file_idx,
256
+ out
257
+ )
258
+ # Create Gradio interface
259
+ with gr.Blocks(title="AnchorCrafter", theme=gr.themes.Soft(), css=css) as demo:
260
+ selected_state = gr.State(0)
261
+ video_state = gr.State(0)
262
+ gr.Markdown("# AnchorCrafter: Animate Cyber-Anchors Selling Your Products via Human-Object Interacting Video Generation")
263
+ top_description = gr.HTML(f'''
264
+ <div class="text-container">
265
+ <h2>To reduce inference time, we set the generated video to 28 frames, which takes approximately 5 minutes on Nvidia L4.</h2>
266
+ <p>If you require long video processing, please copy or download this space to run it on a private GPU and modify the config/test.yaml file accordingly.</p>
267
+ </div>
268
+
269
+ ''', elem_id="top_description")
270
+ with gr.Row():
271
+ with gr.Column(scale=2):
272
+ gr.Markdown("## 1. Choose Object")
273
+ gallery = gr.Gallery(value=EXAMPLE_IMAGES, label="objects", columns=3, height=320, object_fit="contain")
274
+ gr.Markdown("## 3. Anchor Image")
275
+ anchor = gr.Image(label="anchor", image_mode="RGB", height=380, width=250, sources="upload")
276
+ with gr.Row():
277
+ clear_btn3 = gr.Button("🧹 Clear")
278
+ run_btn4 = gr.Button("🚀 Run")
279
+ with gr.Column(scale=3):
280
+ gr.Markdown("## 2. Control Video")
281
+ with gr.Row():
282
+ video_preview1 = gr.Video(label="video 1", height=260)
283
+ video_preview2 = gr.Video(label="video 2", height=260)
284
+ with gr.Row():
285
+ btn1 = gr.Button("choose video 1", variant="secondary")
286
+ btn2 = gr.Button("choose video 2", variant="secondary")
287
+ gr.Markdown("## 4. Results")
288
+ video_display = gr.Video(label="results", height=380)
289
+ video_exam= gr.Video(label="Control Video",visible=False)
290
+ object_exam = gr.Image(label="Object", visible=False)
291
+ examples = gr.Examples(
292
+ examples=[
293
+ ["data/anchor/1.jpg", "data/object/hmbb_1.jpg", "data/video/hmbb_2.mp4"],
294
+ ["data/anchor/2.jpg", "data/object/earphone_1.jpg", "data/video/earphone_1.mp4"],
295
+ ["data/anchor/3.jpg", "data/object/cup_1.jpg", "data/video/cup_2.mp4"],
296
+ ["data/anchor/4.jpg", "data/object/mouse_1.jpg", "data/video/mouse_1.mp4"],
297
+ ["data/anchor/5.jpg", "data/object/cheese_1.jpg", "data/video/cheese_2.mp4"],
298
+ ],
299
+ fn=exam_result,
300
+ run_on_click=True,
301
+ cache_examples=False,
302
+ inputs=[anchor, object_exam, video_exam],
303
+ outputs=[video_preview1,video_preview2,btn1, btn2, selected_state, video_state, video_display])
304
+
305
+ gallery.select(
306
+ update_video_choices,
307
+ inputs=[selected_state],
308
+ outputs=[video_preview1, video_preview2, selected_state]
309
+ )
310
+
311
+ btn1.click(
312
+ select_button1,
313
+ inputs=[selected_state, video_state],
314
+ outputs=[btn1, btn2, video_state, selected_state]
315
+ )
316
+
317
+ btn2.click(
318
+ select_button2,
319
+ inputs=[selected_state, video_state],
320
+ outputs=[btn1, btn2, video_state, selected_state]
321
+ )
322
+
323
+ clear_btn3.click(clear_anchor, outputs=[anchor])
324
+ run_btn4.click(
325
+ pre,
326
+ inputs=[selected_state, video_state, anchor],
327
+ outputs=[video_display]
328
+ )
329
+ # Initialize model in background
330
+ Thread(target=load_model, daemon=True).start()
331
+ if __name__ == "__main__":
332
+ demo.launch(server_name="0.0.0.0", server_port=7860)
config/test.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # base svd model path
2
+ base_model_path: stabilityai/stable-video-diffusion-img2vid-xt
3
+ dino_path: facebook/dinov2-large
4
+ anchorcrafter_path: cangcz/test
5
+ fps: 11
6
+ total_frames: 28 # The final length of the generated video
7
+ test_case:
8
+ - num_frames: 15
9
+ resolution: 576
10
+ frames_overlap: 5
11
+ num_inference_steps: 30
12
+ noise_aug_strength: 0
13
+ guidance_scale: 4.0
14
+ sample_stride: 4
15
+ seed: 42
16
+
17
+
constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # w/h apsect ratio
2
+ #ASPECT_RATIO = 2 / 3 # 512*768
3
+ ASPECT_RATIO = 9 / 16 # 576*1024
4
+
data/anchor/1.jpg ADDED
data/anchor/2.jpg ADDED
data/anchor/3.jpg ADDED

Git LFS Details

  • SHA256: 2dcb46014ee99f97e09ac8a7ba55178ea7c53598d6d091c7bf7f264c8d009bf2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.99 MB
data/anchor/4.jpg ADDED

Git LFS Details

  • SHA256: 7a2b0baaec289c488e8735482e2a4f6d34a7dfb80e476a5d81a8dd51cfd885ac
  • Pointer size: 131 Bytes
  • Size of remote file: 539 kB
data/anchor/5.jpg ADDED

Git LFS Details

  • SHA256: 316c3450e6152bf8131c1d84ddd7df0e1153e2137d3f46b2814e71f6e8f9618f
  • Pointer size: 131 Bytes
  • Size of remote file: 357 kB
data/depth_cut/cheese_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef63934b250bbb51d83a241ddc8cd03d2f3dd61b06a53cf5c3eec7d28bfea393
3
+ size 1832420
data/depth_cut/cheese_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:173c9cea9609806de733c000405d8380cb99dcf8ca0c0b25726ae308d81bd8e7
3
+ size 1718113
data/depth_cut/cup_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ad2e78eb58bcc21e6ee7c74d272c6053641269ea88e046f6e0c8fee63dee1ff
3
+ size 1802707
data/depth_cut/cup_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03c2c1206c62381f4fa09e635f5d63bb834a9c007d07f79825b567da46e786f5
3
+ size 2013536
data/depth_cut/earphone_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe45d51fe485ff8402bb37fb62e4c32fc892644b7828e70056d92336e9400de3
3
+ size 1075588
data/depth_cut/earphone_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de88bc15419cf5813eec169fcd950a70637b85c6d7616e3ff5f1d1797f3c9681
3
+ size 1528559
data/depth_cut/hmbb_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00ccdb08c58da95ef5aa8e1ad97583f6e2daec32a13921ce7664b6e1e44aa49e
3
+ size 2361117
data/depth_cut/hmbb_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbc4db207fb5c163ab99594eb93499a70e1858c5e4f7d7133c663ac5b05c0704
3
+ size 4774735
data/depth_cut/mouse_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1cac2f4ee52e73fafbf14cc9411be564750978343ecfa175b2c59d70f592905
3
+ size 1618354
data/depth_cut/mouse_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5df30727ec4b8d1d0ea8532e95a72807ef2174107026e6f063182941acadc44b
3
+ size 1384678
data/hand_cut/cheese_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e502009545f2ed2fdf0815cd96a6c4354aa8100e5ac5cfe84323c3b56533f661
3
+ size 2962933
data/hand_cut/cheese_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60b16aa66a5b6d79ad5722e6570aa5379c2be54cdb47c5cec264568931b94a25
3
+ size 2668692
data/hand_cut/cup_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:729a51c840377d3634e008946eca0213d88c5ce388b76b7863cc5e6f6594946a
3
+ size 2412734
data/hand_cut/cup_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cb95d1d1001237da04eac68c5439a5a00da548097016f975536271028f10dc7
3
+ size 2707637
data/hand_cut/earphone_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d576867e34f29a55a90ed9aa0e7027530ec6c6b39aed28e681f719ad62169df9
3
+ size 1969306
data/hand_cut/earphone_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a97a4b66693b5562d99b2d9dec078aad7e362b6b92bd9623f898b787547b51bb
3
+ size 3502771
data/hand_cut/hmbb_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d5d40a92e5bb62ea8a3240bbb2e17b01e23b63b5047891f14ba729eeeb28b9b
3
+ size 1875012
data/hand_cut/hmbb_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84150bd20ec8f19a31cde8b4c814a938aaf48140635ae32496fc7f27be2a3834
3
+ size 3162317
data/hand_cut/mouse_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:816f0525711efe17bced6554efb32ebc28924cf1710546e2af39382fe3d915e8
3
+ size 2183209