DeepBeepMeep commited on
Commit
562fe58
·
1 Parent(s): c309a1f

Added Vace Inpainting Support and Create a Mask inside WanGP

Browse files
Files changed (49) hide show
  1. preprocessing/matanyone/__init__.py +0 -0
  2. preprocessing/matanyone/app.py +656 -0
  3. preprocessing/matanyone/matanyone/config/__init__.py +0 -0
  4. preprocessing/matanyone/matanyone/config/eval_matanyone_config.yaml +47 -0
  5. preprocessing/matanyone/matanyone/config/hydra/job_logging/custom-no-rank.yaml +22 -0
  6. preprocessing/matanyone/matanyone/config/hydra/job_logging/custom.yaml +22 -0
  7. preprocessing/matanyone/matanyone/config/model/base.yaml +58 -0
  8. preprocessing/matanyone/matanyone/inference/__init__.py +0 -0
  9. preprocessing/matanyone/matanyone/inference/image_feature_store.py +56 -0
  10. preprocessing/matanyone/matanyone/inference/inference_core.py +406 -0
  11. preprocessing/matanyone/matanyone/inference/kv_memory_store.py +348 -0
  12. preprocessing/matanyone/matanyone/inference/memory_manager.py +453 -0
  13. preprocessing/matanyone/matanyone/inference/object_info.py +24 -0
  14. preprocessing/matanyone/matanyone/inference/object_manager.py +149 -0
  15. preprocessing/matanyone/matanyone/inference/utils/__init__.py +0 -0
  16. preprocessing/matanyone/matanyone/inference/utils/args_utils.py +30 -0
  17. preprocessing/matanyone/matanyone/model/__init__.py +0 -0
  18. preprocessing/matanyone/matanyone/model/aux_modules.py +93 -0
  19. preprocessing/matanyone/matanyone/model/big_modules.py +365 -0
  20. preprocessing/matanyone/matanyone/model/channel_attn.py +39 -0
  21. preprocessing/matanyone/matanyone/model/group_modules.py +126 -0
  22. preprocessing/matanyone/matanyone/model/matanyone.py +333 -0
  23. preprocessing/matanyone/matanyone/model/modules.py +149 -0
  24. preprocessing/matanyone/matanyone/model/transformer/__init__.py +0 -0
  25. preprocessing/matanyone/matanyone/model/transformer/object_summarizer.py +89 -0
  26. preprocessing/matanyone/matanyone/model/transformer/object_transformer.py +206 -0
  27. preprocessing/matanyone/matanyone/model/transformer/positional_encoding.py +108 -0
  28. preprocessing/matanyone/matanyone/model/transformer/transformer_layers.py +161 -0
  29. preprocessing/matanyone/matanyone/model/utils/__init__.py +0 -0
  30. preprocessing/matanyone/matanyone/model/utils/memory_utils.py +107 -0
  31. preprocessing/matanyone/matanyone/model/utils/parameter_groups.py +72 -0
  32. preprocessing/matanyone/matanyone/model/utils/resnet.py +179 -0
  33. preprocessing/matanyone/matanyone_wrapper.py +73 -0
  34. preprocessing/matanyone/tools/__init__.py +0 -0
  35. preprocessing/matanyone/tools/base_segmenter.py +141 -0
  36. preprocessing/matanyone/tools/download_util.py +109 -0
  37. preprocessing/matanyone/tools/interact_tools.py +99 -0
  38. preprocessing/matanyone/tools/mask_painter.py +288 -0
  39. preprocessing/matanyone/tools/misc.py +131 -0
  40. preprocessing/matanyone/tools/painter.py +215 -0
  41. preprocessing/matanyone/utils/__init__.py +0 -0
  42. preprocessing/matanyone/utils/get_default_model.py +27 -0
  43. preprocessing/matanyone/utils/tensor_utils.py +62 -0
  44. requirements.txt +2 -0
  45. wan/modules/model.py +2 -47
  46. wan/text2video.py +6 -1
  47. wan/utils/utils.py +2 -3
  48. wan/utils/vace_preprocessor.py +1 -1
  49. wgp.py +122 -47
preprocessing/matanyone/__init__.py ADDED
File without changes
preprocessing/matanyone/app.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import os
4
+ import json
5
+ import time
6
+ import psutil
7
+ import ffmpeg
8
+ import imageio
9
+ from PIL import Image
10
+
11
+ import cv2
12
+ import torch
13
+ import numpy as np
14
+ import gradio as gr
15
+ from .tools.painter import mask_painter
16
+ from .tools.interact_tools import SamControler
17
+ from .tools.misc import get_device
18
+ from .tools.download_util import load_file_from_url
19
+
20
+ from .utils.get_default_model import get_matanyone_model
21
+ from .matanyone.inference.inference_core import InferenceCore
22
+ from .matanyone_wrapper import matanyone
23
+
24
+ arg_device = "cuda"
25
+ arg_sam_model_type="vit_h"
26
+ arg_mask_save = False
27
+ model = None
28
+ matanyone_model = None
29
+
30
+ # SAM generator
31
+ class MaskGenerator():
32
+ def __init__(self, sam_checkpoint, device):
33
+ global args_device
34
+ args_device = device
35
+ self.samcontroler = SamControler(sam_checkpoint, arg_sam_model_type, arg_device)
36
+
37
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
38
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
39
+ return mask, logit, painted_image
40
+
41
+ # convert points input to prompt state
42
+ def get_prompt(click_state, click_input):
43
+ inputs = json.loads(click_input)
44
+ points = click_state[0]
45
+ labels = click_state[1]
46
+ for input in inputs:
47
+ points.append(input[:2])
48
+ labels.append(input[2])
49
+ click_state[0] = points
50
+ click_state[1] = labels
51
+ prompt = {
52
+ "prompt_type":["click"],
53
+ "input_point":click_state[0],
54
+ "input_label":click_state[1],
55
+ "multimask_output":"True",
56
+ }
57
+ return prompt
58
+
59
+ def get_frames_from_image(image_input, image_state):
60
+ """
61
+ Args:
62
+ video_path:str
63
+ timestamp:float64
64
+ Return
65
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
66
+ """
67
+
68
+ user_name = time.time()
69
+ frames = [image_input] * 2 # hardcode: mimic a video with 2 frames
70
+ image_size = (frames[0].shape[0],frames[0].shape[1])
71
+ # initialize video_state
72
+ image_state = {
73
+ "user_name": user_name,
74
+ "image_name": "output.png",
75
+ "origin_images": frames,
76
+ "painted_images": frames.copy(),
77
+ "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
78
+ "logits": [None]*len(frames),
79
+ "select_frame_number": 0,
80
+ "last_frame_numer": 0,
81
+ "fps": None
82
+ }
83
+ image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
84
+ model.samcontroler.sam_controler.reset_image()
85
+ model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
86
+ return image_state, image_info, image_state["origin_images"][0], \
87
+ gr.update(visible=True, maximum=10, value=10), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
88
+ gr.update(visible=True), gr.update(visible=True), \
89
+ gr.update(visible=True), gr.update(visible=True),\
90
+ gr.update(visible=True), gr.update(visible=True), \
91
+ gr.update(visible=True), gr.update(visible=False), \
92
+ gr.update(visible=False), gr.update(visible=True), \
93
+ gr.update(visible=True)
94
+
95
+ # extract frames from upload video
96
+ def get_frames_from_video(video_input, video_state):
97
+ """
98
+ Args:
99
+ video_path:str
100
+ timestamp:float64
101
+ Return
102
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
103
+ """
104
+
105
+ while model == None:
106
+ time.sleep(1)
107
+
108
+ video_path = video_input
109
+ frames = []
110
+ user_name = time.time()
111
+
112
+ # extract Audio
113
+ # try:
114
+ # audio_path = video_input.replace(".mp4", "_audio.wav")
115
+ # ffmpeg.input(video_path).output(audio_path, format='wav', acodec='pcm_s16le', ac=2, ar='44100').run(overwrite_output=True, quiet=True)
116
+ # except Exception as e:
117
+ # print(f"Audio extraction error: {str(e)}")
118
+ # audio_path = "" # Set to "" if extraction fails
119
+ # print(f'audio_path: {audio_path}')
120
+ audio_path = ""
121
+ # extract frames
122
+ try:
123
+ cap = cv2.VideoCapture(video_path)
124
+ fps = cap.get(cv2.CAP_PROP_FPS)
125
+ while cap.isOpened():
126
+ ret, frame = cap.read()
127
+ if ret == True:
128
+ current_memory_usage = psutil.virtual_memory().percent
129
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
130
+ if current_memory_usage > 90:
131
+ break
132
+ else:
133
+ break
134
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
135
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
136
+ image_size = (frames[0].shape[0],frames[0].shape[1])
137
+
138
+ # resize if resolution too big
139
+ if image_size[0]>=1280 and image_size[0]>=1280:
140
+ scale = 1080 / min(image_size)
141
+ new_w = int(image_size[1] * scale)
142
+ new_h = int(image_size[0] * scale)
143
+ # update frames
144
+ frames = [cv2.resize(f, (new_w, new_h), interpolation=cv2.INTER_AREA) for f in frames]
145
+ # update image_size
146
+ image_size = (frames[0].shape[0],frames[0].shape[1])
147
+
148
+ # initialize video_state
149
+ video_state = {
150
+ "user_name": user_name,
151
+ "video_name": os.path.split(video_path)[-1],
152
+ "origin_images": frames,
153
+ "painted_images": frames.copy(),
154
+ "masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames),
155
+ "logits": [None]*len(frames),
156
+ "select_frame_number": 0,
157
+ "last_frame_number": 0,
158
+ "fps": fps,
159
+ "audio": audio_path
160
+ }
161
+ video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
162
+ model.samcontroler.sam_controler.reset_image()
163
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
164
+ return video_state, video_info, video_state["origin_images"][0], \
165
+ gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=False, maximum=len(frames), value=len(frames)), \
166
+ gr.update(visible=True), gr.update(visible=True), \
167
+ gr.update(visible=True), gr.update(visible=True),\
168
+ gr.update(visible=True), gr.update(visible=True), \
169
+ gr.update(visible=True), gr.update(visible=False), \
170
+ gr.update(visible=False), gr.update(visible=True), \
171
+ gr.update(visible=True)
172
+
173
+ # get the select frame from gradio slider
174
+ def select_video_template(image_selection_slider, video_state, interactive_state):
175
+
176
+ image_selection_slider -= 1
177
+ video_state["select_frame_number"] = image_selection_slider
178
+
179
+ # once select a new template frame, set the image in sam
180
+ model.samcontroler.sam_controler.reset_image()
181
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
182
+
183
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state
184
+
185
+ def select_image_template(image_selection_slider, video_state, interactive_state):
186
+
187
+ image_selection_slider = 0 # fixed for image
188
+ video_state["select_frame_number"] = image_selection_slider
189
+
190
+ # once select a new template frame, set the image in sam
191
+ model.samcontroler.sam_controler.reset_image()
192
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
193
+
194
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state
195
+
196
+ # set the tracking end frame
197
+ def get_end_number(track_pause_number_slider, video_state, interactive_state):
198
+ interactive_state["track_end_number"] = track_pause_number_slider
199
+
200
+ return video_state["painted_images"][track_pause_number_slider],interactive_state
201
+
202
+ # use sam to get the mask
203
+ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData ): #
204
+ """
205
+ Args:
206
+ template_frame: PIL.Image
207
+ point_prompt: flag for positive or negative button click
208
+ click_state: [[points], [labels]]
209
+ """
210
+ if point_prompt == "Positive":
211
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
212
+ interactive_state["positive_click_times"] += 1
213
+ else:
214
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
215
+ interactive_state["negative_click_times"] += 1
216
+
217
+ # prompt for sam model
218
+ model.samcontroler.sam_controler.reset_image()
219
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
220
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
221
+
222
+ mask, logit, painted_image = model.first_frame_click(
223
+ image=video_state["origin_images"][video_state["select_frame_number"]],
224
+ points=np.array(prompt["input_point"]),
225
+ labels=np.array(prompt["input_label"]),
226
+ multimask=prompt["multimask_output"],
227
+ )
228
+ video_state["masks"][video_state["select_frame_number"]] = mask
229
+ video_state["logits"][video_state["select_frame_number"]] = logit
230
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
231
+
232
+ return painted_image, video_state, interactive_state
233
+
234
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
235
+ mask = video_state["masks"][video_state["select_frame_number"]]
236
+ interactive_state["multi_mask"]["masks"].append(mask)
237
+ interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
238
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
239
+ select_frame = show_mask(video_state, interactive_state, mask_dropdown)
240
+
241
+ return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
242
+
243
+ def clear_click(video_state, click_state):
244
+ click_state = [[],[]]
245
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
246
+ return template_frame, click_state
247
+
248
+ def remove_multi_mask(interactive_state, mask_dropdown):
249
+ interactive_state["multi_mask"]["mask_names"]= []
250
+ interactive_state["multi_mask"]["masks"] = []
251
+
252
+ return interactive_state, gr.update(choices=[],value=[])
253
+
254
+ def show_mask(video_state, interactive_state, mask_dropdown):
255
+ mask_dropdown.sort()
256
+ if video_state["origin_images"]:
257
+ select_frame = video_state["origin_images"][video_state["select_frame_number"]]
258
+ for i in range(len(mask_dropdown)):
259
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
260
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
261
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
262
+
263
+ return select_frame
264
+
265
+
266
+ def save_video(frames, output_path, fps):
267
+
268
+ writer = imageio.get_writer( output_path, fps=fps, codec='libx264', quality=8)
269
+ for frame in frames:
270
+ writer.append_data(frame)
271
+ writer.close()
272
+
273
+ return output_path
274
+
275
+ # video matting
276
+ def video_matting(video_state, end_slider, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size):
277
+ matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
278
+ # if interactive_state["track_end_number"]:
279
+ # following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
280
+ # else:
281
+ end_slider = max(video_state["select_frame_number"] +1, end_slider)
282
+ following_frames = video_state["origin_images"][video_state["select_frame_number"]: end_slider]
283
+
284
+ if interactive_state["multi_mask"]["masks"]:
285
+ if len(mask_dropdown) == 0:
286
+ mask_dropdown = ["mask_001"]
287
+ mask_dropdown.sort()
288
+ template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
289
+ for i in range(1,len(mask_dropdown)):
290
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
291
+ template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
292
+ video_state["masks"][video_state["select_frame_number"]]= template_mask
293
+ else:
294
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
295
+ fps = video_state["fps"]
296
+
297
+ audio_path = video_state["audio"]
298
+
299
+ # operation error
300
+ if len(np.unique(template_mask))==1:
301
+ template_mask[0][0]=1
302
+ foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
303
+ output_frames = []
304
+ for frame_origin, frame_alpha in zip(following_frames, alpha):
305
+ frame_alpha[frame_alpha > 127] = 255
306
+ frame_alpha[frame_alpha <= 127] = 0
307
+ output_frame = np.bitwise_and(frame_origin, 255-frame_alpha)
308
+ frame_grey = frame_alpha.copy()
309
+ frame_grey[frame_alpha == 255] = 127
310
+ output_frame += frame_grey
311
+ output_frames.append(output_frame)
312
+ foreground = output_frames
313
+
314
+ foreground_output = save_video(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps)
315
+ # foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video
316
+ alpha_output = save_video(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps)
317
+ # alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video
318
+
319
+ return foreground_output, alpha_output
320
+
321
+
322
+ def add_audio_to_video(video_path, audio_path, output_path):
323
+ try:
324
+ video_input = ffmpeg.input(video_path)
325
+ audio_input = ffmpeg.input(audio_path)
326
+
327
+ _ = (
328
+ ffmpeg
329
+ .output(video_input, audio_input, output_path, vcodec="copy", acodec="aac")
330
+ .run(overwrite_output=True, capture_stdout=True, capture_stderr=True)
331
+ )
332
+ return output_path
333
+ except ffmpeg.Error as e:
334
+ print(f"FFmpeg error:\n{e.stderr.decode()}")
335
+ return None
336
+
337
+
338
+ def generate_video_from_frames(frames, output_path, fps=30, gray2rgb=False, audio_path=""):
339
+ """
340
+ Generates a video from a list of frames.
341
+
342
+ Args:
343
+ frames (list of numpy arrays): The frames to include in the video.
344
+ output_path (str): The path to save the generated video.
345
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
346
+ """
347
+ frames = torch.from_numpy(np.asarray(frames))
348
+ _, h, w, _ = frames.shape
349
+ if gray2rgb:
350
+ frames = np.repeat(frames, 3, axis=3)
351
+
352
+ if not os.path.exists(os.path.dirname(output_path)):
353
+ os.makedirs(os.path.dirname(output_path))
354
+ video_temp_path = output_path.replace(".mp4", "_temp.mp4")
355
+
356
+ # resize back to ensure input resolution
357
+ imageio.mimwrite(video_temp_path, frames, fps=fps, quality=7,
358
+ codec='libx264', ffmpeg_params=["-vf", f"scale={w}:{h}"])
359
+
360
+ # add audio to video if audio path exists
361
+ if audio_path != "" and os.path.exists(audio_path):
362
+ output_path = add_audio_to_video(video_temp_path, audio_path, output_path)
363
+ os.remove(video_temp_path)
364
+ return output_path
365
+ else:
366
+ return video_temp_path
367
+
368
+ # reset all states for a new input
369
+ def restart():
370
+ return {
371
+ "user_name": "",
372
+ "video_name": "",
373
+ "origin_images": None,
374
+ "painted_images": None,
375
+ "masks": None,
376
+ "inpaint_masks": None,
377
+ "logits": None,
378
+ "select_frame_number": 0,
379
+ "fps": 30
380
+ }, {
381
+ "inference_times": 0,
382
+ "negative_click_times" : 0,
383
+ "positive_click_times": 0,
384
+ "mask_save": arg_mask_save,
385
+ "multi_mask": {
386
+ "mask_names": [],
387
+ "masks": []
388
+ },
389
+ "track_end_number": None,
390
+ }, [[],[]], None, None, \
391
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\
392
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
393
+ gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
394
+ gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
395
+
396
+ def load_unload_models(selected):
397
+ global model
398
+ global matanyone_model
399
+ if selected:
400
+ # args, defined in track_anything.py
401
+ sam_checkpoint_url_dict = {
402
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
403
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
404
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
405
+ }
406
+ # os.path.join('.')
407
+
408
+ from mmgp import offload
409
+
410
+ # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
411
+ sam_checkpoint = None
412
+
413
+ transfer_stream = torch.cuda.Stream()
414
+ with torch.cuda.stream(transfer_stream):
415
+ # initialize sams
416
+ model = MaskGenerator(sam_checkpoint, "cuda")
417
+ from .matanyone.model.matanyone import MatAnyone
418
+ matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
419
+ # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
420
+ # offload.profile(pipe)
421
+ matanyone_model = matanyone_model.to(arg_device).eval()
422
+ matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
423
+ else:
424
+ import gc
425
+ model = None
426
+ matanyone_model = None
427
+ gc.collect()
428
+ torch.cuda.empty_cache()
429
+
430
+
431
+ def get_vmc_event_handler():
432
+ return load_unload_models
433
+
434
+ def export_to_vace_video_input(foreground_video_output):
435
+ gr.Info("Masked Video Input transferred to Vace For Inpainting")
436
+ return "V#" + str(time.time()), foreground_video_output
437
+
438
+ def export_to_vace_video_mask(foreground_video_output, alpha_video_output):
439
+ gr.Info("Masked Video Input and Full Mask transferred to Vace For Stronger Inpainting")
440
+ return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
441
+
442
+ def display(vace_video_input, vace_video_mask, video_prompt_video_guide_trigger):
443
+ # my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
444
+
445
+ media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/"
446
+
447
+ # download assets
448
+
449
+ gr.Markdown("Mast Edition is provided by MatAnyone")
450
+
451
+ with gr.Column( visible=True):
452
+ with gr.Row():
453
+ with gr.Accordion("Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
454
+ with gr.Row():
455
+ with gr.Column():
456
+ gr.Markdown("### Case 1: Single Target")
457
+ gr.Video(value="preprocessing/matanyone/tutorial_single_target.mp4", elem_classes="video")
458
+
459
+ with gr.Column():
460
+ gr.Markdown("### Case 2: Multiple Targets")
461
+ gr.Video(value="preprocessing/matanyone/tutorial_multi_targets.mp4", elem_classes="video")
462
+
463
+
464
+ click_state = gr.State([[],[]])
465
+
466
+ interactive_state = gr.State({
467
+ "inference_times": 0,
468
+ "negative_click_times" : 0,
469
+ "positive_click_times": 0,
470
+ "mask_save": arg_mask_save,
471
+ "multi_mask": {
472
+ "mask_names": [],
473
+ "masks": []
474
+ },
475
+ "track_end_number": None,
476
+ }
477
+ )
478
+
479
+ video_state = gr.State(
480
+ {
481
+ "user_name": "",
482
+ "video_name": "",
483
+ "origin_images": None,
484
+ "painted_images": None,
485
+ "masks": None,
486
+ "inpaint_masks": None,
487
+ "logits": None,
488
+ "select_frame_number": 0,
489
+ "fps": 16,
490
+ "audio": "",
491
+ }
492
+ )
493
+
494
+ with gr.Column( visible=True):
495
+ with gr.Row():
496
+ with gr.Accordion('MatAnyone Settings (click to expand)', open=False):
497
+ with gr.Row():
498
+ erode_kernel_size = gr.Slider(label='Erode Kernel Size',
499
+ minimum=0,
500
+ maximum=30,
501
+ step=1,
502
+ value=10,
503
+ info="Erosion on the added mask",
504
+ interactive=True)
505
+ dilate_kernel_size = gr.Slider(label='Dilate Kernel Size',
506
+ minimum=0,
507
+ maximum=30,
508
+ step=1,
509
+ value=10,
510
+ info="Dilation on the added mask",
511
+ interactive=True)
512
+
513
+ with gr.Row():
514
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Start Frame", info="Choose the start frame for target assignment and video matting", visible=False)
515
+ end_selection_slider = gr.Slider(minimum=1, maximum=300, step=1, value=81, label="Last Frame to Process", info="Last Frame to Process", visible=False)
516
+
517
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="End frame", visible=False)
518
+ with gr.Row():
519
+ point_prompt = gr.Radio(
520
+ choices=["Positive", "Negative"],
521
+ value="Positive",
522
+ label="Point Prompt",
523
+ info="Click to add positive or negative point for target mask",
524
+ interactive=True,
525
+ visible=False,
526
+ min_width=100,
527
+ scale=1)
528
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2", visible=False)
529
+
530
+ gr.Markdown("---")
531
+
532
+ with gr.Column():
533
+ # input video
534
+ with gr.Row(equal_height=True):
535
+ with gr.Column(scale=2):
536
+ gr.Markdown("## Step1: Upload video")
537
+ with gr.Column(scale=2):
538
+ step2_title = gr.Markdown("## Step2: Add masks <small>(Several clicks then **`Add Mask`** <u>one by one</u>)</small>", visible=False)
539
+ with gr.Row(equal_height=True):
540
+ with gr.Column(scale=2):
541
+ video_input = gr.Video(label="Input Video", elem_classes="video")
542
+ extract_frames_button = gr.Button(value="Load Video", interactive=True, elem_classes="new_button")
543
+ with gr.Column(scale=2):
544
+ video_info = gr.Textbox(label="Video Info", visible=False)
545
+ template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
546
+ with gr.Row():
547
+ clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
548
+ add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
549
+ remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
550
+ matting_button = gr.Button(value="Video Matting", interactive=True, visible=False, min_width=100)
551
+ with gr.Row():
552
+ gr.Markdown("")
553
+
554
+ # output video
555
+ with gr.Row(equal_height=True) as output_row:
556
+ with gr.Column(scale=2):
557
+ foreground_video_output = gr.Video(label="Masked Video Output", visible=False, elem_classes="video")
558
+ foreground_output_button = gr.Button(value="Black & White Video Output", visible=False, elem_classes="new_button")
559
+ export_to_vace_video_input_btn = gr.Button("Export to Vace Video Input Video For Inpainting")
560
+ with gr.Column(scale=2):
561
+ alpha_video_output = gr.Video(label="B & W Mask Video Output", visible=False, elem_classes="video")
562
+ alpha_output_button = gr.Button(value="Alpha Mask Output", visible=False, elem_classes="new_button")
563
+ export_to_vace_video_mask_btn = gr.Button("Export to Vace Video Input and Video Mask for stronger Inpainting")
564
+
565
+ export_to_vace_video_input_btn.click(fn=export_to_vace_video_input, inputs= [foreground_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input])
566
+ export_to_vace_video_mask_btn.click(fn=export_to_vace_video_mask, inputs= [foreground_video_output, alpha_video_output], outputs= [video_prompt_video_guide_trigger, vace_video_input, vace_video_mask])
567
+ # first step: get the video information
568
+ extract_frames_button.click(
569
+ fn=get_frames_from_video,
570
+ inputs=[
571
+ video_input, video_state
572
+ ],
573
+ outputs=[video_state, video_info, template_frame,
574
+ image_selection_slider, end_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, add_mask_button, matting_button, template_frame,
575
+ foreground_video_output, alpha_video_output, foreground_output_button, alpha_output_button, mask_dropdown, step2_title]
576
+ )
577
+
578
+ # second step: select images from slider
579
+ image_selection_slider.release(fn=select_video_template,
580
+ inputs=[image_selection_slider, video_state, interactive_state],
581
+ outputs=[template_frame, video_state, interactive_state], api_name="select_image")
582
+ track_pause_number_slider.release(fn=get_end_number,
583
+ inputs=[track_pause_number_slider, video_state, interactive_state],
584
+ outputs=[template_frame, interactive_state], api_name="end_image")
585
+
586
+ # click select image to get mask using sam
587
+ template_frame.select(
588
+ fn=sam_refine,
589
+ inputs=[video_state, point_prompt, click_state, interactive_state],
590
+ outputs=[template_frame, video_state, interactive_state]
591
+ )
592
+
593
+ # add different mask
594
+ add_mask_button.click(
595
+ fn=add_multi_mask,
596
+ inputs=[video_state, interactive_state, mask_dropdown],
597
+ outputs=[interactive_state, mask_dropdown, template_frame, click_state]
598
+ )
599
+
600
+ remove_mask_button.click(
601
+ fn=remove_multi_mask,
602
+ inputs=[interactive_state, mask_dropdown],
603
+ outputs=[interactive_state, mask_dropdown]
604
+ )
605
+
606
+ # video matting
607
+ matting_button.click(
608
+ fn=video_matting,
609
+ inputs=[video_state, end_selection_slider, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size],
610
+ outputs=[foreground_video_output, alpha_video_output]
611
+ )
612
+
613
+ # click to get mask
614
+ mask_dropdown.change(
615
+ fn=show_mask,
616
+ inputs=[video_state, interactive_state, mask_dropdown],
617
+ outputs=[template_frame]
618
+ )
619
+
620
+ # clear input
621
+ video_input.change(
622
+ fn=restart,
623
+ inputs=[],
624
+ outputs=[
625
+ video_state,
626
+ interactive_state,
627
+ click_state,
628
+ foreground_video_output, alpha_video_output,
629
+ template_frame,
630
+ image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
631
+ add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
632
+ ],
633
+ queue=False,
634
+ show_progress=False)
635
+
636
+ video_input.clear(
637
+ fn=restart,
638
+ inputs=[],
639
+ outputs=[
640
+ video_state,
641
+ interactive_state,
642
+ click_state,
643
+ foreground_video_output, alpha_video_output,
644
+ template_frame,
645
+ image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click,
646
+ add_mask_button, matting_button, template_frame, foreground_video_output, alpha_video_output, remove_mask_button, foreground_output_button, alpha_output_button, mask_dropdown, video_info, step2_title
647
+ ],
648
+ queue=False,
649
+ show_progress=False)
650
+
651
+ # points clear
652
+ clear_button_click.click(
653
+ fn = clear_click,
654
+ inputs = [video_state, click_state,],
655
+ outputs = [template_frame,click_state],
656
+ )
preprocessing/matanyone/matanyone/config/__init__.py ADDED
File without changes
preprocessing/matanyone/matanyone/config/eval_matanyone_config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - model: base
4
+ - override hydra/job_logging: custom-no-rank.yaml
5
+
6
+ hydra:
7
+ run:
8
+ dir: ../output/${exp_id}/${dataset}
9
+ output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
10
+
11
+ amp: False
12
+ weights: pretrained_models/matanyone.pth # default (can be modified from outside)
13
+ output_dir: null # defaults to run_dir; specify this to override
14
+ flip_aug: False
15
+
16
+
17
+ # maximum shortest side of the input; -1 means no resizing
18
+ # With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader)
19
+ # this parameter is added for the sole purpose for the GUI in the current codebase
20
+ # InferenceCore will downsize the input and restore the output to the original size if needed
21
+ # if you are using this code for some other project, you can also utilize this parameter
22
+ max_internal_size: -1
23
+
24
+ # these parameters, when set, override the dataset's default; useful for debugging
25
+ save_all: True
26
+ use_all_masks: False
27
+ use_long_term: False
28
+ mem_every: 5
29
+
30
+ # only relevant when long_term is not enabled
31
+ max_mem_frames: 5
32
+
33
+ # only relevant when long_term is enabled
34
+ long_term:
35
+ count_usage: True
36
+ max_mem_frames: 10
37
+ min_mem_frames: 5
38
+ num_prototypes: 128
39
+ max_num_tokens: 10000
40
+ buffer_tokens: 2000
41
+
42
+ top_k: 30
43
+ stagger_updates: 5
44
+ chunk_size: -1 # number of objects to process in parallel; -1 means unlimited
45
+ save_scores: False
46
+ save_aux: False
47
+ visualize: False
preprocessing/matanyone/matanyone/config/hydra/job_logging/custom-no-rank.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ handlers:
8
+ console:
9
+ class: logging.StreamHandler
10
+ formatter: simple
11
+ stream: ext://sys.stdout
12
+ file:
13
+ class: logging.FileHandler
14
+ formatter: simple
15
+ # absolute file path
16
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
17
+ mode: w
18
+ root:
19
+ level: INFO
20
+ handlers: [console, file]
21
+
22
+ disable_existing_loggers: false
preprocessing/matanyone/matanyone/config/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ handlers:
8
+ console:
9
+ class: logging.StreamHandler
10
+ formatter: simple
11
+ stream: ext://sys.stdout
12
+ file:
13
+ class: logging.FileHandler
14
+ formatter: simple
15
+ # absolute file path
16
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
17
+ mode: w
18
+ root:
19
+ level: INFO
20
+ handlers: [console, file]
21
+
22
+ disable_existing_loggers: false
preprocessing/matanyone/matanyone/config/model/base.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pixel_mean: [0.485, 0.456, 0.406]
2
+ pixel_std: [0.229, 0.224, 0.225]
3
+
4
+ pixel_dim: 256
5
+ key_dim: 64
6
+ value_dim: 256
7
+ sensory_dim: 256
8
+ embed_dim: 256
9
+
10
+ pixel_encoder:
11
+ type: resnet50
12
+ ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1
13
+
14
+ mask_encoder:
15
+ type: resnet18
16
+ final_dim: 256
17
+
18
+ pixel_pe_scale: 32
19
+ pixel_pe_temperature: 128
20
+
21
+ object_transformer:
22
+ embed_dim: ${model.embed_dim}
23
+ ff_dim: 2048
24
+ num_heads: 8
25
+ num_blocks: 3
26
+ num_queries: 16
27
+ read_from_pixel:
28
+ input_norm: False
29
+ input_add_pe: False
30
+ add_pe_to_qkv: [True, True, False]
31
+ read_from_past:
32
+ add_pe_to_qkv: [True, True, False]
33
+ read_from_memory:
34
+ add_pe_to_qkv: [True, True, False]
35
+ read_from_query:
36
+ add_pe_to_qkv: [True, True, False]
37
+ output_norm: False
38
+ query_self_attention:
39
+ add_pe_to_qkv: [True, True, False]
40
+ pixel_self_attention:
41
+ add_pe_to_qkv: [True, True, False]
42
+
43
+ object_summarizer:
44
+ embed_dim: ${model.object_transformer.embed_dim}
45
+ num_summaries: ${model.object_transformer.num_queries}
46
+ add_pe: True
47
+
48
+ aux_loss:
49
+ sensory:
50
+ enabled: True
51
+ weight: 0.01
52
+ query:
53
+ enabled: True
54
+ weight: 0.01
55
+
56
+ mask_decoder:
57
+ # first value must equal embed_dim
58
+ up_dims: [256, 128, 128, 64, 16]
preprocessing/matanyone/matanyone/inference/__init__.py ADDED
File without changes
preprocessing/matanyone/matanyone/inference/image_feature_store.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Iterable
3
+ import torch
4
+ from ..model.matanyone import MatAnyone
5
+
6
+
7
+ class ImageFeatureStore:
8
+ """
9
+ A cache for image features.
10
+ These features might be reused at different parts of the inference pipeline.
11
+ This class provide an interface for reusing these features.
12
+ It is the user's responsibility to delete redundant features.
13
+
14
+ Feature of a frame should be associated with a unique index -- typically the frame id.
15
+ """
16
+ def __init__(self, network: MatAnyone, no_warning: bool = False):
17
+ self.network = network
18
+ self._store = {}
19
+ self.no_warning = no_warning
20
+
21
+ def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None:
22
+ ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats)
23
+ key, shrinkage, selection = self.network.transform_key(ms_features[0])
24
+ self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
25
+
26
+ def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
27
+ seq_length = images.shape[0]
28
+ ms_features, pix_feat = self.network.encode_image(images, seq_length)
29
+ key, shrinkage, selection = self.network.transform_key(ms_features[0])
30
+ for index in range(seq_length):
31
+ self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0))
32
+
33
+ def get_features(self, index: int,
34
+ image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor):
35
+ if index not in self._store:
36
+ self._encode_feature(index, image, last_feats)
37
+
38
+ return self._store[index][:2]
39
+
40
+ def get_key(self, index: int,
41
+ image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
42
+ if index not in self._store:
43
+ self._encode_feature(index, image, last_feats)
44
+
45
+ return self._store[index][2:]
46
+
47
+ def delete(self, index: int) -> None:
48
+ if index in self._store:
49
+ del self._store[index]
50
+
51
+ def __len__(self):
52
+ return len(self._store)
53
+
54
+ def __del__(self):
55
+ if len(self._store) > 0 and not self.no_warning:
56
+ warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
preprocessing/matanyone/matanyone/inference/inference_core.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Iterable
2
+ import logging
3
+ from omegaconf import DictConfig
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from .memory_manager import MemoryManager
10
+ from .object_manager import ObjectManager
11
+ from .image_feature_store import ImageFeatureStore
12
+ from ..model.matanyone import MatAnyone
13
+ from ...utils.tensor_utils import pad_divide_by, unpad, aggregate
14
+
15
+ log = logging.getLogger()
16
+
17
+
18
+ class InferenceCore:
19
+
20
+ def __init__(self,
21
+ network: MatAnyone,
22
+ cfg: DictConfig,
23
+ *,
24
+ image_feature_store: ImageFeatureStore = None):
25
+ self.network = network
26
+ self.cfg = cfg
27
+ self.mem_every = cfg.mem_every
28
+ stagger_updates = cfg.stagger_updates
29
+ self.chunk_size = cfg.chunk_size
30
+ self.save_aux = cfg.save_aux
31
+ self.max_internal_size = cfg.max_internal_size
32
+ self.flip_aug = cfg.flip_aug
33
+
34
+ self.curr_ti = -1
35
+ self.last_mem_ti = 0
36
+ # at which time indices should we update the sensory memory
37
+ if stagger_updates >= self.mem_every:
38
+ self.stagger_ti = set(range(1, self.mem_every + 1))
39
+ else:
40
+ self.stagger_ti = set(
41
+ np.round(np.linspace(1, self.mem_every, stagger_updates)).astype(int))
42
+ self.object_manager = ObjectManager()
43
+ self.memory = MemoryManager(cfg=cfg, object_manager=self.object_manager)
44
+
45
+ if image_feature_store is None:
46
+ self.image_feature_store = ImageFeatureStore(self.network)
47
+ else:
48
+ self.image_feature_store = image_feature_store
49
+
50
+ self.last_mask = None
51
+ self.last_pix_feat = None
52
+ self.last_msk_value = None
53
+
54
+ def clear_memory(self):
55
+ self.curr_ti = -1
56
+ self.last_mem_ti = 0
57
+ self.memory = MemoryManager(cfg=self.cfg, object_manager=self.object_manager)
58
+
59
+ def clear_non_permanent_memory(self):
60
+ self.curr_ti = -1
61
+ self.last_mem_ti = 0
62
+ self.memory.clear_non_permanent_memory()
63
+
64
+ def clear_sensory_memory(self):
65
+ self.curr_ti = -1
66
+ self.last_mem_ti = 0
67
+ self.memory.clear_sensory_memory()
68
+
69
+ def update_config(self, cfg):
70
+ self.mem_every = cfg['mem_every']
71
+ self.memory.update_config(cfg)
72
+
73
+ def clear_temp_mem(self):
74
+ self.memory.clear_work_mem()
75
+ # self.object_manager = ObjectManager()
76
+ self.memory.clear_obj_mem()
77
+ # self.memory.clear_sensory_memory()
78
+
79
+ def _add_memory(self,
80
+ image: torch.Tensor,
81
+ pix_feat: torch.Tensor,
82
+ prob: torch.Tensor,
83
+ key: torch.Tensor,
84
+ shrinkage: torch.Tensor,
85
+ selection: torch.Tensor,
86
+ *,
87
+ is_deep_update: bool = True,
88
+ force_permanent: bool = False) -> None:
89
+ """
90
+ Memorize the given segmentation in all memory stores.
91
+
92
+ The batch dimension is 1 if flip augmentation is not used.
93
+ image: RGB image, (1/2)*3*H*W
94
+ pix_feat: from the key encoder, (1/2)*_*H*W
95
+ prob: (1/2)*num_objects*H*W, in [0, 1]
96
+ key/shrinkage/selection: for anisotropic l2, (1/2)*_*H*W
97
+ selection can be None if not using long-term memory
98
+ is_deep_update: whether to use deep update (e.g. with the mask encoder)
99
+ force_permanent: whether to force the memory to be permanent
100
+ """
101
+ if prob.shape[1] == 0:
102
+ # nothing to add
103
+ log.warn('Trying to add an empty object mask to memory!')
104
+ return
105
+
106
+ if force_permanent:
107
+ as_permanent = 'all'
108
+ else:
109
+ as_permanent = 'first'
110
+
111
+ self.memory.initialize_sensory_if_needed(key, self.object_manager.all_obj_ids)
112
+ msk_value, sensory, obj_value, _ = self.network.encode_mask(
113
+ image,
114
+ pix_feat,
115
+ self.memory.get_sensory(self.object_manager.all_obj_ids),
116
+ prob,
117
+ deep_update=is_deep_update,
118
+ chunk_size=self.chunk_size,
119
+ need_weights=self.save_aux)
120
+ self.memory.add_memory(key,
121
+ shrinkage,
122
+ msk_value,
123
+ obj_value,
124
+ self.object_manager.all_obj_ids,
125
+ selection=selection,
126
+ as_permanent=as_permanent)
127
+ self.last_mem_ti = self.curr_ti
128
+ if is_deep_update:
129
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
130
+ self.last_msk_value = msk_value
131
+
132
+ def _segment(self,
133
+ key: torch.Tensor,
134
+ selection: torch.Tensor,
135
+ pix_feat: torch.Tensor,
136
+ ms_features: Iterable[torch.Tensor],
137
+ update_sensory: bool = True) -> torch.Tensor:
138
+ """
139
+ Produce a segmentation using the given features and the memory
140
+
141
+ The batch dimension is 1 if flip augmentation is not used.
142
+ key/selection: for anisotropic l2: (1/2) * _ * H * W
143
+ pix_feat: from the key encoder, (1/2) * _ * H * W
144
+ ms_features: an iterable of multiscale features from the encoder, each is (1/2)*_*H*W
145
+ with strides 16, 8, and 4 respectively
146
+ update_sensory: whether to update the sensory memory
147
+
148
+ Returns: (num_objects+1)*H*W normalized probability; the first channel is the background
149
+ """
150
+ bs = key.shape[0]
151
+ if self.flip_aug:
152
+ assert bs == 2
153
+ else:
154
+ assert bs == 1
155
+
156
+ if not self.memory.engaged:
157
+ log.warn('Trying to segment without any memory!')
158
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
159
+ device=key.device,
160
+ dtype=key.dtype)
161
+
162
+ uncert_output = None
163
+
164
+ if self.curr_ti == 0: # ONLY for the first frame for prediction
165
+ memory_readout = self.memory.read_first_frame(self.last_msk_value, pix_feat, self.last_mask, self.network, uncert_output=uncert_output)
166
+ else:
167
+ memory_readout = self.memory.read(pix_feat, key, selection, self.last_mask, self.network, uncert_output=uncert_output, last_msk_value=self.last_msk_value, ti=self.curr_ti,
168
+ last_pix_feat=self.last_pix_feat, last_pred_mask=self.last_mask)
169
+ memory_readout = self.object_manager.realize_dict(memory_readout)
170
+
171
+ sensory, _, pred_prob_with_bg = self.network.segment(ms_features,
172
+ memory_readout,
173
+ self.memory.get_sensory(
174
+ self.object_manager.all_obj_ids),
175
+ chunk_size=self.chunk_size,
176
+ update_sensory=update_sensory)
177
+ # remove batch dim
178
+ if self.flip_aug:
179
+ # average predictions of the non-flipped and flipped version
180
+ pred_prob_with_bg = (pred_prob_with_bg[0] +
181
+ torch.flip(pred_prob_with_bg[1], dims=[-1])) / 2
182
+ else:
183
+ pred_prob_with_bg = pred_prob_with_bg[0]
184
+ if update_sensory:
185
+ self.memory.update_sensory(sensory, self.object_manager.all_obj_ids)
186
+ return pred_prob_with_bg
187
+
188
+ def pred_all_flow(self, images):
189
+ self.total_len = images.shape[0]
190
+ images, self.pad = pad_divide_by(images, 16)
191
+ images = images.unsqueeze(0) # add the batch dimension: (1,t,c,h,w)
192
+
193
+ self.flows_forward, self.flows_backward = self.network.pred_forward_backward_flow(images)
194
+
195
+ def encode_all_images(self, images):
196
+ images, self.pad = pad_divide_by(images, 16)
197
+ self.image_feature_store.get_all_features(images) # t c h w
198
+ return images
199
+
200
+ def step(self,
201
+ image: torch.Tensor,
202
+ mask: Optional[torch.Tensor] = None,
203
+ objects: Optional[List[int]] = None,
204
+ *,
205
+ idx_mask: bool = False,
206
+ end: bool = False,
207
+ delete_buffer: bool = True,
208
+ force_permanent: bool = False,
209
+ matting: bool = True,
210
+ first_frame_pred: bool = False) -> torch.Tensor:
211
+ """
212
+ Take a step with a new incoming image.
213
+ If there is an incoming mask with new objects, we will memorize them.
214
+ If there is no incoming mask, we will segment the image using the memory.
215
+ In both cases, we will update the memory and return a segmentation.
216
+
217
+ image: 3*H*W
218
+ mask: H*W (if idx mask) or len(objects)*H*W or None
219
+ objects: list of object ids that are valid in the mask Tensor.
220
+ The ids themselves do not need to be consecutive/in order, but they need to be
221
+ in the same position in the list as the corresponding mask
222
+ in the tensor in non-idx-mask mode.
223
+ objects is ignored if the mask is None.
224
+ If idx_mask is False and objects is None, we sequentially infer the object ids.
225
+ idx_mask: if True, mask is expected to contain an object id at every pixel.
226
+ If False, mask should have multiple channels with each channel representing one object.
227
+ end: if we are at the end of the sequence, we do not need to update memory
228
+ if unsure just set it to False
229
+ delete_buffer: whether to delete the image feature buffer after this step
230
+ force_permanent: the memory recorded this frame will be added to the permanent memory
231
+ """
232
+ if objects is None and mask is not None:
233
+ assert not idx_mask
234
+ objects = list(range(1, mask.shape[0] + 1))
235
+
236
+ # resize input if needed -- currently only used for the GUI
237
+ resize_needed = False
238
+ if self.max_internal_size > 0:
239
+ h, w = image.shape[-2:]
240
+ min_side = min(h, w)
241
+ if min_side > self.max_internal_size:
242
+ resize_needed = True
243
+ new_h = int(h / min_side * self.max_internal_size)
244
+ new_w = int(w / min_side * self.max_internal_size)
245
+ image = F.interpolate(image.unsqueeze(0),
246
+ size=(new_h, new_w),
247
+ mode='bilinear',
248
+ align_corners=False)[0]
249
+ if mask is not None:
250
+ if idx_mask:
251
+ mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0).float(),
252
+ size=(new_h, new_w),
253
+ mode='nearest-exact',
254
+ align_corners=False)[0, 0].round().long()
255
+ else:
256
+ mask = F.interpolate(mask.unsqueeze(0),
257
+ size=(new_h, new_w),
258
+ mode='bilinear',
259
+ align_corners=False)[0]
260
+
261
+ self.curr_ti += 1
262
+
263
+ image, self.pad = pad_divide_by(image, 16) # DONE alreay for 3DCNN!!
264
+ image = image.unsqueeze(0) # add the batch dimension
265
+ if self.flip_aug:
266
+ image = torch.cat([image, torch.flip(image, dims=[-1])], dim=0)
267
+
268
+ # whether to update the working memory
269
+ is_mem_frame = ((self.curr_ti - self.last_mem_ti >= self.mem_every) or
270
+ (mask is not None)) and (not end)
271
+ # segment when there is no input mask or when the input mask is incomplete
272
+ need_segment = (mask is None) or (self.object_manager.num_obj > 0
273
+ and not self.object_manager.has_all(objects))
274
+ update_sensory = ((self.curr_ti - self.last_mem_ti) in self.stagger_ti) and (not end)
275
+
276
+ # reinit if it is the first frame for prediction
277
+ if first_frame_pred:
278
+ self.curr_ti = 0
279
+ self.last_mem_ti = 0
280
+ is_mem_frame = True
281
+ need_segment = True
282
+ update_sensory = True
283
+
284
+ # encoding the image
285
+ ms_feat, pix_feat = self.image_feature_store.get_features(self.curr_ti, image)
286
+ key, shrinkage, selection = self.image_feature_store.get_key(self.curr_ti, image)
287
+
288
+ # segmentation from memory if needed
289
+ if need_segment:
290
+ pred_prob_with_bg = self._segment(key,
291
+ selection,
292
+ pix_feat,
293
+ ms_feat,
294
+ update_sensory=update_sensory)
295
+
296
+ # use the input mask if provided
297
+ if mask is not None:
298
+ # inform the manager of the new objects, and get a list of temporary id
299
+ # temporary ids -- indicates the position of objects in the tensor
300
+ # (starts with 1 due to the background channel)
301
+ corresponding_tmp_ids, _ = self.object_manager.add_new_objects(objects)
302
+
303
+ mask, _ = pad_divide_by(mask, 16)
304
+ if need_segment:
305
+ # merge predicted mask with the incomplete input mask
306
+ pred_prob_no_bg = pred_prob_with_bg[1:]
307
+ # use the mutual exclusivity of segmentation
308
+ if idx_mask:
309
+ pred_prob_no_bg[:, mask > 0] = 0
310
+ else:
311
+ pred_prob_no_bg[:, mask.max(0) > 0.5] = 0
312
+
313
+ new_masks = []
314
+ for mask_id, tmp_id in enumerate(corresponding_tmp_ids):
315
+ if idx_mask:
316
+ this_mask = (mask == objects[mask_id]).type_as(pred_prob_no_bg)
317
+ else:
318
+ this_mask = mask[tmp_id]
319
+ if tmp_id > pred_prob_no_bg.shape[0]:
320
+ new_masks.append(this_mask.unsqueeze(0))
321
+ else:
322
+ # +1 for padding the background channel
323
+ pred_prob_no_bg[tmp_id - 1] = this_mask
324
+ # new_masks are always in the order of tmp_id
325
+ mask = torch.cat([pred_prob_no_bg, *new_masks], dim=0)
326
+ elif idx_mask:
327
+ # simply convert cls to one-hot representation
328
+ if len(objects) == 0:
329
+ if delete_buffer:
330
+ self.image_feature_store.delete(self.curr_ti)
331
+ log.warn('Trying to insert an empty mask as memory!')
332
+ return torch.zeros((1, key.shape[-2] * 16, key.shape[-1] * 16),
333
+ device=key.device,
334
+ dtype=key.dtype)
335
+ mask = torch.stack(
336
+ [mask == objects[mask_id] for mask_id, _ in enumerate(corresponding_tmp_ids)],
337
+ dim=0)
338
+ if matting:
339
+ mask = mask.unsqueeze(0).float() / 255.
340
+ pred_prob_with_bg = torch.cat([1-mask, mask], 0)
341
+ else:
342
+ pred_prob_with_bg = aggregate(mask, dim=0)
343
+ pred_prob_with_bg = torch.softmax(pred_prob_with_bg, dim=0)
344
+
345
+ self.last_mask = pred_prob_with_bg[1:].unsqueeze(0)
346
+ if self.flip_aug:
347
+ self.last_mask = torch.cat(
348
+ [self.last_mask, torch.flip(self.last_mask, dims=[-1])], dim=0)
349
+ self.last_pix_feat = pix_feat
350
+
351
+ # save as memory if needed
352
+ if is_mem_frame or force_permanent:
353
+ # clear the memory for given mask and add the first predicted mask
354
+ if first_frame_pred:
355
+ self.clear_temp_mem()
356
+ self._add_memory(image,
357
+ pix_feat,
358
+ self.last_mask,
359
+ key,
360
+ shrinkage,
361
+ selection,
362
+ force_permanent=force_permanent,
363
+ is_deep_update=True)
364
+ else: # compute self.last_msk_value for non-memory frame
365
+ msk_value, _, _, _ = self.network.encode_mask(
366
+ image,
367
+ pix_feat,
368
+ self.memory.get_sensory(self.object_manager.all_obj_ids),
369
+ self.last_mask,
370
+ deep_update=False,
371
+ chunk_size=self.chunk_size,
372
+ need_weights=self.save_aux)
373
+ self.last_msk_value = msk_value
374
+
375
+ if delete_buffer:
376
+ self.image_feature_store.delete(self.curr_ti)
377
+
378
+ output_prob = unpad(pred_prob_with_bg, self.pad)
379
+ if resize_needed:
380
+ # restore output to the original size
381
+ output_prob = F.interpolate(output_prob.unsqueeze(0),
382
+ size=(h, w),
383
+ mode='bilinear',
384
+ align_corners=False)[0]
385
+
386
+ return output_prob
387
+
388
+ def delete_objects(self, objects: List[int]) -> None:
389
+ """
390
+ Delete the given objects from the memory.
391
+ """
392
+ self.object_manager.delete_objects(objects)
393
+ self.memory.purge_except(self.object_manager.all_obj_ids)
394
+
395
+ def output_prob_to_mask(self, output_prob: torch.Tensor, matting: bool = True) -> torch.Tensor:
396
+ if matting:
397
+ new_mask = output_prob[1:].squeeze(0)
398
+ else:
399
+ mask = torch.argmax(output_prob, dim=0)
400
+
401
+ # index in tensor != object id -- remap the ids here
402
+ new_mask = torch.zeros_like(mask)
403
+ for tmp_id, obj in self.object_manager.tmp_id_to_obj.items():
404
+ new_mask[mask == tmp_id] = obj.id
405
+
406
+ return new_mask
preprocessing/matanyone/matanyone/inference/kv_memory_store.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Literal
2
+ from collections import defaultdict
3
+ import torch
4
+
5
+
6
+ def _add_last_dim(dictionary, key, new_value, prepend=False):
7
+ # append/prepend a new value to the last dimension of a tensor in a dictionary
8
+ # if the key does not exist, put the new value in
9
+ # append by default
10
+ if key in dictionary:
11
+ dictionary[key] = torch.cat([dictionary[key], new_value], -1)
12
+ else:
13
+ dictionary[key] = new_value
14
+
15
+
16
+ class KeyValueMemoryStore:
17
+ """
18
+ Works for key/value pairs type storage
19
+ e.g., working and long-term memory
20
+ """
21
+ def __init__(self, save_selection: bool = False, save_usage: bool = False):
22
+ """
23
+ We store keys and values of objects that first appear in the same frame in a bucket.
24
+ Each bucket contains a set of object ids.
25
+ Each bucket is associated with a single key tensor
26
+ and a dictionary of value tensors indexed by object id.
27
+
28
+ The keys and values are stored as the concatenation of a permanent part and a temporary part.
29
+ """
30
+ self.save_selection = save_selection
31
+ self.save_usage = save_usage
32
+
33
+ self.global_bucket_id = 0 # does not reduce even if buckets are removed
34
+ self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
35
+ self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
36
+ self.v: Dict[int, torch.Tensor] = {} # indexed by object id
37
+
38
+ # indexed by bucket id; the end point of permanent memory
39
+ self.perm_end_pt: Dict[int, int] = defaultdict(int)
40
+
41
+ # shrinkage and selection are just like the keys
42
+ self.s = {}
43
+ if self.save_selection:
44
+ self.e = {} # does not contain the permanent memory part
45
+
46
+ # usage
47
+ if self.save_usage:
48
+ self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
49
+ self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
50
+
51
+ def add(self,
52
+ key: torch.Tensor,
53
+ values: Dict[int, torch.Tensor],
54
+ shrinkage: torch.Tensor,
55
+ selection: torch.Tensor,
56
+ supposed_bucket_id: int = -1,
57
+ as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
58
+ """
59
+ key: (1/2)*C*N
60
+ values: dict of values ((1/2)*C*N), object ids are used as keys
61
+ shrinkage: (1/2)*1*N
62
+ selection: (1/2)*C*N
63
+
64
+ supposed_bucket_id: used to sync the bucket id between working and long-term memory
65
+ if provided, the input should all be in a single bucket indexed by this id
66
+ as_permanent: whether to store the input as permanent memory
67
+ 'no': don't
68
+ 'first': only store it as permanent memory if the bucket is empty
69
+ 'all': always store it as permanent memory
70
+ """
71
+ bs = key.shape[0]
72
+ ne = key.shape[-1]
73
+ assert len(key.shape) == 3
74
+ assert len(shrinkage.shape) == 3
75
+ assert not self.save_selection or len(selection.shape) == 3
76
+ assert as_permanent in ['no', 'first', 'all']
77
+
78
+ # add the value and create new buckets if necessary
79
+ if supposed_bucket_id >= 0:
80
+ enabled_buckets = [supposed_bucket_id]
81
+ bucket_exist = supposed_bucket_id in self.buckets
82
+ for obj, value in values.items():
83
+ if bucket_exist:
84
+ assert obj in self.v
85
+ assert obj in self.buckets[supposed_bucket_id]
86
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
87
+ else:
88
+ assert obj not in self.v
89
+ self.v[obj] = value
90
+ self.buckets[supposed_bucket_id] = list(values.keys())
91
+ else:
92
+ new_bucket_id = None
93
+ enabled_buckets = set()
94
+ for obj, value in values.items():
95
+ assert len(value.shape) == 3
96
+ if obj in self.v:
97
+ _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
98
+ bucket_used = [
99
+ bucket_id for bucket_id, object_ids in self.buckets.items()
100
+ if obj in object_ids
101
+ ]
102
+ assert len(bucket_used) == 1 # each object should only be in one bucket
103
+ enabled_buckets.add(bucket_used[0])
104
+ else:
105
+ self.v[obj] = value
106
+ if new_bucket_id is None:
107
+ # create new bucket
108
+ new_bucket_id = self.global_bucket_id
109
+ self.global_bucket_id += 1
110
+ self.buckets[new_bucket_id] = []
111
+ # put the new object into the corresponding bucket
112
+ self.buckets[new_bucket_id].append(obj)
113
+ enabled_buckets.add(new_bucket_id)
114
+
115
+ # increment the permanent size if necessary
116
+ add_as_permanent = {} # indexed by bucket id
117
+ for bucket_id in enabled_buckets:
118
+ add_as_permanent[bucket_id] = False
119
+ if as_permanent == 'all':
120
+ self.perm_end_pt[bucket_id] += ne
121
+ add_as_permanent[bucket_id] = True
122
+ elif as_permanent == 'first':
123
+ if self.perm_end_pt[bucket_id] == 0:
124
+ self.perm_end_pt[bucket_id] = ne
125
+ add_as_permanent[bucket_id] = True
126
+
127
+ # create new counters for usage if necessary
128
+ if self.save_usage and as_permanent != 'all':
129
+ new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
130
+ new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
131
+
132
+ # add the key to every bucket
133
+ for bucket_id in self.buckets:
134
+ if bucket_id not in enabled_buckets:
135
+ # if we are not adding new values to a bucket, we should skip it
136
+ continue
137
+
138
+ _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
139
+ _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
140
+ if not add_as_permanent[bucket_id]:
141
+ if self.save_selection:
142
+ _add_last_dim(self.e, bucket_id, selection)
143
+ if self.save_usage:
144
+ _add_last_dim(self.use_cnt, bucket_id, new_count)
145
+ _add_last_dim(self.life_cnt, bucket_id, new_life)
146
+
147
+ def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
148
+ # increase all life count by 1
149
+ # increase use of indexed elements
150
+ if not self.save_usage:
151
+ return
152
+
153
+ usage = usage[:, self.perm_end_pt[bucket_id]:]
154
+ if usage.shape[-1] == 0:
155
+ # if there is no temporary memory, we don't need to update
156
+ return
157
+ self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
158
+ self.life_cnt[bucket_id] += 1
159
+
160
+ def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
161
+ # keep only the temporary elements *outside* of this range (with some boundary conditions)
162
+ # the permanent elements are ignored in this computation
163
+ # i.e., concat (a[:start], a[end:])
164
+ # bucket with size <= min_size are not modified
165
+
166
+ assert start >= 0
167
+ assert end <= 0
168
+
169
+ object_ids = self.buckets[bucket_id]
170
+ bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
171
+ if bucket_num_elements <= min_size:
172
+ return
173
+
174
+ if end == 0:
175
+ # negative 0 would not work as the end index!
176
+ # effectively make the second part an empty slice
177
+ end = self.k[bucket_id].shape[-1] + 1
178
+
179
+ p_size = self.perm_end_pt[bucket_id]
180
+ start = start + p_size
181
+
182
+ k = self.k[bucket_id]
183
+ s = self.s[bucket_id]
184
+ if self.save_selection:
185
+ e = self.e[bucket_id]
186
+ if self.save_usage:
187
+ use_cnt = self.use_cnt[bucket_id]
188
+ life_cnt = self.life_cnt[bucket_id]
189
+
190
+ self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
191
+ self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
192
+ if self.save_selection:
193
+ self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
194
+ if self.save_usage:
195
+ self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
196
+ self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
197
+ -1)
198
+ for obj_id in object_ids:
199
+ v = self.v[obj_id]
200
+ self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
201
+
202
+ def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
203
+ self.sieve_by_range(bucket_id, 0, -max_len, max_len)
204
+
205
+ def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
206
+ # for long-term memory only
207
+ object_ids = self.buckets[bucket_id]
208
+
209
+ assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
210
+
211
+ # normalize with life duration
212
+ usage = self.get_usage(bucket_id)
213
+ bs = usage.shape[0]
214
+
215
+ survivals = []
216
+
217
+ for bi in range(bs):
218
+ _, survived = torch.topk(usage[bi], k=max_size)
219
+ survivals.append(survived.flatten())
220
+ assert survived.shape[-1] == survivals[0].shape[-1]
221
+
222
+ self.k[bucket_id] = torch.stack(
223
+ [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
224
+ self.s[bucket_id] = torch.stack(
225
+ [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
226
+
227
+ if self.save_selection:
228
+ # Long-term memory does not store selection so this should not be needed
229
+ self.e[bucket_id] = torch.stack(
230
+ [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
231
+ for obj_id in object_ids:
232
+ self.v[obj_id] = torch.stack(
233
+ [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
234
+
235
+ self.use_cnt[bucket_id] = torch.stack(
236
+ [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
237
+ self.life_cnt[bucket_id] = torch.stack(
238
+ [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
239
+
240
+ def get_usage(self, bucket_id: int) -> torch.Tensor:
241
+ # return normalized usage
242
+ if not self.save_usage:
243
+ raise RuntimeError('I did not count usage!')
244
+ else:
245
+ usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
246
+ return usage
247
+
248
+ def get_all_sliced(
249
+ self, bucket_id: int, start: int, end: int
250
+ ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
251
+ # return k, sk, ek, value, normalized usage in order, sliced by start and end
252
+ # this only queries the temporary memory
253
+
254
+ assert start >= 0
255
+ assert end <= 0
256
+
257
+ p_size = self.perm_end_pt[bucket_id]
258
+ start = start + p_size
259
+
260
+ if end == 0:
261
+ # negative 0 would not work as the end index!
262
+ k = self.k[bucket_id][:, :, start:]
263
+ sk = self.s[bucket_id][:, :, start:]
264
+ ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
265
+ value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
266
+ usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
267
+ else:
268
+ k = self.k[bucket_id][:, :, start:end]
269
+ sk = self.s[bucket_id][:, :, start:end]
270
+ ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
271
+ value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
272
+ usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
273
+
274
+ return k, sk, ek, value, usage
275
+
276
+ def purge_except(self, obj_keep_idx: List[int]):
277
+ # purge certain objects from the memory except the one listed
278
+ obj_keep_idx = set(obj_keep_idx)
279
+
280
+ # remove objects that are not in the keep list from the buckets
281
+ buckets_to_remove = []
282
+ for bucket_id, object_ids in self.buckets.items():
283
+ self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
284
+ if len(self.buckets[bucket_id]) == 0:
285
+ buckets_to_remove.append(bucket_id)
286
+
287
+ # remove object values that are not in the keep list
288
+ self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
289
+
290
+ # remove buckets that are empty
291
+ for bucket_id in buckets_to_remove:
292
+ del self.buckets[bucket_id]
293
+ del self.k[bucket_id]
294
+ del self.s[bucket_id]
295
+ if self.save_selection:
296
+ del self.e[bucket_id]
297
+ if self.save_usage:
298
+ del self.use_cnt[bucket_id]
299
+ del self.life_cnt[bucket_id]
300
+
301
+ def clear_non_permanent_memory(self):
302
+ # clear all non-permanent memory
303
+ for bucket_id in self.buckets:
304
+ self.sieve_by_range(bucket_id, 0, 0, 0)
305
+
306
+ def get_v_size(self, obj_id: int) -> int:
307
+ return self.v[obj_id].shape[-1]
308
+
309
+ def size(self, bucket_id: int) -> int:
310
+ if bucket_id not in self.k:
311
+ return 0
312
+ else:
313
+ return self.k[bucket_id].shape[-1]
314
+
315
+ def perm_size(self, bucket_id: int) -> int:
316
+ return self.perm_end_pt[bucket_id]
317
+
318
+ def non_perm_size(self, bucket_id: int) -> int:
319
+ return self.size(bucket_id) - self.perm_size(bucket_id)
320
+
321
+ def engaged(self, bucket_id: Optional[int] = None) -> bool:
322
+ if bucket_id is None:
323
+ return len(self.buckets) > 0
324
+ else:
325
+ return bucket_id in self.buckets
326
+
327
+ @property
328
+ def num_objects(self) -> int:
329
+ return len(self.v)
330
+
331
+ @property
332
+ def key(self) -> Dict[int, torch.Tensor]:
333
+ return self.k
334
+
335
+ @property
336
+ def value(self) -> Dict[int, torch.Tensor]:
337
+ return self.v
338
+
339
+ @property
340
+ def shrinkage(self) -> Dict[int, torch.Tensor]:
341
+ return self.s
342
+
343
+ @property
344
+ def selection(self) -> Dict[int, torch.Tensor]:
345
+ return self.e
346
+
347
+ def __contains__(self, key):
348
+ return key in self.v
preprocessing/matanyone/matanyone/inference/memory_manager.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from omegaconf import DictConfig
3
+ from typing import List, Dict
4
+ import torch
5
+
6
+ from .object_manager import ObjectManager
7
+ from .kv_memory_store import KeyValueMemoryStore
8
+ from ..model.matanyone import MatAnyone
9
+ from ..model.utils.memory_utils import get_similarity, do_softmax
10
+
11
+ log = logging.getLogger()
12
+
13
+
14
+ class MemoryManager:
15
+ """
16
+ Manages all three memory stores and the transition between working/long-term memory
17
+ """
18
+ def __init__(self, cfg: DictConfig, object_manager: ObjectManager):
19
+ self.object_manager = object_manager
20
+ self.sensory_dim = cfg.model.sensory_dim
21
+ self.top_k = cfg.top_k
22
+ self.chunk_size = cfg.chunk_size
23
+
24
+ self.save_aux = cfg.save_aux
25
+
26
+ self.use_long_term = cfg.use_long_term
27
+ self.count_long_term_usage = cfg.long_term.count_usage
28
+ # subtract 1 because the first-frame is now counted as "permanent memory"
29
+ # and is not counted towards max_mem_frames
30
+ # but we want to keep the hyperparameters consistent as before for the same behavior
31
+ if self.use_long_term:
32
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
33
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
34
+ self.num_prototypes = cfg.long_term.num_prototypes
35
+ self.max_long_tokens = cfg.long_term.max_num_tokens
36
+ self.buffer_tokens = cfg.long_term.buffer_tokens
37
+ else:
38
+ self.max_mem_frames = cfg.max_mem_frames - 1
39
+
40
+ # dimensions will be inferred from input later
41
+ self.CK = self.CV = None
42
+ self.H = self.W = None
43
+
44
+ # The sensory memory is stored as a dictionary indexed by object ids
45
+ # each of shape bs * C^h * H * W
46
+ self.sensory = {}
47
+
48
+ # a dictionary indexed by object ids, each of shape bs * T * Q * C
49
+ self.obj_v = {}
50
+
51
+ self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
52
+ save_usage=self.use_long_term)
53
+ if self.use_long_term:
54
+ self.long_mem = KeyValueMemoryStore(save_usage=self.count_long_term_usage)
55
+
56
+ self.config_stale = True
57
+ self.engaged = False
58
+
59
+ def update_config(self, cfg: DictConfig) -> None:
60
+ self.config_stale = True
61
+ self.top_k = cfg['top_k']
62
+
63
+ assert self.use_long_term == cfg.use_long_term, 'cannot update this'
64
+ assert self.count_long_term_usage == cfg.long_term.count_usage, 'cannot update this'
65
+
66
+ self.use_long_term = cfg.use_long_term
67
+ self.count_long_term_usage = cfg.long_term.count_usage
68
+ if self.use_long_term:
69
+ self.max_mem_frames = cfg.long_term.max_mem_frames - 1
70
+ self.min_mem_frames = cfg.long_term.min_mem_frames - 1
71
+ self.num_prototypes = cfg.long_term.num_prototypes
72
+ self.max_long_tokens = cfg.long_term.max_num_tokens
73
+ self.buffer_tokens = cfg.long_term.buffer_tokens
74
+ else:
75
+ self.max_mem_frames = cfg.max_mem_frames - 1
76
+
77
+ def _readout(self, affinity, v, uncert_mask=None) -> torch.Tensor:
78
+ # affinity: bs*N*HW
79
+ # v: bs*C*N or bs*num_objects*C*N
80
+ # returns bs*C*HW or bs*num_objects*C*HW
81
+ if len(v.shape) == 3:
82
+ # single object
83
+ if uncert_mask is not None:
84
+ return v @ affinity * uncert_mask
85
+ else:
86
+ return v @ affinity
87
+ else:
88
+ bs, num_objects, C, N = v.shape
89
+ v = v.view(bs, num_objects * C, N)
90
+ out = v @ affinity
91
+ if uncert_mask is not None:
92
+ uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, C, -1)
93
+ out = out * uncert_mask
94
+ return out.view(bs, num_objects, C, -1)
95
+
96
+ def _get_mask_by_ids(self, mask: torch.Tensor, obj_ids: List[int]) -> torch.Tensor:
97
+ # -1 because the mask does not contain the background channel
98
+ return mask[:, [self.object_manager.find_tmp_by_id(obj) - 1 for obj in obj_ids]]
99
+
100
+ def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
101
+ return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)
102
+
103
+ def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
104
+ return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)
105
+
106
+ def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
107
+ # All the values that the object ids refer to should have the same shape
108
+ value = torch.stack([self.work_mem.value[obj] for obj in obj_ids], dim=1)
109
+ if self.use_long_term and obj_ids[0] in self.long_mem.value:
110
+ lt_value = torch.stack([self.long_mem.value[obj] for obj in obj_ids], dim=1)
111
+ value = torch.cat([lt_value, value], dim=-1)
112
+
113
+ return value
114
+
115
+ def read_first_frame(self, last_msk_value, pix_feat: torch.Tensor,
116
+ last_mask: torch.Tensor, network: MatAnyone, uncert_output=None) -> Dict[int, torch.Tensor]:
117
+ """
118
+ Read from all memory stores and returns a single memory readout tensor for each object
119
+
120
+ pix_feat: (1/2) x C x H x W
121
+ query_key: (1/2) x C^k x H x W
122
+ selection: (1/2) x C^k x H x W
123
+ last_mask: (1/2) x num_objects x H x W (at stride 16)
124
+ return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
125
+ """
126
+ h, w = pix_feat.shape[-2:]
127
+ bs = pix_feat.shape[0]
128
+ assert last_mask.shape[0] == bs
129
+
130
+ """
131
+ Compute affinity and perform readout
132
+ """
133
+ all_readout_mem = {}
134
+ buckets = self.work_mem.buckets
135
+ for bucket_id, bucket in buckets.items():
136
+
137
+ if self.chunk_size < 1:
138
+ object_chunks = [bucket]
139
+ else:
140
+ object_chunks = [
141
+ bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
142
+ ]
143
+
144
+ for objects in object_chunks:
145
+ this_sensory = self._get_sensory_by_ids(objects)
146
+ this_last_mask = self._get_mask_by_ids(last_mask, objects)
147
+ this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
148
+ pixel_readout = network.pixel_fusion(pix_feat, last_msk_value, this_sensory,
149
+ this_last_mask)
150
+ this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
151
+ readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
152
+ for i, obj in enumerate(objects):
153
+ all_readout_mem[obj] = readout_memory[:, i]
154
+
155
+ if self.save_aux:
156
+ aux_output = {
157
+ # 'sensory': this_sensory,
158
+ # 'pixel_readout': pixel_readout,
159
+ 'q_logits': aux_features['logits'] if aux_features else None,
160
+ # 'q_weights': aux_features['q_weights'] if aux_features else None,
161
+ # 'p_weights': aux_features['p_weights'] if aux_features else None,
162
+ # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
163
+ }
164
+ self.aux = aux_output
165
+
166
+ return all_readout_mem
167
+
168
+ def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch.Tensor,
169
+ last_mask: torch.Tensor, network: MatAnyone, uncert_output=None, last_msk_value=None, ti=None,
170
+ last_pix_feat=None, last_pred_mask=None) -> Dict[int, torch.Tensor]:
171
+ """
172
+ Read from all memory stores and returns a single memory readout tensor for each object
173
+
174
+ pix_feat: (1/2) x C x H x W
175
+ query_key: (1/2) x C^k x H x W
176
+ selection: (1/2) x C^k x H x W
177
+ last_mask: (1/2) x num_objects x H x W (at stride 16)
178
+ return a dict of memory readouts, indexed by object indices. Each readout is C*H*W
179
+ """
180
+ h, w = pix_feat.shape[-2:]
181
+ bs = pix_feat.shape[0]
182
+ assert query_key.shape[0] == bs
183
+ assert selection.shape[0] == bs
184
+ assert last_mask.shape[0] == bs
185
+
186
+ uncert_mask = uncert_output["mask"] if uncert_output is not None else None
187
+
188
+ query_key = query_key.flatten(start_dim=2) # bs*C^k*HW
189
+ selection = selection.flatten(start_dim=2) # bs*C^k*HW
190
+ """
191
+ Compute affinity and perform readout
192
+ """
193
+ all_readout_mem = {}
194
+ buckets = self.work_mem.buckets
195
+ for bucket_id, bucket in buckets.items():
196
+ if self.use_long_term and self.long_mem.engaged(bucket_id):
197
+ # Use long-term memory
198
+ long_mem_size = self.long_mem.size(bucket_id)
199
+ memory_key = torch.cat([self.long_mem.key[bucket_id], self.work_mem.key[bucket_id]],
200
+ -1)
201
+ shrinkage = torch.cat(
202
+ [self.long_mem.shrinkage[bucket_id], self.work_mem.shrinkage[bucket_id]], -1)
203
+
204
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
205
+ affinity, usage = do_softmax(similarity,
206
+ top_k=self.top_k,
207
+ inplace=True,
208
+ return_usage=True)
209
+ """
210
+ Record memory usage for working and long-term memory
211
+ """
212
+ # ignore the index return for long-term memory
213
+ work_usage = usage[:, long_mem_size:]
214
+ self.work_mem.update_bucket_usage(bucket_id, work_usage)
215
+
216
+ if self.count_long_term_usage:
217
+ # ignore the index return for working memory
218
+ long_usage = usage[:, :long_mem_size]
219
+ self.long_mem.update_bucket_usage(bucket_id, long_usage)
220
+ else:
221
+ # no long-term memory
222
+ memory_key = self.work_mem.key[bucket_id]
223
+ shrinkage = self.work_mem.shrinkage[bucket_id]
224
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection, uncert_mask=uncert_mask)
225
+
226
+ if self.use_long_term:
227
+ affinity, usage = do_softmax(similarity,
228
+ top_k=self.top_k,
229
+ inplace=True,
230
+ return_usage=True)
231
+ self.work_mem.update_bucket_usage(bucket_id, usage)
232
+ else:
233
+ affinity = do_softmax(similarity, top_k=self.top_k, inplace=True)
234
+
235
+ if self.chunk_size < 1:
236
+ object_chunks = [bucket]
237
+ else:
238
+ object_chunks = [
239
+ bucket[i:i + self.chunk_size] for i in range(0, len(bucket), self.chunk_size)
240
+ ]
241
+
242
+ for objects in object_chunks:
243
+ this_sensory = self._get_sensory_by_ids(objects)
244
+ this_last_mask = self._get_mask_by_ids(last_mask, objects)
245
+ this_msk_value = self._get_visual_values_by_ids(objects) # (1/2)*num_objects*C*N
246
+ visual_readout = self._readout(affinity,
247
+ this_msk_value, uncert_mask).view(bs, len(objects), self.CV, h, w)
248
+
249
+ uncert_output = network.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, visual_readout[:,0]-last_msk_value[:,0])
250
+
251
+ if uncert_output is not None:
252
+ uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
253
+ visual_readout = visual_readout*uncert_prob + last_msk_value*(1-uncert_prob)
254
+
255
+ pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
256
+ this_last_mask)
257
+ this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
258
+ readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
259
+ for i, obj in enumerate(objects):
260
+ all_readout_mem[obj] = readout_memory[:, i]
261
+
262
+ if self.save_aux:
263
+ aux_output = {
264
+ # 'sensory': this_sensory,
265
+ # 'pixel_readout': pixel_readout,
266
+ 'q_logits': aux_features['logits'] if aux_features else None,
267
+ # 'q_weights': aux_features['q_weights'] if aux_features else None,
268
+ # 'p_weights': aux_features['p_weights'] if aux_features else None,
269
+ # 'attn_mask': aux_features['attn_mask'].float() if aux_features else None,
270
+ }
271
+ self.aux = aux_output
272
+
273
+ return all_readout_mem
274
+
275
+ def add_memory(self,
276
+ key: torch.Tensor,
277
+ shrinkage: torch.Tensor,
278
+ msk_value: torch.Tensor,
279
+ obj_value: torch.Tensor,
280
+ objects: List[int],
281
+ selection: torch.Tensor = None,
282
+ *,
283
+ as_permanent: bool = False) -> None:
284
+ # key: (1/2)*C*H*W
285
+ # msk_value: (1/2)*num_objects*C*H*W
286
+ # obj_value: (1/2)*num_objects*Q*C
287
+ # objects contains a list of object ids corresponding to the objects in msk_value/obj_value
288
+ bs = key.shape[0]
289
+ assert shrinkage.shape[0] == bs
290
+ assert msk_value.shape[0] == bs
291
+ assert obj_value.shape[0] == bs
292
+
293
+ self.engaged = True
294
+ if self.H is None or self.config_stale:
295
+ self.config_stale = False
296
+ self.H, self.W = msk_value.shape[-2:]
297
+ self.HW = self.H * self.W
298
+ # convert from num. frames to num. tokens
299
+ self.max_work_tokens = self.max_mem_frames * self.HW
300
+ if self.use_long_term:
301
+ self.min_work_tokens = self.min_mem_frames * self.HW
302
+
303
+ # key: bs*C*N
304
+ # value: bs*num_objects*C*N
305
+ key = key.flatten(start_dim=2)
306
+ shrinkage = shrinkage.flatten(start_dim=2)
307
+ self.CK = key.shape[1]
308
+
309
+ msk_value = msk_value.flatten(start_dim=3)
310
+ self.CV = msk_value.shape[2]
311
+
312
+ if selection is not None:
313
+ # not used in non-long-term mode
314
+ selection = selection.flatten(start_dim=2)
315
+
316
+ # insert object values into object memory
317
+ for obj_id, obj in enumerate(objects):
318
+ if obj in self.obj_v:
319
+ """streaming average
320
+ each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
321
+ first embed_dim keeps track of the sum of embeddings
322
+ the last dim keeps the total count
323
+ averaging in done inside the object transformer
324
+
325
+ incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
326
+ self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
327
+ """
328
+ last_acc = self.obj_v[obj][:, :, -1]
329
+ new_acc = last_acc + obj_value[:, obj_id, :, -1]
330
+
331
+ self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
332
+ obj_value[:, obj_id, :, :-1])
333
+ self.obj_v[obj][:, :, -1] = new_acc
334
+ else:
335
+ self.obj_v[obj] = obj_value[:, obj_id]
336
+
337
+ # convert mask value tensor into a dict for insertion
338
+ msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
339
+ self.work_mem.add(key,
340
+ msk_values,
341
+ shrinkage,
342
+ selection=selection,
343
+ as_permanent=as_permanent)
344
+
345
+ for bucket_id in self.work_mem.buckets.keys():
346
+ # long-term memory cleanup
347
+ if self.use_long_term:
348
+ # Do memory compressed if needed
349
+ if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
350
+ # Remove obsolete features if needed
351
+ if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
352
+ self.num_prototypes):
353
+ self.long_mem.remove_obsolete_features(
354
+ bucket_id,
355
+ self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
356
+
357
+ self.compress_features(bucket_id)
358
+ else:
359
+ # FIFO
360
+ self.work_mem.remove_old_memory(bucket_id, self.max_work_tokens)
361
+
362
+ def purge_except(self, obj_keep_idx: List[int]) -> None:
363
+ # purge certain objects from the memory except the one listed
364
+ self.work_mem.purge_except(obj_keep_idx)
365
+ if self.use_long_term and self.long_mem.engaged():
366
+ self.long_mem.purge_except(obj_keep_idx)
367
+ self.sensory = {k: v for k, v in self.sensory.items() if k in obj_keep_idx}
368
+
369
+ if not self.work_mem.engaged():
370
+ # everything is removed!
371
+ self.engaged = False
372
+
373
+ def compress_features(self, bucket_id: int) -> None:
374
+
375
+ # perform memory consolidation
376
+ prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
377
+ *self.work_mem.get_all_sliced(bucket_id, 0, -self.min_work_tokens))
378
+
379
+ # remove consolidated working memory
380
+ self.work_mem.sieve_by_range(bucket_id,
381
+ 0,
382
+ -self.min_work_tokens,
383
+ min_size=self.min_work_tokens)
384
+
385
+ # add to long-term memory
386
+ self.long_mem.add(prototype_key,
387
+ prototype_value,
388
+ prototype_shrinkage,
389
+ selection=None,
390
+ supposed_bucket_id=bucket_id)
391
+
392
+ def consolidation(self, candidate_key: torch.Tensor, candidate_shrinkage: torch.Tensor,
393
+ candidate_selection: torch.Tensor, candidate_value: Dict[int, torch.Tensor],
394
+ usage: torch.Tensor) -> (torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
395
+ # find the indices with max usage
396
+ bs = candidate_key.shape[0]
397
+ assert bs in [1, 2]
398
+
399
+ prototype_key = []
400
+ prototype_selection = []
401
+ for bi in range(bs):
402
+ _, max_usage_indices = torch.topk(usage[bi], k=self.num_prototypes, dim=-1, sorted=True)
403
+ prototype_indices = max_usage_indices.flatten()
404
+ prototype_key.append(candidate_key[bi, :, prototype_indices])
405
+ prototype_selection.append(candidate_selection[bi, :, prototype_indices])
406
+ prototype_key = torch.stack(prototype_key, dim=0)
407
+ prototype_selection = torch.stack(prototype_selection, dim=0)
408
+ """
409
+ Potentiation step
410
+ """
411
+ similarity = get_similarity(candidate_key, candidate_shrinkage, prototype_key,
412
+ prototype_selection)
413
+ affinity = do_softmax(similarity)
414
+
415
+ # readout the values
416
+ prototype_value = {k: self._readout(affinity, v) for k, v in candidate_value.items()}
417
+
418
+ # readout the shrinkage term
419
+ prototype_shrinkage = self._readout(affinity, candidate_shrinkage)
420
+
421
+ return prototype_key, prototype_value, prototype_shrinkage
422
+
423
+ def initialize_sensory_if_needed(self, sample_key: torch.Tensor, ids: List[int]):
424
+ for obj in ids:
425
+ if obj not in self.sensory:
426
+ # also initializes the sensory memory
427
+ bs, _, h, w = sample_key.shape
428
+ self.sensory[obj] = torch.zeros((bs, self.sensory_dim, h, w),
429
+ device=sample_key.device)
430
+
431
+ def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
432
+ # sensory: 1*num_objects*C*H*W
433
+ for obj_id, obj in enumerate(ids):
434
+ self.sensory[obj] = sensory[:, obj_id]
435
+
436
+ def get_sensory(self, ids: List[int]):
437
+ # returns (1/2)*num_objects*C*H*W
438
+ return self._get_sensory_by_ids(ids)
439
+
440
+ def clear_non_permanent_memory(self):
441
+ self.work_mem.clear_non_permanent_memory()
442
+ if self.use_long_term:
443
+ self.long_mem.clear_non_permanent_memory()
444
+
445
+ def clear_sensory_memory(self):
446
+ self.sensory = {}
447
+
448
+ def clear_work_mem(self):
449
+ self.work_mem = KeyValueMemoryStore(save_selection=self.use_long_term,
450
+ save_usage=self.use_long_term)
451
+
452
+ def clear_obj_mem(self):
453
+ self.obj_v = {}
preprocessing/matanyone/matanyone/inference/object_info.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ObjectInfo:
2
+ """
3
+ Store meta information for an object
4
+ """
5
+ def __init__(self, id: int):
6
+ self.id = id
7
+ self.poke_count = 0 # count number of detections missed
8
+
9
+ def poke(self) -> None:
10
+ self.poke_count += 1
11
+
12
+ def unpoke(self) -> None:
13
+ self.poke_count = 0
14
+
15
+ def __hash__(self):
16
+ return hash(self.id)
17
+
18
+ def __eq__(self, other):
19
+ if type(other) == int:
20
+ return self.id == other
21
+ return self.id == other.id
22
+
23
+ def __repr__(self):
24
+ return f'(ID: {self.id})'
preprocessing/matanyone/matanyone/inference/object_manager.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List, Dict
2
+
3
+ import torch
4
+ from .object_info import ObjectInfo
5
+
6
+
7
+ class ObjectManager:
8
+ """
9
+ Object IDs are immutable. The same ID always represent the same object.
10
+ Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
11
+ Temporary IDs start from 1.
12
+ """
13
+
14
+ def __init__(self):
15
+ self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
16
+ self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
17
+ self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
18
+
19
+ self.all_historical_object_ids: List[int] = []
20
+
21
+ def _recompute_obj_id_to_obj_mapping(self) -> None:
22
+ self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
23
+
24
+ def add_new_objects(
25
+ self, objects: Union[List[ObjectInfo], ObjectInfo,
26
+ List[int]]) -> (List[int], List[int]):
27
+ if not isinstance(objects, list):
28
+ objects = [objects]
29
+
30
+ corresponding_tmp_ids = []
31
+ corresponding_obj_ids = []
32
+ for obj in objects:
33
+ if isinstance(obj, int):
34
+ obj = ObjectInfo(id=obj)
35
+
36
+ if obj in self.obj_to_tmp_id:
37
+ # old object
38
+ corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
39
+ corresponding_obj_ids.append(obj.id)
40
+ else:
41
+ # new object
42
+ new_obj = ObjectInfo(id=obj.id)
43
+
44
+ # new object
45
+ new_tmp_id = len(self.obj_to_tmp_id) + 1
46
+ self.obj_to_tmp_id[new_obj] = new_tmp_id
47
+ self.tmp_id_to_obj[new_tmp_id] = new_obj
48
+ self.all_historical_object_ids.append(new_obj.id)
49
+ corresponding_tmp_ids.append(new_tmp_id)
50
+ corresponding_obj_ids.append(new_obj.id)
51
+
52
+ self._recompute_obj_id_to_obj_mapping()
53
+ assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
54
+ return corresponding_tmp_ids, corresponding_obj_ids
55
+
56
+ def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
57
+ # delete an object or a list of objects
58
+ # re-sort the tmp ids
59
+ if isinstance(obj_ids_to_remove, int):
60
+ obj_ids_to_remove = [obj_ids_to_remove]
61
+
62
+ new_tmp_id = 1
63
+ total_num_id = len(self.obj_to_tmp_id)
64
+
65
+ local_obj_to_tmp_id = {}
66
+ local_tmp_to_obj_id = {}
67
+
68
+ for tmp_iter in range(1, total_num_id + 1):
69
+ obj = self.tmp_id_to_obj[tmp_iter]
70
+ if obj.id not in obj_ids_to_remove:
71
+ local_obj_to_tmp_id[obj] = new_tmp_id
72
+ local_tmp_to_obj_id[new_tmp_id] = obj
73
+ new_tmp_id += 1
74
+
75
+ self.obj_to_tmp_id = local_obj_to_tmp_id
76
+ self.tmp_id_to_obj = local_tmp_to_obj_id
77
+ self._recompute_obj_id_to_obj_mapping()
78
+
79
+ def purge_inactive_objects(self,
80
+ max_missed_detection_count: int) -> (bool, List[int], List[int]):
81
+ # remove tmp ids of objects that are removed
82
+ obj_id_to_be_deleted = []
83
+ tmp_id_to_be_deleted = []
84
+ tmp_id_to_keep = []
85
+ obj_id_to_keep = []
86
+
87
+ for obj in self.obj_to_tmp_id:
88
+ if obj.poke_count > max_missed_detection_count:
89
+ obj_id_to_be_deleted.append(obj.id)
90
+ tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
91
+ else:
92
+ tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
93
+ obj_id_to_keep.append(obj.id)
94
+
95
+ purge_activated = len(obj_id_to_be_deleted) > 0
96
+ if purge_activated:
97
+ self.delete_objects(obj_id_to_be_deleted)
98
+ return purge_activated, tmp_id_to_keep, obj_id_to_keep
99
+
100
+ def tmp_to_obj_cls(self, mask) -> torch.Tensor:
101
+ # remap tmp id cls representation to the true object id representation
102
+ new_mask = torch.zeros_like(mask)
103
+ for tmp_id, obj in self.tmp_id_to_obj.items():
104
+ new_mask[mask == tmp_id] = obj.id
105
+ return new_mask
106
+
107
+ def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
108
+ # returns the mapping in a dict format for saving it with pickle
109
+ return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
110
+
111
+ def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
112
+ # turns a dict indexed by obj id into a tensor, ordered by tmp IDs
113
+ output = []
114
+ for _, obj in self.tmp_id_to_obj.items():
115
+ if obj.id not in obj_dict:
116
+ raise NotImplementedError
117
+ output.append(obj_dict[obj.id])
118
+ output = torch.stack(output, dim=dim)
119
+ return output
120
+
121
+ def make_one_hot(self, cls_mask) -> torch.Tensor:
122
+ output = []
123
+ for _, obj in self.tmp_id_to_obj.items():
124
+ output.append(cls_mask == obj.id)
125
+ if len(output) == 0:
126
+ output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
127
+ else:
128
+ output = torch.stack(output, dim=0)
129
+ return output
130
+
131
+ @property
132
+ def all_obj_ids(self) -> List[int]:
133
+ return [k.id for k in self.obj_to_tmp_id]
134
+
135
+ @property
136
+ def num_obj(self) -> int:
137
+ return len(self.obj_to_tmp_id)
138
+
139
+ def has_all(self, objects: List[int]) -> bool:
140
+ for obj in objects:
141
+ if obj not in self.obj_to_tmp_id:
142
+ return False
143
+ return True
144
+
145
+ def find_object_by_id(self, obj_id) -> ObjectInfo:
146
+ return self.obj_id_to_obj[obj_id]
147
+
148
+ def find_tmp_by_id(self, obj_id) -> int:
149
+ return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]
preprocessing/matanyone/matanyone/inference/utils/__init__.py ADDED
File without changes
preprocessing/matanyone/matanyone/inference/utils/args_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from omegaconf import DictConfig
3
+
4
+ log = logging.getLogger()
5
+
6
+
7
+ def get_dataset_cfg(cfg: DictConfig):
8
+ dataset_name = cfg.dataset
9
+ data_cfg = cfg.datasets[dataset_name]
10
+
11
+ potential_overrides = [
12
+ 'image_directory',
13
+ 'mask_directory',
14
+ 'json_directory',
15
+ 'size',
16
+ 'save_all',
17
+ 'use_all_masks',
18
+ 'use_long_term',
19
+ 'mem_every',
20
+ ]
21
+
22
+ for override in potential_overrides:
23
+ if cfg[override] is not None:
24
+ log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}')
25
+ data_cfg[override] = cfg[override]
26
+ # escalte all potential overrides to the top-level config
27
+ if override in data_cfg:
28
+ cfg[override] = data_cfg[override]
29
+
30
+ return data_cfg
preprocessing/matanyone/matanyone/model/__init__.py ADDED
File without changes
preprocessing/matanyone/matanyone/model/aux_modules.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ For computing auxiliary outputs for auxiliary losses
3
+ """
4
+ from typing import Dict
5
+ from omegaconf import DictConfig
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .group_modules import GConv2d
10
+ from ...utils.tensor_utils import aggregate
11
+
12
+
13
+ class LinearPredictor(nn.Module):
14
+ def __init__(self, x_dim: int, pix_dim: int):
15
+ super().__init__()
16
+ self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1)
17
+
18
+ def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
19
+ # pixel_feat: B*pix_dim*H*W
20
+ # x: B*num_objects*x_dim*H*W
21
+ num_objects = x.shape[1]
22
+ x = self.projection(x)
23
+
24
+ pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
25
+ logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1]
26
+ return logits
27
+
28
+
29
+ class DirectPredictor(nn.Module):
30
+ def __init__(self, x_dim: int):
31
+ super().__init__()
32
+ self.projection = GConv2d(x_dim, 1, kernel_size=1)
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ # x: B*num_objects*x_dim*H*W
36
+ logits = self.projection(x).squeeze(2)
37
+ return logits
38
+
39
+
40
+ class AuxComputer(nn.Module):
41
+ def __init__(self, cfg: DictConfig):
42
+ super().__init__()
43
+
44
+ use_sensory_aux = cfg.model.aux_loss.sensory.enabled
45
+ self.use_query_aux = cfg.model.aux_loss.query.enabled
46
+ self.use_sensory_aux = use_sensory_aux
47
+
48
+ sensory_dim = cfg.model.sensory_dim
49
+ embed_dim = cfg.model.embed_dim
50
+
51
+ if use_sensory_aux:
52
+ self.sensory_aux = LinearPredictor(sensory_dim, embed_dim)
53
+
54
+ def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor:
55
+ prob = torch.sigmoid(logits)
56
+ if selector is not None:
57
+ prob = prob * selector
58
+ logits = aggregate(prob, dim=1)
59
+ return logits
60
+
61
+ def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
62
+ selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
63
+ sensory = aux_input['sensory']
64
+ q_logits = aux_input['q_logits']
65
+
66
+ aux_output = {}
67
+ aux_output['attn_mask'] = aux_input['attn_mask']
68
+
69
+ if self.use_sensory_aux:
70
+ # B*num_objects*H*W
71
+ logits = self.sensory_aux(pix_feat, sensory)
72
+ aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
73
+ if self.use_query_aux:
74
+ # B*num_objects*num_levels*H*W
75
+ aux_output['q_logits'] = self._aggregate_with_selector(
76
+ torch.stack(q_logits, dim=2),
77
+ selector.unsqueeze(2) if selector is not None else None)
78
+
79
+ return aux_output
80
+
81
+ def compute_mask(self, aux_input: Dict[str, torch.Tensor],
82
+ selector: torch.Tensor) -> Dict[str, torch.Tensor]:
83
+ # sensory = aux_input['sensory']
84
+ q_logits = aux_input['q_logits']
85
+
86
+ aux_output = {}
87
+
88
+ # B*num_objects*num_levels*H*W
89
+ aux_output['q_logits'] = self._aggregate_with_selector(
90
+ torch.stack(q_logits, dim=2),
91
+ selector.unsqueeze(2) if selector is not None else None)
92
+
93
+ return aux_output
preprocessing/matanyone/matanyone/model/big_modules.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ big_modules.py - This file stores higher-level network blocks.
3
+
4
+ x - usually denotes features that are shared between objects.
5
+ g - usually denotes features that are not shared between objects
6
+ with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W).
7
+
8
+ The trailing number of a variable usually denotes the stride
9
+ """
10
+
11
+ from typing import Iterable
12
+ from omegaconf import DictConfig
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from .group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d
18
+ from .utils import resnet
19
+ from .modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock
20
+
21
+ class UncertPred(nn.Module):
22
+ def __init__(self, model_cfg: DictConfig):
23
+ super().__init__()
24
+ self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(64)
26
+ self.relu = nn.ReLU(inplace=True)
27
+ self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
28
+ self.bn2 = nn.BatchNorm2d(32)
29
+ self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1)
30
+
31
+ def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
32
+ last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area')
33
+ x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1)
34
+ x = self.conv1x1_v2(x)
35
+ x = self.bn1(x)
36
+ x = self.relu(x)
37
+ x = self.conv3x3(x)
38
+ x = self.bn2(x)
39
+ x = self.relu(x)
40
+ x = self.conv3x3_out(x)
41
+ return x
42
+
43
+ # override the default train() to freeze BN statistics
44
+ def train(self, mode=True):
45
+ self.training = False
46
+ for module in self.children():
47
+ module.train(False)
48
+ return self
49
+
50
+ class PixelEncoder(nn.Module):
51
+ def __init__(self, model_cfg: DictConfig):
52
+ super().__init__()
53
+
54
+ self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type
55
+ # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
56
+ # else default to True
57
+ is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
58
+ if self.is_resnet:
59
+ if model_cfg.pixel_encoder.type == 'resnet18':
60
+ network = resnet.resnet18(pretrained=is_pretrained_resnet)
61
+ elif model_cfg.pixel_encoder.type == 'resnet50':
62
+ network = resnet.resnet50(pretrained=is_pretrained_resnet)
63
+ else:
64
+ raise NotImplementedError
65
+ self.conv1 = network.conv1
66
+ self.bn1 = network.bn1
67
+ self.relu = network.relu
68
+ self.maxpool = network.maxpool
69
+
70
+ self.res2 = network.layer1
71
+ self.layer2 = network.layer2
72
+ self.layer3 = network.layer3
73
+ else:
74
+ raise NotImplementedError
75
+
76
+ def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor):
77
+ f1 = x
78
+ x = self.conv1(x)
79
+ x = self.bn1(x)
80
+ x = self.relu(x)
81
+ f2 = x
82
+ x = self.maxpool(x)
83
+ f4 = self.res2(x)
84
+ f8 = self.layer2(f4)
85
+ f16 = self.layer3(f8)
86
+
87
+ return f16, f8, f4, f2, f1
88
+
89
+ # override the default train() to freeze BN statistics
90
+ def train(self, mode=True):
91
+ self.training = False
92
+ for module in self.children():
93
+ module.train(False)
94
+ return self
95
+
96
+
97
+ class KeyProjection(nn.Module):
98
+ def __init__(self, model_cfg: DictConfig):
99
+ super().__init__()
100
+ in_dim = model_cfg.pixel_encoder.ms_dims[0]
101
+ mid_dim = model_cfg.pixel_dim
102
+ key_dim = model_cfg.key_dim
103
+
104
+ self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1)
105
+ self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
106
+ # shrinkage
107
+ self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1)
108
+ # selection
109
+ self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1)
110
+
111
+ nn.init.orthogonal_(self.key_proj.weight.data)
112
+ nn.init.zeros_(self.key_proj.bias.data)
113
+
114
+ def forward(self, x: torch.Tensor, *, need_s: bool,
115
+ need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
116
+ x = self.pix_feat_proj(x)
117
+ shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
118
+ selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
119
+
120
+ return self.key_proj(x), shrinkage, selection
121
+
122
+
123
+ class MaskEncoder(nn.Module):
124
+ def __init__(self, model_cfg: DictConfig, single_object=False):
125
+ super().__init__()
126
+ pixel_dim = model_cfg.pixel_dim
127
+ value_dim = model_cfg.value_dim
128
+ sensory_dim = model_cfg.sensory_dim
129
+ final_dim = model_cfg.mask_encoder.final_dim
130
+
131
+ self.single_object = single_object
132
+ extra_dim = 1 if single_object else 2
133
+
134
+ # if model_cfg.pretrained_resnet is set in the model_cfg we get the value
135
+ # else default to True
136
+ is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True)
137
+ if model_cfg.mask_encoder.type == 'resnet18':
138
+ network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
139
+ elif model_cfg.mask_encoder.type == 'resnet50':
140
+ network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim)
141
+ else:
142
+ raise NotImplementedError
143
+ self.conv1 = network.conv1
144
+ self.bn1 = network.bn1
145
+ self.relu = network.relu
146
+ self.maxpool = network.maxpool
147
+
148
+ self.layer1 = network.layer1
149
+ self.layer2 = network.layer2
150
+ self.layer3 = network.layer3
151
+
152
+ self.distributor = MainToGroupDistributor()
153
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim)
154
+
155
+ self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim)
156
+
157
+ def forward(self,
158
+ image: torch.Tensor,
159
+ pix_feat: torch.Tensor,
160
+ sensory: torch.Tensor,
161
+ masks: torch.Tensor,
162
+ others: torch.Tensor,
163
+ *,
164
+ deep_update: bool = True,
165
+ chunk_size: int = -1) -> (torch.Tensor, torch.Tensor):
166
+ # ms_features are from the key encoder
167
+ # we only use the first one (lowest resolution), following XMem
168
+ if self.single_object:
169
+ g = masks.unsqueeze(2)
170
+ else:
171
+ g = torch.stack([masks, others], dim=2)
172
+
173
+ g = self.distributor(image, g)
174
+
175
+ batch_size, num_objects = g.shape[:2]
176
+ if chunk_size < 1 or chunk_size >= num_objects:
177
+ chunk_size = num_objects
178
+ fast_path = True
179
+ new_sensory = sensory
180
+ else:
181
+ if deep_update:
182
+ new_sensory = torch.empty_like(sensory)
183
+ else:
184
+ new_sensory = sensory
185
+ fast_path = False
186
+
187
+ # chunk-by-chunk inference
188
+ all_g = []
189
+ for i in range(0, num_objects, chunk_size):
190
+ if fast_path:
191
+ g_chunk = g
192
+ else:
193
+ g_chunk = g[:, i:i + chunk_size]
194
+ actual_chunk_size = g_chunk.shape[1]
195
+ g_chunk = g_chunk.flatten(start_dim=0, end_dim=1)
196
+
197
+ g_chunk = self.conv1(g_chunk)
198
+ g_chunk = self.bn1(g_chunk) # 1/2, 64
199
+ g_chunk = self.maxpool(g_chunk) # 1/4, 64
200
+ g_chunk = self.relu(g_chunk)
201
+
202
+ g_chunk = self.layer1(g_chunk) # 1/4
203
+ g_chunk = self.layer2(g_chunk) # 1/8
204
+ g_chunk = self.layer3(g_chunk) # 1/16
205
+
206
+ g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:])
207
+ g_chunk = self.fuser(pix_feat, g_chunk)
208
+ all_g.append(g_chunk)
209
+ if deep_update:
210
+ if fast_path:
211
+ new_sensory = self.sensory_update(g_chunk, sensory)
212
+ else:
213
+ new_sensory[:, i:i + chunk_size] = self.sensory_update(
214
+ g_chunk, sensory[:, i:i + chunk_size])
215
+ g = torch.cat(all_g, dim=1)
216
+
217
+ return g, new_sensory
218
+
219
+ # override the default train() to freeze BN statistics
220
+ def train(self, mode=True):
221
+ self.training = False
222
+ for module in self.children():
223
+ module.train(False)
224
+ return self
225
+
226
+
227
+ class PixelFeatureFuser(nn.Module):
228
+ def __init__(self, model_cfg: DictConfig, single_object=False):
229
+ super().__init__()
230
+ value_dim = model_cfg.value_dim
231
+ sensory_dim = model_cfg.sensory_dim
232
+ pixel_dim = model_cfg.pixel_dim
233
+ embed_dim = model_cfg.embed_dim
234
+ self.single_object = single_object
235
+
236
+ self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim)
237
+ if self.single_object:
238
+ self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1)
239
+ else:
240
+ self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1)
241
+
242
+ def forward(self,
243
+ pix_feat: torch.Tensor,
244
+ pixel_memory: torch.Tensor,
245
+ sensory_memory: torch.Tensor,
246
+ last_mask: torch.Tensor,
247
+ last_others: torch.Tensor,
248
+ *,
249
+ chunk_size: int = -1) -> torch.Tensor:
250
+ batch_size, num_objects = pixel_memory.shape[:2]
251
+
252
+ if self.single_object:
253
+ last_mask = last_mask.unsqueeze(2)
254
+ else:
255
+ last_mask = torch.stack([last_mask, last_others], dim=2)
256
+
257
+ if chunk_size < 1:
258
+ chunk_size = num_objects
259
+
260
+ # chunk-by-chunk inference
261
+ all_p16 = []
262
+ for i in range(0, num_objects, chunk_size):
263
+ sensory_readout = self.sensory_compress(
264
+ torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2))
265
+ p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout
266
+ p16 = self.fuser(pix_feat, p16)
267
+ all_p16.append(p16)
268
+ p16 = torch.cat(all_p16, dim=1)
269
+
270
+ return p16
271
+
272
+
273
+ class MaskDecoder(nn.Module):
274
+ def __init__(self, model_cfg: DictConfig):
275
+ super().__init__()
276
+ embed_dim = model_cfg.embed_dim
277
+ sensory_dim = model_cfg.sensory_dim
278
+ ms_image_dims = model_cfg.pixel_encoder.ms_dims
279
+ up_dims = model_cfg.mask_decoder.up_dims
280
+
281
+ assert embed_dim == up_dims[0]
282
+
283
+ self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim,
284
+ sensory_dim)
285
+
286
+ self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1])
287
+ self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1])
288
+ self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2])
289
+ # newly add for alpha matte
290
+ self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3])
291
+ self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4])
292
+
293
+ self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
294
+ self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1)
295
+
296
+ def forward(self,
297
+ ms_image_feat: Iterable[torch.Tensor],
298
+ memory_readout: torch.Tensor,
299
+ sensory: torch.Tensor,
300
+ *,
301
+ chunk_size: int = -1,
302
+ update_sensory: bool = True,
303
+ seg_pass: bool = False,
304
+ last_mask=None,
305
+ sigmoid_residual=False) -> (torch.Tensor, torch.Tensor):
306
+
307
+ batch_size, num_objects = memory_readout.shape[:2]
308
+ f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:])
309
+ if chunk_size < 1 or chunk_size >= num_objects:
310
+ chunk_size = num_objects
311
+ fast_path = True
312
+ new_sensory = sensory
313
+ else:
314
+ if update_sensory:
315
+ new_sensory = torch.empty_like(sensory)
316
+ else:
317
+ new_sensory = sensory
318
+ fast_path = False
319
+
320
+ # chunk-by-chunk inference
321
+ all_logits = []
322
+ for i in range(0, num_objects, chunk_size):
323
+ if fast_path:
324
+ p16 = memory_readout
325
+ else:
326
+ p16 = memory_readout[:, i:i + chunk_size]
327
+ actual_chunk_size = p16.shape[1]
328
+
329
+ p8 = self.up_16_8(p16, f8)
330
+ p4 = self.up_8_4(p8, f4)
331
+ p2 = self.up_4_2(p4, f2)
332
+ p1 = self.up_2_1(p2, f1)
333
+ with torch.amp.autocast("cuda"):
334
+ if seg_pass:
335
+ if last_mask is not None:
336
+ res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
337
+ if sigmoid_residual:
338
+ res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
339
+ logits = last_mask + res
340
+ else:
341
+ logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
342
+ else:
343
+ if last_mask is not None:
344
+ res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
345
+ if sigmoid_residual:
346
+ res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask
347
+ logits = last_mask + res
348
+ else:
349
+ logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float()))
350
+ ## SensoryUpdater_fullscale
351
+ if update_sensory:
352
+ p1 = torch.cat(
353
+ [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2)
354
+ if fast_path:
355
+ new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory)
356
+ else:
357
+ new_sensory[:,
358
+ i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1],
359
+ sensory[:,
360
+ i:i + chunk_size])
361
+ all_logits.append(logits)
362
+ logits = torch.cat(all_logits, dim=0)
363
+ logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
364
+
365
+ return new_sensory, logits
preprocessing/matanyone/matanyone/model/channel_attn.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class CAResBlock(nn.Module):
8
+ def __init__(self, in_dim: int, out_dim: int, residual: bool = True):
9
+ super().__init__()
10
+ self.residual = residual
11
+ self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
12
+ self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
13
+
14
+ t = int((abs(math.log2(out_dim)) + 1) // 2)
15
+ k = t if t % 2 else t + 1
16
+ self.pool = nn.AdaptiveAvgPool2d(1)
17
+ self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
18
+
19
+ if self.residual:
20
+ if in_dim == out_dim:
21
+ self.downsample = nn.Identity()
22
+ else:
23
+ self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ r = x
27
+ x = self.conv1(F.relu(x))
28
+ x = self.conv2(F.relu(x))
29
+
30
+ b, c = x.shape[:2]
31
+ w = self.pool(x).view(b, 1, c)
32
+ w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1
33
+
34
+ if self.residual:
35
+ x = x * w + self.downsample(r)
36
+ else:
37
+ x = x * w
38
+
39
+ return x
preprocessing/matanyone/matanyone/model/group_modules.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from .channel_attn import CAResBlock
6
+
7
+ def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
8
+ align_corners: bool) -> torch.Tensor:
9
+ batch_size, num_objects = g.shape[:2]
10
+ g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
11
+ scale_factor=ratio,
12
+ mode=mode,
13
+ align_corners=align_corners)
14
+ g = g.view(batch_size, num_objects, *g.shape[1:])
15
+ return g
16
+
17
+
18
+ def upsample_groups(g: torch.Tensor,
19
+ ratio: float = 2,
20
+ mode: str = 'bilinear',
21
+ align_corners: bool = False) -> torch.Tensor:
22
+ return interpolate_groups(g, ratio, mode, align_corners)
23
+
24
+
25
+ def downsample_groups(g: torch.Tensor,
26
+ ratio: float = 1 / 2,
27
+ mode: str = 'area',
28
+ align_corners: bool = None) -> torch.Tensor:
29
+ return interpolate_groups(g, ratio, mode, align_corners)
30
+
31
+
32
+ class GConv2d(nn.Conv2d):
33
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
34
+ batch_size, num_objects = g.shape[:2]
35
+ g = super().forward(g.flatten(start_dim=0, end_dim=1))
36
+ return g.view(batch_size, num_objects, *g.shape[1:])
37
+
38
+
39
+ class GroupResBlock(nn.Module):
40
+ def __init__(self, in_dim: int, out_dim: int):
41
+ super().__init__()
42
+
43
+ if in_dim == out_dim:
44
+ self.downsample = nn.Identity()
45
+ else:
46
+ self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)
47
+
48
+ self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
49
+ self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)
50
+
51
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
52
+ out_g = self.conv1(F.relu(g))
53
+ out_g = self.conv2(F.relu(out_g))
54
+
55
+ g = self.downsample(g)
56
+
57
+ return out_g + g
58
+
59
+
60
+ class MainToGroupDistributor(nn.Module):
61
+ def __init__(self,
62
+ x_transform: Optional[nn.Module] = None,
63
+ g_transform: Optional[nn.Module] = None,
64
+ method: str = 'cat',
65
+ reverse_order: bool = False):
66
+ super().__init__()
67
+
68
+ self.x_transform = x_transform
69
+ self.g_transform = g_transform
70
+ self.method = method
71
+ self.reverse_order = reverse_order
72
+
73
+ def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
74
+ num_objects = g.shape[1]
75
+
76
+ if self.x_transform is not None:
77
+ x = self.x_transform(x)
78
+
79
+ if self.g_transform is not None:
80
+ g = self.g_transform(g)
81
+
82
+ if not skip_expand:
83
+ x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
84
+ if self.method == 'cat':
85
+ if self.reverse_order:
86
+ g = torch.cat([g, x], 2)
87
+ else:
88
+ g = torch.cat([x, g], 2)
89
+ elif self.method == 'add':
90
+ g = x + g
91
+ elif self.method == 'mulcat':
92
+ g = torch.cat([x * g, g], dim=2)
93
+ elif self.method == 'muladd':
94
+ g = x * g + g
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ return g
99
+
100
+
101
+ class GroupFeatureFusionBlock(nn.Module):
102
+ def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
103
+ super().__init__()
104
+
105
+ x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
106
+ g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)
107
+
108
+ self.distributor = MainToGroupDistributor(x_transform=x_transform,
109
+ g_transform=g_transform,
110
+ method='add')
111
+ self.block1 = CAResBlock(out_dim, out_dim)
112
+ self.block2 = CAResBlock(out_dim, out_dim)
113
+
114
+ def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
115
+ batch_size, num_objects = g.shape[:2]
116
+
117
+ g = self.distributor(x, g)
118
+
119
+ g = g.flatten(start_dim=0, end_dim=1)
120
+
121
+ g = self.block1(g)
122
+ g = self.block2(g)
123
+
124
+ g = g.view(batch_size, num_objects, *g.shape[1:])
125
+
126
+ return g
preprocessing/matanyone/matanyone/model/matanyone.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Iterable
2
+ import logging
3
+ from omegaconf import DictConfig
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from omegaconf import OmegaConf
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ from .big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder
11
+ from .aux_modules import AuxComputer
12
+ from .utils.memory_utils import get_affinity, readout
13
+ from .transformer.object_transformer import QueryTransformer
14
+ from .transformer.object_summarizer import ObjectSummarizer
15
+ from ...utils.tensor_utils import aggregate
16
+
17
+ log = logging.getLogger()
18
+ class MatAnyone(nn.Module,
19
+ PyTorchModelHubMixin,
20
+ library_name="matanyone",
21
+ repo_url="https://github.com/pq-yang/MatAnyone",
22
+ coders={
23
+ DictConfig: (
24
+ lambda x: OmegaConf.to_container(x),
25
+ lambda data: OmegaConf.create(data),
26
+ )
27
+ },
28
+ ):
29
+
30
+ def __init__(self, cfg: DictConfig, *, single_object=False):
31
+ super().__init__()
32
+ self.cfg = cfg
33
+ model_cfg = cfg.model
34
+ self.ms_dims = model_cfg.pixel_encoder.ms_dims
35
+ self.key_dim = model_cfg.key_dim
36
+ self.value_dim = model_cfg.value_dim
37
+ self.sensory_dim = model_cfg.sensory_dim
38
+ self.pixel_dim = model_cfg.pixel_dim
39
+ self.embed_dim = model_cfg.embed_dim
40
+ self.single_object = single_object
41
+
42
+ log.info(f'Single object: {self.single_object}')
43
+
44
+ self.pixel_encoder = PixelEncoder(model_cfg)
45
+ self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
46
+ self.key_proj = KeyProjection(model_cfg)
47
+ self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
48
+ self.mask_decoder = MaskDecoder(model_cfg)
49
+ self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
50
+ self.object_transformer = QueryTransformer(model_cfg)
51
+ self.object_summarizer = ObjectSummarizer(model_cfg)
52
+ self.aux_computer = AuxComputer(cfg)
53
+ self.temp_sparity = UncertPred(model_cfg)
54
+
55
+ self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
56
+ self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False)
57
+
58
+ def _get_others(self, masks: torch.Tensor) -> torch.Tensor:
59
+ # for each object, return the sum of masks of all other objects
60
+ if self.single_object:
61
+ return None
62
+
63
+ num_objects = masks.shape[1]
64
+ if num_objects >= 1:
65
+ others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1)
66
+ else:
67
+ others = torch.zeros_like(masks)
68
+ return others
69
+
70
+ def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor):
71
+ logits = self.temp_sparity(last_frame_feat=last_pix_feat,
72
+ cur_frame_feat=cur_pix_feat,
73
+ last_mask=last_mask,
74
+ mem_val_diff=mem_val_diff)
75
+
76
+ prob = torch.sigmoid(logits)
77
+ mask = (prob > 0) + 0
78
+
79
+ uncert_output = {"logits": logits,
80
+ "prob": prob,
81
+ "mask": mask}
82
+
83
+ return uncert_output
84
+
85
+ def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore
86
+ image = (image - self.pixel_mean) / self.pixel_std
87
+ ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1
88
+ return ms_image_feat, self.pix_feat_proj(ms_image_feat[0])
89
+
90
+ def encode_mask(
91
+ self,
92
+ image: torch.Tensor,
93
+ ms_features: List[torch.Tensor],
94
+ sensory: torch.Tensor,
95
+ masks: torch.Tensor,
96
+ *,
97
+ deep_update: bool = True,
98
+ chunk_size: int = -1,
99
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
100
+ image = (image - self.pixel_mean) / self.pixel_std
101
+ others = self._get_others(masks)
102
+ mask_value, new_sensory = self.mask_encoder(image,
103
+ ms_features,
104
+ sensory,
105
+ masks,
106
+ others,
107
+ deep_update=deep_update,
108
+ chunk_size=chunk_size)
109
+ object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
110
+ return mask_value, new_sensory, object_summaries, object_logits
111
+
112
+ def transform_key(self,
113
+ final_pix_feat: torch.Tensor,
114
+ *,
115
+ need_sk: bool = True,
116
+ need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor):
117
+ key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek)
118
+ return key, shrinkage, selection
119
+
120
+ # Used in training only.
121
+ # This step is replaced by MemoryManager in test time
122
+ def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor,
123
+ memory_key: torch.Tensor, memory_shrinkage: torch.Tensor,
124
+ msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor,
125
+ sensory: torch.Tensor, last_mask: torch.Tensor,
126
+ selector: torch.Tensor, uncert_output=None, seg_pass=False,
127
+ last_pix_feat=None, last_pred_mask=None) -> (torch.Tensor, Dict[str, torch.Tensor]):
128
+ """
129
+ query_key : B * CK * H * W
130
+ query_selection : B * CK * H * W
131
+ memory_key : B * CK * T * H * W
132
+ memory_shrinkage: B * 1 * T * H * W
133
+ msk_value : B * num_objects * CV * T * H * W
134
+ obj_memory : B * num_objects * T * num_summaries * C
135
+ pixel_feature : B * C * H * W
136
+ """
137
+ batch_size, num_objects = msk_value.shape[:2]
138
+
139
+ uncert_mask = uncert_output["mask"] if uncert_output is not None else None
140
+
141
+ # read using visual attention
142
+ with torch.cuda.amp.autocast(enabled=False):
143
+ affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(),
144
+ query_selection.float(), uncert_mask=uncert_mask)
145
+
146
+ msk_value = msk_value.flatten(start_dim=1, end_dim=2).float()
147
+
148
+ # B * (num_objects*CV) * H * W
149
+ pixel_readout = readout(affinity, msk_value, uncert_mask)
150
+ pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim,
151
+ *pixel_readout.shape[-2:])
152
+
153
+ uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1])
154
+ uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w
155
+ pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob)
156
+
157
+ pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
158
+
159
+
160
+ # read from query transformer
161
+ mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
162
+
163
+ aux_output = {
164
+ 'sensory': sensory,
165
+ 'q_logits': aux_features['logits'] if aux_features else None,
166
+ 'attn_mask': aux_features['attn_mask'] if aux_features else None,
167
+ }
168
+
169
+ return mem_readout, aux_output, uncert_output
170
+
171
+ def read_first_frame_memory(self, pixel_readout,
172
+ obj_memory: torch.Tensor, pix_feat: torch.Tensor,
173
+ sensory: torch.Tensor, last_mask: torch.Tensor,
174
+ selector: torch.Tensor, seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
175
+ """
176
+ query_key : B * CK * H * W
177
+ query_selection : B * CK * H * W
178
+ memory_key : B * CK * T * H * W
179
+ memory_shrinkage: B * 1 * T * H * W
180
+ msk_value : B * num_objects * CV * T * H * W
181
+ obj_memory : B * num_objects * T * num_summaries * C
182
+ pixel_feature : B * C * H * W
183
+ """
184
+
185
+ pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask)
186
+
187
+ # read from query transformer
188
+ mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass)
189
+
190
+ aux_output = {
191
+ 'sensory': sensory,
192
+ 'q_logits': aux_features['logits'] if aux_features else None,
193
+ 'attn_mask': aux_features['attn_mask'] if aux_features else None,
194
+ }
195
+
196
+ return mem_readout, aux_output
197
+
198
+ def pixel_fusion(self,
199
+ pix_feat: torch.Tensor,
200
+ pixel: torch.Tensor,
201
+ sensory: torch.Tensor,
202
+ last_mask: torch.Tensor,
203
+ *,
204
+ chunk_size: int = -1) -> torch.Tensor:
205
+ last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area')
206
+ last_others = self._get_others(last_mask)
207
+ fused = self.pixel_fuser(pix_feat,
208
+ pixel,
209
+ sensory,
210
+ last_mask,
211
+ last_others,
212
+ chunk_size=chunk_size)
213
+ return fused
214
+
215
+ def readout_query(self,
216
+ pixel_readout,
217
+ obj_memory,
218
+ *,
219
+ selector=None,
220
+ need_weights=False,
221
+ seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
222
+ return self.object_transformer(pixel_readout,
223
+ obj_memory,
224
+ selector=selector,
225
+ need_weights=need_weights,
226
+ seg_pass=seg_pass)
227
+
228
+ def segment(self,
229
+ ms_image_feat: List[torch.Tensor],
230
+ memory_readout: torch.Tensor,
231
+ sensory: torch.Tensor,
232
+ *,
233
+ selector: bool = None,
234
+ chunk_size: int = -1,
235
+ update_sensory: bool = True,
236
+ seg_pass: bool = False,
237
+ clamp_mat: bool = True,
238
+ last_mask=None,
239
+ sigmoid_residual=False,
240
+ seg_mat=False) -> (torch.Tensor, torch.Tensor, torch.Tensor):
241
+ """
242
+ multi_scale_features is from the key encoder for skip-connection
243
+ memory_readout is from working/long-term memory
244
+ sensory is the sensory memory
245
+ last_mask is the mask from the last frame, supplementing sensory memory
246
+ selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects
247
+ during training.
248
+ """
249
+ #### use mat head for seg data
250
+ if seg_mat:
251
+ assert seg_pass
252
+ seg_pass = False
253
+ ####
254
+ sensory, logits = self.mask_decoder(ms_image_feat,
255
+ memory_readout,
256
+ sensory,
257
+ chunk_size=chunk_size,
258
+ update_sensory=update_sensory,
259
+ seg_pass = seg_pass,
260
+ last_mask=last_mask,
261
+ sigmoid_residual=sigmoid_residual)
262
+ if seg_pass:
263
+ prob = torch.sigmoid(logits)
264
+ if selector is not None:
265
+ prob = prob * selector
266
+
267
+ # Softmax over all objects[]
268
+ logits = aggregate(prob, dim=1)
269
+ prob = F.softmax(logits, dim=1)
270
+ else:
271
+ if clamp_mat:
272
+ logits = logits.clamp(0.0, 1.0)
273
+ logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1)
274
+ prob = logits
275
+
276
+ return sensory, logits, prob
277
+
278
+ def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor],
279
+ selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]:
280
+ return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass)
281
+
282
+ def forward(self, *args, **kwargs):
283
+ raise NotImplementedError
284
+
285
+ def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None:
286
+ if not self.single_object:
287
+ # Map single-object weight to multi-object weight (4->5 out channels in conv1)
288
+ for k in list(src_dict.keys()):
289
+ if k == 'mask_encoder.conv1.weight':
290
+ if src_dict[k].shape[1] == 4:
291
+ log.info(f'Converting {k} from single object to multiple objects.')
292
+ pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
293
+ if not init_as_zero_if_needed:
294
+ nn.init.orthogonal_(pads)
295
+ log.info(f'Randomly initialized padding for {k}.')
296
+ else:
297
+ log.info(f'Zero-initialized padding for {k}.')
298
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
299
+ elif k == 'pixel_fuser.sensory_compress.weight':
300
+ if src_dict[k].shape[1] == self.sensory_dim + 1:
301
+ log.info(f'Converting {k} from single object to multiple objects.')
302
+ pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device)
303
+ if not init_as_zero_if_needed:
304
+ nn.init.orthogonal_(pads)
305
+ log.info(f'Randomly initialized padding for {k}.')
306
+ else:
307
+ log.info(f'Zero-initialized padding for {k}.')
308
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
309
+ elif self.single_object:
310
+ """
311
+ If the model is multiple-object and we are training in single-object,
312
+ we strip the last channel of conv1.
313
+ This is not supposed to happen in standard training except when users are trying to
314
+ finetune a trained model with single object datasets.
315
+ """
316
+ if src_dict['mask_encoder.conv1.weight'].shape[1] == 5:
317
+ log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.'
318
+ 'This is not supposed to happen in standard training.')
319
+ src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1]
320
+ src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1]
321
+
322
+ for k in src_dict:
323
+ if k not in self.state_dict():
324
+ log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!')
325
+ for k in self.state_dict():
326
+ if k not in src_dict:
327
+ log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!')
328
+
329
+ self.load_state_dict(src_dict, strict=False)
330
+
331
+ @property
332
+ def device(self) -> torch.device:
333
+ return self.pixel_mean.device
preprocessing/matanyone/matanyone/model/modules.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterable
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups
7
+
8
+
9
+ class UpsampleBlock(nn.Module):
10
+ def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
11
+ super().__init__()
12
+ self.out_conv = ResBlock(in_dim, out_dim)
13
+ self.scale_factor = scale_factor
14
+
15
+ def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
16
+ g = F.interpolate(in_g,
17
+ scale_factor=self.scale_factor,
18
+ mode='bilinear')
19
+ g = self.out_conv(g)
20
+ g = g + skip_f
21
+ return g
22
+
23
+ class MaskUpsampleBlock(nn.Module):
24
+ def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2):
25
+ super().__init__()
26
+ self.distributor = MainToGroupDistributor(method='add')
27
+ self.out_conv = GroupResBlock(in_dim, out_dim)
28
+ self.scale_factor = scale_factor
29
+
30
+ def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor:
31
+ g = upsample_groups(in_g, ratio=self.scale_factor)
32
+ g = self.distributor(skip_f, g)
33
+ g = self.out_conv(g)
34
+ return g
35
+
36
+
37
+ class DecoderFeatureProcessor(nn.Module):
38
+ def __init__(self, decoder_dims: List[int], out_dims: List[int]):
39
+ super().__init__()
40
+ self.transforms = nn.ModuleList([
41
+ nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
42
+ ])
43
+
44
+ def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
45
+ outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
46
+ return outputs
47
+
48
+
49
+ # @torch.jit.script
50
+ def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
51
+ # h: batch_size * num_objects * hidden_dim * h * w
52
+ # values: batch_size * num_objects * (hidden_dim*3) * h * w
53
+ dim = values.shape[2] // 3
54
+ forget_gate = torch.sigmoid(values[:, :, :dim])
55
+ update_gate = torch.sigmoid(values[:, :, dim:dim * 2])
56
+ new_value = torch.tanh(values[:, :, dim * 2:])
57
+ new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
58
+ return new_h
59
+
60
+
61
+ class SensoryUpdater_fullscale(nn.Module):
62
+ # Used in the decoder, multi-scale feature + GRU
63
+ def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
64
+ super().__init__()
65
+ self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
66
+ self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
67
+ self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
68
+ self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1)
69
+ self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1)
70
+
71
+ self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
72
+
73
+ nn.init.xavier_normal_(self.transform.weight)
74
+
75
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
76
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
77
+ self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \
78
+ self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \
79
+ self.g1_conv(downsample_groups(g[4], ratio=1/16))
80
+
81
+ with torch.amp.autocast("cuda"):
82
+ g = g.float()
83
+ h = h.float()
84
+ values = self.transform(torch.cat([g, h], dim=2))
85
+ new_h = _recurrent_update(h, values)
86
+
87
+ return new_h
88
+
89
+ class SensoryUpdater(nn.Module):
90
+ # Used in the decoder, multi-scale feature + GRU
91
+ def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
92
+ super().__init__()
93
+ self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1)
94
+ self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1)
95
+ self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1)
96
+
97
+ self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
98
+
99
+ nn.init.xavier_normal_(self.transform.weight)
100
+
101
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
102
+ g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
103
+ self.g4_conv(downsample_groups(g[2], ratio=1/4))
104
+
105
+ with torch.amp.autocast("cuda"):
106
+ g = g.float()
107
+ h = h.float()
108
+ values = self.transform(torch.cat([g, h], dim=2))
109
+ new_h = _recurrent_update(h, values)
110
+
111
+ return new_h
112
+
113
+
114
+ class SensoryDeepUpdater(nn.Module):
115
+ def __init__(self, f_dim: int, sensory_dim: int):
116
+ super().__init__()
117
+ self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
118
+
119
+ nn.init.xavier_normal_(self.transform.weight)
120
+
121
+ def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
122
+ with torch.amp.autocast("cuda"):
123
+ g = g.float()
124
+ h = h.float()
125
+ values = self.transform(torch.cat([g, h], dim=2))
126
+ new_h = _recurrent_update(h, values)
127
+
128
+ return new_h
129
+
130
+
131
+ class ResBlock(nn.Module):
132
+ def __init__(self, in_dim: int, out_dim: int):
133
+ super().__init__()
134
+
135
+ if in_dim == out_dim:
136
+ self.downsample = nn.Identity()
137
+ else:
138
+ self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1)
139
+
140
+ self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
141
+ self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
142
+
143
+ def forward(self, g: torch.Tensor) -> torch.Tensor:
144
+ out_g = self.conv1(F.relu(g))
145
+ out_g = self.conv2(F.relu(out_g))
146
+
147
+ g = self.downsample(g)
148
+
149
+ return out_g + g
preprocessing/matanyone/matanyone/model/transformer/__init__.py ADDED
File without changes
preprocessing/matanyone/matanyone/model/transformer/object_summarizer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from omegaconf import DictConfig
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from .positional_encoding import PositionalEncoding
8
+
9
+
10
+ # @torch.jit.script
11
+ def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor,
12
+ logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
13
+ # value: B*num_objects*H*W*value_dim
14
+ # logits: B*num_objects*H*W*num_summaries
15
+ # masks: B*num_objects*H*W*num_summaries: 1 if allowed
16
+ weights = logits.sigmoid() * masks
17
+ # B*num_objects*num_summaries*value_dim
18
+ sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value)
19
+ # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1
20
+ area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1)
21
+
22
+ # B*num_objects*num_summaries*value_dim
23
+ return sums, area
24
+
25
+
26
+ class ObjectSummarizer(nn.Module):
27
+ def __init__(self, model_cfg: DictConfig):
28
+ super().__init__()
29
+
30
+ this_cfg = model_cfg.object_summarizer
31
+ self.value_dim = model_cfg.value_dim
32
+ self.embed_dim = this_cfg.embed_dim
33
+ self.num_summaries = this_cfg.num_summaries
34
+ self.add_pe = this_cfg.add_pe
35
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
36
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
37
+
38
+ if self.add_pe:
39
+ self.pos_enc = PositionalEncoding(self.embed_dim,
40
+ scale=self.pixel_pe_scale,
41
+ temperature=self.pixel_pe_temperature)
42
+
43
+ self.input_proj = nn.Linear(self.value_dim, self.embed_dim)
44
+ self.feature_pred = nn.Sequential(
45
+ nn.Linear(self.embed_dim, self.embed_dim),
46
+ nn.ReLU(inplace=True),
47
+ nn.Linear(self.embed_dim, self.embed_dim),
48
+ )
49
+ self.weights_pred = nn.Sequential(
50
+ nn.Linear(self.embed_dim, self.embed_dim),
51
+ nn.ReLU(inplace=True),
52
+ nn.Linear(self.embed_dim, self.num_summaries),
53
+ )
54
+
55
+ def forward(self,
56
+ masks: torch.Tensor,
57
+ value: torch.Tensor,
58
+ need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]):
59
+ # masks: B*num_objects*(H0)*(W0)
60
+ # value: B*num_objects*value_dim*H*W
61
+ # -> B*num_objects*H*W*value_dim
62
+ h, w = value.shape[-2:]
63
+ masks = F.interpolate(masks, size=(h, w), mode='area')
64
+ masks = masks.unsqueeze(-1)
65
+ inv_masks = 1 - masks
66
+ repeated_masks = torch.cat([
67
+ masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
68
+ inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2),
69
+ ],
70
+ dim=-1)
71
+
72
+ value = value.permute(0, 1, 3, 4, 2)
73
+ value = self.input_proj(value)
74
+ if self.add_pe:
75
+ pe = self.pos_enc(value)
76
+ value = value + pe
77
+
78
+ with torch.amp.autocast("cuda"):
79
+ value = value.float()
80
+ feature = self.feature_pred(value)
81
+ logits = self.weights_pred(value)
82
+ sums, area = _weighted_pooling(repeated_masks, feature, logits)
83
+
84
+ summaries = torch.cat([sums, area], dim=-1)
85
+
86
+ if need_weights:
87
+ return summaries, logits
88
+ else:
89
+ return summaries, None
preprocessing/matanyone/matanyone/model/transformer/object_transformer.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ from omegaconf import DictConfig
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from ..group_modules import GConv2d
7
+ from ....utils.tensor_utils import aggregate
8
+ from .positional_encoding import PositionalEncoding
9
+ from .transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN
10
+
11
+
12
+ class QueryTransformerBlock(nn.Module):
13
+ def __init__(self, model_cfg: DictConfig):
14
+ super().__init__()
15
+
16
+ this_cfg = model_cfg.object_transformer
17
+ self.embed_dim = this_cfg.embed_dim
18
+ self.num_heads = this_cfg.num_heads
19
+ self.num_queries = this_cfg.num_queries
20
+ self.ff_dim = this_cfg.ff_dim
21
+
22
+ self.read_from_pixel = CrossAttention(self.embed_dim,
23
+ self.num_heads,
24
+ add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv)
25
+ self.self_attn = SelfAttention(self.embed_dim,
26
+ self.num_heads,
27
+ add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv)
28
+ self.ffn = FFN(self.embed_dim, self.ff_dim)
29
+ self.read_from_query = CrossAttention(self.embed_dim,
30
+ self.num_heads,
31
+ add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv,
32
+ norm=this_cfg.read_from_query.output_norm)
33
+ self.pixel_ffn = PixelFFN(self.embed_dim)
34
+
35
+ def forward(
36
+ self,
37
+ x: torch.Tensor,
38
+ pixel: torch.Tensor,
39
+ query_pe: torch.Tensor,
40
+ pixel_pe: torch.Tensor,
41
+ attn_mask: torch.Tensor,
42
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
43
+ # x: (bs*num_objects)*num_queries*embed_dim
44
+ # pixel: bs*num_objects*C*H*W
45
+ # query_pe: (bs*num_objects)*num_queries*embed_dim
46
+ # pixel_pe: (bs*num_objects)*(H*W)*C
47
+ # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W)
48
+
49
+ # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C
50
+ pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
51
+ x, q_weights = self.read_from_pixel(x,
52
+ pixel_flat,
53
+ query_pe,
54
+ pixel_pe,
55
+ attn_mask=attn_mask,
56
+ need_weights=need_weights)
57
+ x = self.self_attn(x, query_pe)
58
+ x = self.ffn(x)
59
+
60
+ pixel_flat, p_weights = self.read_from_query(pixel_flat,
61
+ x,
62
+ pixel_pe,
63
+ query_pe,
64
+ need_weights=need_weights)
65
+ pixel = self.pixel_ffn(pixel, pixel_flat)
66
+
67
+ if need_weights:
68
+ bs, num_objects, _, h, w = pixel.shape
69
+ q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w)
70
+ p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads,
71
+ self.num_queries, h, w)
72
+
73
+ return x, pixel, q_weights, p_weights
74
+
75
+
76
+ class QueryTransformer(nn.Module):
77
+ def __init__(self, model_cfg: DictConfig):
78
+ super().__init__()
79
+
80
+ this_cfg = model_cfg.object_transformer
81
+ self.value_dim = model_cfg.value_dim
82
+ self.embed_dim = this_cfg.embed_dim
83
+ self.num_heads = this_cfg.num_heads
84
+ self.num_queries = this_cfg.num_queries
85
+
86
+ # query initialization and embedding
87
+ self.query_init = nn.Embedding(self.num_queries, self.embed_dim)
88
+ self.query_emb = nn.Embedding(self.num_queries, self.embed_dim)
89
+
90
+ # projection from object summaries to query initialization and embedding
91
+ self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim)
92
+ self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim)
93
+
94
+ self.pixel_pe_scale = model_cfg.pixel_pe_scale
95
+ self.pixel_pe_temperature = model_cfg.pixel_pe_temperature
96
+ self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
97
+ self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1)
98
+ self.spatial_pe = PositionalEncoding(self.embed_dim,
99
+ scale=self.pixel_pe_scale,
100
+ temperature=self.pixel_pe_temperature,
101
+ channel_last=False,
102
+ transpose_output=True)
103
+
104
+ # transformer blocks
105
+ self.num_blocks = this_cfg.num_blocks
106
+ self.blocks = nn.ModuleList(
107
+ QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks))
108
+ self.mask_pred = nn.ModuleList(
109
+ nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1))
110
+ for _ in range(self.num_blocks + 1))
111
+
112
+ self.act = nn.ReLU(inplace=True)
113
+
114
+ def forward(self,
115
+ pixel: torch.Tensor,
116
+ obj_summaries: torch.Tensor,
117
+ selector: Optional[torch.Tensor] = None,
118
+ need_weights: bool = False,
119
+ seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
120
+ # pixel: B*num_objects*embed_dim*H*W
121
+ # obj_summaries: B*num_objects*T*num_queries*embed_dim
122
+ T = obj_summaries.shape[2]
123
+ bs, num_objects, _, H, W = pixel.shape
124
+
125
+ # normalize object values
126
+ # the last channel is the cumulative area of the object
127
+ obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries,
128
+ self.embed_dim + 1)
129
+ # sum over time
130
+ # during inference, T=1 as we already did streaming average in memory_manager
131
+ obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1)
132
+ obj_area = obj_summaries[:, :, :, -1:].sum(dim=1)
133
+ obj_values = obj_sums / (obj_area + 1e-4)
134
+ obj_init = self.summary_to_query_init(obj_values)
135
+ obj_emb = self.summary_to_query_emb(obj_values)
136
+
137
+ # positional embeddings for object queries
138
+ query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init
139
+ query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb
140
+
141
+ # positional embeddings for pixel features
142
+ pixel_init = self.pixel_init_proj(pixel)
143
+ pixel_emb = self.pixel_emb_proj(pixel)
144
+ pixel_pe = self.spatial_pe(pixel.flatten(0, 1))
145
+ pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous()
146
+ pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb
147
+
148
+ pixel = pixel_init
149
+
150
+ # run the transformer
151
+ aux_features = {'logits': []}
152
+
153
+ # first aux output
154
+ aux_logits = self.mask_pred[0](pixel).squeeze(2)
155
+ attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
156
+ aux_features['logits'].append(aux_logits)
157
+ for i in range(self.num_blocks):
158
+ query, pixel, q_weights, p_weights = self.blocks[i](query,
159
+ pixel,
160
+ query_emb,
161
+ pixel_pe,
162
+ attn_mask,
163
+ need_weights=need_weights)
164
+
165
+ if self.training or i <= self.num_blocks - 1 or need_weights:
166
+ aux_logits = self.mask_pred[i + 1](pixel).squeeze(2)
167
+ attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass)
168
+ aux_features['logits'].append(aux_logits)
169
+
170
+ aux_features['q_weights'] = q_weights # last layer only
171
+ aux_features['p_weights'] = p_weights # last layer only
172
+
173
+ if self.training:
174
+ # no need to save all heads
175
+ aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads,
176
+ self.num_queries, H, W)[:, :, 0]
177
+
178
+ return pixel, aux_features
179
+
180
+ def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor:
181
+ # logits: batch_size*num_objects*H*W
182
+ # selector: batch_size*num_objects*1*1
183
+ # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W)
184
+ # where True means the attention is blocked
185
+
186
+ if selector is None:
187
+ prob = logits.sigmoid()
188
+ else:
189
+ prob = logits.sigmoid() * selector
190
+ logits = aggregate(prob, dim=1)
191
+
192
+ is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0])
193
+ foreground_mask = is_foreground.bool().flatten(start_dim=2)
194
+ inv_foreground_mask = ~foreground_mask
195
+ inv_background_mask = foreground_mask
196
+
197
+ aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat(
198
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
199
+ aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat(
200
+ 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2)
201
+
202
+ aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1)
203
+
204
+ aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False
205
+
206
+ return aux_mask
preprocessing/matanyone/matanyone/model/transformer/positional_encoding.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference:
2
+ # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py
3
+ # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ def get_emb(sin_inp: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Gets a base embedding for one dimension with sin and cos intertwined
15
+ """
16
+ emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
17
+ return torch.flatten(emb, -2, -1)
18
+
19
+
20
+ class PositionalEncoding(nn.Module):
21
+ def __init__(self,
22
+ dim: int,
23
+ scale: float = math.pi * 2,
24
+ temperature: float = 10000,
25
+ normalize: bool = True,
26
+ channel_last: bool = True,
27
+ transpose_output: bool = False):
28
+ super().__init__()
29
+ dim = int(np.ceil(dim / 4) * 2)
30
+ self.dim = dim
31
+ inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim))
32
+ self.register_buffer("inv_freq", inv_freq)
33
+ self.normalize = normalize
34
+ self.scale = scale
35
+ self.eps = 1e-6
36
+ self.channel_last = channel_last
37
+ self.transpose_output = transpose_output
38
+
39
+ self.cached_penc = None # the cache is irrespective of the number of objects
40
+
41
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ :param tensor: A 4/5d tensor of size
44
+ channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c)
45
+ channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w)
46
+ :return: positional encoding tensor that has the same shape as the input if the input is 4d
47
+ if the input is 5d, the output is broadcastable along the k-dimension
48
+ """
49
+ if len(tensor.shape) != 4 and len(tensor.shape) != 5:
50
+ raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!')
51
+
52
+ if len(tensor.shape) == 5:
53
+ # take a sample from the k dimension
54
+ num_objects = tensor.shape[1]
55
+ tensor = tensor[:, 0]
56
+ else:
57
+ num_objects = None
58
+
59
+ if self.channel_last:
60
+ batch_size, h, w, c = tensor.shape
61
+ else:
62
+ batch_size, c, h, w = tensor.shape
63
+
64
+ if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
65
+ if num_objects is None:
66
+ return self.cached_penc
67
+ else:
68
+ return self.cached_penc.unsqueeze(1)
69
+
70
+ self.cached_penc = None
71
+
72
+ pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype)
73
+ pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype)
74
+ if self.normalize:
75
+ pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale
76
+ pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale
77
+
78
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
79
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
80
+ emb_y = get_emb(sin_inp_y).unsqueeze(1)
81
+ emb_x = get_emb(sin_inp_x)
82
+
83
+ emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype)
84
+ emb[:, :, :self.dim] = emb_x
85
+ emb[:, :, self.dim:] = emb_y
86
+
87
+ if not self.channel_last and self.transpose_output:
88
+ # cancelled out
89
+ pass
90
+ elif (not self.channel_last) or (self.transpose_output):
91
+ emb = emb.permute(2, 0, 1)
92
+
93
+ self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1)
94
+ if num_objects is None:
95
+ return self.cached_penc
96
+ else:
97
+ return self.cached_penc.unsqueeze(1)
98
+
99
+
100
+ if __name__ == '__main__':
101
+ pe = PositionalEncoding(8).cuda()
102
+ input = torch.ones((1, 8, 8, 8)).cuda()
103
+ output = pe(input)
104
+ # print(output)
105
+ print(output[0, :, 0, 0])
106
+ print(output[0, :, 0, 5])
107
+ print(output[0, 0, :, 0])
108
+ print(output[0, 0, 0, :])
preprocessing/matanyone/matanyone/model/transformer/transformer_layers.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from PyTorch nn.Transformer
2
+
3
+ from typing import List, Callable
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from ...model.channel_attn import CAResBlock
10
+
11
+
12
+ class SelfAttention(nn.Module):
13
+ def __init__(self,
14
+ dim: int,
15
+ nhead: int,
16
+ dropout: float = 0.0,
17
+ batch_first: bool = True,
18
+ add_pe_to_qkv: List[bool] = [True, True, False]):
19
+ super().__init__()
20
+ self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first)
21
+ self.norm = nn.LayerNorm(dim)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.add_pe_to_qkv = add_pe_to_qkv
24
+
25
+ def forward(self,
26
+ x: torch.Tensor,
27
+ pe: torch.Tensor,
28
+ attn_mask: bool = None,
29
+ key_padding_mask: bool = None) -> torch.Tensor:
30
+ x = self.norm(x)
31
+ if any(self.add_pe_to_qkv):
32
+ x_with_pe = x + pe
33
+ q = x_with_pe if self.add_pe_to_qkv[0] else x
34
+ k = x_with_pe if self.add_pe_to_qkv[1] else x
35
+ v = x_with_pe if self.add_pe_to_qkv[2] else x
36
+ else:
37
+ q = k = v = x
38
+
39
+ r = x
40
+ x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0]
41
+ return r + self.dropout(x)
42
+
43
+
44
+ # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
45
+ class CrossAttention(nn.Module):
46
+ def __init__(self,
47
+ dim: int,
48
+ nhead: int,
49
+ dropout: float = 0.0,
50
+ batch_first: bool = True,
51
+ add_pe_to_qkv: List[bool] = [True, True, False],
52
+ residual: bool = True,
53
+ norm: bool = True):
54
+ super().__init__()
55
+ self.cross_attn = nn.MultiheadAttention(dim,
56
+ nhead,
57
+ dropout=dropout,
58
+ batch_first=batch_first)
59
+ if norm:
60
+ self.norm = nn.LayerNorm(dim)
61
+ else:
62
+ self.norm = nn.Identity()
63
+ self.dropout = nn.Dropout(dropout)
64
+ self.add_pe_to_qkv = add_pe_to_qkv
65
+ self.residual = residual
66
+
67
+ def forward(self,
68
+ x: torch.Tensor,
69
+ mem: torch.Tensor,
70
+ x_pe: torch.Tensor,
71
+ mem_pe: torch.Tensor,
72
+ attn_mask: bool = None,
73
+ *,
74
+ need_weights: bool = False) -> (torch.Tensor, torch.Tensor):
75
+ x = self.norm(x)
76
+ if self.add_pe_to_qkv[0]:
77
+ q = x + x_pe
78
+ else:
79
+ q = x
80
+
81
+ if any(self.add_pe_to_qkv[1:]):
82
+ mem_with_pe = mem + mem_pe
83
+ k = mem_with_pe if self.add_pe_to_qkv[1] else mem
84
+ v = mem_with_pe if self.add_pe_to_qkv[2] else mem
85
+ else:
86
+ k = v = mem
87
+ r = x
88
+ x, weights = self.cross_attn(q,
89
+ k,
90
+ v,
91
+ attn_mask=attn_mask,
92
+ need_weights=need_weights,
93
+ average_attn_weights=False)
94
+
95
+ if self.residual:
96
+ return r + self.dropout(x), weights
97
+ else:
98
+ return self.dropout(x), weights
99
+
100
+
101
+ class FFN(nn.Module):
102
+ def __init__(self, dim_in: int, dim_ff: int, activation=F.relu):
103
+ super().__init__()
104
+ self.linear1 = nn.Linear(dim_in, dim_ff)
105
+ self.linear2 = nn.Linear(dim_ff, dim_in)
106
+ self.norm = nn.LayerNorm(dim_in)
107
+
108
+ if isinstance(activation, str):
109
+ self.activation = _get_activation_fn(activation)
110
+ else:
111
+ self.activation = activation
112
+
113
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
+ r = x
115
+ x = self.norm(x)
116
+ x = self.linear2(self.activation(self.linear1(x)))
117
+ x = r + x
118
+ return x
119
+
120
+
121
+ class PixelFFN(nn.Module):
122
+ def __init__(self, dim: int):
123
+ super().__init__()
124
+ self.dim = dim
125
+ self.conv = CAResBlock(dim, dim)
126
+
127
+ def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor:
128
+ # pixel: batch_size * num_objects * dim * H * W
129
+ # pixel_flat: (batch_size*num_objects) * (H*W) * dim
130
+ bs, num_objects, _, h, w = pixel.shape
131
+ pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim)
132
+ pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous()
133
+
134
+ x = self.conv(pixel_flat)
135
+ x = x.view(bs, num_objects, self.dim, h, w)
136
+ return x
137
+
138
+
139
+ class OutputFFN(nn.Module):
140
+ def __init__(self, dim_in: int, dim_out: int, activation=F.relu):
141
+ super().__init__()
142
+ self.linear1 = nn.Linear(dim_in, dim_out)
143
+ self.linear2 = nn.Linear(dim_out, dim_out)
144
+
145
+ if isinstance(activation, str):
146
+ self.activation = _get_activation_fn(activation)
147
+ else:
148
+ self.activation = activation
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ x = self.linear2(self.activation(self.linear1(x)))
152
+ return x
153
+
154
+
155
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
156
+ if activation == "relu":
157
+ return F.relu
158
+ elif activation == "gelu":
159
+ return F.gelu
160
+
161
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
preprocessing/matanyone/matanyone/model/utils/__init__.py ADDED
File without changes
preprocessing/matanyone/matanyone/model/utils/memory_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from typing import Optional, Union, Tuple
4
+
5
+
6
+ # @torch.jit.script
7
+ def get_similarity(mk: torch.Tensor,
8
+ ms: torch.Tensor,
9
+ qk: torch.Tensor,
10
+ qe: torch.Tensor,
11
+ add_batch_dim: bool = False,
12
+ uncert_mask = None) -> torch.Tensor:
13
+ # used for training/inference and memory reading/memory potentiation
14
+ # mk: B x CK x [N] - Memory keys
15
+ # ms: B x 1 x [N] - Memory shrinkage
16
+ # qk: B x CK x [HW/P] - Query keys
17
+ # qe: B x CK x [HW/P] - Query selection
18
+ # Dimensions in [] are flattened
19
+ # Return: B*N*HW
20
+ if add_batch_dim:
21
+ mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
22
+ qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
23
+
24
+ CK = mk.shape[1]
25
+
26
+ mk = mk.flatten(start_dim=2)
27
+ ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
28
+ qk = qk.flatten(start_dim=2)
29
+ qe = qe.flatten(start_dim=2) if qe is not None else None
30
+
31
+ # query token selection based on temporal sparsity
32
+ if uncert_mask is not None:
33
+ uncert_mask = uncert_mask.flatten(start_dim=2)
34
+ uncert_mask = uncert_mask.expand(-1, 64, -1)
35
+ qk = qk * uncert_mask
36
+ qe = qe * uncert_mask
37
+
38
+ if qe is not None:
39
+ # See XMem's appendix for derivation
40
+ mk = mk.transpose(1, 2)
41
+ a_sq = (mk.pow(2) @ qe)
42
+ two_ab = 2 * (mk @ (qk * qe))
43
+ b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
44
+ similarity = (-a_sq + two_ab - b_sq)
45
+ else:
46
+ # similar to STCN if we don't have the selection term
47
+ a_sq = mk.pow(2).sum(1).unsqueeze(2)
48
+ two_ab = 2 * (mk.transpose(1, 2) @ qk)
49
+ similarity = (-a_sq + two_ab)
50
+
51
+ if ms is not None:
52
+ similarity = similarity * ms / math.sqrt(CK) # B*N*HW
53
+ else:
54
+ similarity = similarity / math.sqrt(CK) # B*N*HW
55
+
56
+ return similarity
57
+
58
+
59
+ def do_softmax(
60
+ similarity: torch.Tensor,
61
+ top_k: Optional[int] = None,
62
+ inplace: bool = False,
63
+ return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
64
+ # normalize similarity with top-k softmax
65
+ # similarity: B x N x [HW/P]
66
+ # use inplace with care
67
+ if top_k is not None:
68
+ values, indices = torch.topk(similarity, k=top_k, dim=1)
69
+
70
+ x_exp = values.exp_()
71
+ x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
72
+ if inplace:
73
+ similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
74
+ affinity = similarity
75
+ else:
76
+ affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
77
+ else:
78
+ maxes = torch.max(similarity, dim=1, keepdim=True)[0]
79
+ x_exp = torch.exp(similarity - maxes)
80
+ x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
81
+ affinity = x_exp / x_exp_sum
82
+ indices = None
83
+
84
+ if return_usage:
85
+ return affinity, affinity.sum(dim=2)
86
+
87
+ return affinity
88
+
89
+
90
+ def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
91
+ qe: torch.Tensor, uncert_mask = None) -> torch.Tensor:
92
+ # shorthand used in training with no top-k
93
+ similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask)
94
+ affinity = do_softmax(similarity)
95
+ return affinity
96
+
97
+ def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor:
98
+ B, CV, T, H, W = mv.shape
99
+
100
+ mo = mv.view(B, CV, T * H * W)
101
+ mem = torch.bmm(mo, affinity)
102
+ if uncert_mask is not None:
103
+ uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1)
104
+ mem = mem * uncert_mask
105
+ mem = mem.view(B, CV, H, W)
106
+
107
+ return mem
preprocessing/matanyone/matanyone/model/utils/parameter_groups.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ log = logging.getLogger()
4
+
5
+
6
+ def get_parameter_groups(model, stage_cfg, print_log=False):
7
+ """
8
+ Assign different weight decays and learning rates to different parameters.
9
+ Returns a parameter group which can be passed to the optimizer.
10
+ """
11
+ weight_decay = stage_cfg.weight_decay
12
+ embed_weight_decay = stage_cfg.embed_weight_decay
13
+ backbone_lr_ratio = stage_cfg.backbone_lr_ratio
14
+ base_lr = stage_cfg.learning_rate
15
+
16
+ backbone_params = []
17
+ embed_params = []
18
+ other_params = []
19
+
20
+ embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
21
+ embedding_names = [e + '.weight' for e in embedding_names]
22
+
23
+ # inspired by detectron2
24
+ memo = set()
25
+ for name, param in model.named_parameters():
26
+ if not param.requires_grad:
27
+ continue
28
+ # Avoid duplicating parameters
29
+ if param in memo:
30
+ continue
31
+ memo.add(param)
32
+
33
+ if name.startswith('module'):
34
+ name = name[7:]
35
+
36
+ inserted = False
37
+ if name.startswith('pixel_encoder.'):
38
+ backbone_params.append(param)
39
+ inserted = True
40
+ if print_log:
41
+ log.info(f'{name} counted as a backbone parameter.')
42
+ else:
43
+ for e in embedding_names:
44
+ if name.endswith(e):
45
+ embed_params.append(param)
46
+ inserted = True
47
+ if print_log:
48
+ log.info(f'{name} counted as an embedding parameter.')
49
+ break
50
+
51
+ if not inserted:
52
+ other_params.append(param)
53
+
54
+ parameter_groups = [
55
+ {
56
+ 'params': backbone_params,
57
+ 'lr': base_lr * backbone_lr_ratio,
58
+ 'weight_decay': weight_decay
59
+ },
60
+ {
61
+ 'params': embed_params,
62
+ 'lr': base_lr,
63
+ 'weight_decay': embed_weight_decay
64
+ },
65
+ {
66
+ 'params': other_params,
67
+ 'lr': base_lr,
68
+ 'weight_decay': weight_decay
69
+ },
70
+ ]
71
+
72
+ return parameter_groups
preprocessing/matanyone/matanyone/model/utils/resnet.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ resnet.py - A modified ResNet structure
3
+ We append extra channels to the first conv by some network surgery
4
+ """
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils import model_zoo
12
+
13
+
14
+ def load_weights_add_extra_dim(target, source_state, extra_dim=1):
15
+ new_dict = OrderedDict()
16
+
17
+ for k1, v1 in target.state_dict().items():
18
+ if 'num_batches_tracked' not in k1:
19
+ if k1 in source_state:
20
+ tar_v = source_state[k1]
21
+
22
+ if v1.shape != tar_v.shape:
23
+ # Init the new segmentation channel with zeros
24
+ # print(v1.shape, tar_v.shape)
25
+ c, _, w, h = v1.shape
26
+ pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
27
+ nn.init.orthogonal_(pads)
28
+ tar_v = torch.cat([tar_v, pads], 1)
29
+
30
+ new_dict[k1] = tar_v
31
+
32
+ target.load_state_dict(new_dict)
33
+
34
+
35
+ model_urls = {
36
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
37
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
38
+ }
39
+
40
+
41
+ def conv3x3(in_planes, out_planes, stride=1, dilation=1):
42
+ return nn.Conv2d(in_planes,
43
+ out_planes,
44
+ kernel_size=3,
45
+ stride=stride,
46
+ padding=dilation,
47
+ dilation=dilation,
48
+ bias=False)
49
+
50
+
51
+ class BasicBlock(nn.Module):
52
+ expansion = 1
53
+
54
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
55
+ super(BasicBlock, self).__init__()
56
+ self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
57
+ self.bn1 = nn.BatchNorm2d(planes)
58
+ self.relu = nn.ReLU(inplace=True)
59
+ self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
60
+ self.bn2 = nn.BatchNorm2d(planes)
61
+ self.downsample = downsample
62
+ self.stride = stride
63
+
64
+ def forward(self, x):
65
+ residual = x
66
+
67
+ out = self.conv1(x)
68
+ out = self.bn1(out)
69
+ out = self.relu(out)
70
+
71
+ out = self.conv2(out)
72
+ out = self.bn2(out)
73
+
74
+ if self.downsample is not None:
75
+ residual = self.downsample(x)
76
+
77
+ out += residual
78
+ out = self.relu(out)
79
+
80
+ return out
81
+
82
+
83
+ class Bottleneck(nn.Module):
84
+ expansion = 4
85
+
86
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
87
+ super(Bottleneck, self).__init__()
88
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
89
+ self.bn1 = nn.BatchNorm2d(planes)
90
+ self.conv2 = nn.Conv2d(planes,
91
+ planes,
92
+ kernel_size=3,
93
+ stride=stride,
94
+ dilation=dilation,
95
+ padding=dilation,
96
+ bias=False)
97
+ self.bn2 = nn.BatchNorm2d(planes)
98
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
99
+ self.bn3 = nn.BatchNorm2d(planes * 4)
100
+ self.relu = nn.ReLU(inplace=True)
101
+ self.downsample = downsample
102
+ self.stride = stride
103
+
104
+ def forward(self, x):
105
+ residual = x
106
+
107
+ out = self.conv1(x)
108
+ out = self.bn1(out)
109
+ out = self.relu(out)
110
+
111
+ out = self.conv2(out)
112
+ out = self.bn2(out)
113
+ out = self.relu(out)
114
+
115
+ out = self.conv3(out)
116
+ out = self.bn3(out)
117
+
118
+ if self.downsample is not None:
119
+ residual = self.downsample(x)
120
+
121
+ out += residual
122
+ out = self.relu(out)
123
+
124
+ return out
125
+
126
+
127
+ class ResNet(nn.Module):
128
+ def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
129
+ self.inplanes = 64
130
+ super(ResNet, self).__init__()
131
+ self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
132
+ self.bn1 = nn.BatchNorm2d(64)
133
+ self.relu = nn.ReLU(inplace=True)
134
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
135
+ self.layer1 = self._make_layer(block, 64, layers[0])
136
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
137
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
138
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
139
+
140
+ for m in self.modules():
141
+ if isinstance(m, nn.Conv2d):
142
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
143
+ m.weight.data.normal_(0, math.sqrt(2. / n))
144
+ elif isinstance(m, nn.BatchNorm2d):
145
+ m.weight.data.fill_(1)
146
+ m.bias.data.zero_()
147
+
148
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
149
+ downsample = None
150
+ if stride != 1 or self.inplanes != planes * block.expansion:
151
+ downsample = nn.Sequential(
152
+ nn.Conv2d(self.inplanes,
153
+ planes * block.expansion,
154
+ kernel_size=1,
155
+ stride=stride,
156
+ bias=False),
157
+ nn.BatchNorm2d(planes * block.expansion),
158
+ )
159
+
160
+ layers = [block(self.inplanes, planes, stride, downsample)]
161
+ self.inplanes = planes * block.expansion
162
+ for i in range(1, blocks):
163
+ layers.append(block(self.inplanes, planes, dilation=dilation))
164
+
165
+ return nn.Sequential(*layers)
166
+
167
+
168
+ def resnet18(pretrained=True, extra_dim=0):
169
+ model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
170
+ if pretrained:
171
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
172
+ return model
173
+
174
+
175
+ def resnet50(pretrained=True, extra_dim=0):
176
+ model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
177
+ if pretrained:
178
+ load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
179
+ return model
preprocessing/matanyone/matanyone_wrapper.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import torch
3
+ from torchvision.transforms.functional import to_tensor
4
+ import numpy as np
5
+ import random
6
+ import cv2
7
+
8
+ def gen_dilate(alpha, min_kernel_size, max_kernel_size):
9
+ kernel_size = random.randint(min_kernel_size, max_kernel_size)
10
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
11
+ fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
12
+ dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255
13
+ return dilate.astype(np.float32)
14
+
15
+ def gen_erosion(alpha, min_kernel_size, max_kernel_size):
16
+ kernel_size = random.randint(min_kernel_size, max_kernel_size)
17
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size))
18
+ fg = np.array(np.equal(alpha, 255).astype(np.float32))
19
+ erode = cv2.erode(fg, kernel, iterations=1)*255
20
+ return erode.astype(np.float32)
21
+
22
+ @torch.inference_mode()
23
+ @torch.amp.autocast('cuda')
24
+ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
25
+ """
26
+ Args:
27
+ frames_np: [(H,W,C)]*n, uint8
28
+ mask: (H,W), uint8
29
+ Outputs:
30
+ com: [(H,W,C)]*n, uint8
31
+ pha: [(H,W,C)]*n, uint8
32
+ """
33
+
34
+ # print(f'===== [r_erode] {r_erode}; [r_dilate] {r_dilate} =====')
35
+ bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3))
36
+ objects = [1]
37
+
38
+ # [optional] erode & dilate on given seg mask
39
+ if r_dilate > 0:
40
+ mask = gen_dilate(mask, r_dilate, r_dilate)
41
+ if r_erode > 0:
42
+ mask = gen_erosion(mask, r_erode, r_erode)
43
+
44
+ mask = torch.from_numpy(mask).cuda()
45
+
46
+ frames_np = [frames_np[0]]* n_warmup + frames_np
47
+
48
+ frames = []
49
+ phas = []
50
+ for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
51
+ image = to_tensor(frame_single).cuda().float()
52
+
53
+ if ti == 0:
54
+ output_prob = processor.step(image, mask, objects=objects) # encode given mask
55
+ output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
56
+ else:
57
+ if ti <= n_warmup:
58
+ output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames
59
+ else:
60
+ output_prob = processor.step(image)
61
+
62
+ # convert output probabilities to an object mask
63
+ mask = processor.output_prob_to_mask(output_prob)
64
+
65
+ pha = mask.unsqueeze(2).cpu().numpy()
66
+ com_np = frame_single / 255. * pha + bgr * (1 - pha)
67
+
68
+ # DONOT save the warmup frames
69
+ if ti > (n_warmup-1):
70
+ frames.append((com_np*255).astype(np.uint8))
71
+ phas.append((pha*255).astype(np.uint8))
72
+
73
+ return frames, phas
preprocessing/matanyone/tools/__init__.py ADDED
File without changes
preprocessing/matanyone/tools/base_segmenter.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter
11
+
12
+
13
+ class BaseSegmenter:
14
+ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
15
+ """
16
+ device: model device
17
+ SAM_checkpoint: path of SAM checkpoint
18
+ model_type: vit_b, vit_l, vit_h
19
+ """
20
+ print(f"Initializing BaseSegmenter to {device}")
21
+ assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
22
+
23
+ self.device = device
24
+ # SAM_checkpoint = None
25
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
26
+ from accelerate import init_empty_weights
27
+
28
+ # self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
29
+ with init_empty_weights():
30
+ self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint)
31
+ from mmgp import offload
32
+ # self.model.to(torch.float16)
33
+ # offload.save_model(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors")
34
+
35
+ offload.load_model_data(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors")
36
+ self.model.to(torch.float32) # need to be optimized, if not f32 crappy precision
37
+ self.model.to(device=self.device)
38
+ self.predictor = SamPredictor(self.model)
39
+ self.embedded = False
40
+
41
+ @torch.no_grad()
42
+ def set_image(self, image: np.ndarray):
43
+ # PIL.open(image_path) 3channel: RGB
44
+ # image embedding: avoid encode the same image multiple times
45
+ self.orignal_image = image
46
+ if self.embedded:
47
+ print('repeat embedding, please reset_image.')
48
+ return
49
+ self.predictor.set_image(image)
50
+ self.embedded = True
51
+ return
52
+
53
+ @torch.no_grad()
54
+ def reset_image(self):
55
+ # reset image embeding
56
+ self.predictor.reset_image()
57
+ self.embedded = False
58
+
59
+ def predict(self, prompts, mode, multimask=True):
60
+ """
61
+ image: numpy array, h, w, 3
62
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
63
+ prompts['point_coords']: numpy array [N,2]
64
+ prompts['point_labels']: numpy array [1,N]
65
+ prompts['mask_input']: numpy array [1,256,256]
66
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
67
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
68
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
69
+ """
70
+ assert self.embedded, 'prediction is called before set_image (feature embedding).'
71
+ assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both'
72
+
73
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
74
+ if mode == 'point':
75
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
76
+ point_labels=prompts['point_labels'],
77
+ multimask_output=multimask)
78
+ elif mode == 'mask':
79
+ masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'],
80
+ multimask_output=multimask)
81
+ elif mode == 'both': # both
82
+ masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'],
83
+ point_labels=prompts['point_labels'],
84
+ mask_input=prompts['mask_input'],
85
+ multimask_output=multimask)
86
+ else:
87
+ raise("Not implement now!")
88
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
89
+ return masks, scores, logits
90
+
91
+
92
+ if __name__ == "__main__":
93
+ # load and show an image
94
+ image = cv2.imread('/hhd3/gaoshang/truck.jpg')
95
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3)
96
+
97
+ # initialise BaseSegmenter
98
+ SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
99
+ model_type = 'vit_h'
100
+ device = "cuda:4"
101
+ base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
102
+
103
+ # image embedding (once embedded, multiple prompts can be applied)
104
+ base_segmenter.set_image(image)
105
+
106
+ # examples
107
+ # point only ------------------------
108
+ mode = 'point'
109
+ prompts = {
110
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
111
+ 'point_labels': np.array([1, 1]),
112
+ }
113
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256)
114
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
115
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
116
+ cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
117
+
118
+ # both ------------------------
119
+ mode = 'both'
120
+ mask_input = logits[np.argmax(scores), :, :]
121
+ prompts = {'mask_input': mask_input [None, :, :]}
122
+ prompts = {
123
+ 'point_coords': np.array([[500, 375], [1125, 625]]),
124
+ 'point_labels': np.array([1, 0]),
125
+ 'mask_input': mask_input[None, :, :]
126
+ }
127
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
128
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
129
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
130
+ cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image)
131
+
132
+ # mask only ------------------------
133
+ mode = 'mask'
134
+ mask_input = logits[np.argmax(scores), :, :]
135
+
136
+ prompts = {'mask_input': mask_input[None, :, :]}
137
+
138
+ masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256)
139
+ painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8)
140
+ painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
141
+ cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image)
preprocessing/matanyone/tools/download_util.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import requests
4
+ from torch.hub import download_url_to_file, get_dir
5
+ from tqdm import tqdm
6
+ from urllib.parse import urlparse
7
+
8
+ def sizeof_fmt(size, suffix='B'):
9
+ """Get human readable file size.
10
+
11
+ Args:
12
+ size (int): File size.
13
+ suffix (str): Suffix. Default: 'B'.
14
+
15
+ Return:
16
+ str: Formated file siz.
17
+ """
18
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
19
+ if abs(size) < 1024.0:
20
+ return f'{size:3.1f} {unit}{suffix}'
21
+ size /= 1024.0
22
+ return f'{size:3.1f} Y{suffix}'
23
+
24
+
25
+ def download_file_from_google_drive(file_id, save_path):
26
+ """Download files from google drive.
27
+ Ref:
28
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
29
+ Args:
30
+ file_id (str): File id.
31
+ save_path (str): Save path.
32
+ """
33
+
34
+ session = requests.Session()
35
+ URL = 'https://docs.google.com/uc?export=download'
36
+ params = {'id': file_id}
37
+
38
+ response = session.get(URL, params=params, stream=True)
39
+ token = get_confirm_token(response)
40
+ if token:
41
+ params['confirm'] = token
42
+ response = session.get(URL, params=params, stream=True)
43
+
44
+ # get file size
45
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
46
+ print(response_file_size)
47
+ if 'Content-Range' in response_file_size.headers:
48
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
49
+ else:
50
+ file_size = None
51
+
52
+ save_response_content(response, save_path, file_size)
53
+
54
+
55
+ def get_confirm_token(response):
56
+ for key, value in response.cookies.items():
57
+ if key.startswith('download_warning'):
58
+ return value
59
+ return None
60
+
61
+
62
+ def save_response_content(response, destination, file_size=None, chunk_size=32768):
63
+ if file_size is not None:
64
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
65
+
66
+ readable_file_size = sizeof_fmt(file_size)
67
+ else:
68
+ pbar = None
69
+
70
+ with open(destination, 'wb') as f:
71
+ downloaded_size = 0
72
+ for chunk in response.iter_content(chunk_size):
73
+ downloaded_size += chunk_size
74
+ if pbar is not None:
75
+ pbar.update(1)
76
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
77
+ if chunk: # filter out keep-alive new chunks
78
+ f.write(chunk)
79
+ if pbar is not None:
80
+ pbar.close()
81
+
82
+
83
+ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
84
+ """Load file form http url, will download models if necessary.
85
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
86
+ Args:
87
+ url (str): URL to be downloaded.
88
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
89
+ Default: None.
90
+ progress (bool): Whether to show the download progress. Default: True.
91
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
92
+ Returns:
93
+ str: The path to the downloaded file.
94
+ """
95
+ if model_dir is None: # use the pytorch hub_dir
96
+ hub_dir = get_dir()
97
+ model_dir = os.path.join(hub_dir, 'checkpoints')
98
+
99
+ os.makedirs(model_dir, exist_ok=True)
100
+
101
+ parts = urlparse(url)
102
+ filename = os.path.basename(parts.path)
103
+ if file_name is not None:
104
+ filename = file_name
105
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
106
+ if not os.path.exists(cached_file):
107
+ print(f'Downloading: "{url}" to {cached_file}\n')
108
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
109
+ return cached_file
preprocessing/matanyone/tools/interact_tools.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+ from PIL import Image, ImageDraw, ImageOps
5
+ import numpy as np
6
+ from typing import Union
7
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
8
+ import matplotlib.pyplot as plt
9
+ import PIL
10
+ from .mask_painter import mask_painter as mask_painter2
11
+ from .base_segmenter import BaseSegmenter
12
+ from .painter import mask_painter, point_painter
13
+ import os
14
+ import requests
15
+ import sys
16
+
17
+
18
+ mask_color = 3
19
+ mask_alpha = 0.7
20
+ contour_color = 1
21
+ contour_width = 5
22
+ point_color_ne = 8
23
+ point_color_ps = 50
24
+ point_alpha = 0.9
25
+ point_radius = 15
26
+ contour_color = 2
27
+ contour_width = 5
28
+
29
+
30
+ class SamControler():
31
+ def __init__(self, SAM_checkpoint, model_type, device):
32
+ '''
33
+ initialize sam controler
34
+ '''
35
+ self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
36
+
37
+
38
+ # def seg_again(self, image: np.ndarray):
39
+ # '''
40
+ # it is used when interact in video
41
+ # '''
42
+ # self.sam_controler.reset_image()
43
+ # self.sam_controler.set_image(image)
44
+ # return
45
+
46
+
47
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
48
+ '''
49
+ it is used in first frame in video
50
+ return: mask, logit, painted image(mask+point)
51
+ '''
52
+ # self.sam_controler.set_image(image)
53
+ origal_image = self.sam_controler.orignal_image
54
+ neg_flag = labels[-1]
55
+ if neg_flag==1:
56
+ #find neg
57
+ prompts = {
58
+ 'point_coords': points,
59
+ 'point_labels': labels,
60
+ }
61
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
62
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
63
+ prompts = {
64
+ 'point_coords': points,
65
+ 'point_labels': labels,
66
+ 'mask_input': logit[None, :, :]
67
+ }
68
+ masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
69
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
70
+ else:
71
+ #find positive
72
+ prompts = {
73
+ 'point_coords': points,
74
+ 'point_labels': labels,
75
+ }
76
+ masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
77
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
78
+
79
+
80
+ assert len(points)==len(labels)
81
+
82
+ painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
83
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
84
+ painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
85
+ painted_image = Image.fromarray(painted_image)
86
+
87
+ return mask, logit, painted_image
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
preprocessing/matanyone/tools/mask_painter.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import copy
6
+ import time
7
+
8
+
9
+ def colormap(rgb=True):
10
+ color_list = np.array(
11
+ [
12
+ 0.000, 0.000, 0.000,
13
+ 1.000, 1.000, 1.000,
14
+ 1.000, 0.498, 0.313,
15
+ 0.392, 0.581, 0.929,
16
+ 0.000, 0.447, 0.741,
17
+ 0.850, 0.325, 0.098,
18
+ 0.929, 0.694, 0.125,
19
+ 0.494, 0.184, 0.556,
20
+ 0.466, 0.674, 0.188,
21
+ 0.301, 0.745, 0.933,
22
+ 0.635, 0.078, 0.184,
23
+ 0.300, 0.300, 0.300,
24
+ 0.600, 0.600, 0.600,
25
+ 1.000, 0.000, 0.000,
26
+ 1.000, 0.500, 0.000,
27
+ 0.749, 0.749, 0.000,
28
+ 0.000, 1.000, 0.000,
29
+ 0.000, 0.000, 1.000,
30
+ 0.667, 0.000, 1.000,
31
+ 0.333, 0.333, 0.000,
32
+ 0.333, 0.667, 0.000,
33
+ 0.333, 1.000, 0.000,
34
+ 0.667, 0.333, 0.000,
35
+ 0.667, 0.667, 0.000,
36
+ 0.667, 1.000, 0.000,
37
+ 1.000, 0.333, 0.000,
38
+ 1.000, 0.667, 0.000,
39
+ 1.000, 1.000, 0.000,
40
+ 0.000, 0.333, 0.500,
41
+ 0.000, 0.667, 0.500,
42
+ 0.000, 1.000, 0.500,
43
+ 0.333, 0.000, 0.500,
44
+ 0.333, 0.333, 0.500,
45
+ 0.333, 0.667, 0.500,
46
+ 0.333, 1.000, 0.500,
47
+ 0.667, 0.000, 0.500,
48
+ 0.667, 0.333, 0.500,
49
+ 0.667, 0.667, 0.500,
50
+ 0.667, 1.000, 0.500,
51
+ 1.000, 0.000, 0.500,
52
+ 1.000, 0.333, 0.500,
53
+ 1.000, 0.667, 0.500,
54
+ 1.000, 1.000, 0.500,
55
+ 0.000, 0.333, 1.000,
56
+ 0.000, 0.667, 1.000,
57
+ 0.000, 1.000, 1.000,
58
+ 0.333, 0.000, 1.000,
59
+ 0.333, 0.333, 1.000,
60
+ 0.333, 0.667, 1.000,
61
+ 0.333, 1.000, 1.000,
62
+ 0.667, 0.000, 1.000,
63
+ 0.667, 0.333, 1.000,
64
+ 0.667, 0.667, 1.000,
65
+ 0.667, 1.000, 1.000,
66
+ 1.000, 0.000, 1.000,
67
+ 1.000, 0.333, 1.000,
68
+ 1.000, 0.667, 1.000,
69
+ 0.167, 0.000, 0.000,
70
+ 0.333, 0.000, 0.000,
71
+ 0.500, 0.000, 0.000,
72
+ 0.667, 0.000, 0.000,
73
+ 0.833, 0.000, 0.000,
74
+ 1.000, 0.000, 0.000,
75
+ 0.000, 0.167, 0.000,
76
+ 0.000, 0.333, 0.000,
77
+ 0.000, 0.500, 0.000,
78
+ 0.000, 0.667, 0.000,
79
+ 0.000, 0.833, 0.000,
80
+ 0.000, 1.000, 0.000,
81
+ 0.000, 0.000, 0.167,
82
+ 0.000, 0.000, 0.333,
83
+ 0.000, 0.000, 0.500,
84
+ 0.000, 0.000, 0.667,
85
+ 0.000, 0.000, 0.833,
86
+ 0.000, 0.000, 1.000,
87
+ 0.143, 0.143, 0.143,
88
+ 0.286, 0.286, 0.286,
89
+ 0.429, 0.429, 0.429,
90
+ 0.571, 0.571, 0.571,
91
+ 0.714, 0.714, 0.714,
92
+ 0.857, 0.857, 0.857
93
+ ]
94
+ ).astype(np.float32)
95
+ color_list = color_list.reshape((-1, 3)) * 255
96
+ if not rgb:
97
+ color_list = color_list[:, ::-1]
98
+ return color_list
99
+
100
+
101
+ color_list = colormap()
102
+ color_list = color_list.astype('uint8').tolist()
103
+
104
+
105
+ def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
106
+ background_color = np.array(background_color)
107
+ contour_color = np.array(contour_color)
108
+
109
+ # background_mask = 1 - background_mask
110
+ # contour_mask = 1 - contour_mask
111
+
112
+ for i in range(3):
113
+ image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
114
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
115
+
116
+ image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
117
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
118
+
119
+ return image.astype('uint8')
120
+
121
+
122
+ def mask_generator_00(mask, background_radius, contour_radius):
123
+ # no background width when '00'
124
+ # distance map
125
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
126
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
127
+ dist_map = dist_transform_fore - dist_transform_back
128
+ # ...:::!!!:::...
129
+ contour_radius += 2
130
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
131
+ contour_mask = contour_mask / np.max(contour_mask)
132
+ contour_mask[contour_mask>0.5] = 1.
133
+
134
+ return mask, contour_mask
135
+
136
+
137
+ def mask_generator_01(mask, background_radius, contour_radius):
138
+ # no background width when '00'
139
+ # distance map
140
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
141
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
142
+ dist_map = dist_transform_fore - dist_transform_back
143
+ # ...:::!!!:::...
144
+ contour_radius += 2
145
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
146
+ contour_mask = contour_mask / np.max(contour_mask)
147
+ return mask, contour_mask
148
+
149
+
150
+ def mask_generator_10(mask, background_radius, contour_radius):
151
+ # distance map
152
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
153
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
154
+ dist_map = dist_transform_fore - dist_transform_back
155
+ # .....:::::!!!!!
156
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
157
+ background_mask = (background_mask - np.min(background_mask))
158
+ background_mask = background_mask / np.max(background_mask)
159
+ # ...:::!!!:::...
160
+ contour_radius += 2
161
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
162
+ contour_mask = contour_mask / np.max(contour_mask)
163
+ contour_mask[contour_mask>0.5] = 1.
164
+ return background_mask, contour_mask
165
+
166
+
167
+ def mask_generator_11(mask, background_radius, contour_radius):
168
+ # distance map
169
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
170
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
171
+ dist_map = dist_transform_fore - dist_transform_back
172
+ # .....:::::!!!!!
173
+ background_mask = np.clip(dist_map, -background_radius, background_radius)
174
+ background_mask = (background_mask - np.min(background_mask))
175
+ background_mask = background_mask / np.max(background_mask)
176
+ # ...:::!!!:::...
177
+ contour_radius += 2
178
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
179
+ contour_mask = contour_mask / np.max(contour_mask)
180
+ return background_mask, contour_mask
181
+
182
+
183
+ def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
184
+ """
185
+ Input:
186
+ input_image: numpy array
187
+ input_mask: numpy array
188
+ background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
189
+ background_blur_radius: radius of background blur, must be odd number
190
+ contour_width: width of mask contour, must be odd number
191
+ contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
192
+ contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
193
+ mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
194
+
195
+ Output:
196
+ painted_image: numpy array
197
+ """
198
+ assert input_image.shape[:2] == input_mask.shape, 'different shape'
199
+ assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
200
+ assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
201
+
202
+ # downsample input image and mask
203
+ width, height = input_image.shape[0], input_image.shape[1]
204
+ res = 1024
205
+ ratio = min(1.0 * res / max(width, height), 1.0)
206
+ input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
207
+ input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
208
+
209
+ # 0: background, 1: foreground
210
+ msk = np.clip(input_mask, 0, 1)
211
+
212
+ # generate masks for background and contour pixels
213
+ background_radius = (background_blur_radius - 1) // 2
214
+ contour_radius = (contour_width - 1) // 2
215
+ generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
216
+ background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
217
+
218
+ # paint
219
+ painted_image = vis_add_mask\
220
+ (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
221
+
222
+ return painted_image
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
228
+ background_blur_radius = 31 # radius of background blur, must be odd number
229
+ contour_width = 11 # contour width, must be odd number
230
+ contour_color = 3 # id in color map, 0: black, 1: white, >1: others
231
+ contour_alpha = 1 # transparency of background, 0: no contour highlighted
232
+
233
+ # load input image and mask
234
+ input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
235
+ input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
236
+
237
+ # paint
238
+ overall_time_1 = 0
239
+ overall_time_2 = 0
240
+ overall_time_3 = 0
241
+ overall_time_4 = 0
242
+ overall_time_5 = 0
243
+
244
+ for i in range(50):
245
+ t2 = time.time()
246
+ painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
247
+ e2 = time.time()
248
+
249
+ t3 = time.time()
250
+ painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
251
+ e3 = time.time()
252
+
253
+ t1 = time.time()
254
+ painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
255
+ e1 = time.time()
256
+
257
+ t4 = time.time()
258
+ painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
259
+ e4 = time.time()
260
+
261
+ t5 = time.time()
262
+ painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
263
+ e5 = time.time()
264
+
265
+ overall_time_1 += (e1 - t1)
266
+ overall_time_2 += (e2 - t2)
267
+ overall_time_3 += (e3 - t3)
268
+ overall_time_4 += (e4 - t4)
269
+ overall_time_5 += (e5 - t5)
270
+
271
+ print(f'average time w gaussian: {overall_time_1/50}')
272
+ print(f'average time w/o gaussian00: {overall_time_2/50}')
273
+ print(f'average time w/o gaussian10: {overall_time_3/50}')
274
+ print(f'average time w/o gaussian01: {overall_time_4/50}')
275
+ print(f'average time w/o gaussian11: {overall_time_5/50}')
276
+
277
+ # save
278
+ painted_image_00 = Image.fromarray(painted_image_00)
279
+ painted_image_00.save('./test_img/painter_output_image_00.png')
280
+
281
+ painted_image_10 = Image.fromarray(painted_image_10)
282
+ painted_image_10.save('./test_img/painter_output_image_10.png')
283
+
284
+ painted_image_01 = Image.fromarray(painted_image_01)
285
+ painted_image_01.save('./test_img/painter_output_image_01.png')
286
+
287
+ painted_image_11 = Image.fromarray(painted_image_11)
288
+ painted_image_11.save('./test_img/painter_output_image_11.png')
preprocessing/matanyone/tools/misc.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import random
4
+ import time
5
+ import torch
6
+ import torch.nn as nn
7
+ import logging
8
+ import numpy as np
9
+ from os import path as osp
10
+
11
+ def constant_init(module, val, bias=0):
12
+ if hasattr(module, 'weight') and module.weight is not None:
13
+ nn.init.constant_(module.weight, val)
14
+ if hasattr(module, 'bias') and module.bias is not None:
15
+ nn.init.constant_(module.bias, bias)
16
+
17
+ initialized_logger = {}
18
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
19
+ """Get the root logger.
20
+ The logger will be initialized if it has not been initialized. By default a
21
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
22
+ also be added.
23
+ Args:
24
+ logger_name (str): root logger name. Default: 'basicsr'.
25
+ log_file (str | None): The log filename. If specified, a FileHandler
26
+ will be added to the root logger.
27
+ log_level (int): The root logger level. Note that only the process of
28
+ rank 0 is affected, while other processes will set the level to
29
+ "Error" and be silent most of the time.
30
+ Returns:
31
+ logging.Logger: The root logger.
32
+ """
33
+ logger = logging.getLogger(logger_name)
34
+ # if the logger has been initialized, just return it
35
+ if logger_name in initialized_logger:
36
+ return logger
37
+
38
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
39
+ stream_handler = logging.StreamHandler()
40
+ stream_handler.setFormatter(logging.Formatter(format_str))
41
+ logger.addHandler(stream_handler)
42
+ logger.propagate = False
43
+
44
+ if log_file is not None:
45
+ logger.setLevel(log_level)
46
+ # add file handler
47
+ # file_handler = logging.FileHandler(log_file, 'w')
48
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
49
+ file_handler.setFormatter(logging.Formatter(format_str))
50
+ file_handler.setLevel(log_level)
51
+ logger.addHandler(file_handler)
52
+ initialized_logger[logger_name] = True
53
+ return logger
54
+
55
+
56
+ IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
57
+ torch.__version__)[0][:3])] >= [1, 12, 0]
58
+
59
+ def gpu_is_available():
60
+ if IS_HIGH_VERSION:
61
+ if torch.backends.mps.is_available():
62
+ return True
63
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
64
+
65
+ def get_device(gpu_id=None):
66
+ if gpu_id is None:
67
+ gpu_str = ''
68
+ elif isinstance(gpu_id, int):
69
+ gpu_str = f':{gpu_id}'
70
+ else:
71
+ raise TypeError('Input should be int value.')
72
+
73
+ if IS_HIGH_VERSION:
74
+ if torch.backends.mps.is_available():
75
+ return torch.device('mps'+gpu_str)
76
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
77
+
78
+
79
+ def set_random_seed(seed):
80
+ """Set random seeds."""
81
+ random.seed(seed)
82
+ np.random.seed(seed)
83
+ torch.manual_seed(seed)
84
+ torch.cuda.manual_seed(seed)
85
+ torch.cuda.manual_seed_all(seed)
86
+
87
+
88
+ def get_time_str():
89
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
90
+
91
+
92
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
93
+ """Scan a directory to find the interested files.
94
+
95
+ Args:
96
+ dir_path (str): Path of the directory.
97
+ suffix (str | tuple(str), optional): File suffix that we are
98
+ interested in. Default: None.
99
+ recursive (bool, optional): If set to True, recursively scan the
100
+ directory. Default: False.
101
+ full_path (bool, optional): If set to True, include the dir_path.
102
+ Default: False.
103
+
104
+ Returns:
105
+ A generator for all the interested files with relative pathes.
106
+ """
107
+
108
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
109
+ raise TypeError('"suffix" must be a string or tuple of strings')
110
+
111
+ root = dir_path
112
+
113
+ def _scandir(dir_path, suffix, recursive):
114
+ for entry in os.scandir(dir_path):
115
+ if not entry.name.startswith('.') and entry.is_file():
116
+ if full_path:
117
+ return_path = entry.path
118
+ else:
119
+ return_path = osp.relpath(entry.path, root)
120
+
121
+ if suffix is None:
122
+ yield return_path
123
+ elif return_path.endswith(suffix):
124
+ yield return_path
125
+ else:
126
+ if recursive:
127
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
128
+ else:
129
+ continue
130
+
131
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
preprocessing/matanyone/tools/painter.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # paint masks, contours, or points on images, with specified colors
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import copy
7
+ import time
8
+
9
+
10
+ def colormap(rgb=True):
11
+ color_list = np.array(
12
+ [
13
+ 0.000, 0.000, 0.000,
14
+ 1.000, 1.000, 1.000,
15
+ 1.000, 0.498, 0.313,
16
+ 0.392, 0.581, 0.929,
17
+ 0.000, 0.447, 0.741,
18
+ 0.850, 0.325, 0.098,
19
+ 0.929, 0.694, 0.125,
20
+ 0.494, 0.184, 0.556,
21
+ 0.466, 0.674, 0.188,
22
+ 0.301, 0.745, 0.933,
23
+ 0.635, 0.078, 0.184,
24
+ 0.300, 0.300, 0.300,
25
+ 0.600, 0.600, 0.600,
26
+ 1.000, 0.000, 0.000,
27
+ 1.000, 0.500, 0.000,
28
+ 0.749, 0.749, 0.000,
29
+ 0.000, 1.000, 0.000,
30
+ 0.000, 0.000, 1.000,
31
+ 0.667, 0.000, 1.000,
32
+ 0.333, 0.333, 0.000,
33
+ 0.333, 0.667, 0.000,
34
+ 0.333, 1.000, 0.000,
35
+ 0.667, 0.333, 0.000,
36
+ 0.667, 0.667, 0.000,
37
+ 0.667, 1.000, 0.000,
38
+ 1.000, 0.333, 0.000,
39
+ 1.000, 0.667, 0.000,
40
+ 1.000, 1.000, 0.000,
41
+ 0.000, 0.333, 0.500,
42
+ 0.000, 0.667, 0.500,
43
+ 0.000, 1.000, 0.500,
44
+ 0.333, 0.000, 0.500,
45
+ 0.333, 0.333, 0.500,
46
+ 0.333, 0.667, 0.500,
47
+ 0.333, 1.000, 0.500,
48
+ 0.667, 0.000, 0.500,
49
+ 0.667, 0.333, 0.500,
50
+ 0.667, 0.667, 0.500,
51
+ 0.667, 1.000, 0.500,
52
+ 1.000, 0.000, 0.500,
53
+ 1.000, 0.333, 0.500,
54
+ 1.000, 0.667, 0.500,
55
+ 1.000, 1.000, 0.500,
56
+ 0.000, 0.333, 1.000,
57
+ 0.000, 0.667, 1.000,
58
+ 0.000, 1.000, 1.000,
59
+ 0.333, 0.000, 1.000,
60
+ 0.333, 0.333, 1.000,
61
+ 0.333, 0.667, 1.000,
62
+ 0.333, 1.000, 1.000,
63
+ 0.667, 0.000, 1.000,
64
+ 0.667, 0.333, 1.000,
65
+ 0.667, 0.667, 1.000,
66
+ 0.667, 1.000, 1.000,
67
+ 1.000, 0.000, 1.000,
68
+ 1.000, 0.333, 1.000,
69
+ 1.000, 0.667, 1.000,
70
+ 0.167, 0.000, 0.000,
71
+ 0.333, 0.000, 0.000,
72
+ 0.500, 0.000, 0.000,
73
+ 0.667, 0.000, 0.000,
74
+ 0.833, 0.000, 0.000,
75
+ 1.000, 0.000, 0.000,
76
+ 0.000, 0.167, 0.000,
77
+ 0.000, 0.333, 0.000,
78
+ 0.000, 0.500, 0.000,
79
+ 0.000, 0.667, 0.000,
80
+ 0.000, 0.833, 0.000,
81
+ 0.000, 1.000, 0.000,
82
+ 0.000, 0.000, 0.167,
83
+ 0.000, 0.000, 0.333,
84
+ 0.000, 0.000, 0.500,
85
+ 0.000, 0.000, 0.667,
86
+ 0.000, 0.000, 0.833,
87
+ 0.000, 0.000, 1.000,
88
+ 0.143, 0.143, 0.143,
89
+ 0.286, 0.286, 0.286,
90
+ 0.429, 0.429, 0.429,
91
+ 0.571, 0.571, 0.571,
92
+ 0.714, 0.714, 0.714,
93
+ 0.857, 0.857, 0.857
94
+ ]
95
+ ).astype(np.float32)
96
+ color_list = color_list.reshape((-1, 3)) * 255
97
+ if not rgb:
98
+ color_list = color_list[:, ::-1]
99
+ return color_list
100
+
101
+
102
+ color_list = colormap()
103
+ color_list = color_list.astype('uint8').tolist()
104
+
105
+
106
+ def vis_add_mask(image, mask, color, alpha):
107
+ color = np.array(color_list[color])
108
+ mask = mask > 0.5
109
+ image[mask] = image[mask] * (1-alpha) + color * alpha
110
+ return image.astype('uint8')
111
+
112
+ def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5):
113
+ h, w = input_image.shape[:2]
114
+ point_mask = np.zeros((h, w)).astype('uint8')
115
+ for point in input_points:
116
+ point_mask[point[1], point[0]] = 1
117
+
118
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
119
+ point_mask = cv2.dilate(point_mask, kernel)
120
+
121
+ contour_radius = (contour_width - 1) // 2
122
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
123
+ dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3)
124
+ dist_map = dist_transform_fore - dist_transform_back
125
+ # ...:::!!!:::...
126
+ contour_radius += 2
127
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
128
+ contour_mask = contour_mask / np.max(contour_mask)
129
+ contour_mask[contour_mask>0.5] = 1.
130
+
131
+ # paint mask
132
+ painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha)
133
+ # paint contour
134
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
135
+ return painted_image
136
+
137
+ def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3):
138
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
139
+ # 0: background, 1: foreground
140
+ mask = np.clip(input_mask, 0, 1)
141
+ contour_radius = (contour_width - 1) // 2
142
+
143
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
144
+ dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
145
+ dist_map = dist_transform_fore - dist_transform_back
146
+ # ...:::!!!:::...
147
+ contour_radius += 2
148
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
149
+ contour_mask = contour_mask / np.max(contour_mask)
150
+ contour_mask[contour_mask>0.5] = 1.
151
+
152
+ # paint mask
153
+ painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha)
154
+ # paint contour
155
+ painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1)
156
+
157
+ return painted_image
158
+
159
+ def background_remover(input_image, input_mask):
160
+ """
161
+ input_image: H, W, 3, np.array
162
+ input_mask: H, W, np.array
163
+
164
+ image_wo_background: PIL.Image
165
+ """
166
+ assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask'
167
+ # 0: background, 1: foreground
168
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255
169
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
170
+ image_wo_background = Image.fromarray(image_wo_background).convert('RGBA')
171
+
172
+ return image_wo_background
173
+
174
+ if __name__ == '__main__':
175
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
176
+ input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P'))
177
+
178
+ # example of mask painter
179
+ mask_color = 3
180
+ mask_alpha = 0.7
181
+ contour_color = 1
182
+ contour_width = 5
183
+
184
+ # save
185
+ painted_image = Image.fromarray(input_image)
186
+ painted_image.save('images/original.png')
187
+
188
+ painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width)
189
+ # save
190
+ painted_image = Image.fromarray(input_image)
191
+ painted_image.save('images/original1.png')
192
+
193
+ # example of point painter
194
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
195
+ input_points = np.array([[500, 375], [70, 600]]) # x, y
196
+ point_color = 5
197
+ point_alpha = 0.9
198
+ point_radius = 15
199
+ contour_color = 2
200
+ contour_width = 5
201
+ painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width)
202
+ # save
203
+ painted_image = Image.fromarray(painted_image_1)
204
+ painted_image.save('images/point_painter_1.png')
205
+
206
+ input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB'))
207
+ painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29)
208
+ # save
209
+ painted_image = Image.fromarray(painted_image_2)
210
+ painted_image.save('images/point_painter_2.png')
211
+
212
+ # example of background remover
213
+ input_image = np.array(Image.open('images/original.png').convert('RGB'))
214
+ image_wo_background = background_remover(input_image, input_mask) # return PIL.Image
215
+ image_wo_background.save('images/image_wo_background.png')
preprocessing/matanyone/utils/__init__.py ADDED
File without changes
preprocessing/matanyone/utils/get_default_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A helper function to get a default model for quick testing
3
+ """
4
+ from omegaconf import open_dict
5
+ from hydra import compose, initialize
6
+
7
+ import torch
8
+ from ..matanyone.model.matanyone import MatAnyone
9
+
10
+ def get_matanyone_model(ckpt_path, device=None) -> MatAnyone:
11
+ initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config")
12
+ cfg = compose(config_name="eval_matanyone_config")
13
+
14
+ with open_dict(cfg):
15
+ cfg['weights'] = ckpt_path
16
+
17
+ # Load the network weights
18
+ if device is not None:
19
+ matanyone = MatAnyone(cfg, single_object=True).to(device).eval()
20
+ model_weights = torch.load(cfg.weights, map_location=device)
21
+ else: # if device is not specified, `.cuda()` by default
22
+ matanyone = MatAnyone(cfg, single_object=True).cuda().eval()
23
+ model_weights = torch.load(cfg.weights)
24
+
25
+ matanyone.load_weights(model_weights)
26
+
27
+ return matanyone
preprocessing/matanyone/utils/tensor_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Iterable
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # STM
7
+ def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
8
+ h, w = in_img.shape[-2:]
9
+
10
+ if h % d > 0:
11
+ new_h = h + d - h % d
12
+ else:
13
+ new_h = h
14
+ if w % d > 0:
15
+ new_w = w + d - w % d
16
+ else:
17
+ new_w = w
18
+ lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
19
+ lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
20
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
21
+ out = F.pad(in_img, pad_array)
22
+ return out, pad_array
23
+
24
+
25
+ def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
26
+ if len(img.shape) == 4:
27
+ if pad[2] + pad[3] > 0:
28
+ img = img[:, :, pad[2]:-pad[3], :]
29
+ if pad[0] + pad[1] > 0:
30
+ img = img[:, :, :, pad[0]:-pad[1]]
31
+ elif len(img.shape) == 3:
32
+ if pad[2] + pad[3] > 0:
33
+ img = img[:, pad[2]:-pad[3], :]
34
+ if pad[0] + pad[1] > 0:
35
+ img = img[:, :, pad[0]:-pad[1]]
36
+ elif len(img.shape) == 5:
37
+ if pad[2] + pad[3] > 0:
38
+ img = img[:, :, :, pad[2]:-pad[3], :]
39
+ if pad[0] + pad[1] > 0:
40
+ img = img[:, :, :, :, pad[0]:-pad[1]]
41
+ else:
42
+ raise NotImplementedError
43
+ return img
44
+
45
+
46
+ # @torch.jit.script
47
+ def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor:
48
+ with torch.amp.autocast("cuda"):
49
+ prob = prob.float()
50
+ new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob],
51
+ dim).clamp(1e-7, 1 - 1e-7)
52
+ logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf)
53
+
54
+ return logits
55
+
56
+
57
+ # @torch.jit.script
58
+ def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor:
59
+ # cls_gt: B*1*H*W
60
+ B, _, H, W = cls_gt.shape
61
+ one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1)
62
+ return one_hot
requirements.txt CHANGED
@@ -24,4 +24,6 @@ onnxruntime-gpu
24
  rembg[gpu]==2.0.65
25
  matplotlib
26
  timm
 
 
27
  # rembg==2.0.65
 
24
  rembg[gpu]==2.0.65
25
  matplotlib
26
  timm
27
+ segment-anything
28
+ ffmpeg-python
29
  # rembg==2.0.65
wan/modules/model.py CHANGED
@@ -482,7 +482,6 @@ class WanAttentionBlock(nn.Module):
482
  y *= 1 + e[4]
483
  y += e[3]
484
 
485
-
486
  ffn = self.ffn[0]
487
  gelu = self.ffn[1]
488
  ffn2= self.ffn[2]
@@ -500,8 +499,6 @@ class WanAttentionBlock(nn.Module):
500
 
501
  x.addcmul_(y, e[5])
502
 
503
-
504
-
505
  if hint is not None:
506
  if context_scale == 1:
507
  x.add_(hint)
@@ -539,24 +536,13 @@ class VaceWanAttentionBlock(WanAttentionBlock):
539
  c = hints[0]
540
  hints[0] = None
541
  if self.block_id == 0:
542
- c = self.before_proj(c) + x
 
543
  c = super().forward(c, **kwargs)
544
  c_skip = self.after_proj(c)
545
  hints[0] = c
546
  return c_skip
547
 
548
- # def forward(self, c, x, **kwargs):
549
- # # behold dbm magic !
550
- # if self.block_id == 0:
551
- # c = self.before_proj(c) + x
552
- # all_c = []
553
- # else:
554
- # all_c = c
555
- # c = all_c.pop(-1)
556
- # c = super().forward(c, **kwargs)
557
- # c_skip = self.after_proj(c)
558
- # all_c += [c_skip, c]
559
- # return all_c
560
 
561
  class Head(nn.Module):
562
 
@@ -793,37 +779,6 @@ class WanModel(ModelMixin, ConfigMixin):
793
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
794
  return best_threshold
795
 
796
-
797
-
798
- # def forward_vace(
799
- # self,
800
- # x,
801
- # vace_context,
802
- # seq_len,
803
- # context,
804
- # e,
805
- # kwargs
806
- # ):
807
- # # embeddings
808
- # c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
809
- # c = [u.flatten(2).transpose(1, 2) for u in c]
810
- # if (len(c) == 1 and seq_len == c[0].size(1)):
811
- # c = c[0]
812
- # else:
813
- # c = torch.cat([
814
- # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
815
- # dim=1) for u in c
816
- # ])
817
-
818
- # # arguments
819
- # new_kwargs = dict(x=x)
820
- # new_kwargs.update(kwargs)
821
-
822
- # for block in self.vace_blocks:
823
- # c = block(c, context= context, e= e, **new_kwargs)
824
- # hints = c[:-1]
825
-
826
- # return hints
827
 
828
  def forward(
829
  self,
 
482
  y *= 1 + e[4]
483
  y += e[3]
484
 
 
485
  ffn = self.ffn[0]
486
  gelu = self.ffn[1]
487
  ffn2= self.ffn[2]
 
499
 
500
  x.addcmul_(y, e[5])
501
 
 
 
502
  if hint is not None:
503
  if context_scale == 1:
504
  x.add_(hint)
 
536
  c = hints[0]
537
  hints[0] = None
538
  if self.block_id == 0:
539
+ c = self.before_proj(c)
540
+ c += x
541
  c = super().forward(c, **kwargs)
542
  c_skip = self.after_proj(c)
543
  hints[0] = c
544
  return c_skip
545
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
  class Head(nn.Module):
548
 
 
779
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
780
  return best_threshold
781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
 
783
  def forward(
784
  self,
wan/text2video.py CHANGED
@@ -209,8 +209,9 @@ class WanT2V:
209
  def vace_latent(self, z, m):
210
  return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
211
 
212
- def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, trim_video= 0):
213
  image_sizes = []
 
214
  for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
215
  if sub_src_mask is not None and sub_src_video is not None:
216
  src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
@@ -237,6 +238,10 @@ class WanT2V:
237
  src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
238
  src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
239
  image_sizes.append(src_video[i].shape[2:])
 
 
 
 
240
 
241
  for i, ref_images in enumerate(src_ref_images):
242
  if ref_images is not None:
 
209
  def vace_latent(self, z, m):
210
  return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
211
 
212
+ def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, keep_frames= []):
213
  image_sizes = []
214
+ trim_video = len(keep_frames)
215
  for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
216
  if sub_src_mask is not None and sub_src_video is not None:
217
  src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
 
238
  src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
239
  src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
240
  image_sizes.append(src_video[i].shape[2:])
241
+ for k, keep in enumerate(keep_frames):
242
+ if not keep:
243
+ src_video[i][:, k:k+1] = 0
244
+ src_mask[i][:, k:k+1] = 1
245
 
246
  for i, ref_images in enumerate(src_ref_images):
247
  if ref_images is not None:
wan/utils/utils.py CHANGED
@@ -37,12 +37,11 @@ def resample(video_fps, video_frames_count, max_frames, target_fps):
37
  break
38
  add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
39
  frame_no += add_frames_count
 
 
40
  frame_ids.append(frame_no)
41
  cur_time += add_frames_count * video_frame_duration
42
  target_time += target_frame_duration
43
- if frame_no >= video_frames_count -1:
44
- break
45
- frame_ids = frame_ids[:video_frames_count]
46
  return frame_ids
47
 
48
  def get_video_frame(file_name, frame_no):
 
37
  break
38
  add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
39
  frame_no += add_frames_count
40
+ if frame_no >= video_frames_count:
41
+ break
42
  frame_ids.append(frame_no)
43
  cur_time += add_frames_count * video_frame_duration
44
  target_time += target_frame_duration
 
 
 
45
  return frame_ids
46
 
47
  def get_video_frame(file_name, frame_no):
wan/utils/vace_preprocessor.py CHANGED
@@ -254,7 +254,7 @@ class VaceVideoProcessor(object):
254
 
255
  if src_video != None:
256
  fps = 16
257
- length = src_video.shape[1]
258
  if len(readers) > 0:
259
  min_readers = min([len(r) for r in readers])
260
  length = min(length, min_readers )
 
254
 
255
  if src_video != None:
256
  fps = 16
257
+ length = src_video.shape[0]
258
  if len(readers) > 0:
259
  min_readers = min([len(r) for r in readers])
260
  length = min(length, min_readers )
wgp.py CHANGED
@@ -153,7 +153,7 @@ def process_prompt_and_add_tasks(state, model_choice):
153
  if "Vace" in model_filename and "1.3B" in model_filename :
154
  resolution_reformated = str(height) + "*" + str(width)
155
  if not resolution_reformated in VACE_SIZE_CONFIGS:
156
- res = VACE_SIZE_CONFIGS.keys().join(" and ")
157
  gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
158
  return
159
  if "I" in video_prompt_type:
@@ -175,12 +175,19 @@ def process_prompt_and_add_tasks(state, model_choice):
175
  else:
176
  video_mask = None
177
  if "O" in video_prompt_type :
178
- max_frames= inputs["max_frames"]
179
  video_length = inputs["video_length"]
180
- if max_frames ==0:
181
  gr.Info(f"Warning : you have asked to reuse all the frames of the control Video in the Alternate Video Ending it. Please make sure the number of frames of the control Video is lower than the total number of frames to generate otherwise it won't make a difference.")
182
- elif max_frames >= video_length:
183
- gr.Info(f"The number of frames in the control Video to reuse ({max_frames}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.")
 
 
 
 
 
 
 
184
  return
185
 
186
  if isinstance(image_refs, list):
@@ -1540,8 +1547,8 @@ def download_models(transformer_filename, text_encoder_filename):
1540
 
1541
  from huggingface_hub import hf_hub_download, snapshot_download
1542
  repoId = "DeepBeepMeep/Wan2.1"
1543
- sourceFolderList = ["xlm-roberta-large", "pose", "depth", "", ]
1544
- fileList = [ [], [],[], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
1545
  targetRoot = "ckpts/"
1546
  for sourceFolder, files in zip(sourceFolderList,fileList ):
1547
  if len(files)==0:
@@ -1782,25 +1789,6 @@ def get_model_name(model_filename):
1782
 
1783
  return model_name
1784
 
1785
- # def generate_header(model_filename, compile, attention_mode):
1786
-
1787
- # header = "<div class='title-with-lines'><div class=line></div><h2>"
1788
-
1789
- # model_name = get_model_name(model_filename)
1790
-
1791
- # header += model_name
1792
- # header += " (attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
1793
- # if attention_mode not in attention_modes_installed:
1794
- # header += " -NOT INSTALLED-"
1795
- # elif attention_mode not in attention_modes_supported:
1796
- # header += " -NOT SUPPORTED-"
1797
-
1798
- # if compile:
1799
- # header += ", pytorch compilation ON"
1800
- # header += ") </h2><div class=line></div> "
1801
-
1802
-
1803
- # return header
1804
 
1805
 
1806
  def generate_header(model_filename, compile, attention_mode):
@@ -2122,6 +2110,57 @@ def preprocess_video(process_type, height, width, video_in, max_frames):
2122
 
2123
  return torch.stack(torch_frames)
2124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2125
  def generate_video(
2126
  task_id,
2127
  progress,
@@ -2147,7 +2186,7 @@ def generate_video(
2147
  image_refs,
2148
  video_guide,
2149
  video_mask,
2150
- max_frames,
2151
  remove_background_image_ref,
2152
  temporal_upsampling,
2153
  spatial_upsampling,
@@ -2325,12 +2364,16 @@ def generate_video(
2325
  gen["progress_args"] = progress_args
2326
  video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
2327
  image_refs = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
 
 
 
 
2328
  src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
2329
  [video_mask],
2330
  [image_refs],
2331
  video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
2332
  original_video= "O" in video_prompt_type,
2333
- trim_video=max_frames)
2334
  else:
2335
  src_video, src_mask, src_ref_images = None, None, None
2336
 
@@ -2635,6 +2678,8 @@ def process_tasks(state, progress=gr.Progress()):
2635
  finally:
2636
  if not ok:
2637
  queue.clear()
 
 
2638
  yield status
2639
 
2640
  queue[:] = [item for item in queue if item['id'] != task['id']]
@@ -3014,7 +3059,7 @@ def prepare_inputs_dict(target, inputs ):
3014
 
3015
 
3016
  if not "Vace" in model_filename:
3017
- unsaved_params = ["video_prompt_type", "max_frames", "remove_background_image_ref"]
3018
  for k in unsaved_params:
3019
  inputs.pop(k)
3020
 
@@ -3056,7 +3101,7 @@ def save_inputs(
3056
  image_refs,
3057
  video_guide,
3058
  video_mask,
3059
- max_frames,
3060
  remove_background_image_ref,
3061
  temporal_upsampling,
3062
  spatial_upsampling,
@@ -3246,6 +3291,13 @@ def refresh_video_prompt_type_video_guide(video_prompt_type, video_prompt_type_v
3246
  visible = "V" in video_prompt_type
3247
  return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type )
3248
 
 
 
 
 
 
 
 
3249
 
3250
  def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
3251
  global inputs_names #, advanced
@@ -3365,12 +3417,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3365
  ("Transfer Depth from the Control Video", "DV"),
3366
  ("Recolorize the Control Video", "CV"),
3367
  # ("Alternate Video Ending", "OV"),
3368
- ("(adv) Video contains Open Pose, Depth or Black & White ", "V"),
3369
- ("(adv) Inpainting of Control Video using Mask Video ", "MV"),
3370
  ],
3371
  value=filter_letters(video_prompt_type_value, "ODPCMV"),
3372
  label="Video to Video", scale = 3
3373
  )
 
3374
 
3375
  video_prompt_type_image_refs = gr.Dropdown(
3376
  choices=[
@@ -3384,8 +3437,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3384
  # video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
3385
 
3386
  video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
3387
- max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
3388
-
3389
  image_refs = gr.Gallery( label ="Reference Images",
3390
  type ="pil", show_label= True,
3391
  columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
@@ -3798,9 +3851,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3798
  target_settings = gr.Text(value = "settings", interactive= False, visible= False)
3799
 
3800
  image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
3801
- # video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
3802
  video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
3803
- video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, max_frames, video_mask])
3804
 
3805
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
3806
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
@@ -3903,12 +3956,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3903
  return (
3904
  loras_choices, lset_name, state, queue_df, current_gen_column,
3905
  gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
3906
- gen_info,
3907
- prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
3908
- prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
3909
- advanced_row, image_prompt_column, video_prompt_column, queue_accordion,
3910
- *prompt_vars
3911
- )
3912
 
3913
 
3914
  def generate_download_tab(lset_name,loras_choices, state):
@@ -4132,8 +4181,30 @@ def generate_dropdown_model_list():
4132
  )
4133
 
4134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4135
 
4136
  def create_demo():
 
4137
  css = """
4138
  #model_list{
4139
  background-color:black;
@@ -4370,6 +4441,8 @@ def create_demo():
4370
  gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
4371
  global model_list
4372
 
 
 
4373
  with gr.Tabs(selected="video_gen", ) as main_tabs:
4374
  with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
4375
  with gr.Row():
@@ -4386,14 +4459,15 @@ def create_demo():
4386
  (
4387
  loras_choices, lset_name, state, queue_df, current_gen_column,
4388
  gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
4389
- gen_info,
4390
- prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
4391
- prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
4392
- advanced_row, image_prompt_column, video_prompt_column, queue_accordion,
4393
- *prompt_vars_outputs
4394
  ) = generate_video_tab(model_choice=model_choice, header=header)
4395
  with gr.Tab("Informations"):
4396
  generate_info_tab()
 
 
 
 
 
4397
  if not args.lock_config:
4398
  with gr.Tab("Downloads", id="downloads") as downloads_tab:
4399
  generate_download_tab(lset_name, loras_choices, state)
@@ -4420,6 +4494,7 @@ def create_demo():
4420
  trigger_mode="always_last"
4421
  )
4422
 
 
4423
  return demo
4424
 
4425
  if __name__ == "__main__":
 
153
  if "Vace" in model_filename and "1.3B" in model_filename :
154
  resolution_reformated = str(height) + "*" + str(width)
155
  if not resolution_reformated in VACE_SIZE_CONFIGS:
156
+ res = (" and ").join(VACE_SIZE_CONFIGS.keys())
157
  gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
158
  return
159
  if "I" in video_prompt_type:
 
175
  else:
176
  video_mask = None
177
  if "O" in video_prompt_type :
178
+ keep_frames= inputs["keep_frames"]
179
  video_length = inputs["video_length"]
180
+ if len(keep_frames) ==0:
181
  gr.Info(f"Warning : you have asked to reuse all the frames of the control Video in the Alternate Video Ending it. Please make sure the number of frames of the control Video is lower than the total number of frames to generate otherwise it won't make a difference.")
182
+ # elif keep_frames >= video_length:
183
+ # gr.Info(f"The number of frames in the control Video to reuse ({keep_frames}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.")
184
+ # return
185
+ elif "V" in video_prompt_type:
186
+ keep_frames= inputs["keep_frames"]
187
+ video_length = inputs["video_length"]
188
+ _, error = parse_keep_frames(keep_frames, video_length)
189
+ if len(error) > 0:
190
+ gr.Info(f"Invalid Keep Frames property: {error}")
191
  return
192
 
193
  if isinstance(image_refs, list):
 
1547
 
1548
  from huggingface_hub import hf_hub_download, snapshot_download
1549
  repoId = "DeepBeepMeep/Wan2.1"
1550
+ sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
1551
+ fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
1552
  targetRoot = "ckpts/"
1553
  for sourceFolder, files in zip(sourceFolderList,fileList ):
1554
  if len(files)==0:
 
1789
 
1790
  return model_name
1791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1792
 
1793
 
1794
  def generate_header(model_filename, compile, attention_mode):
 
2110
 
2111
  return torch.stack(torch_frames)
2112
 
2113
+ def parse_keep_frames(keep_frames, video_length):
2114
+ def is_integer(n):
2115
+ try:
2116
+ float(n)
2117
+ except ValueError:
2118
+ return False
2119
+ else:
2120
+ return float(n).is_integer()
2121
+
2122
+ def absolute(n):
2123
+ if n==0:
2124
+ return 0
2125
+ elif n < 0:
2126
+ return max(0, video_length + n)
2127
+ else:
2128
+ return min(n-1, video_length-1)
2129
+
2130
+ if len(keep_frames) == 0:
2131
+ return [True] *video_length, ""
2132
+ frames =[False] *video_length
2133
+ error = ""
2134
+ sections = keep_frames.split(" ")
2135
+ for section in sections:
2136
+ section = section.strip()
2137
+ if ":" in section:
2138
+ parts = section.split(":")
2139
+ if not is_integer(parts[0]):
2140
+ error =f"Invalid integer {parts[0]}"
2141
+ break
2142
+ start_range = absolute(int(parts[0]))
2143
+ if not is_integer(parts[1]):
2144
+ error =f"Invalid integer {parts[1]}"
2145
+ break
2146
+ end_range = absolute(int(parts[1]))
2147
+ for i in range(start_range, end_range + 1):
2148
+ frames[i] = True
2149
+ else:
2150
+ if not is_integer(section):
2151
+ error =f"Invalid integer {section}"
2152
+ break
2153
+ index = absolute(int(section))
2154
+ frames[index] = True
2155
+
2156
+ if len(error ) > 0:
2157
+ return [], error
2158
+ for i in range(len(frames)-1, 0, -1):
2159
+ if frames[i]:
2160
+ break
2161
+ frames= frames[0: i+1]
2162
+ return frames, error
2163
+
2164
  def generate_video(
2165
  task_id,
2166
  progress,
 
2186
  image_refs,
2187
  video_guide,
2188
  video_mask,
2189
+ keep_frames,
2190
  remove_background_image_ref,
2191
  temporal_upsampling,
2192
  spatial_upsampling,
 
2364
  gen["progress_args"] = progress_args
2365
  video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
2366
  image_refs = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
2367
+ keep_frames_parsed, error = parse_keep_frames(keep_frames, video_length)
2368
+ if len(error) > 0:
2369
+ raise gr.Error(f"invalid keep frames {keep_frames}")
2370
+
2371
  src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
2372
  [video_mask],
2373
  [image_refs],
2374
  video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
2375
  original_video= "O" in video_prompt_type,
2376
+ keep_frames=keep_frames_parsed)
2377
  else:
2378
  src_video, src_mask, src_ref_images = None, None, None
2379
 
 
2678
  finally:
2679
  if not ok:
2680
  queue.clear()
2681
+ gen["prompts_max"] = 0
2682
+ gen["prompt"] = ""
2683
  yield status
2684
 
2685
  queue[:] = [item for item in queue if item['id'] != task['id']]
 
3059
 
3060
 
3061
  if not "Vace" in model_filename:
3062
+ unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref"]
3063
  for k in unsaved_params:
3064
  inputs.pop(k)
3065
 
 
3101
  image_refs,
3102
  video_guide,
3103
  video_mask,
3104
+ keep_frames,
3105
  remove_background_image_ref,
3106
  temporal_upsampling,
3107
  spatial_upsampling,
 
3291
  visible = "V" in video_prompt_type
3292
  return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type )
3293
 
3294
+ def refresh_video_prompt_video_guide_trigger(video_prompt_type, video_prompt_type_video_guide):
3295
+ video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
3296
+ video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV")
3297
+ video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
3298
+
3299
+ return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= "V" in video_prompt_type ), gr.update(visible= "M" in video_prompt_type) , gr.update(visible= "V" in video_prompt_type )
3300
+
3301
 
3302
  def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
3303
  global inputs_names #, advanced
 
3417
  ("Transfer Depth from the Control Video", "DV"),
3418
  ("Recolorize the Control Video", "CV"),
3419
  # ("Alternate Video Ending", "OV"),
3420
+ ("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"),
3421
+ ("Control Video and Mask video for stronger Inpainting ", "MV"),
3422
  ],
3423
  value=filter_letters(video_prompt_type_value, "ODPCMV"),
3424
  label="Video to Video", scale = 3
3425
  )
3426
+ video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
3427
 
3428
  video_prompt_type_image_refs = gr.Dropdown(
3429
  choices=[
 
3437
  # video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
3438
 
3439
  video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
3440
+ # keep_frames = gr.Slider(0, 100, value=ui_defaults.get("keep_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
3441
+ keep_frames = gr.Text(value=ui_defaults.get("keep_frames","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
3442
  image_refs = gr.Gallery( label ="Reference Images",
3443
  type ="pil", show_label= True,
3444
  columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
 
3851
  target_settings = gr.Text(value = "settings", interactive= False, visible= False)
3852
 
3853
  image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
3854
+ video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, video_mask, keep_frames])
3855
  video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
3856
+ video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames, video_mask])
3857
 
3858
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
3859
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
 
3956
  return (
3957
  loras_choices, lset_name, state, queue_df, current_gen_column,
3958
  gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
3959
+ gen_info, queue_accordion, video_guide, video_mask, video_prompt_video_guide_trigger
3960
+ )
 
 
 
 
3961
 
3962
 
3963
  def generate_download_tab(lset_name,loras_choices, state):
 
4181
  )
4182
 
4183
 
4184
+ def select_tab(tab_state, evt:gr.SelectData):
4185
+ tab_video_mask_creator = 2
4186
+
4187
+ old_tab_no = tab_state.get("tab_no",0)
4188
+ new_tab_no = evt.index
4189
+ if old_tab_no == tab_video_mask_creator:
4190
+ vmc_event_handler(False)
4191
+ elif new_tab_no == tab_video_mask_creator:
4192
+ if gen_in_progress:
4193
+ gr.Info("Unable to access this Tab while a Generation is in Progress. Please come back later")
4194
+ tab_state["tab_auto"]=old_tab_no
4195
+ else:
4196
+ vmc_event_handler(True)
4197
+ tab_state["tab_no"] = new_tab_no
4198
+ def select_tab_auto(tab_state):
4199
+ old_tab_no = tab_state.pop("tab_auto", -1)
4200
+ if old_tab_no>= 0:
4201
+ tab_state["tab_auto"]=old_tab_no
4202
+ return gr.Tabs(selected=old_tab_no) # !! doesnt work !!
4203
+ return gr.Tab()
4204
+
4205
 
4206
  def create_demo():
4207
+ global vmc_event_handler
4208
  css = """
4209
  #model_list{
4210
  background-color:black;
 
4441
  gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
4442
  global model_list
4443
 
4444
+ tab_state = gr.State({ "tab_no":0 })
4445
+
4446
  with gr.Tabs(selected="video_gen", ) as main_tabs:
4447
  with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
4448
  with gr.Row():
 
4459
  (
4460
  loras_choices, lset_name, state, queue_df, current_gen_column,
4461
  gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
4462
+ gen_info, queue_accordion, video_guide, video_mask, video_prompt_type_video_trigger
 
 
 
 
4463
  ) = generate_video_tab(model_choice=model_choice, header=header)
4464
  with gr.Tab("Informations"):
4465
  generate_info_tab()
4466
+ with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
4467
+ from preprocessing.matanyone import app as matanyone_app
4468
+ vmc_event_handler = matanyone_app.get_vmc_event_handler()
4469
+
4470
+ matanyone_app.display(video_guide, video_mask, video_prompt_type_video_trigger)
4471
  if not args.lock_config:
4472
  with gr.Tab("Downloads", id="downloads") as downloads_tab:
4473
  generate_download_tab(lset_name, loras_choices, state)
 
4494
  trigger_mode="always_last"
4495
  )
4496
 
4497
+ main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= None).then(fn=select_tab_auto, inputs= [tab_state], outputs=[main_tabs])
4498
  return demo
4499
 
4500
  if __name__ == "__main__":