PengLiu commited on
Commit
6302644
·
1 Parent(s): c4a1381
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +9 -0
  2. demo/gradio_demo2.py +369 -0
  3. demo/sam3_examples/init.py +0 -0
  4. detect_tools/upn/__init__.py +45 -0
  5. detect_tools/upn/builder.py +39 -0
  6. detect_tools/upn/configs/upn_large.py +73 -0
  7. detect_tools/upn/inference_wrapper.py +237 -0
  8. detect_tools/upn/models/architecture/__init__.py +4 -0
  9. detect_tools/upn/models/architecture/deformable_transformer.py +336 -0
  10. detect_tools/upn/models/architecture/upn_model.py +343 -0
  11. detect_tools/upn/models/backbone/__init__.py +4 -0
  12. detect_tools/upn/models/backbone/swin.py +814 -0
  13. detect_tools/upn/models/backbone/wrapper.py +297 -0
  14. detect_tools/upn/models/decoder/__init__.py +3 -0
  15. detect_tools/upn/models/decoder/upn_decoder.py +378 -0
  16. detect_tools/upn/models/encoder/__init__.py +3 -0
  17. detect_tools/upn/models/encoder/upn_encoder.py +288 -0
  18. detect_tools/upn/models/module/__init__.py +5 -0
  19. detect_tools/upn/models/module/contrastive.py +29 -0
  20. detect_tools/upn/models/module/mlp.py +18 -0
  21. detect_tools/upn/models/module/nested_tensor.py +199 -0
  22. detect_tools/upn/models/utils/__init__.py +23 -0
  23. detect_tools/upn/models/utils/detr_utils.py +415 -0
  24. detect_tools/upn/ops/functions/__init__.py +10 -0
  25. detect_tools/upn/ops/functions/ms_deform_attn_func.py +61 -0
  26. detect_tools/upn/ops/modules/__init__.py +9 -0
  27. detect_tools/upn/ops/modules/ms_deform_attn.py +204 -0
  28. detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py +130 -0
  29. detect_tools/upn/ops/setup.py +73 -0
  30. detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp +41 -0
  31. detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h +33 -0
  32. detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu +153 -0
  33. detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h +30 -0
  34. detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh +1327 -0
  35. detect_tools/upn/ops/src/ms_deform_attn.h +62 -0
  36. detect_tools/upn/ops/src/vision.cpp +16 -0
  37. detect_tools/upn/ops/test.py +89 -0
  38. detect_tools/upn/requirments.txt +1 -0
  39. detect_tools/upn/transforms/transform.py +142 -0
  40. requirements.txt +19 -0
  41. resources/__init__.py +0 -0
  42. vlm_fo1/__init__.py +1 -0
  43. vlm_fo1/constants.py +29 -0
  44. vlm_fo1/mm_utils.py +660 -0
  45. vlm_fo1/model/__init__.py +1 -0
  46. vlm_fo1/model/builder.py +89 -0
  47. vlm_fo1/model/language_model/omchat_qwen2_5_vl.py +576 -0
  48. vlm_fo1/model/multimodal_encoder/__init__.py +0 -0
  49. vlm_fo1/model/multimodal_encoder/base_encoder.py +33 -0
  50. vlm_fo1/model/multimodal_encoder/builder.py +38 -0
app.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+
4
+ @spaces.GPU
5
+ def greet(name):
6
+ return "Hello " + name + "!!"
7
+
8
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
9
+ demo.launch()
demo/gradio_demo2.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import re
5
+ import random
6
+ import numpy as np
7
+ from skimage.measure import label, regionprops
8
+ from skimage.morphology import binary_dilation, disk
9
+ from sam3.model_builder import build_sam3_image_model
10
+ from sam3.model.sam3_image_processor import Sam3Processor
11
+ from sam3.visualization_utils import plot_bbox, plot_mask, COLORS
12
+ import matplotlib.pyplot as plt
13
+
14
+ from detect_tools.upn import UPNWrapper
15
+ from vlm_fo1.model.builder import load_pretrained_model
16
+ from vlm_fo1.mm_utils import (
17
+ prepare_inputs,
18
+ extract_predictions_to_indexes,
19
+ )
20
+ from vlm_fo1.task_templates import *
21
+ import torch
22
+ import os
23
+ from copy import deepcopy
24
+
25
+
26
+ EXAMPLES = [
27
+ ["demo/sam3_examples/00000-72.jpg","airplane with letter AE on its body"],
28
+ ["demo/sam3_examples/00000-32.jpg","the lying cat which is not black"],
29
+ ["demo/sam3_examples/00000-22.jpg","person wearing a black top"],
30
+ ["demo/sam3_examples/000000378453.jpg", "zebra inside the mud puddle"],
31
+ ]
32
+
33
+
34
+ def get_valid_examples():
35
+ valid_examples = []
36
+ demo_dir = os.path.dirname(os.path.abspath(__file__))
37
+ for example in EXAMPLES:
38
+ img_path = example[0]
39
+ full_path = os.path.join(demo_dir, img_path)
40
+ if os.path.exists(full_path):
41
+ valid_examples.append([
42
+ full_path,
43
+ example[1],
44
+ ])
45
+ elif os.path.exists(img_path):
46
+ valid_examples.append([
47
+ img_path,
48
+ example[1],
49
+ ])
50
+ return valid_examples
51
+
52
+
53
+ def detect_model_upn(image, threshold=0.3):
54
+ proposals = upn_model.inference(image)
55
+ filtered_proposals = upn_model.filter(proposals, min_score=threshold)
56
+ picked_proposals = filtered_proposals['original_xyxy_boxes'][0][:100]
57
+ return picked_proposals
58
+
59
+
60
+ def detect_model_sam3(image, text, threshold=0.3):
61
+ inference_state = sam3_processor.set_image(image)
62
+ output = sam3_processor.set_text_prompt(state=inference_state, prompt=text)
63
+ boxes, scores, masks = output["boxes"], output["scores"], output["masks"]
64
+ sorted_indices = torch.argsort(scores, descending=True)
65
+ boxes = boxes[sorted_indices][:100, :]
66
+ scores = scores[sorted_indices][:100]
67
+ masks = masks[sorted_indices][:100]
68
+
69
+ output = {
70
+ "boxes": boxes,
71
+ "scores": scores,
72
+ "masks": masks,
73
+ }
74
+ return boxes.tolist(), scores.tolist(), masks.tolist(), output
75
+
76
+
77
+ def multimodal_model(image, bboxes, text, scores=None):
78
+ if len(bboxes) == 0:
79
+ return None, {}, []
80
+
81
+ if '<image>' in text:
82
+ print(text)
83
+ parts = [part.replace('\\n', '\n') for part in re.split(rf'(<image>)', text) if part.strip()]
84
+ print(parts)
85
+ content = []
86
+ for part in parts:
87
+ if part == '<image>':
88
+ content.append({"type": "image_url", "image_url": {"url": image}})
89
+ else:
90
+ content.append({"type": "text", "text": part})
91
+ else:
92
+ content = [{
93
+ "type": "image_url",
94
+ "image_url": {
95
+ "url": image
96
+ }
97
+ }, {
98
+ "type": "text",
99
+ "text": text
100
+ }]
101
+
102
+ messages = [
103
+ {
104
+ "role": "user",
105
+ "content": content,
106
+ "bbox_list": bboxes
107
+ }
108
+ ]
109
+ generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages,
110
+ max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False, image_size=1024)
111
+ with torch.inference_mode():
112
+ output_ids = model.generate(**generation_kwargs)
113
+ outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
114
+ print("========output========\n", outputs)
115
+
116
+ if '<ground>' in outputs:
117
+ prediction_dict = extract_predictions_to_indexes(outputs)
118
+ else:
119
+ match_pattern = r"<region(\d+)>"
120
+ matches = re.findall(match_pattern, outputs)
121
+ prediction_dict = {f"<region{m}>": {int(m)} for m in matches}
122
+
123
+ ans_bbox_json = []
124
+ ans_bbox_list = []
125
+ for k, v in prediction_dict.items():
126
+ for box_index in v:
127
+ box_index = int(box_index)
128
+ if box_index < len(bboxes):
129
+ current_bbox = bboxes[box_index]
130
+ item = {
131
+ "region_index": f"<region{box_index}>",
132
+ "xmin": current_bbox[0],
133
+ "ymin": current_bbox[1],
134
+ "xmax": current_bbox[2],
135
+ "ymax": current_bbox[3],
136
+ "label": k,
137
+ }
138
+ if scores is not None and box_index < len(scores):
139
+ item["score"] = scores[box_index]
140
+
141
+ ans_bbox_json.append(item)
142
+ ans_bbox_list.append(current_bbox)
143
+
144
+ return outputs, ans_bbox_json, ans_bbox_list
145
+
146
+
147
+ def draw_sam3_results(img, results):
148
+ fig, ax = plt.subplots(figsize=(12, 8))
149
+ # fig.subplots_adjust(0, 0, 1, 1)
150
+ ax.imshow(img)
151
+ nb_objects = len(results["scores"])
152
+ print(f"found {nb_objects} object(s)")
153
+ for i in range(nb_objects):
154
+ color = COLORS[i % len(COLORS)]
155
+ plot_mask(results["masks"][i].squeeze(0).cpu(), color=color)
156
+ w, h = img.size
157
+ prob = results["scores"][i].item()
158
+ plot_bbox(
159
+ h,
160
+ w,
161
+ results["boxes"][i].cpu(),
162
+ text=f"(id={i}, {prob=:.2f})",
163
+ box_format="XYXY",
164
+ color=color,
165
+ relative_coords=False,
166
+ )
167
+ ax.axis("off")
168
+ fig.tight_layout(pad=0)
169
+
170
+ # Convert matplotlib figure to PIL Image
171
+ fig.canvas.draw()
172
+ buf = fig.canvas.buffer_rgba()
173
+ pil_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf)
174
+ plt.close(fig)
175
+
176
+ return pil_img
177
+
178
+
179
+ def draw_bboxes_simple(image, bboxes, labels=None):
180
+ image = image.copy()
181
+ draw = ImageDraw.Draw(image)
182
+
183
+ for bbox in bboxes:
184
+ draw.rectangle(bbox, outline="red", width=3)
185
+ return image
186
+
187
+ @spaces.GPU
188
+ def process(image, prompt, threshold=0.3):
189
+ if image is None:
190
+ error_msg = "Error: Please upload an image or select a valid example."
191
+ print(f"Error: image is None, original input type: {type(image)}")
192
+ return None, None, None, None, [], []
193
+
194
+ try:
195
+ image = image.convert('RGB')
196
+ except Exception as e:
197
+ error_msg = f"Error: Cannot process image - {str(e)}"
198
+ return None, None, None, None, [], []
199
+
200
+ # --- SAM3 Pipeline ---
201
+ print("Running SAM3 Pipeline...")
202
+ sam3_bboxes, sam3_scores, masks, sam3_output = detect_model_sam3(image, prompt, threshold)
203
+
204
+ # Generate SAM3 outputs (Directly from SAM3, no VLM-FO1)
205
+ sam3_detection_image = draw_sam3_results(image, sam3_output)
206
+
207
+ sam3_annotated_bboxes = []
208
+ sam3_ans_bbox_json = []
209
+
210
+ img_width, img_height = image.size
211
+ for i, bbox in enumerate(sam3_bboxes):
212
+ xmin = max(0, min(img_width, int(bbox[0])))
213
+ ymin = max(0, min(img_height, int(bbox[1])))
214
+ xmax = max(0, min(img_width, int(bbox[2])))
215
+ ymax = max(0, min(img_height, int(bbox[3])))
216
+ score = sam3_scores[i]
217
+
218
+ # Format label with score
219
+ label_text = f"{prompt} {score:.2f}"
220
+
221
+ sam3_annotated_bboxes.append(
222
+ ((xmin, ymin, xmax, ymax), label_text)
223
+ )
224
+
225
+ sam3_ans_bbox_json.append({
226
+ "region_index": i,
227
+ "xmin": bbox[0],
228
+ "ymin": bbox[1],
229
+ "xmax": bbox[2],
230
+ "ymax": bbox[3],
231
+ "label": prompt,
232
+ "score": score
233
+ })
234
+
235
+ sam3_annotated_image = (image, sam3_annotated_bboxes)
236
+
237
+ # --- UPN Pipeline ---
238
+ print("Running UPN Pipeline...")
239
+ upn_bboxes = detect_model_upn(image, threshold=0.3) # Use default threshold for UPN
240
+
241
+ fo1_prompt_upn = OD_template.format(prompt)
242
+ upn_bboxes = upn_bboxes[::-1]
243
+ upn_ans, upn_ans_bbox_json, upn_ans_bbox_list = multimodal_model(image, upn_bboxes, fo1_prompt_upn)
244
+
245
+ upn_detection_image = draw_bboxes_simple(image, upn_bboxes)
246
+
247
+ upn_annotated_bboxes = []
248
+ if len(upn_ans_bbox_json) > 0:
249
+ img_width, img_height = image.size
250
+ for item in upn_ans_bbox_json:
251
+ xmin = max(0, min(img_width, int(item['xmin'])))
252
+ ymin = max(0, min(img_height, int(item['ymin'])))
253
+ xmax = max(0, min(img_width, int(item['xmax'])))
254
+ ymax = max(0, min(img_height, int(item['ymax'])))
255
+ upn_annotated_bboxes.append(
256
+ ((xmin, ymin, xmax, ymax), item['label'])
257
+ )
258
+ upn_annotated_image = (image, upn_annotated_bboxes)
259
+
260
+
261
+ return sam3_annotated_image, sam3_detection_image, \
262
+ upn_annotated_image, upn_detection_image, upn_ans_bbox_json
263
+
264
+
265
+ def update_btn(is_processing):
266
+ if is_processing:
267
+ return gr.update(value="Processing...", interactive=False)
268
+ else:
269
+ return gr.update(value="Submit", interactive=True)
270
+
271
+
272
+ def launch_demo():
273
+ with gr.Blocks() as demo:
274
+ gr.Markdown("# 🚀 VLM-FO1 vs SAM3 Demo")
275
+ gr.Markdown("""
276
+ ### 📋 Instructions
277
+ Compare the detection performance of **SAM3** vs **VLM-FO1**.
278
+
279
+ **How it works**
280
+ 1. Upload or pick an example image.
281
+ 2. Describe the target object in natural language.
282
+ 3. Hit **Submit** to run both pipelines.
283
+ """)
284
+
285
+ with gr.Row():
286
+ with gr.Column():
287
+ img_input_draw = gr.Image(
288
+ label="Image Input",
289
+ type="pil",
290
+ sources=['upload'],
291
+ )
292
+
293
+ gr.Markdown("### Prompt")
294
+
295
+ prompt_input = gr.Textbox(
296
+ label="Label Prompt",
297
+ lines=2,
298
+ )
299
+
300
+ submit_btn = gr.Button("Submit", variant="primary")
301
+
302
+ examples = gr.Examples(
303
+ examples=EXAMPLES,
304
+ inputs=[img_input_draw, prompt_input],
305
+ label="Click to load example",
306
+ examples_per_page=5
307
+ )
308
+
309
+ with gr.Column():
310
+ gr.Markdown("### SAM3 Result")
311
+ with gr.Accordion("SAM3 Masks & Boxes", open=False):
312
+ sam3_detection_output = gr.Image(label="SAM3 Visualization", height=300)
313
+
314
+ sam3_final_output = gr.AnnotatedImage(label="SAM3 Detections", height=400)
315
+ # sam3_json_output = gr.JSON(label="SAM3 Output Data")
316
+
317
+ with gr.Column():
318
+ gr.Markdown("### VLM-FO1 Result")
319
+ with gr.Accordion("Bboxes Proposals", open=False):
320
+ upn_detection_output = gr.Image(label="Bboxes", height=300)
321
+
322
+ upn_final_output = gr.AnnotatedImage(label="VLM-FO1 Final", height=400)
323
+ upn_json_output = gr.JSON(label="VLM-FO1 Details")
324
+
325
+ submit_btn.click(
326
+ update_btn,
327
+ inputs=[gr.State(True)],
328
+ outputs=[submit_btn],
329
+ queue=False
330
+ ).then(
331
+ process,
332
+ inputs=[img_input_draw, prompt_input],
333
+ outputs=[
334
+ sam3_final_output, sam3_detection_output,
335
+ upn_final_output, upn_detection_output, upn_json_output
336
+ ],
337
+ queue=True
338
+ ).then(
339
+ update_btn,
340
+ inputs=[gr.State(False)],
341
+ outputs=[submit_btn],
342
+ queue=False
343
+ )
344
+
345
+ return demo
346
+
347
+ if __name__ == "__main__":
348
+ import os
349
+ exit_code = os.system(f"wget -c https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/sam3.pt")
350
+
351
+ model_path = 'omlab/VLM-FO1_Qwen2.5-VL-3B-v01'
352
+ # sam3_model_path = './resources/sam3/sam3.pt'
353
+ upn_ckpt_path = "./resources/upn_large.pth"
354
+
355
+ # Load FO1
356
+ tokenizer, model, image_processors = load_pretrained_model(
357
+ model_path=model_path,
358
+ device="cuda:0",
359
+ )
360
+
361
+ # Load SAM3
362
+ sam3_model = build_sam3_image_model(checkpoint_path='./sam3.pt', device="cuda",bpe_path='/home/user/app/resources/bpe_simple_vocab_16e6.txt.gz')
363
+ sam3_processor = Sam3Processor(sam3_model, confidence_threshold=0.0, device="cuda")
364
+
365
+ # Load UPN
366
+ upn_model = UPNWrapper(upn_ckpt_path)
367
+
368
+ demo = launch_demo()
369
+ demo.launch()
demo/sam3_examples/init.py ADDED
File without changes
detect_tools/upn/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import models
2
+ from .builder import (
3
+ ARCHITECTURES,
4
+ BACKBONES,
5
+ DECODERS,
6
+ ENCODERS,
7
+ POS_EMBEDDINGS,
8
+ build_architecture,
9
+ build_backbone,
10
+ build_decoder,
11
+ build_encoder,
12
+ build_position_embedding,
13
+ )
14
+ from .inference_wrapper import UPNWrapper
15
+ from .models.architecture import *
16
+ from .models.backbone import *
17
+ from .models.decoder import *
18
+ from .models.encoder import *
19
+ from .models.module import *
20
+ from .models.utils import *
21
+
22
+ __all__ = [
23
+ "BACKBONES",
24
+ "POS_EMBEDDINGS",
25
+ "ENCODERS",
26
+ "DECODERS",
27
+ "ARCHITECTURES",
28
+ "build_backbone",
29
+ "build_position_embedding",
30
+ "build_encoder",
31
+ "build_decoder",
32
+ "build_architecture",
33
+ "UPNWrapper",
34
+ ]
35
+
36
+ __all__ += (
37
+ models.module.__all__
38
+ + models.utils.__all__
39
+ + models.architecture.__all__
40
+ + models.backbone.__all__
41
+ + models.encoder.__all__
42
+ + models.decoder.__all__
43
+ + models.module.__all__
44
+ + models.utils.__all__
45
+ )
detect_tools/upn/builder.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine import Registry, build_from_cfg
2
+
3
+ BACKBONES = Registry("backbone")
4
+ POS_EMBEDDINGS = Registry("position_embedding")
5
+ FUSERS = Registry("fuser")
6
+ ENCODERS = Registry("encoder")
7
+ DECODERS = Registry("decoder")
8
+ ARCHITECTURES = Registry("architecture")
9
+
10
+
11
+ def build_backbone(cfg):
12
+ """Build encoder."""
13
+ return build_from_cfg(cfg, BACKBONES)
14
+
15
+
16
+ def build_position_embedding(cfg):
17
+ """Build position embedding."""
18
+ return build_from_cfg(cfg, POS_EMBEDDINGS)
19
+
20
+
21
+ def build_fuser(cfg):
22
+ """Build fuser."""
23
+ return build_from_cfg(cfg, FUSERS)
24
+
25
+
26
+ def build_encoder(cfg):
27
+ """Build encoder."""
28
+ return build_from_cfg(cfg, ENCODERS)
29
+
30
+
31
+ def build_decoder(cfg):
32
+ """Build decoder."""
33
+ return build_from_cfg(cfg, DECODERS)
34
+
35
+
36
+ def build_architecture(cfg):
37
+ """Build architecture."""
38
+
39
+ return build_from_cfg(cfg, ARCHITECTURES)
detect_tools/upn/configs/upn_large.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_cfg = dict(
2
+ type="DeformableTransformer",
3
+ num_queries=900,
4
+ encoder_cfg=dict(
5
+ type="UPNEncoder",
6
+ encoder_layer_cfg=dict(
7
+ type="DeformableTransformerEncoderLayer",
8
+ activation="relu",
9
+ d_model=256,
10
+ dropout=0.0,
11
+ d_ffn=2048,
12
+ n_heads=8,
13
+ n_levels=5,
14
+ ),
15
+ d_model=256,
16
+ num_layers=6,
17
+ use_checkpoint=False,
18
+ use_transformer_ckpt=False,
19
+ ),
20
+ decoder_cfg=dict(
21
+ type="UPNDecoder",
22
+ decoder_layer_cfg=dict(
23
+ type="DeformableTransformerDecoderLayer",
24
+ activation="relu",
25
+ d_model=256,
26
+ n_heads=8,
27
+ dropout=0.0,
28
+ d_ffn=2048,
29
+ n_levels=5,
30
+ ),
31
+ d_model=256,
32
+ return_intermediate=True,
33
+ num_layers=6,
34
+ rm_dec_query_scale=True,
35
+ use_detached_boxes_dec_out=False,
36
+ ),
37
+ learnable_tgt_init=True,
38
+ random_refpoints_xy=False,
39
+ num_feature_levels=5,
40
+ two_stage_bbox_embed_share=False,
41
+ two_stage_class_embed_share=False,
42
+ two_stage_keep_all_tokens=False,
43
+ two_stage_learn_wh=False,
44
+ two_stage_type="standard",
45
+ binary_query_selection=False,
46
+ )
47
+
48
+ vision_backbone = dict(
49
+ type="SwinWrapper",
50
+ backbone_cfg="swin_L_384_22k",
51
+ lr_backbone=1e-05,
52
+ dilation=False,
53
+ return_interm_indices=[0, 1, 2, 3],
54
+ backbone_freeze_keywords=None,
55
+ backbone_ckpt_path=None,
56
+ use_checkpoint=False,
57
+ position_embedding_cfg=dict(
58
+ type="PositionEmbeddingSineHW",
59
+ normalize=True,
60
+ num_pos_feats=128,
61
+ temperatureH=20,
62
+ temperatureW=20,
63
+ ),
64
+ )
65
+
66
+ model = dict(
67
+ type="UPN",
68
+ vision_backbone_cfg=vision_backbone,
69
+ transformer_cfg=transformer_cfg,
70
+ num_queries=900,
71
+ dec_pred_bbox_embed_share=True,
72
+ dec_pred_class_embed_share=True,
73
+ )
detect_tools/upn/inference_wrapper.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Dict, List, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from mmengine import Config
8
+ from PIL import Image
9
+ from torchvision.ops import nms
10
+
11
+ import detect_tools.upn.transforms.transform as T
12
+ from detect_tools.upn import build_architecture
13
+ from detect_tools.upn.models.module import nested_tensor_from_tensor_list
14
+
15
+
16
+ def build_model(
17
+ ckpt_path: str,
18
+ ):
19
+ current_path = os.path.dirname(os.path.abspath(__file__))
20
+ config_file = f"configs/upn_large.py"
21
+ config_path = os.path.join(current_path, config_file)
22
+ model_cfg = Config.fromfile(config_path).model
23
+ model = build_architecture(model_cfg)
24
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
25
+ model.load_state_dict(checkpoint["model"], strict=False)
26
+ return model
27
+
28
+
29
+ class UPNWrapper:
30
+ """A wrapper class for the UPN model.
31
+
32
+ Args:
33
+ ckpt_path (str): The path to the model checkpoint
34
+ """
35
+
36
+ def __init__(self, ckpt_path: str):
37
+
38
+ self.model = build_model(ckpt_path)
39
+ self.model.eval()
40
+ self.model.to("cuda")
41
+
42
+ def inference(
43
+ self,
44
+ image: List[Union[str, Image.Image]],
45
+ prompt_type: str = 'fine_grained_prompt',
46
+ ):
47
+ """Single image prediction.
48
+
49
+ Args:
50
+ image List[Union[str, Image.Image]]: A list of image path or
51
+ PIL.Image.Image object.
52
+ prompt_type (str): The type of prompt to use for the prediction. Choice in
53
+ ['fine_grained_prompt', 'coarse_grained_prompt'].
54
+
55
+ Returns:
56
+ Dict: Return dict in format:
57
+ {
58
+ "original_xyxy_boxes": (np.ndarray): Original prediction boxes in shape (batch_size, 900, 4),
59
+ "scores": (np.ndarray): Score in shape (batch_size, N)
60
+ }
61
+ """
62
+ if not isinstance(image, list):
63
+ image = [image]
64
+ input_images, image_sizes = self.construct_input(image)
65
+ outputs = self._inference(input_images, prompt_type)
66
+ post_processed_outputs = self.postprocess(outputs, image_sizes)
67
+ return post_processed_outputs
68
+
69
+ def _inference(self, input_images: List[torch.Tensor], prompt_type: str):
70
+ """Inference for T-Rex2
71
+
72
+ Args:
73
+ input_images (List[torch.Tensor]): Transformed Image
74
+
75
+ Retunrs:
76
+ (Dict): Return dict with keys:
77
+ - query_features: (torch.Tensor): Query features in shape (batch_size, N, 256)
78
+ - pred_boxes: (torch.Tensor): Normalized prediction boxes in shape (batch_size, N, 4),
79
+ in cxcywh format
80
+ """
81
+ input_images = nested_tensor_from_tensor_list(input_images)
82
+ input_images = input_images.to("cuda")
83
+ with torch.no_grad():
84
+ outputs = self.model(input_images, prompt_type)
85
+ return outputs
86
+
87
+ def construct_input(self, image: List[Union[str, Image.Image]]):
88
+ """Construct input for the model
89
+
90
+ Args:
91
+ image (image: Union[List[Union[str, Image.Image]], torch.Tensor]): A list of image path or
92
+ PIL.Image.Image object. If the length of the list is more than 1, the model w`ill
93
+ perform batch inference.
94
+
95
+ Returns:
96
+ Tuple[torch.Tensor, List[List[int]]]: A tuple containing the
97
+ input images, and the sizes of the input images.
98
+ """
99
+ input_images = []
100
+ image_sizes = []
101
+ for _, img in enumerate(image):
102
+ if isinstance(img, str):
103
+ img = Image.open(img)
104
+ elif isinstance(img, Image.Image):
105
+ img = img
106
+ else:
107
+ raise ValueError(
108
+ "image must be either a string or a PIL.Image.Image object"
109
+ )
110
+ W, H = img.size
111
+ image_sizes.append([H, W])
112
+ # add image in tensor format
113
+ input_images.append(self.transform_image(img))
114
+ return input_images, image_sizes
115
+
116
+ def transform_image(self, image_pil: Image) -> Image:
117
+ """apply a set of transformations to a cv2 load image.
118
+
119
+ Args:
120
+ image_path (str): The path to the image file.
121
+
122
+ Returns:
123
+ Tuple[PIL.Image, torch.Tensor]: A tuple containing the original PIL Image and the
124
+ transformed image as a PyTorch tensor.
125
+ """
126
+ transform = T.Compose(
127
+ [
128
+ T.RandomResize([800], max_size=1333),
129
+ T.ToTensor(),
130
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
131
+ ]
132
+ )
133
+ transformed_image, _ = transform(image_pil, None) # 3, h, w
134
+ return transformed_image
135
+
136
+ def postprocess(
137
+ self,
138
+ outputs: Dict[str, torch.Tensor],
139
+ image_pil_sizes: List[List[int]] = None,
140
+ ):
141
+ boxes = outputs["pred_boxes"].cpu()
142
+ scores = (
143
+ outputs["pred_logits"].sigmoid().cpu() if "pred_logits" in outputs else None
144
+ )
145
+ normalized_xyxy_boxes = []
146
+ original_xyxy_boxes = []
147
+ for batch_idx, (H, W) in enumerate(image_pil_sizes):
148
+ batch_boxes = boxes[batch_idx] # (num_queries, 4)
149
+ # from (cx, cy, w, h) to (x1, y1, x2, y2)
150
+ batch_boxes[:, 0] = batch_boxes[:, 0] - batch_boxes[:, 2] / 2
151
+ batch_boxes[:, 1] = batch_boxes[:, 1] - batch_boxes[:, 3] / 2
152
+ batch_boxes[:, 2] = batch_boxes[:, 0] + batch_boxes[:, 2]
153
+ batch_boxes[:, 3] = batch_boxes[:, 1] + batch_boxes[:, 3]
154
+ normalized_xyxy_boxes.append(copy.deepcopy(batch_boxes))
155
+ # scale boxes
156
+ original_boxes = (
157
+ batch_boxes.clone()
158
+ ) # Copy the normalized boxes to scale to original sizes
159
+ original_boxes[:, 0] = original_boxes[:, 0] * W
160
+ original_boxes[:, 1] = original_boxes[:, 1] * H
161
+ original_boxes[:, 2] = original_boxes[:, 2] * W
162
+ original_boxes[:, 3] = original_boxes[:, 3] * H
163
+ original_xyxy_boxes.append(original_boxes)
164
+
165
+ original_xyxy_boxes = torch.stack(original_xyxy_boxes)
166
+ original_xyxy_boxes = original_xyxy_boxes.numpy()
167
+
168
+ # sort everything by score from highest to lowest
169
+ sorted_original_boxes = []
170
+ sorted_scores = []
171
+ for i in range(len(normalized_xyxy_boxes)):
172
+ scores_i = scores[i] if scores is not None else None
173
+ # sort in descending order
174
+ sorted_indices = scores_i.squeeze(-1).argsort(descending=True)
175
+ sorted_original_boxes.append(original_xyxy_boxes[i][sorted_indices])
176
+ sorted_scores.append(scores_i[sorted_indices])
177
+
178
+ original_xyxy_boxes = np.stack(sorted_original_boxes)
179
+ scores = torch.stack(sorted_scores)
180
+
181
+ return dict(
182
+ original_xyxy_boxes=original_xyxy_boxes,
183
+ scores=scores,
184
+ )
185
+
186
+ def filter(self, result: Dict, min_score: float, nms_value: float = 0.8):
187
+ """Filter the UPN detection result. Only keep boxes with score above min_score
188
+ and apply Non-Maximum Suppression (NMS) to filter overlapping boxes.
189
+
190
+ Args:
191
+ result (Dict): A dictionary containing detection results with 'original_xyxy_boxes' and 'scores'.
192
+ min_score (float): Minimum score threshold for keeping a box.
193
+ nms_value (float): NMS threshold for filtering boxes.
194
+
195
+ Returns:
196
+ Dict: Filtered result containing 'original_xyxy_boxes' and 'scores' with the filtered boxes.
197
+ """
198
+ filtered_result = {"original_xyxy_boxes": [], "scores": []}
199
+
200
+ for boxes, scores in zip(
201
+ np.array(result["original_xyxy_boxes"]), result["scores"].numpy()
202
+ ):
203
+ # Filter out boxes with score below min_score
204
+ keep = scores >= min_score
205
+ boxes = boxes[keep[:, 0]]
206
+ scores = scores[keep[:, 0]][:, 0]
207
+
208
+ if len(boxes) == 0:
209
+ return filtered_result
210
+
211
+ # Convert to torch tensors
212
+ boxes = torch.tensor(boxes, dtype=torch.float32)
213
+ scores = torch.tensor(scores, dtype=torch.float32)
214
+
215
+ # Apply Non-Maximum Suppression (NMS)
216
+ if nms_value > 0:
217
+ keep_indices = nms(boxes, scores, nms_value)
218
+ else:
219
+ keep_indices = torch.arange(len(boxes))
220
+
221
+ # Keep only the boxes that passed NMS
222
+ filtered_boxes = boxes[keep_indices].numpy().astype(np.int32)
223
+ filtered_scores = scores[keep_indices].numpy()
224
+
225
+ # Sort the boxes by score in descending order
226
+ sorted_indices = np.argsort(filtered_scores)[::-1]
227
+ filtered_boxes = filtered_boxes[sorted_indices]
228
+ filtered_scores = filtered_scores[sorted_indices]
229
+
230
+ # Round the scores to two decimal places
231
+ filtered_scores = [round(score, 2) for score in filtered_scores]
232
+
233
+ # Store the filtered boxes and scores in the result dictionary
234
+ filtered_result["original_xyxy_boxes"].append(filtered_boxes.tolist())
235
+ filtered_result["scores"].append(filtered_scores)
236
+
237
+ return filtered_result
detect_tools/upn/models/architecture/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .deformable_transformer import DeformableTransformer
2
+ from .upn_model import UPN
3
+
4
+ __all__ = ["UPN", "DeformableTransformer"]
detect_tools/upn/models/architecture/deformable_transformer.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from detect_tools.upn import ARCHITECTURES, build_decoder, build_encoder
8
+ from detect_tools.upn.models.utils import (gen_encoder_output_proposals,
9
+ inverse_sigmoid)
10
+ from detect_tools.upn.ops.modules import MSDeformAttn
11
+
12
+
13
+ @ARCHITECTURES.register_module()
14
+ class DeformableTransformer(nn.Module):
15
+ """Implementation of Deformable DETR.
16
+
17
+ Args:
18
+ encoder_cfg (Dict): Config for the TransformerEncoder.
19
+ decoder_cfg (Dict): Config for the TransformerDecoder.
20
+ num_queries (int): Number of queries. This is for matching part. Default: 900.
21
+ d_model (int): Dimension of the model. Default: 256.
22
+ num_feature_levels (int): Number of feature levels. Default: 1.
23
+ binary_query_selection (bool): Whether to use binary query selection. Default: False.
24
+ When using binary query selection, a linear with out channe =1 will be used to select
25
+ topk proposals. Otherwise, we will use ContrastiveAssign to select topk proposals.
26
+ learnable_tgt_init (bool): Whether to use learnable target init. Default: True. If False,
27
+ we will use the topk encoder features as the target init.
28
+ random_refpoints_xy (bool): Whether to use random refpoints xy. This is only used when
29
+ two_stage_type is not 'no'. Default: False. If True, we will use random refpoints xy.
30
+ two_stage_type (str): Type of two stage. Default: 'standard'. Options: 'no', 'standard'
31
+ two_stage_learn_wh (bool): Whether to learn the width and height of anchor boxes. Default: False.
32
+ two_stage_keep_all_tokens (bool): If False, the returned hs_enc, ref_enc, init_box_proposal
33
+ will only be the topk proposals. Otherwise, we will return all the proposals from the
34
+ encoder. Default: False.
35
+ two_stage_bbox_embed_share (bool): Whether to share the bbox embedding between the two stages.
36
+ Default: False.
37
+ two_stage_class_embed_share (bool): Whether to share the class embedding between the two stages.
38
+ rm_self_attn_layers (List[int]): The indices of the decoder layers to remove self-attention.
39
+ Default: None.
40
+ rm_detach (bool): Whether to detach the decoder output. Default: None.
41
+ embed_init_tgt (bool): If true, the target embedding is learnable. Otherwise, we will use
42
+ the topk encoder features as the target init. Default: True.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ encoder_cfg: Dict,
48
+ decoder_cfg: Dict,
49
+ mask_decoder_cfg: Dict = None,
50
+ num_queries: int = 900,
51
+ d_model: int = 256,
52
+ num_feature_levels: int = 4,
53
+ binary_query_selection: bool = False,
54
+ # init query (target)
55
+ learnable_tgt_init=True,
56
+ random_refpoints_xy=False,
57
+ # for two stage
58
+ two_stage_type: str = "standard",
59
+ two_stage_learn_wh: bool = False,
60
+ two_stage_keep_all_tokens: bool = False,
61
+ two_stage_bbox_embed_share: bool = False,
62
+ two_stage_class_embed_share: bool = False,
63
+ # evo of #anchors
64
+ rm_self_attn_layers: List[int] = None,
65
+ # for detach
66
+ rm_detach: bool = None,
67
+ with_encoder_out: bool = True,
68
+ ) -> None:
69
+ super().__init__()
70
+ self.binary_query_selection = binary_query_selection
71
+ self.num_queries = num_queries
72
+ self.num_feature_levels = num_feature_levels
73
+ self.rm_self_attn_layers = rm_self_attn_layers
74
+ self.d_model = d_model
75
+ self.two_stage_bbox_embed_share = two_stage_bbox_embed_share
76
+ self.two_stage_class_embed_share = two_stage_class_embed_share
77
+
78
+ if self.binary_query_selection:
79
+ self.binary_query_selection_layer = nn.Linear(d_model, 1)
80
+
81
+ # build encoder
82
+ self.encoder = build_encoder(encoder_cfg)
83
+
84
+ # build decoder
85
+ self.decoder = build_decoder(decoder_cfg)
86
+ self.num_decoder_layers = self.decoder.num_layers
87
+
88
+ # build sole mask decoder
89
+ if mask_decoder_cfg is not None:
90
+ self.mask_decoder = build_decoder(mask_decoder_cfg)
91
+ else:
92
+ self.mask_decoder = None
93
+ # level embedding
94
+ if num_feature_levels > 1:
95
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
96
+
97
+ # learnable target embedding
98
+ self.learnable_tgt_init = learnable_tgt_init
99
+ assert learnable_tgt_init, "why not learnable_tgt_init"
100
+
101
+ self.tgt_embed = nn.Embedding(num_queries, d_model)
102
+ nn.init.normal_(self.tgt_embed.weight.data)
103
+
104
+ # for two stage
105
+ # TODO: this part is really confusing
106
+ self.two_stage_type = two_stage_type
107
+ self.two_stage_learn_wh = two_stage_learn_wh
108
+ self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
109
+ assert two_stage_type in [
110
+ "no",
111
+ "standard",
112
+ ], "unknown param {} of two_stage_type".format(two_stage_type)
113
+ self.with_encoder_out = with_encoder_out
114
+ if two_stage_type == "standard":
115
+ # anchor selection at the output of encoder
116
+ if with_encoder_out:
117
+ self.enc_output = nn.Linear(d_model, d_model)
118
+ self.enc_output_norm = nn.LayerNorm(d_model)
119
+
120
+ if two_stage_learn_wh:
121
+ # import ipdb; ipdb.set_trace()
122
+ self.two_stage_wh_embedding = nn.Embedding(1, 2)
123
+ else:
124
+ self.two_stage_wh_embedding = None
125
+
126
+ elif two_stage_type == "no":
127
+ self.init_ref_points(
128
+ num_queries, random_refpoints_xy
129
+ ) # init self.refpoint_embed
130
+
131
+ self.enc_out_class_embed = None # this will be initialized outside of the model
132
+ self.enc_out_bbox_embed = None # this will be initialized outside of the model
133
+
134
+ # remove some self_attn_layers or rm_detach
135
+ self._reset_parameters()
136
+
137
+ self.rm_self_attn_layers = rm_self_attn_layers
138
+ if rm_self_attn_layers is not None:
139
+ # assert len(rm_self_attn_layers) == num_decoder_layers
140
+ print(
141
+ "Removing the self-attn in {} decoder layers".format(
142
+ rm_self_attn_layers
143
+ )
144
+ )
145
+ for lid, dec_layer in enumerate(self.decoder.layers):
146
+ if lid in rm_self_attn_layers:
147
+ dec_layer.rm_self_attn_modules()
148
+
149
+ self.rm_detach = rm_detach
150
+ if self.rm_detach:
151
+ assert isinstance(rm_detach, list)
152
+ assert any([i in ["enc_ref", "enc_tgt", "dec"] for i in rm_detach])
153
+ self.decoder.rm_detach = rm_detach
154
+
155
+ def _reset_parameters(self):
156
+ for p in self.parameters():
157
+ if p.dim() > 1:
158
+ nn.init.xavier_uniform_(p)
159
+ for m in self.modules():
160
+ if isinstance(m, MSDeformAttn):
161
+ m._reset_parameters()
162
+ if self.num_feature_levels > 1 and self.level_embed is not None:
163
+ nn.init.normal_(self.level_embed)
164
+
165
+ if self.two_stage_learn_wh:
166
+ nn.init.constant_(
167
+ self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05))
168
+ )
169
+
170
+ def init_ref_points(self, num_queries: int, random_refpoints_xy: bool = False):
171
+ """Initialize learnable reference points for each query.
172
+
173
+ Args:
174
+ num_queries (int): number of queries
175
+ random_refpoints_xy (bool, optional): whether to init the refpoints randomly.
176
+ Defaults to False.
177
+ """
178
+ self.refpoint_embed = nn.Embedding(num_queries, 4)
179
+ if random_refpoints_xy:
180
+ self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
181
+ self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
182
+ self.refpoint_embed.weight.data[:, :2]
183
+ )
184
+ self.refpoint_embed.weight.data[:, :2].requires_grad = False
185
+
186
+ def get_valid_ratio(self, mask):
187
+ _, H, W = mask.shape
188
+ valid_H = torch.sum(~mask[:, :, 0], 1)
189
+ valid_W = torch.sum(~mask[:, 0, :], 1)
190
+ valid_ratio_h = valid_H.float() / H
191
+ valid_ratio_w = valid_W.float() / W
192
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
193
+ return valid_ratio
194
+
195
+ def forward(
196
+ self,
197
+ src_flatten: torch.Tensor,
198
+ lvl_pos_embed_flatten: torch.Tensor,
199
+ level_start_index: List[int],
200
+ spatial_shapes: List[torch.Tensor],
201
+ valid_ratios: List[torch.Tensor],
202
+ mask_flatten: torch.Tensor,
203
+ prompt_type: str,
204
+ ) -> List[torch.Tensor]:
205
+ """Forward function."""
206
+ memory = self.encoder(
207
+ src_flatten,
208
+ pos=lvl_pos_embed_flatten,
209
+ level_start_index=level_start_index,
210
+ spatial_shapes=spatial_shapes,
211
+ valid_ratios=valid_ratios,
212
+ key_padding_mask=mask_flatten,
213
+ )
214
+ batch_size = src_flatten.shape[0]
215
+ crop_region_features = torch.zeros(batch_size, 1, self.d_model).to(
216
+ memory.device
217
+ )
218
+ if prompt_type == "fine_grained_prompt":
219
+ crop_region_features = (
220
+ self.fine_grained_prompt.weight[0]
221
+ .unsqueeze(0)
222
+ .unsqueeze(0)
223
+ .repeat(batch_size, 1, 1)
224
+ )
225
+ elif prompt_type == "coarse_grained_prompt":
226
+ crop_region_features = (
227
+ self.coarse_grained_prompt.weight[0]
228
+ .unsqueeze(0)
229
+ .unsqueeze(0)
230
+ .repeat(batch_size, 1, 1)
231
+ )
232
+ pad_mask = torch.ones(batch_size, 1).to(crop_region_features.device).bool()
233
+ self_attn_mask = torch.ones(batch_size, 1, 1).to(crop_region_features.device)
234
+ ref_dict = dict(
235
+ encoded_ref_feature=crop_region_features,
236
+ pad_mask=pad_mask,
237
+ self_attn_mask=self_attn_mask,
238
+ prompt_type="universal_prompt",
239
+ )
240
+
241
+ (
242
+ refpoint_embed,
243
+ tgt,
244
+ init_box_proposal,
245
+ ) = self.get_two_stage_proposal(memory, mask_flatten, spatial_shapes, ref_dict)
246
+
247
+ hs, references = self.decoder(
248
+ tgt=tgt.transpose(0, 1),
249
+ tgt_key_padding_mask=None,
250
+ memory=memory.transpose(0, 1),
251
+ memory_key_padding_mask=mask_flatten,
252
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
253
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
254
+ level_start_index=level_start_index,
255
+ spatial_shapes=spatial_shapes,
256
+ valid_ratios=valid_ratios,
257
+ tgt_mask=None,
258
+ # we ~ the mask . False means use the token; True means pad the token
259
+ )
260
+ hs_enc = ref_enc = None
261
+ return (
262
+ hs,
263
+ references,
264
+ ref_dict,
265
+ )
266
+
267
+ def get_two_stage_proposal(
268
+ self,
269
+ memory: torch.Tensor,
270
+ mask_flatten: torch.Tensor,
271
+ spatial_shapes: List[torch.Tensor],
272
+ ref_dict: Dict,
273
+ ) -> List[torch.Tensor]:
274
+ """Two stage proposal generation for decoder
275
+
276
+ Args:
277
+ memory (torch.Tensor): Image encoded feature. [bs, n, 256]
278
+ mask_flatten (torch.Tensor): Flattened mask. [bs, n]
279
+ spatial_shapes (List[torch.Tensor]): Spatial shapes of each feature map. [bs, num_levels, 2]
280
+ refpoint_embed_dn (torch.Tensor): Denosing refpoint embedding. [bs, num_dn_queries, 256]
281
+ tgt_dn (torch.Tensor): Denosing target embedding. [bs, num_dn_queries, 256]
282
+ ref_dict (Dict): A dict containing all kinds of reference image related features.
283
+ """
284
+ bs = memory.shape[0]
285
+ input_hw = None
286
+ output_memory, output_proposals = gen_encoder_output_proposals(
287
+ memory, mask_flatten, spatial_shapes, input_hw
288
+ )
289
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
290
+
291
+ if self.binary_query_selection: # Unused
292
+ topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1)
293
+ else:
294
+ if ref_dict is not None:
295
+ enc_outputs_class_unselected = self.enc_out_class_embed(
296
+ output_memory, ref_dict
297
+ ) # this is not a linear layer for prediction. But contrastive similaryity, shape [B, len_image, len_text]
298
+ else:
299
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
300
+ topk_logits = enc_outputs_class_unselected.max(-1)[
301
+ 0
302
+ ] # shape [B, len_image]
303
+ enc_outputs_coord_unselected = (
304
+ self.enc_out_bbox_embed(output_memory) + output_proposals
305
+ ) # (bs, \sum{hw}, 4) unsigmoid
306
+ topk = self.num_queries
307
+
308
+ try:
309
+ topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
310
+ except:
311
+ raise ValueError(f"dadad {topk_logits.shape}")
312
+
313
+ # gather boxes
314
+ refpoint_embed_undetach = torch.gather(
315
+ enc_outputs_coord_unselected,
316
+ 1,
317
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
318
+ ) # unsigmoid
319
+ refpoint_embed_ = refpoint_embed_undetach.detach()
320
+ init_box_proposal = torch.gather(
321
+ output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
322
+ ).sigmoid() # sigmoid
323
+ # gather tgt
324
+ tgt_undetach = torch.gather(
325
+ output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
326
+ )
327
+ tgt_ = (
328
+ self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
329
+ ) # nq, bs, d_model
330
+ refpoint_embed, tgt = refpoint_embed_, tgt_
331
+
332
+ return (
333
+ refpoint_embed,
334
+ tgt,
335
+ init_box_proposal,
336
+ )
detect_tools/upn/models/architecture/upn_model.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Dict, List, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from detect_tools.upn import ARCHITECTURES, build_architecture, build_backbone
9
+ from detect_tools.upn.models.module import (MLP, ContrastiveAssign, NestedTensor,
10
+ nested_tensor_from_tensor_list)
11
+ from detect_tools.upn.models.utils import inverse_sigmoid
12
+
13
+
14
+ class LayerNorm2d(nn.Module):
15
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(num_channels))
18
+ self.bias = nn.Parameter(torch.zeros(num_channels))
19
+ self.eps = eps
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ u = x.mean(1, keepdim=True)
23
+ s = (x - u).pow(2).mean(1, keepdim=True)
24
+ x = (x - u) / torch.sqrt(s + self.eps)
25
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
26
+ return x
27
+
28
+
29
+ @ARCHITECTURES.register_module()
30
+ class UPN(nn.Module):
31
+ """Implementation of UPN"""
32
+
33
+ def __init__(
34
+ self,
35
+ vision_backbone_cfg: Dict,
36
+ transformer_cfg: Dict,
37
+ num_queries: int,
38
+ dec_pred_class_embed_share=True,
39
+ dec_pred_bbox_embed_share=True,
40
+ decoder_sa_type="sa",
41
+ ):
42
+ super().__init__()
43
+ # build vision backbone
44
+ self.backbone = build_backbone(vision_backbone_cfg)
45
+ # build transformer
46
+ self.transformer = build_architecture(transformer_cfg)
47
+
48
+ self.hidden_dim = self.transformer.d_model
49
+
50
+ # for dn training
51
+ self.num_queries = num_queries
52
+ self.num_feature_levels = self.transformer.num_feature_levels
53
+
54
+ # prepare projection layer for vision feature
55
+ self.input_proj = self.prepare_vision_feature_projection_layer(
56
+ self.backbone,
57
+ self.transformer.num_feature_levels,
58
+ self.hidden_dim,
59
+ self.transformer.two_stage_type,
60
+ )
61
+ # prepare prediction head
62
+ self.prepare_prediction_head(
63
+ dec_pred_class_embed_share,
64
+ dec_pred_bbox_embed_share,
65
+ self.hidden_dim,
66
+ self.transformer.num_decoder_layers,
67
+ )
68
+
69
+ self.decoder_sa_type = decoder_sa_type
70
+ assert decoder_sa_type in ["sa", "ca_label", "ca_content"]
71
+ # self.replace_sa_with_double_ca = replace_sa_with_double_ca
72
+
73
+ for layer in self.transformer.decoder.layers:
74
+ layer.label_embedding = None
75
+ self.label_embedding = None
76
+
77
+ # build a unversal token
78
+ self.transformer.fine_grained_prompt = nn.Embedding(1, self.hidden_dim)
79
+ self.transformer.coarse_grained_prompt = nn.Embedding(1, self.hidden_dim)
80
+
81
+ self._reset_parameters()
82
+
83
+ def forward(self, samples: NestedTensor, prompt_type: str = None) -> Dict:
84
+ """Foward function"""
85
+ self.device = samples.device
86
+
87
+ (
88
+ src_flatten,
89
+ lvl_pos_embed_flatten,
90
+ level_start_index,
91
+ spatial_shapes,
92
+ valid_ratios,
93
+ mask_flatten,
94
+ ) = self.forward_backbone_encoder(samples)
95
+
96
+ (
97
+ hs,
98
+ reference,
99
+ ref_dict,
100
+ ) = self.transformer(
101
+ src_flatten,
102
+ lvl_pos_embed_flatten,
103
+ level_start_index,
104
+ spatial_shapes,
105
+ valid_ratios,
106
+ mask_flatten,
107
+ prompt_type,
108
+ )
109
+
110
+ # deformable-detr-line anchor update
111
+ outputs_coord_list = []
112
+ outputs_class = []
113
+
114
+ for layer_idx, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
115
+ zip(reference[:-1], self.bbox_embed, hs)
116
+ ):
117
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
118
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
119
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
120
+ outputs_coord_list.append(layer_outputs_unsig)
121
+
122
+ outputs_coord_list = torch.stack(outputs_coord_list)
123
+
124
+ if ref_dict is None:
125
+ # build a mock outputs_class for mask_dn training
126
+ outputs_class = torch.zeros(
127
+ outputs_coord_list.shape[0],
128
+ outputs_coord_list.shape[1],
129
+ outputs_coord_list.shape[2],
130
+ self.hidden_dim,
131
+ )
132
+ else:
133
+ outputs_class = torch.stack(
134
+ [
135
+ layer_cls_embed(layer_hs, ref_dict)
136
+ for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
137
+ ]
138
+ )
139
+
140
+ out = {
141
+ "pred_logits": outputs_class[-1],
142
+ "pred_boxes": outputs_coord_list[-1],
143
+ }
144
+ out["ref_dict"] = ref_dict
145
+ return out
146
+
147
+ def forward_backbone_encoder(self, samples: NestedTensor) -> Tuple:
148
+ # pass through backbone
149
+ if isinstance(samples, (list, torch.Tensor)):
150
+ samples = nested_tensor_from_tensor_list(samples)
151
+ features, poss = self.backbone(samples)
152
+ # project features
153
+ srcs = []
154
+ masks = []
155
+ for l, feat in enumerate(features):
156
+ src, mask = feat.decompose()
157
+ srcs.append(self.input_proj[l](src)) # downsample the feature map to 256
158
+ masks.append(mask)
159
+ assert mask is not None
160
+
161
+ if self.num_feature_levels > len(
162
+ srcs
163
+ ): # add more feature levels by downsampling the last feature map
164
+ _len_srcs = len(srcs)
165
+ for l in range(_len_srcs, self.num_feature_levels):
166
+ if l == _len_srcs:
167
+ src = self.input_proj[l](features[-1].tensors)
168
+ else:
169
+ src = self.input_proj[l](srcs[-1])
170
+ m = samples.mask
171
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
172
+ torch.bool
173
+ )[0]
174
+ pos_l = self.backbone.forward_pos_embed_only(
175
+ NestedTensor(src, mask)
176
+ ).to(src.dtype)
177
+ srcs.append(src)
178
+ masks.append(mask)
179
+ poss.append(pos_l)
180
+
181
+ # prepare input for encoder with the following steps:
182
+ # 1. flatten the feature maps and masks
183
+ # 2. Add positional embedding and level embedding
184
+ # 3. Calculate the valid ratio of each feature map based on the mask
185
+ src_flatten = []
186
+ mask_flatten = []
187
+ lvl_pos_embed_flatten = []
188
+ spatial_shapes = []
189
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, poss)):
190
+ bs, c, h, w = src.shape
191
+ spatial_shape = (h, w)
192
+ spatial_shapes.append(spatial_shape)
193
+
194
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
195
+ mask = mask.flatten(1) # bs, hw
196
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
197
+ if self.num_feature_levels > 1 and self.transformer.level_embed is not None:
198
+ lvl_pos_embed = pos_embed + self.transformer.level_embed[lvl].view(
199
+ 1, 1, -1
200
+ )
201
+ else:
202
+ lvl_pos_embed = pos_embed
203
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
204
+ src_flatten.append(src)
205
+ mask_flatten.append(mask)
206
+ src_flatten = torch.cat(src_flatten, 1)
207
+ mask_flatten = torch.cat(mask_flatten, 1)
208
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
209
+ spatial_shapes = torch.as_tensor(
210
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
211
+ )
212
+ level_start_index = torch.cat(
213
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
214
+ )
215
+ valid_ratios = torch.stack(
216
+ [self.transformer.get_valid_ratio(m) for m in masks], 1
217
+ )
218
+
219
+ return (
220
+ src_flatten,
221
+ lvl_pos_embed_flatten,
222
+ level_start_index,
223
+ spatial_shapes,
224
+ valid_ratios,
225
+ mask_flatten,
226
+ )
227
+
228
+ def prepare_vision_feature_projection_layer(
229
+ self,
230
+ backbone: nn.Module,
231
+ num_feature_levels: int,
232
+ hidden_dim: int,
233
+ two_stage_type: str,
234
+ ) -> nn.ModuleList:
235
+ """Prepare projection layer to map backbone feature to hidden dim.
236
+
237
+ Args:
238
+ backbone (nn.Module): Backbone.
239
+ num_feature_levels (int): Number of feature levels.
240
+ hidden_dim (int): Hidden dim.
241
+ two_stage_type (str): Type of two stage.
242
+
243
+ Returns:
244
+ nn.ModuleList: Projection layer.
245
+ """
246
+ if num_feature_levels > 1:
247
+ num_backbone_outs = len(backbone.num_channels)
248
+ input_proj_list = []
249
+ for _ in range(num_backbone_outs):
250
+ in_channels = backbone.num_channels[_]
251
+ input_proj_list.append(
252
+ nn.Sequential(
253
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
254
+ nn.GroupNorm(32, hidden_dim),
255
+ )
256
+ )
257
+ for _ in range(num_feature_levels - num_backbone_outs):
258
+ input_proj_list.append(
259
+ nn.Sequential(
260
+ nn.Conv2d(
261
+ in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
262
+ ),
263
+ nn.GroupNorm(32, hidden_dim),
264
+ )
265
+ )
266
+ in_channels = hidden_dim
267
+ input_proj = nn.ModuleList(input_proj_list)
268
+ else:
269
+ assert (
270
+ two_stage_type == "no"
271
+ ), "two_stage_type should be no if num_feature_levels=1 !!!"
272
+ input_proj = nn.ModuleList(
273
+ [
274
+ nn.Sequential(
275
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
276
+ nn.GroupNorm(32, hidden_dim),
277
+ )
278
+ ]
279
+ )
280
+ return input_proj
281
+
282
+ def prepare_prediction_head(
283
+ self,
284
+ dec_pred_class_embed_share: bool,
285
+ dec_pred_bbox_embed_share: bool,
286
+ hidden_dim: int,
287
+ num_decoder_layers: int,
288
+ ) -> Union[nn.ModuleList, nn.ModuleList]:
289
+ """Prepare prediction head. Including class embed and bbox embed.
290
+
291
+ Args:
292
+ dec_pred_class_embed_share (bool): Whether to share class embed for all decoder layers.
293
+ dec_pred_bbox_embed_share (bool): Whether to share bbox embed for all decoder layers.
294
+ im (int): Hidden dim.
295
+ num_decoder_layers (int): Number of decoder layers.
296
+
297
+ """
298
+ _class_embed = ContrastiveAssign()
299
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
300
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
301
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
302
+ if dec_pred_bbox_embed_share:
303
+ box_embed_layerlist = [_bbox_embed for _ in range(num_decoder_layers)]
304
+ else:
305
+ box_embed_layerlist = [
306
+ copy.deepcopy(_bbox_embed) for i in range(num_decoder_layers)
307
+ ]
308
+ if dec_pred_class_embed_share:
309
+ class_embed_layerlist = [_class_embed for i in range(num_decoder_layers)]
310
+ else:
311
+ class_embed_layerlist = [
312
+ copy.deepcopy(_class_embed) for i in range(num_decoder_layers)
313
+ ]
314
+ bbox_embed = nn.ModuleList(box_embed_layerlist)
315
+ class_embed = nn.ModuleList(class_embed_layerlist)
316
+ self.bbox_embed = bbox_embed
317
+ self.class_embed = class_embed
318
+
319
+ # iniitalize bbox embed and class embed in transformer
320
+ self.transformer.decoder.bbox_embed = bbox_embed
321
+ self.transformer.decoder.class_embed = class_embed
322
+
323
+ if self.transformer.two_stage_type != "no":
324
+ if self.transformer.two_stage_bbox_embed_share:
325
+ assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
326
+ self.transformer.enc_out_bbox_embed = _bbox_embed
327
+ else:
328
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
329
+
330
+ if self.transformer.two_stage_class_embed_share:
331
+ assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
332
+ self.transformer.enc_out_class_embed = _class_embed
333
+ else:
334
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
335
+
336
+ self.refpoint_embed = None
337
+
338
+ def _reset_parameters(self):
339
+ # init input_proj
340
+ for proj in self.input_proj:
341
+
342
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
343
+ nn.init.constant_(proj[0].bias, 0)
detect_tools/upn/models/backbone/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .swin import SwinTransformer
2
+ from .wrapper import SwinWrapper
3
+
4
+ __all__ = ["SwinWrapper", "SwinTransformer"]
detect_tools/upn/models/backbone/swin.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint as checkpoint
8
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
9
+
10
+ from detect_tools.upn import BACKBONES
11
+ from detect_tools.upn.models.module import NestedTensor
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ """Multilayer perceptron."""
16
+
17
+ def __init__(
18
+ self,
19
+ in_features,
20
+ hidden_features=None,
21
+ out_features=None,
22
+ act_layer=nn.GELU,
23
+ drop=0.0,
24
+ ):
25
+ super().__init__()
26
+ out_features = out_features or in_features
27
+ hidden_features = hidden_features or in_features
28
+ self.fc1 = nn.Linear(in_features, hidden_features)
29
+ self.act = act_layer()
30
+ self.fc2 = nn.Linear(hidden_features, out_features)
31
+ self.drop = nn.Dropout(drop)
32
+
33
+ def forward(self, x):
34
+ x = self.fc1(x)
35
+ x = self.act(x)
36
+ x = self.drop(x)
37
+ x = self.fc2(x)
38
+ x = self.drop(x)
39
+ return x
40
+
41
+
42
+ def window_partition(x, window_size):
43
+ """
44
+ Args:
45
+ x: (B, H, W, C)
46
+ window_size (int): window size
47
+ Returns:
48
+ windows: (num_windows*B, window_size, window_size, C)
49
+ """
50
+ B, H, W, C = x.shape
51
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
52
+ windows = (
53
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
54
+ )
55
+ return windows
56
+
57
+
58
+ def window_reverse(windows, window_size, H, W):
59
+ """
60
+ Args:
61
+ windows: (num_windows*B, window_size, window_size, C)
62
+ window_size (int): Window size
63
+ H (int): Height of image
64
+ W (int): Width of image
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(
70
+ B, H // window_size, W // window_size, window_size, window_size, -1
71
+ )
72
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
73
+ return x
74
+
75
+
76
+ class WindowAttention(nn.Module):
77
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
78
+ It supports both of shifted and non-shifted window.
79
+ Args:
80
+ dim (int): Number of input channels.
81
+ window_size (tuple[int]): The height and width of the window.
82
+ num_heads (int): Number of attention heads.
83
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
84
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
85
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
86
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ window_size,
93
+ num_heads,
94
+ qkv_bias=True,
95
+ qk_scale=None,
96
+ attn_drop=0.0,
97
+ proj_drop=0.0,
98
+ ):
99
+
100
+ super().__init__()
101
+ self.dim = dim
102
+ self.window_size = window_size # Wh, Ww
103
+ self.num_heads = num_heads
104
+ head_dim = dim // num_heads
105
+ self.scale = qk_scale or head_dim**-0.5
106
+
107
+ # define a parameter table of relative position bias
108
+ self.relative_position_bias_table = nn.Parameter(
109
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
110
+ ) # 2*Wh-1 * 2*Ww-1, nH
111
+
112
+ # get pair-wise relative position index for each token inside the window
113
+ coords_h = torch.arange(self.window_size[0])
114
+ coords_w = torch.arange(self.window_size[1])
115
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
116
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
117
+ relative_coords = (
118
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
119
+ ) # 2, Wh*Ww, Wh*Ww
120
+ relative_coords = relative_coords.permute(
121
+ 1, 2, 0
122
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
123
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
124
+ relative_coords[:, :, 1] += self.window_size[1] - 1
125
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
126
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
127
+ self.register_buffer("relative_position_index", relative_position_index)
128
+
129
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
130
+ self.attn_drop = nn.Dropout(attn_drop)
131
+ self.proj = nn.Linear(dim, dim)
132
+ self.proj_drop = nn.Dropout(proj_drop)
133
+
134
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
135
+ self.softmax = nn.Softmax(dim=-1)
136
+
137
+ def forward(self, x, mask=None):
138
+ """Forward function.
139
+ Args:
140
+ x: input features with shape of (num_windows*B, N, C)
141
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
142
+ """
143
+ B_, N, C = x.shape
144
+ qkv = (
145
+ self.qkv(x)
146
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
147
+ .permute(2, 0, 3, 1, 4)
148
+ )
149
+ q, k, v = (
150
+ qkv[0],
151
+ qkv[1],
152
+ qkv[2],
153
+ ) # make torchscript happy (cannot use tensor as tuple)
154
+
155
+ q = q * self.scale
156
+ attn = q @ k.transpose(-2, -1)
157
+
158
+ relative_position_bias = self.relative_position_bias_table[
159
+ self.relative_position_index.view(-1)
160
+ ].view(
161
+ self.window_size[0] * self.window_size[1],
162
+ self.window_size[0] * self.window_size[1],
163
+ -1,
164
+ ) # Wh*Ww,Wh*Ww,nH
165
+ relative_position_bias = relative_position_bias.permute(
166
+ 2, 0, 1
167
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
168
+ attn = attn + relative_position_bias.unsqueeze(0)
169
+
170
+ if mask is not None:
171
+ nW = mask.shape[0]
172
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
173
+ 1
174
+ ).unsqueeze(0)
175
+ attn = attn.view(-1, self.num_heads, N, N)
176
+ attn = self.softmax(attn)
177
+ else:
178
+ attn = self.softmax(attn)
179
+
180
+ attn = self.attn_drop(attn)
181
+
182
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
183
+ x = self.proj(x)
184
+ x = self.proj_drop(x)
185
+ return x
186
+
187
+
188
+ class SwinTransformerBlock(nn.Module):
189
+ """Swin Transformer Block.
190
+ Args:
191
+ dim (int): Number of input channels.
192
+ num_heads (int): Number of attention heads.
193
+ window_size (int): Window size.
194
+ shift_size (int): Shift size for SW-MSA.
195
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
196
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
197
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
198
+ drop (float, optional): Dropout rate. Default: 0.0
199
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
200
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
201
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
202
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ dim,
208
+ num_heads,
209
+ window_size=7,
210
+ shift_size=0,
211
+ mlp_ratio=4.0,
212
+ qkv_bias=True,
213
+ qk_scale=None,
214
+ drop=0.0,
215
+ attn_drop=0.0,
216
+ drop_path=0.0,
217
+ act_layer=nn.GELU,
218
+ norm_layer=nn.LayerNorm,
219
+ ):
220
+ super().__init__()
221
+ self.dim = dim
222
+ self.num_heads = num_heads
223
+ self.window_size = window_size
224
+ self.shift_size = shift_size
225
+ self.mlp_ratio = mlp_ratio
226
+ assert (
227
+ 0 <= self.shift_size < self.window_size
228
+ ), "shift_size must in 0-window_size"
229
+
230
+ self.norm1 = norm_layer(dim)
231
+ self.attn = WindowAttention(
232
+ dim,
233
+ window_size=to_2tuple(self.window_size),
234
+ num_heads=num_heads,
235
+ qkv_bias=qkv_bias,
236
+ qk_scale=qk_scale,
237
+ attn_drop=attn_drop,
238
+ proj_drop=drop,
239
+ )
240
+
241
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
242
+ self.norm2 = norm_layer(dim)
243
+ mlp_hidden_dim = int(dim * mlp_ratio)
244
+ self.mlp = Mlp(
245
+ in_features=dim,
246
+ hidden_features=mlp_hidden_dim,
247
+ act_layer=act_layer,
248
+ drop=drop,
249
+ )
250
+
251
+ self.H = None
252
+ self.W = None
253
+
254
+ def forward(self, x, mask_matrix):
255
+ """Forward function.
256
+ Args:
257
+ x: Input feature, tensor size (B, H*W, C).
258
+ H, W: Spatial resolution of the input feature.
259
+ mask_matrix: Attention mask for cyclic shift.
260
+ """
261
+ B, L, C = x.shape
262
+ H, W = self.H, self.W
263
+ assert L == H * W, "input feature has wrong size"
264
+
265
+ shortcut = x
266
+ x = self.norm1(x)
267
+ x = x.view(B, H, W, C)
268
+
269
+ # pad feature maps to multiples of window size
270
+ pad_l = pad_t = 0
271
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
272
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
273
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
274
+ _, Hp, Wp, _ = x.shape
275
+
276
+ # cyclic shift
277
+ if self.shift_size > 0:
278
+ shifted_x = torch.roll(
279
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
280
+ )
281
+ attn_mask = mask_matrix
282
+ else:
283
+ shifted_x = x
284
+ attn_mask = None
285
+
286
+ # partition windows
287
+ x_windows = window_partition(
288
+ shifted_x, self.window_size
289
+ ) # nW*B, window_size, window_size, C
290
+ x_windows = x_windows.view(
291
+ -1, self.window_size * self.window_size, C
292
+ ) # nW*B, window_size*window_size, C
293
+
294
+ # W-MSA/SW-MSA
295
+ attn_windows = self.attn(
296
+ x_windows, mask=attn_mask
297
+ ) # nW*B, window_size*window_size, C
298
+
299
+ # merge windows
300
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
301
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
302
+
303
+ # reverse cyclic shift
304
+ if self.shift_size > 0:
305
+ x = torch.roll(
306
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
307
+ )
308
+ else:
309
+ x = shifted_x
310
+
311
+ if pad_r > 0 or pad_b > 0:
312
+ x = x[:, :H, :W, :].contiguous()
313
+
314
+ x = x.view(B, H * W, C)
315
+
316
+ # FFN
317
+ x = shortcut + self.drop_path(x)
318
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
319
+
320
+ return x
321
+
322
+
323
+ class PatchMerging(nn.Module):
324
+ """Patch Merging Layer
325
+ Args:
326
+ dim (int): Number of input channels.
327
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
328
+ """
329
+
330
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
331
+ super().__init__()
332
+ self.dim = dim
333
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
334
+ self.norm = norm_layer(4 * dim)
335
+
336
+ def forward(self, x, H, W):
337
+ """Forward function.
338
+ Args:
339
+ x: Input feature, tensor size (B, H*W, C).
340
+ H, W: Spatial resolution of the input feature.
341
+ """
342
+ B, L, C = x.shape
343
+ assert L == H * W, "input feature has wrong size"
344
+
345
+ x = x.view(B, H, W, C)
346
+
347
+ # padding
348
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
349
+ if pad_input:
350
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
351
+
352
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
353
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
354
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
355
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
356
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
357
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
358
+
359
+ x = self.norm(x)
360
+ x = self.reduction(x)
361
+
362
+ return x
363
+
364
+
365
+ class BasicLayer(nn.Module):
366
+ """A basic Swin Transformer layer for one stage.
367
+ Args:
368
+ dim (int): Number of feature channels
369
+ depth (int): Depths of this stage.
370
+ num_heads (int): Number of attention head.
371
+ window_size (int): Local window size. Default: 7.
372
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
373
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
374
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
375
+ drop (float, optional): Dropout rate. Default: 0.0
376
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
377
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
378
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
379
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
380
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ dim,
386
+ depth,
387
+ num_heads,
388
+ window_size=7,
389
+ mlp_ratio=4.0,
390
+ qkv_bias=True,
391
+ qk_scale=None,
392
+ drop=0.0,
393
+ attn_drop=0.0,
394
+ drop_path=0.0,
395
+ norm_layer=nn.LayerNorm,
396
+ downsample=None,
397
+ use_checkpoint=False,
398
+ ):
399
+ super().__init__()
400
+ self.window_size = window_size
401
+ self.shift_size = window_size // 2
402
+ self.depth = depth
403
+ self.use_checkpoint = use_checkpoint
404
+
405
+ # build blocks
406
+ self.blocks = nn.ModuleList(
407
+ [
408
+ SwinTransformerBlock(
409
+ dim=dim,
410
+ num_heads=num_heads,
411
+ window_size=window_size,
412
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
413
+ mlp_ratio=mlp_ratio,
414
+ qkv_bias=qkv_bias,
415
+ qk_scale=qk_scale,
416
+ drop=drop,
417
+ attn_drop=attn_drop,
418
+ drop_path=(
419
+ drop_path[i] if isinstance(drop_path, list) else drop_path
420
+ ),
421
+ norm_layer=norm_layer,
422
+ )
423
+ for i in range(depth)
424
+ ]
425
+ )
426
+
427
+ # patch merging layer
428
+ if downsample is not None:
429
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
430
+ else:
431
+ self.downsample = None
432
+
433
+ def forward(self, x, H, W):
434
+ """Forward function.
435
+ Args:
436
+ x: Input feature, tensor size (B, H*W, C).
437
+ H, W: Spatial resolution of the input feature.
438
+ """
439
+
440
+ # calculate attention mask for SW-MSA
441
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
442
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
443
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
444
+ h_slices = (
445
+ slice(0, -self.window_size),
446
+ slice(-self.window_size, -self.shift_size),
447
+ slice(-self.shift_size, None),
448
+ )
449
+ w_slices = (
450
+ slice(0, -self.window_size),
451
+ slice(-self.window_size, -self.shift_size),
452
+ slice(-self.shift_size, None),
453
+ )
454
+ cnt = 0
455
+ for h in h_slices:
456
+ for w in w_slices:
457
+ img_mask[:, h, w, :] = cnt
458
+ cnt += 1
459
+
460
+ mask_windows = window_partition(
461
+ img_mask, self.window_size
462
+ ) # nW, window_size, window_size, 1
463
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
464
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
465
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
466
+ attn_mask == 0, float(0.0)
467
+ )
468
+
469
+ for blk in self.blocks:
470
+ blk.H, blk.W = H, W
471
+ if self.use_checkpoint:
472
+ x = checkpoint.checkpoint(blk, x, attn_mask)
473
+ else:
474
+ x = blk(x, attn_mask)
475
+ if self.downsample is not None:
476
+ x_down = self.downsample(x, H, W)
477
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
478
+ return x, H, W, x_down, Wh, Ww
479
+ else:
480
+ return x, H, W, x, H, W
481
+
482
+
483
+ class PatchEmbed(nn.Module):
484
+ """Image to Patch Embedding
485
+ Args:
486
+ patch_size (int): Patch token size. Default: 4.
487
+ in_chans (int): Number of input image channels. Default: 3.
488
+ embed_dim (int): Number of linear projection output channels. Default: 96.
489
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
490
+ """
491
+
492
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
493
+ super().__init__()
494
+ patch_size = to_2tuple(patch_size)
495
+ self.patch_size = patch_size
496
+
497
+ self.in_chans = in_chans
498
+ self.embed_dim = embed_dim
499
+
500
+ self.proj = nn.Conv2d(
501
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
502
+ )
503
+ if norm_layer is not None:
504
+ self.norm = norm_layer(embed_dim)
505
+ else:
506
+ self.norm = None
507
+
508
+ def forward(self, x):
509
+ """Forward function."""
510
+ # padding
511
+ _, _, H, W = x.size()
512
+ if W % self.patch_size[1] != 0:
513
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
514
+ if H % self.patch_size[0] != 0:
515
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
516
+
517
+ x = self.proj(x) # B C Wh Ww
518
+ if self.norm is not None:
519
+ Wh, Ww = x.size(2), x.size(3)
520
+ x = x.flatten(2).transpose(1, 2)
521
+ x = self.norm(x)
522
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
523
+
524
+ return x
525
+
526
+
527
+ @BACKBONES.register_module()
528
+ class SwinTransformer(nn.Module):
529
+ """Swin Transformer backbone.
530
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
531
+ https://arxiv.org/pdf/2103.14030
532
+ Args:
533
+ pretrain_img_size (int): Input image size for training the pretrained model,
534
+ used in absolute postion embedding. Default 224.
535
+ patch_size (int | tuple(int)): Patch size. Default: 4.
536
+ in_chans (int): Number of input image channels. Default: 3.
537
+ embed_dim (int): Number of linear projection output channels. Default: 96.
538
+ depths (tuple[int]): Depths of each Swin Transformer stage.
539
+ num_heads (tuple[int]): Number of attention head of each stage.
540
+ window_size (int): Window size. Default: 7.
541
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
542
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
543
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
544
+ drop_rate (float): Dropout rate.
545
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
546
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
547
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
548
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
549
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
550
+ out_indices (Sequence[int]): Output from which stages.
551
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
552
+ -1 means not freezing any parameters.
553
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
554
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
555
+ """
556
+
557
+ def __init__(
558
+ self,
559
+ pretrain_img_size=224,
560
+ patch_size=4,
561
+ in_chans=3,
562
+ embed_dim=96,
563
+ depths=[2, 2, 6, 2],
564
+ num_heads=[3, 6, 12, 24],
565
+ window_size=7,
566
+ mlp_ratio=4.0,
567
+ qkv_bias=True,
568
+ qk_scale=None,
569
+ drop_rate=0.0,
570
+ attn_drop_rate=0.0,
571
+ drop_path_rate=0.2,
572
+ norm_layer=nn.LayerNorm,
573
+ ape=False,
574
+ patch_norm=True,
575
+ out_indices=(0, 1, 2, 3),
576
+ frozen_stages=-1,
577
+ dilation=False,
578
+ use_checkpoint=False,
579
+ ):
580
+ super().__init__()
581
+
582
+ self.pretrain_img_size = pretrain_img_size
583
+ self.num_layers = len(depths)
584
+ self.embed_dim = embed_dim
585
+ self.ape = ape
586
+ self.patch_norm = patch_norm
587
+ self.out_indices = out_indices
588
+ self.frozen_stages = frozen_stages
589
+ self.dilation = dilation
590
+
591
+ if use_checkpoint:
592
+ print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
593
+
594
+ # split image into non-overlapping patches
595
+ self.patch_embed = PatchEmbed(
596
+ patch_size=patch_size,
597
+ in_chans=in_chans,
598
+ embed_dim=embed_dim,
599
+ norm_layer=norm_layer if self.patch_norm else None,
600
+ )
601
+
602
+ # absolute position embedding
603
+ if self.ape:
604
+ pretrain_img_size = to_2tuple(pretrain_img_size)
605
+ patch_size = to_2tuple(patch_size)
606
+ patches_resolution = [
607
+ pretrain_img_size[0] // patch_size[0],
608
+ pretrain_img_size[1] // patch_size[1],
609
+ ]
610
+
611
+ self.absolute_pos_embed = nn.Parameter(
612
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
613
+ )
614
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
615
+
616
+ self.pos_drop = nn.Dropout(p=drop_rate)
617
+
618
+ # stochastic depth
619
+ dpr = [
620
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
621
+ ] # stochastic depth decay rule
622
+
623
+ # build layers
624
+ self.layers = nn.ModuleList()
625
+ # prepare downsample list
626
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
627
+ downsamplelist[-1] = None
628
+ num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
629
+ if self.dilation:
630
+ downsamplelist[-2] = None
631
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
632
+ for i_layer in range(self.num_layers):
633
+ layer = BasicLayer(
634
+ # dim=int(embed_dim * 2 ** i_layer),
635
+ dim=num_features[i_layer],
636
+ depth=depths[i_layer],
637
+ num_heads=num_heads[i_layer],
638
+ window_size=window_size,
639
+ mlp_ratio=mlp_ratio,
640
+ qkv_bias=qkv_bias,
641
+ qk_scale=qk_scale,
642
+ drop=drop_rate,
643
+ attn_drop=attn_drop_rate,
644
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
645
+ norm_layer=norm_layer,
646
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
647
+ downsample=downsamplelist[i_layer],
648
+ use_checkpoint=use_checkpoint,
649
+ )
650
+ self.layers.append(layer)
651
+
652
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
653
+ self.num_features = num_features
654
+
655
+ # add a norm layer for each output
656
+ for i_layer in out_indices:
657
+ layer = norm_layer(num_features[i_layer])
658
+ layer_name = f"norm{i_layer}"
659
+ self.add_module(layer_name, layer)
660
+
661
+ self._freeze_stages()
662
+
663
+ def _freeze_stages(self):
664
+ if self.frozen_stages >= 0:
665
+ self.patch_embed.eval()
666
+ for param in self.patch_embed.parameters():
667
+ param.requires_grad = False
668
+
669
+ if self.frozen_stages >= 1 and self.ape:
670
+ self.absolute_pos_embed.requires_grad = False
671
+
672
+ if self.frozen_stages >= 2:
673
+ self.pos_drop.eval()
674
+ for i in range(0, self.frozen_stages - 1):
675
+ m = self.layers[i]
676
+ m.eval()
677
+ for param in m.parameters():
678
+ param.requires_grad = False
679
+
680
+ def forward_raw(self, x: torch.Tensor) -> List[torch.Tensor]:
681
+ """Forward function."""
682
+ x = self.patch_embed(x)
683
+
684
+ Wh, Ww = x.size(2), x.size(3)
685
+ if self.ape:
686
+ # interpolate the position embedding to the corresponding size
687
+ absolute_pos_embed = F.interpolate(
688
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
689
+ )
690
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
691
+ else:
692
+ x = x.flatten(2).transpose(1, 2)
693
+ x = self.pos_drop(x)
694
+
695
+ outs = []
696
+ for i in range(self.num_layers):
697
+ layer = self.layers[i]
698
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
699
+ # import ipdb; ipdb.set_trace()
700
+
701
+ if i in self.out_indices:
702
+ norm_layer = getattr(self, f"norm{i}")
703
+ x_out = norm_layer(x_out)
704
+
705
+ out = (
706
+ x_out.view(-1, H, W, self.num_features[i])
707
+ .permute(0, 3, 1, 2)
708
+ .contiguous()
709
+ )
710
+ outs.append(out)
711
+
712
+ return tuple(outs)
713
+
714
+ def forward(self, tensor_list: NestedTensor) -> Dict:
715
+ """Forward function.
716
+
717
+ Args:
718
+ tensor_list (NestedTensor): NestedTensor object containing tensors and masks.
719
+
720
+ Returns:
721
+ Dict: Dict containing output tensors. The structure is as follows.
722
+ - 0: NestedTensor from stage 0.
723
+ - 1: NestedTensor from stage 1.
724
+ - 2: NestedTensor from stage 2.
725
+ - 3: NestedTensor from stage 3.
726
+ """
727
+ x = tensor_list.tensors
728
+
729
+ x = self.patch_embed(x)
730
+
731
+ Wh, Ww = x.size(2), x.size(3)
732
+ if self.ape:
733
+ # interpolate the position embedding to the corresponding size
734
+ absolute_pos_embed = F.interpolate(
735
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
736
+ )
737
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
738
+ else:
739
+ x = x.flatten(2).transpose(1, 2)
740
+ x = self.pos_drop(x)
741
+
742
+ outs = []
743
+ for i in range(self.num_layers):
744
+ layer = self.layers[i]
745
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
746
+
747
+ if i in self.out_indices:
748
+ norm_layer = getattr(self, f"norm{i}")
749
+ x_out = norm_layer(x_out)
750
+
751
+ out = (
752
+ x_out.view(-1, H, W, self.num_features[i])
753
+ .permute(0, 3, 1, 2)
754
+ .contiguous()
755
+ )
756
+ outs.append(out)
757
+
758
+ # collect for nesttensors
759
+ outs_dict = {}
760
+ for idx, out_i in enumerate(outs):
761
+ m = tensor_list.mask
762
+ assert m is not None
763
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[
764
+ 0
765
+ ]
766
+ outs_dict[idx] = NestedTensor(out_i, mask)
767
+
768
+ return outs_dict
769
+
770
+ def train(self, mode=True):
771
+ """Convert the model into training mode while keep layers freezed."""
772
+ super(SwinTransformer, self).train(mode)
773
+ self._freeze_stages()
774
+
775
+
776
+ def build_swin_transformer(modelname, pretrain_img_size, **kw):
777
+ assert modelname in [
778
+ "swin_T_224_1k",
779
+ "swin_B_224_22k",
780
+ "swin_B_384_22k",
781
+ "swin_L_224_22k",
782
+ "swin_L_384_22k",
783
+ ]
784
+
785
+ model_para_dict = {
786
+ "swin_T_224_1k": dict(
787
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7
788
+ ),
789
+ "swin_B_224_22k": dict(
790
+ embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=7
791
+ ),
792
+ "swin_B_384_22k": dict(
793
+ embed_dim=128,
794
+ depths=[2, 2, 18, 2],
795
+ num_heads=[4, 8, 16, 32],
796
+ window_size=12,
797
+ ),
798
+ "swin_L_224_22k": dict(
799
+ embed_dim=192,
800
+ depths=[2, 2, 18, 2],
801
+ num_heads=[6, 12, 24, 48],
802
+ window_size=7,
803
+ ),
804
+ "swin_L_384_22k": dict(
805
+ embed_dim=192,
806
+ depths=[2, 2, 18, 2],
807
+ num_heads=[6, 12, 24, 48],
808
+ window_size=12,
809
+ ),
810
+ }
811
+ kw_cgf = model_para_dict[modelname]
812
+ kw_cgf.update(kw)
813
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
814
+ return model
detect_tools/upn/models/backbone/wrapper.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from detect_tools.upn import BACKBONES, build_backbone, build_position_embedding
7
+ from detect_tools.upn.models.module import NestedTensor
8
+ from detect_tools.upn.models.utils import clean_state_dict
9
+
10
+
11
+ class FrozenBatchNorm2d(torch.nn.Module):
12
+ """
13
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
14
+
15
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
16
+ without which any other models than torchvision.models.resnet[18,34,50,101]
17
+ produce nans.
18
+ """
19
+
20
+ def __init__(self, n):
21
+ super(FrozenBatchNorm2d, self).__init__()
22
+ self.register_buffer("weight", torch.ones(n))
23
+ self.register_buffer("bias", torch.zeros(n))
24
+ self.register_buffer("running_mean", torch.zeros(n))
25
+ self.register_buffer("running_var", torch.ones(n))
26
+
27
+ def _load_from_state_dict(
28
+ self,
29
+ state_dict,
30
+ prefix,
31
+ local_metadata,
32
+ strict,
33
+ missing_keys,
34
+ unexpected_keys,
35
+ error_msgs,
36
+ ):
37
+ num_batches_tracked_key = prefix + "num_batches_tracked"
38
+ if num_batches_tracked_key in state_dict:
39
+ del state_dict[num_batches_tracked_key]
40
+
41
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
42
+ state_dict,
43
+ prefix,
44
+ local_metadata,
45
+ strict,
46
+ missing_keys,
47
+ unexpected_keys,
48
+ error_msgs,
49
+ )
50
+
51
+ def forward(self, x):
52
+ # move reshapes to the beginning
53
+ # to make it fuser-friendly
54
+ w = self.weight.reshape(1, -1, 1, 1)
55
+ b = self.bias.reshape(1, -1, 1, 1)
56
+ rv = self.running_var.reshape(1, -1, 1, 1)
57
+ rm = self.running_mean.reshape(1, -1, 1, 1)
58
+ eps = 1e-5
59
+ scale = w * (rv + eps).rsqrt()
60
+ bias = b - rm * scale
61
+ return x * scale + bias
62
+
63
+
64
+ class Joiner(nn.Module):
65
+ """A wrapper for the backbone and the position embedding.
66
+
67
+ Args:
68
+ backbone_cfg (Dict): Config dict to build backbone.
69
+ position_embedding_cfg (Dict): Config dict to build position embedding.
70
+ """
71
+
72
+ def __init__(self, backbone: nn.Module, position_embedding: nn.Module) -> None:
73
+ super().__init__()
74
+ self.backbone = backbone
75
+ self.pos_embed = position_embedding
76
+
77
+ def forward(
78
+ self, tensor_list: NestedTensor
79
+ ) -> Union[List[NestedTensor], List[torch.Tensor]]:
80
+ """Forward function.
81
+
82
+ Args:
83
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
84
+
85
+ Returns:
86
+ [List[NestedTensor]: A list of feature map in NestedTensor format.
87
+ List[torch.Tensor]: A list of position encoding.
88
+ """
89
+
90
+ xs = self.backbone(tensor_list)
91
+ out: List[NestedTensor] = []
92
+ pos = []
93
+ for layer_idx, x in xs.items():
94
+ out.append(x)
95
+ # position encoding
96
+ pos.append(self.pos_embed(x).to(x.tensors.dtype))
97
+
98
+ return out, pos
99
+
100
+ def forward_pos_embed_only(self, x: NestedTensor) -> torch.Tensor:
101
+ """Forward function for position embedding only. This is used to generate additional layer
102
+
103
+ Args:
104
+ x (NestedTensor): NestedTensor wrapping the input tensor.
105
+
106
+ Returns:
107
+ [List[torch.Tensor]: A list of position encoding.
108
+ """
109
+ return self.pos_embed(x)
110
+
111
+
112
+ @BACKBONES.register_module()
113
+ class SwinWrapper(nn.Module):
114
+ """A wrapper for swin transformer.
115
+
116
+ Args:
117
+ backbone_cfg Union[Dict, str]: Config dict to build backbone. If given a str name, we
118
+ will call `get_swin_config` to get the config dict.
119
+ dilation (bool): Whether to use dilation in stage 4.
120
+ position_embedding_cfg (Dict): Config dict to build position embedding.
121
+ lr_backbone (float): Learning rate of the backbone.
122
+ return_interm_layers (List[int]): Which layers to return.
123
+ backbone_freeze_keywords (List[str]): List of keywords to freeze the backbone.
124
+ use_checkpoint (bool): Whether to use checkpoint. Default: False.
125
+ ckpt_path (str): Checkpoint path. Default: None.
126
+ use_pretrained_ckpt (bool): Whether to use pretrained checkpoint. Default: True.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ backbone_cfg: Union[Dict, str],
132
+ dilation: bool,
133
+ position_embedding_cfg: Dict,
134
+ lr_backbone: float,
135
+ return_interm_indices: List[int],
136
+ backbone_freeze_keywords: List[str],
137
+ use_checkpoint: bool = False,
138
+ backbone_ckpt_path: str = None,
139
+ ) -> None:
140
+ super(SwinWrapper, self).__init__()
141
+ pos_embedding = build_position_embedding(position_embedding_cfg)
142
+ train_backbone = lr_backbone > 0
143
+ if not train_backbone:
144
+ raise ValueError("Please set lr_backbone > 0")
145
+ assert return_interm_indices in [[0, 1, 2, 3], [1, 2, 3], [3]]
146
+
147
+ # build backbone
148
+ if isinstance(backbone_cfg, str):
149
+ assert (
150
+ backbone_cfg
151
+ in backbone_cfg
152
+ in [
153
+ "swin_T_224_1k",
154
+ "swin_B_224_22k",
155
+ "swin_B_384_22k",
156
+ "swin_L_224_22k",
157
+ "swin_L_384_22k",
158
+ ]
159
+ )
160
+ pretrain_img_size = int(backbone_cfg.split("_")[-2])
161
+ backbone_cfg = get_swin_config(
162
+ backbone_cfg,
163
+ pretrain_img_size,
164
+ out_indices=tuple(return_interm_indices),
165
+ dilation=dilation,
166
+ use_checkpoint=use_checkpoint,
167
+ )
168
+ backbone = build_backbone(backbone_cfg)
169
+
170
+ # freeze some layers
171
+ if backbone_freeze_keywords is not None:
172
+ for name, parameter in backbone.named_parameters():
173
+ for keyword in backbone_freeze_keywords:
174
+ if keyword in name:
175
+ parameter.requires_grad_(False)
176
+ break
177
+
178
+ # load checkpoint
179
+ if backbone_ckpt_path is not None:
180
+ print("Loading backbone checkpoint from {}".format(backbone_ckpt_path))
181
+ checkpoint = torch.load(backbone_ckpt_path, map_location="cpu")["model"]
182
+ from collections import OrderedDict
183
+
184
+ def key_select_function(keyname):
185
+ if "head" in keyname:
186
+ return False
187
+ if dilation and "layers.3" in keyname:
188
+ return False
189
+ return True
190
+
191
+ _tmp_st = OrderedDict(
192
+ {
193
+ k: v
194
+ for k, v in clean_state_dict(checkpoint).items()
195
+ if key_select_function(k)
196
+ }
197
+ )
198
+ _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
199
+ print(str(_tmp_st_output))
200
+
201
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices) :]
202
+ assert len(bb_num_channels) == len(
203
+ return_interm_indices
204
+ ), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
205
+
206
+ model = Joiner(backbone, pos_embedding)
207
+ model.num_channels = bb_num_channels
208
+ self.num_channels = bb_num_channels
209
+ self.model = model
210
+
211
+ def forward(
212
+ self, tensor_list: NestedTensor
213
+ ) -> Union[List[NestedTensor], List[torch.Tensor]]:
214
+ """Forward function.
215
+
216
+ Args:
217
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
218
+
219
+ Returns:
220
+ [List[NestedTensor]: A list of feature map in NestedTensor format.
221
+ List[torch.Tensor]: A list of position encoding.
222
+ """
223
+
224
+ return self.model(tensor_list)
225
+
226
+ def forward_pos_embed_only(self, tensor_list: NestedTensor) -> torch.Tensor:
227
+ """Forward function to get position embedding only.
228
+
229
+ Args:
230
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
231
+
232
+ Returns:
233
+ torch.Tensor: Position embedding.
234
+ """
235
+ return self.model.forward_pos_embed_only(tensor_list)
236
+
237
+
238
+ def get_swin_config(modelname: str, pretrain_img_size: Tuple[int, int], **kw):
239
+ """Get swin config dict.
240
+
241
+ Args:
242
+ modelname (str): Name of the model.
243
+ pretrain_img_size (Tuple[int, int]): Image size of the pretrain model.
244
+ kw (Dict): Other key word arguments.
245
+
246
+ Returns:
247
+ Dict: Config dict.
248
+ str: Path to the pretrained checkpoint.
249
+ """
250
+ assert modelname in [
251
+ "swin_T_224_1k",
252
+ "swin_B_224_22k",
253
+ "swin_B_384_22k",
254
+ "swin_L_224_22k",
255
+ "swin_L_384_22k",
256
+ ]
257
+ model_para_dict = {
258
+ "swin_T_224_1k": dict(
259
+ type="SwinTransformer",
260
+ embed_dim=96,
261
+ depths=[2, 2, 6, 2],
262
+ num_heads=[3, 6, 12, 24],
263
+ window_size=7,
264
+ ),
265
+ "swin_B_224_22k": dict(
266
+ type="SwinTransformer",
267
+ embed_dim=128,
268
+ depths=[2, 2, 18, 2],
269
+ num_heads=[4, 8, 16, 32],
270
+ window_size=7,
271
+ ),
272
+ "swin_B_384_22k": dict(
273
+ type="SwinTransformer",
274
+ embed_dim=128,
275
+ depths=[2, 2, 18, 2],
276
+ num_heads=[4, 8, 16, 32],
277
+ window_size=12,
278
+ ),
279
+ "swin_L_224_22k": dict(
280
+ type="SwinTransformer",
281
+ embed_dim=192,
282
+ depths=[2, 2, 18, 2],
283
+ num_heads=[6, 12, 24, 48],
284
+ window_size=7,
285
+ ),
286
+ "swin_L_384_22k": dict(
287
+ type="SwinTransformer",
288
+ embed_dim=192,
289
+ depths=[2, 2, 18, 2],
290
+ num_heads=[6, 12, 24, 48],
291
+ window_size=12,
292
+ ),
293
+ }
294
+ kw_cgf = model_para_dict[modelname]
295
+ kw_cgf.update(kw)
296
+ kw_cgf.update(dict(pretrain_img_size=pretrain_img_size))
297
+ return kw_cgf
detect_tools/upn/models/decoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .upn_decoder import UPNDecoder, DeformableTransformerDecoderLayer
2
+
3
+ __all__ = ["UPNDecoder", "DeformableTransformerDecoderLayer"]
detect_tools/upn/models/decoder/upn_decoder.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from detect_tools.upn import DECODERS, build_decoder
7
+ from detect_tools.upn.models.module import MLP
8
+ from detect_tools.upn.models.utils import (gen_sineembed_for_position,
9
+ get_activation_fn, get_clones,
10
+ inverse_sigmoid)
11
+ from detect_tools.upn.ops.modules import MSDeformAttn
12
+
13
+
14
+ @DECODERS.register_module()
15
+ class DeformableTransformerDecoderLayer(nn.Module):
16
+ """Deformable Transformer Decoder Layer. This is a modified version in Grounding DINO.
17
+ After the query is attented to the image feature, it is further attented to the text feature.
18
+ The execute order is: self_attn -> cross_attn to text -> cross_attn to image -> ffn
19
+ Args:
20
+ d_model (int): The dimension of keys/values/queries in :class:`MultiheadAttention`.
21
+ d_ffn (int): The dimension of the feedforward network model.
22
+ dropout (float): Probability of an element to be zeroed.
23
+ activation (str): Activation function in the feedforward network.
24
+ 'relu' and 'gelu' are supported.
25
+ n_levels (int): The number of levels in Multi-Scale Deformable Attention.
26
+ n_heads (int): Parallel attention heads.
27
+ n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
28
+ ffn_extra_layernorm (bool): If True, add an extra layernorm after ffn.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ d_model: int = 256,
34
+ d_ffn: int = 1024,
35
+ dropout: float = 0.1,
36
+ activation: str = "relu",
37
+ n_levels: int = 4,
38
+ n_heads: int = 8,
39
+ n_points: int = 4,
40
+ ffn_extra_layernorm: bool = False,
41
+ ) -> None:
42
+ super().__init__()
43
+
44
+ # cross attention for visual features
45
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
46
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
47
+ self.norm1 = nn.LayerNorm(d_model)
48
+
49
+ # self attention for query
50
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
51
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
52
+ self.norm2 = nn.LayerNorm(d_model)
53
+
54
+ # ffn
55
+ self.linear1 = nn.Linear(d_model, d_ffn)
56
+ self.activation = get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
57
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
58
+ self.linear2 = nn.Linear(d_ffn, d_model)
59
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
60
+ self.norm3 = nn.LayerNorm(d_model)
61
+ if ffn_extra_layernorm:
62
+ raise NotImplementedError("ffn_extra_layernorm not implemented")
63
+ self.norm_ext = nn.LayerNorm(d_ffn)
64
+ else:
65
+ self.norm_ext = None
66
+
67
+ self.key_aware_proj = None
68
+
69
+ def rm_self_attn_modules(self):
70
+ self.self_attn = None
71
+ self.dropout2 = None
72
+ self.norm2 = None
73
+
74
+ @staticmethod
75
+ def with_pos_embed(tensor, pos):
76
+ return tensor if pos is None else tensor + pos
77
+
78
+ def forward_ffn(self, tgt):
79
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
80
+
81
+ tgt = tgt + self.dropout4(tgt2)
82
+ tgt = self.norm3(tgt)
83
+ return tgt
84
+
85
+ def forward(
86
+ self,
87
+ tgt: torch.Tensor,
88
+ tgt_query_pos: torch.Tensor = None,
89
+ tgt_reference_points: torch.Tensor = None,
90
+ memory: torch.Tensor = None,
91
+ memory_key_padding_mask: torch.Tensor = None,
92
+ memory_level_start_index: torch.Tensor = None,
93
+ memory_spatial_shapes: torch.Tensor = None,
94
+ self_attn_mask: torch.Tensor = None,
95
+ cross_attn_mask: torch.Tensor = None,
96
+ ) -> torch.Tensor:
97
+ """Forward function
98
+
99
+ Args:
100
+ tgt (torch.Tensor): Input target in shape (B, T, C)
101
+ tgt_query_pos (torch.Tensor): Positional encoding of the query.
102
+ tgt_query_sine_embed (torch.Tensor): Sine positional encoding of the query. Unused.
103
+ tgt_key_padding_mask (torch.Tensor): Mask for target feature in shape (B, T).
104
+ tgt_reference_points (torch.Tensor): Reference points for the query in shape (B, T, 4).
105
+ memory_text (torch.Tensor): Input text embeddings in shape (B, num_token, C).
106
+ text_attention_mask (torch.Tensor): Attention mask for text embeddings in shape
107
+ (B, num_token).
108
+ memory (torch.Tensor): Input image feature in shape (B, HW, C)
109
+ memory_key_padding_mask (torch.Tensor): Mask for image feature in shape (B, HW)
110
+ memory_level_start_index (torch.Tensor): Starting index of each level in memory.
111
+ memory_spatial_shapes (torch.Tensor): Spatial shape of each level in memory.
112
+ memory_pos (torch.Tensor): Positional encoding of memory. Unused.
113
+ self_attn_mask (torch.Tensor): Mask used for self-attention.
114
+ cross_attn_mask (torch.Tensor): Mask used for cross-attention.
115
+
116
+ Returns:
117
+ torch.Tensor: Output tensor in shape (B, T, C)
118
+ """
119
+ assert cross_attn_mask is None
120
+
121
+ # self attention
122
+ if self.self_attn is not None:
123
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
124
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
125
+ tgt = tgt + self.dropout2(tgt2)
126
+ tgt = self.norm2(tgt)
127
+
128
+ # attend to image features
129
+ tgt2 = self.cross_attn(
130
+ self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
131
+ tgt_reference_points.transpose(0, 1).contiguous(),
132
+ memory.transpose(0, 1),
133
+ memory_spatial_shapes,
134
+ memory_level_start_index,
135
+ memory_key_padding_mask,
136
+ ).transpose(0, 1)
137
+ tgt = tgt + self.dropout1(tgt2)
138
+ tgt = self.norm1(tgt)
139
+ # ffn
140
+ tgt = self.forward_ffn(tgt)
141
+
142
+ return tgt
143
+
144
+
145
+ @DECODERS.register_module()
146
+ class UPNDecoder(nn.Module):
147
+ """Decoder used in UPN. Each layer is a DeformableTransformerDecoderLayer. The query
148
+ will be abled to attend the image feature and text feature. The execute order is:
149
+ self_attn -> cross_attn to image -> ffn
150
+
151
+ Args:
152
+ decoder_layer_cfg (Dict): Config for the DeformableTransformerDecoderLayer.
153
+ num_layers (int): number of layers
154
+ norm (nn.Module, optional): normalization layer. Defaults to None.
155
+ return_intermediate (bool, optional): whether return intermediate results.
156
+ Defaults to False.
157
+ d_model (int, optional): dimension of the model. Defaults to 256.
158
+ query_dim (int, optional): dimension of the query. Defaults to 4.
159
+ modulate_hw_attn (bool, optional): whether modulate the attention weights
160
+ by the height and width of the image feature. Defaults to False.
161
+ num_feature_levels (int, optional): number of feature levels. Defaults to 1.
162
+ deformable_decoder (bool, optional): whether use deformable decoder. Defaults to False.
163
+ decoder_query_perturber ([type], optional): [description]. Defaults to None.
164
+ dec_layer_number ([type], optional): [description]. Defaults to None.
165
+ rm_dec_query_scale (bool, optional): [description]. Defaults to False.
166
+ dec_layer_share (bool, optional): [description]. Defaults to False.
167
+ dec_layer_dropout_prob ([type], optional): [description]. Defaults to None.
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ decoder_layer_cfg: Dict,
173
+ num_layers: int,
174
+ norm: str = "layernorm",
175
+ return_intermediate: bool = True,
176
+ d_model: int = 256,
177
+ query_dim: int = 4,
178
+ modulate_hw_attn: bool = False,
179
+ num_feature_levels: int = 1,
180
+ deformable_decoder: bool = True,
181
+ decoder_query_perturber=None,
182
+ dec_layer_number=None,
183
+ rm_dec_query_scale: bool = True,
184
+ dec_layer_share: bool = False,
185
+ dec_layer_dropout_prob=None,
186
+ use_detached_boxes_dec_out: bool = False,
187
+ ):
188
+ super().__init__()
189
+
190
+ decoder_layer = build_decoder(decoder_layer_cfg)
191
+ if num_layers > 0:
192
+ self.layers = get_clones(
193
+ decoder_layer, num_layers, layer_share=dec_layer_share
194
+ )
195
+ else:
196
+ self.layers = []
197
+ self.num_layers = num_layers
198
+ if norm == "layernorm":
199
+ self.norm = nn.LayerNorm(d_model)
200
+ self.return_intermediate = return_intermediate
201
+ self.query_dim = query_dim
202
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
203
+ self.num_feature_levels = num_feature_levels
204
+ self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
205
+
206
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
207
+ self.ref_point_head_point = MLP(
208
+ d_model, d_model, d_model, 2
209
+ ) # for point reference only
210
+ if not deformable_decoder:
211
+ self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
212
+ else:
213
+ self.query_pos_sine_scale = None
214
+
215
+ if rm_dec_query_scale:
216
+ self.query_scale = None
217
+ else:
218
+ raise NotImplementedError
219
+ self.query_scale = MLP(d_model, d_model, d_model, 2)
220
+ self.bbox_embed = None
221
+ self.class_embed = None
222
+
223
+ self.d_model = d_model
224
+ self.modulate_hw_attn = modulate_hw_attn
225
+ self.deformable_decoder = deformable_decoder
226
+
227
+ if not deformable_decoder and modulate_hw_attn:
228
+ self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
229
+ else:
230
+ self.ref_anchor_head = None
231
+
232
+ self.decoder_query_perturber = decoder_query_perturber
233
+ self.box_pred_damping = None
234
+
235
+ self.dec_layer_number = dec_layer_number
236
+ if dec_layer_number is not None:
237
+ assert isinstance(dec_layer_number, list)
238
+ assert len(dec_layer_number) == num_layers
239
+
240
+ self.dec_layer_dropout_prob = dec_layer_dropout_prob
241
+ if dec_layer_dropout_prob is not None:
242
+ assert isinstance(dec_layer_dropout_prob, list)
243
+ assert len(dec_layer_dropout_prob) == num_layers
244
+ for i in dec_layer_dropout_prob:
245
+ assert 0.0 <= i <= 1.0
246
+
247
+ self.rm_detach = None
248
+
249
+ def forward(
250
+ self,
251
+ tgt: torch.Tensor,
252
+ memory: torch.Tensor,
253
+ tgt_mask: torch.Tensor = None,
254
+ memory_mask: torch.Tensor = None,
255
+ tgt_key_padding_mask: torch.Tensor = None,
256
+ memory_key_padding_mask: torch.Tensor = None,
257
+ pos: torch.Tensor = None,
258
+ refpoints_unsigmoid: torch.Tensor = None,
259
+ level_start_index: torch.Tensor = None,
260
+ spatial_shapes: torch.Tensor = None,
261
+ valid_ratios: torch.Tensor = None,
262
+ memory_ref_image: torch.Tensor = None,
263
+ refImg_padding_mask: torch.Tensor = None,
264
+ memory_visual_prompt: torch.Tensor = None,
265
+ ):
266
+ """Forward function.
267
+
268
+ Args:
269
+ tgt (torch.Tensor): target feature, [bs, num_queries, d_model]
270
+ memory (torch.Tensor): Image feature, [bs, hw, d_model]
271
+ tgt_mask (torch.Tensor, optional): target mask for attention. Defaults to None.
272
+ memory_mask (torch.Tensor, optional): image mask for attention. Defaults to None.
273
+ tgt_key_padding_mask (torch.Tensor, optional): target mask for padding. Defaults to None.
274
+ memory_key_padding_mask (torch.Tensor, optional): image mask for padding. Defaults to None.
275
+ pos (torch.Tensor, optional): query position embedding
276
+ refpoints_unsigmoid (torch.Tensor, optional): reference points. Defaults to None.
277
+ level_start_index (torch.Tensor, optional): start index of each level. Defaults to None.
278
+ spatial_shapes (torch.Tensor, optional): spatial shape of each level. Defaults to None.
279
+ valid_ratios (torch.Tensor, optional): valid ratio of each level. Defaults to None.
280
+ memory_ref_image (torch.Tensor, optional): reference image feature, [bs, num_ref, d_model]. Defaults to None.
281
+ refImg_padding_mask (torch.Tensor, optional): padding mask for attention. Defaults to None.
282
+ """
283
+ output = tgt
284
+
285
+ intermediate = []
286
+ reference_points = refpoints_unsigmoid.sigmoid()
287
+ ref_points = [reference_points]
288
+
289
+ for layer_id, layer in enumerate(self.layers):
290
+
291
+ if reference_points.shape[-1] == 4:
292
+ reference_points_input = (
293
+ reference_points[:, :, None]
294
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
295
+ ) # nq, bs, nlevel, 4
296
+ else:
297
+ assert reference_points.shape[-1] == 2
298
+ reference_points_input = (
299
+ reference_points[:, :, None] * valid_ratios[None, :]
300
+ )
301
+ query_sine_embed = gen_sineembed_for_position(
302
+ reference_points_input[:, :, 0, :]
303
+ ) # nq, bs, 256*2
304
+
305
+ # conditional query
306
+ if query_sine_embed.shape[-1] == 512:
307
+ raw_query_pos = (
308
+ self.ref_point_head(query_sine_embed)
309
+ + self.ref_point_head_point(
310
+ torch.zeros_like(query_sine_embed)[:, :, :256]
311
+ )
312
+ * 0.0
313
+ )
314
+ else:
315
+ raw_query_pos = (
316
+ self.ref_point_head_point(query_sine_embed)
317
+ + self.ref_point_head(
318
+ torch.zeros(
319
+ query_sine_embed.shape[0],
320
+ query_sine_embed.shape[1],
321
+ 512,
322
+ device=query_sine_embed.device,
323
+ )
324
+ )
325
+ * 0.0
326
+ )
327
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
328
+ query_pos = pos_scale * raw_query_pos
329
+
330
+ # main process
331
+ output = layer(
332
+ tgt=output,
333
+ tgt_query_pos=query_pos,
334
+ tgt_reference_points=reference_points_input,
335
+ memory=memory,
336
+ memory_key_padding_mask=memory_key_padding_mask,
337
+ memory_level_start_index=level_start_index,
338
+ memory_spatial_shapes=spatial_shapes,
339
+ self_attn_mask=tgt_mask,
340
+ cross_attn_mask=memory_mask,
341
+ )
342
+ if output.isnan().any() | output.isinf().any():
343
+ print(f"output layer_id {layer_id} is nan")
344
+ try:
345
+ num_nan = output.isnan().sum().item()
346
+ num_inf = output.isinf().sum().item()
347
+ print(f"num_nan {num_nan}, num_inf {num_inf}")
348
+ except Exception as e:
349
+ print(e)
350
+
351
+ # iter update
352
+ if self.bbox_embed is not None:
353
+
354
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
355
+ delta_unsig = self.bbox_embed[layer_id](output)
356
+ outputs_unsig = delta_unsig + reference_before_sigmoid
357
+ new_reference_points = outputs_unsig.sigmoid()
358
+
359
+ if self.rm_detach and "dec" in self.rm_detach:
360
+ reference_points = new_reference_points
361
+ else:
362
+ reference_points = new_reference_points.detach()
363
+
364
+ if self.use_detached_boxes_dec_out:
365
+ ref_points.append(reference_points)
366
+ else:
367
+ ref_points.append(new_reference_points)
368
+
369
+ if self.return_intermediate:
370
+ intermediate.append(self.norm(output))
371
+
372
+ if self.return_intermediate:
373
+ return [
374
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
375
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points],
376
+ ]
377
+ else:
378
+ return self.norm(output).transpose(0, 1)
detect_tools/upn/models/encoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .upn_encoder import DeformableTransformerEncoderLayer, UPNEncoder
2
+
3
+ __all__ = ["UPNEncoder", "DeformableTransformerEncoderLayer"]
detect_tools/upn/models/encoder/upn_encoder.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.utils.checkpoint as checkpoint
6
+
7
+ from detect_tools.upn import ENCODERS, build_encoder
8
+ from detect_tools.upn.models.utils import get_activation_fn, get_clones
9
+ from detect_tools.upn.ops.modules import MSDeformAttn
10
+
11
+
12
+ @ENCODERS.register_module()
13
+ class DeformableTransformerEncoderLayer(nn.Module):
14
+ """Deformable Transformer Encoder Layer.
15
+
16
+ Args:
17
+ d_model (int): The dimension of keys/values/queries in
18
+ :class:`MultiheadAttention`.
19
+ d_ffn (int): The dimension of the feedforward network model.
20
+ dropout (float): Probability of an element to be zeroed.
21
+ activation (str): Activation function in the feedforward network.
22
+ 'relu' and 'gelu' are supported.
23
+ n_levels (int): The number of levels in Multi-Scale Deformable Attention.
24
+ n_heads (int): Parallel attention heads.
25
+ n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
26
+ add_channel_attention (bool): If True, add channel attention.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ d_model: int = 256,
32
+ d_ffn: int = 1024,
33
+ dropout: float = 0.1,
34
+ activation: str = "relu",
35
+ n_levels: int = 4,
36
+ n_heads: int = 8,
37
+ n_points: int = 4,
38
+ add_channel_attention: bool = False,
39
+ ) -> None:
40
+ super().__init__()
41
+
42
+ # self attention
43
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
44
+ self.dropout1 = nn.Dropout(dropout)
45
+ self.norm1 = nn.LayerNorm(d_model)
46
+
47
+ # ffn
48
+ self.linear1 = nn.Linear(d_model, d_ffn)
49
+ self.activation = get_activation_fn(activation, d_model=d_ffn)
50
+ self.dropout2 = nn.Dropout(dropout)
51
+ self.linear2 = nn.Linear(d_ffn, d_model)
52
+ self.dropout3 = nn.Dropout(dropout)
53
+ self.norm2 = nn.LayerNorm(d_model)
54
+
55
+ # channel attention
56
+ self.add_channel_attention = add_channel_attention
57
+ if add_channel_attention:
58
+ self.activ_channel = get_activation_fn("dyrelu", d_model=d_model)
59
+ self.norm_channel = nn.LayerNorm(d_model)
60
+
61
+ @staticmethod
62
+ def with_pos_embed(tensor, pos):
63
+ return tensor if pos is None else tensor + pos
64
+
65
+ def forward_ffn(self, src: torch.Tensor) -> torch.Tensor:
66
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
67
+ src = src + self.dropout3(src2)
68
+ src = self.norm2(src)
69
+ return src
70
+
71
+ def forward(
72
+ self,
73
+ src: torch.Tensor,
74
+ pos: torch.Tensor,
75
+ reference_points: torch.Tensor,
76
+ spatial_shapes: torch.Tensor,
77
+ level_start_index: torch.Tensor,
78
+ key_padding_mask: torch.Tensor = None,
79
+ ) -> torch.Tensor:
80
+ """Forward function for `DeformableTransformerEncoderLayer`.
81
+
82
+ Args:
83
+ src (torch.Tensor): The input sequence of shape (S, N, E).
84
+ pos (torch.Tensor): The position embedding of shape (S, N, E).
85
+ reference_points (torch.Tensor): The reference points of shape (N, L, 2).
86
+ spatial_shapes (torch.Tensor): The spatial shapes of feature levels.
87
+ level_start_index (torch.Tensor): The start index of each level.
88
+ key_padding_mask (torch.Tensor): The mask for keys with shape (N, S).
89
+ """
90
+ # self attention
91
+ # import ipdb; ipdb.set_trace()
92
+ src2 = self.self_attn(
93
+ self.with_pos_embed(src, pos),
94
+ reference_points,
95
+ src,
96
+ spatial_shapes,
97
+ level_start_index,
98
+ key_padding_mask,
99
+ )
100
+ src = src + self.dropout1(src2)
101
+ src = self.norm1(src)
102
+
103
+ # ffn
104
+ src = self.forward_ffn(src)
105
+
106
+ # channel attn
107
+ if self.add_channel_attention:
108
+ src = self.norm_channel(src + self.activ_channel(src))
109
+
110
+ return src
111
+
112
+
113
+ @ENCODERS.register_module()
114
+ class UPNEncoder(nn.Module):
115
+ """Implementation of UPN Encoder.
116
+
117
+ Args:
118
+ num_layers (int): The number of layers in the TransformerEncoder.
119
+ d_model (int, optional): The dimension of the input feature. Defaults to 256.
120
+ encoder_layer_cfg (Dict): Config for the DeformableEncoderLayer.
121
+ use_checkpoint (bool, optional): Whether to use checkpoint in the fusion layer for
122
+ memory saving. Defaults to False.
123
+ use_transformer_ckpt (bool, optional): Whether to use checkpoint for the deformableencoder.
124
+ enc_layer_share (bool, optional): Whether to share the same memory for the encoder_layer.
125
+ Defaults to False. This is used for all the sub-layers in the basic block.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ num_layers: int,
131
+ d_model: int = 256,
132
+ encoder_layer_cfg: Dict = None,
133
+ use_checkpoint: bool = True,
134
+ use_transformer_ckpt: bool = True,
135
+ enc_layer_share: bool = False,
136
+ multi_level_encoder_fusion: str = None,
137
+ ):
138
+ super().__init__()
139
+ # prepare layers
140
+ self.layers = []
141
+ self.refImg_layers = []
142
+ self.fusion_layers = []
143
+ encoder_layer = build_encoder(encoder_layer_cfg)
144
+
145
+ self.multi_level_encoder_fusion = multi_level_encoder_fusion
146
+ self._initilize_memory_fusion_layers(
147
+ multi_level_encoder_fusion, num_layers, d_model
148
+ )
149
+
150
+ if num_layers > 0:
151
+ self.layers = get_clones(
152
+ encoder_layer, num_layers, layer_share=enc_layer_share
153
+ )
154
+ else:
155
+ self.layers = []
156
+ del encoder_layer
157
+
158
+ self.query_scale = None
159
+ self.num_layers = num_layers
160
+ self.d_model = d_model
161
+
162
+ self.use_checkpoint = use_checkpoint
163
+ self.use_transformer_ckpt = use_transformer_ckpt
164
+
165
+ def _initilize_memory_fusion_layers(self, fusion_type, num_layers, d_model):
166
+ if fusion_type is None:
167
+ self.memory_fusion_layer = None
168
+ return
169
+
170
+ assert fusion_type in ["dense_net_fusion", "stable_dense_fusion"]
171
+ if fusion_type == "stable_dense_fusion":
172
+ self.memory_fusion_layer = nn.Sequential(
173
+ nn.Linear(d_model * (num_layers + 1), d_model),
174
+ nn.LayerNorm(d_model),
175
+ )
176
+ nn.init.constant_(self.memory_fusion_layer[0].bias, 0)
177
+ elif fusion_type == "dense_net_fusion":
178
+ self.memory_fusion_layer = nn.ModuleList()
179
+ for i in range(num_layers):
180
+ self.memory_fusion_layer.append(
181
+ nn.Sequential(
182
+ nn.Linear(
183
+ d_model * (i + 2), d_model
184
+ ), # from second encoder layer, 512 -> 256 / 3rd: 768 -> 256
185
+ nn.LayerNorm(d_model),
186
+ )
187
+ )
188
+ for layer in self.memory_fusion_layer:
189
+ nn.init.constant_(layer[0].bias, 0)
190
+ else:
191
+ raise NotImplementedError
192
+
193
+ @staticmethod
194
+ def get_reference_points(spatial_shapes, valid_ratios, device):
195
+ reference_points_list = []
196
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
197
+
198
+ ref_y, ref_x = torch.meshgrid(
199
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
200
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
201
+ )
202
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
203
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
204
+ ref = torch.stack((ref_x, ref_y), -1)
205
+ reference_points_list.append(ref)
206
+ reference_points = torch.cat(reference_points_list, 1)
207
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
208
+ return reference_points
209
+
210
+ def forward(
211
+ self,
212
+ src: torch.Tensor,
213
+ pos: torch.Tensor,
214
+ spatial_shapes: torch.Tensor,
215
+ level_start_index: torch.Tensor,
216
+ valid_ratios: torch.Tensor,
217
+ key_padding_mask: torch.Tensor = None,
218
+ ):
219
+ """Forward function
220
+
221
+ Args:
222
+ src (torch.Tensor): Flattened Image features in shape [bs, sum(hi*wi), 256]
223
+ pos (torch.Tensor): Position embedding for image feature in shape [bs, sum(hi*wi), 256]
224
+ spatial_shapes (torch.Tensor): Spatial shape of each level in shape [num_level, 2]
225
+ level_start_index (torch.Tensor): Start index of each level in shape [num_level]
226
+ valid_ratios (torch.Tensor): Valid ratio of each level in shape [bs, num_level, 2]
227
+ key_padding_mask (torch.Tensor): Padding mask for image feature in shape [bs, sum(hi*wi)]
228
+ memory_refImg (torch.Tensor, optional): Text feature in shape [bs, n_ref, 256]. Defaults
229
+ to None.
230
+ refImg_padding_mask (torch.Tensor, optional): Padding mask for reference image feature
231
+ in shape [bs, n_text]. Defaults to None.
232
+ pos_refImg (torch.Tensor, optional): Position embedding for reference image in shape
233
+ [bs, n_ref, 256]. Defaults to None.
234
+ refImg_self_attention_masks (torch.Tensor, optional): Self attention mask for reference
235
+ image feature in shape [bs, n_ref, n_ref]. Defaults to None.
236
+ Outpus:
237
+ torch.Tensor: Encoded image feature in shape [bs, sum(hi*wi), 256]
238
+ torch.Tensor: Encoded reference image feature in shape [bs, n_ref, 256]
239
+ """
240
+
241
+ output = src
242
+ # preparation and reshape
243
+ if self.num_layers > 0:
244
+ reference_points = self.get_reference_points(
245
+ spatial_shapes, valid_ratios, device=src.device
246
+ )
247
+
248
+ # multi-level dense fusion
249
+ output_list = [output]
250
+ # main process
251
+ for layer_id, layer in enumerate(self.layers):
252
+ # main process
253
+ if self.use_transformer_ckpt:
254
+ output = checkpoint.checkpoint(
255
+ layer,
256
+ output,
257
+ pos,
258
+ reference_points,
259
+ spatial_shapes,
260
+ level_start_index,
261
+ key_padding_mask,
262
+ )
263
+ else:
264
+ output = layer(
265
+ src=output,
266
+ pos=pos,
267
+ reference_points=reference_points,
268
+ spatial_shapes=spatial_shapes,
269
+ level_start_index=level_start_index,
270
+ key_padding_mask=key_padding_mask,
271
+ )
272
+
273
+ output_list.append(output)
274
+ if (
275
+ self.multi_level_encoder_fusion is not None
276
+ and self.multi_level_encoder_fusion == "dense_net_fusion"
277
+ ):
278
+ output = self.memory_fusion_layer[layer_id](
279
+ torch.cat(output_list, dim=-1)
280
+ )
281
+
282
+ if (
283
+ self.multi_level_encoder_fusion is not None
284
+ and self.multi_level_encoder_fusion == "stable_dense_fusion"
285
+ ):
286
+ output = self.memory_fusion_layer(torch.cat(output_list, dim=-1))
287
+
288
+ return output
detect_tools/upn/models/module/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .contrastive import ContrastiveAssign
2
+ from .mlp import MLP
3
+ from .nested_tensor import NestedTensor, nested_tensor_from_tensor_list
4
+
5
+ __all__ = ["MLP", "NestedTensor", "nested_tensor_from_tensor_list", "ContrastiveAssign"]
detect_tools/upn/models/module/contrastive.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class ContrastiveAssign(nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ cal_bias: nn.Module = None,
13
+ ) -> None:
14
+ """Lanuage-Image Contrastive Assignment used to calculate the similarity between
15
+ the text and the image.
16
+
17
+ Args:
18
+ cal_bias (nn.Module, optional): The bias used to calculate the similarity.
19
+ Defaults to None.
20
+ max_text_len (int, optional): The max length of the text. Defaults to 256.
21
+ """
22
+ super().__init__()
23
+ self.cal_bias = cal_bias
24
+
25
+ def forward(self, x: torch.Tensor, ref_dict: Dict):
26
+
27
+ y = ref_dict["encoded_ref_feature"]
28
+ res = x @ y.transpose(-1, -2)
29
+ return res
detect_tools/upn/models/module/mlp.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class MLP(nn.Module):
6
+ """ Very simple multi-layer perceptron (also called FFN)"""
7
+
8
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
9
+ super().__init__()
10
+ self.num_layers = num_layers
11
+ h = [hidden_dim] * (num_layers - 1)
12
+ self.layers = nn.ModuleList(
13
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
14
+
15
+ def forward(self, x):
16
+ for i, layer in enumerate(self.layers):
17
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
18
+ return x
detect_tools/upn/models/module/nested_tensor.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ import torchvision
5
+
6
+
7
+ class NestedTensor(object):
8
+ """Define a NestedTensor class
9
+
10
+ Args:
11
+ tensors (torch.Tensor): Tensor with shape [batch, C, H, W] or [C, H, W]
12
+ mask (Union[torch.Tensor, str]): mask with shape [batch, H, W] or [H, W]. If mask
13
+ is 'auto', it will be generated automatically by summing the tensor along
14
+ the channel dimension. Mask is used to indicate the padding area.
15
+ """
16
+
17
+ def __init__(
18
+ self, tensors: torch.Tensor, mask: Union[torch.Tensor, str] = "auto"
19
+ ) -> None:
20
+ self.tensors = tensors
21
+ self.mask = mask
22
+ if mask == "auto":
23
+ self.mask = torch.zeros_like(tensors).to(tensors.device)
24
+ if self.mask.dim() == 3:
25
+ self.mask = self.mask.sum(0).to(bool)
26
+ elif self.mask.dim() == 4:
27
+ self.mask = self.mask.sum(1).to(bool)
28
+ else:
29
+ raise ValueError(
30
+ "tensors dim must be 3 or 4 but {}({})".format(
31
+ self.tensors.dim(), self.tensors.shape
32
+ )
33
+ )
34
+
35
+ def imgsize(self) -> List[torch.Tensor]:
36
+ """get the img size of the tensor
37
+
38
+ Returns:
39
+ list[torch.Tensor]: list of tensor with shape [2] which is [H, W]
40
+ """
41
+ res = []
42
+ for i in range(self.tensors.shape[0]):
43
+ mask = self.mask[i]
44
+ maxH = (~mask).sum(0).max()
45
+ maxW = (~mask).sum(1).max()
46
+ res.append(torch.Tensor([maxH, maxW]))
47
+ return res
48
+
49
+ def to(self, device: torch.device):
50
+ """Move tensors and mask to the given device
51
+
52
+ Args:
53
+ device (torch.device): device to move
54
+
55
+ Returns:
56
+ NestedTensor: moved NestedTensor
57
+ """
58
+ cast_tensor = self.tensors.to(device)
59
+ mask = self.mask
60
+ if mask is not None:
61
+ assert mask is not None
62
+ cast_mask = mask.to(device)
63
+ else:
64
+ cast_mask = None
65
+ return NestedTensor(cast_tensor, cast_mask)
66
+
67
+ def to_img_list_single(
68
+ self, tensor: torch.Tensor, mask: torch.Tensor
69
+ ) -> torch.Tensor:
70
+ """remove the padding for one image
71
+
72
+ Args:
73
+ tensor (torch.Tensor): tensor with shape [C, H, W]
74
+ mask (torch.Tensor): mask with shape [H, W]
75
+
76
+ Returns:
77
+ torch.Tensor: tensor with shape [C, maxH, maxW]
78
+ """
79
+ assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(
80
+ tensor.dim()
81
+ )
82
+ maxH = (~mask).sum(0).max()
83
+ maxW = (~mask).sum(1).max()
84
+ img = tensor[:, :maxH, :maxW]
85
+ return img
86
+
87
+ def to_img_list(self) -> List[torch.Tensor]:
88
+ """remove the padding and convert to img list
89
+
90
+ Returns:
91
+ list[torch.Tensor]: list of tensor with shape [C, maxH, maxW]
92
+ """
93
+ if self.tensors.dim() == 3:
94
+ return self.to_img_list_single(self.tensors, self.mask)
95
+ else:
96
+ res = []
97
+ for i in range(self.tensors.shape[0]):
98
+ tensor_i = self.tensors[i]
99
+ mask_i = self.mask[i]
100
+ res.append(self.to_img_list_single(tensor_i, mask_i))
101
+ return res
102
+
103
+ @property
104
+ def device(self):
105
+ return self.tensors.device
106
+
107
+ def decompose(self):
108
+ return self.tensors, self.mask
109
+
110
+ def __repr__(self):
111
+ return str(self.tensors)
112
+
113
+ @property
114
+ def shape(self):
115
+ return {"tensors.shape": self.tensors.shape, "mask.shape": self.mask.shape}
116
+
117
+
118
+ def _max_by_axis(the_list):
119
+ # type: (List[List[int]]) -> List[int]
120
+ maxes = the_list[0]
121
+ for sublist in the_list[1:]:
122
+ for index, item in enumerate(sublist):
123
+ maxes[index] = max(maxes[index], item)
124
+ return maxes
125
+
126
+
127
+ def nested_tensor_from_tensor_list(
128
+ tensor_list: List[torch.Tensor], fixed_img_size=None
129
+ ):
130
+ if fixed_img_size is not None:
131
+ if isinstance(fixed_img_size, (list, tuple)):
132
+ assert (
133
+ len(fixed_img_size) == 2
134
+ ), "image size should be a tuple or list with two elements"
135
+ elif isinstance(fixed_img_size, int):
136
+ fixed_img_size = [fixed_img_size, fixed_img_size]
137
+
138
+ if tensor_list[0].ndim == 3:
139
+ if torchvision._is_tracing():
140
+ # nested_tensor_from_tensor_list() does not export well to ONNX
141
+ # call _onnx_nested_tensor_from_tensor_list() instead
142
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
143
+
144
+ # TODO make it support different-sized images
145
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
146
+
147
+ if fixed_img_size is not None:
148
+ c, orig_h, orig_w = max_size
149
+ assert (
150
+ orig_h <= fixed_img_size[0] and orig_w <= fixed_img_size[1]
151
+ ), f"{orig_h} {orig_w} the fixed output image size should be larger than original image"
152
+ max_size = [c, fixed_img_size[0], fixed_img_size[1]]
153
+
154
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
155
+ batch_shape = [len(tensor_list)] + max_size
156
+ b, c, h, w = batch_shape
157
+ dtype = tensor_list[0].dtype
158
+ device = tensor_list[0].device
159
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
160
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
161
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
162
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
163
+ m[: img.shape[1], : img.shape[2]] = False
164
+ else:
165
+ raise ValueError("not supported")
166
+ return NestedTensor(tensor, mask)
167
+
168
+
169
+ @torch.jit.unused
170
+ def _onnx_nested_tensor_from_tensor_list(
171
+ tensor_list: List[torch.Tensor],
172
+ ) -> NestedTensor:
173
+ max_size = []
174
+ for i in range(tensor_list[0].dim()):
175
+ max_size_i = torch.max(
176
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
177
+ ).to(torch.int64)
178
+ max_size.append(max_size_i)
179
+ max_size = tuple(max_size)
180
+
181
+ padded_imgs = []
182
+ padded_masks = []
183
+ for img in tensor_list:
184
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
185
+ padded_img = torch.nn.functional.pad(
186
+ img, (0, padding[2], 0, padding[1], 0, padding[0])
187
+ )
188
+ padded_imgs.append(padded_img)
189
+
190
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
191
+ padded_mask = torch.nn.functional.pad(
192
+ m, (0, padding[2], 0, padding[1]), "constant", 1
193
+ )
194
+ padded_masks.append(padded_mask.to(torch.bool))
195
+
196
+ tensor = torch.stack(padded_imgs)
197
+ mask = torch.stack(padded_masks)
198
+
199
+ return NestedTensor(tensor, mask=mask)
detect_tools/upn/models/utils/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .detr_utils import (
2
+ PositionEmbeddingLearned,
3
+ PositionEmbeddingSine,
4
+ PositionEmbeddingSineHW,
5
+ clean_state_dict,
6
+ gen_encoder_output_proposals,
7
+ gen_sineembed_for_position,
8
+ get_activation_fn,
9
+ get_clones,
10
+ inverse_sigmoid,
11
+ )
12
+
13
+ __all__ = [
14
+ "inverse_sigmoid",
15
+ "gen_encoder_output_proposals",
16
+ "get_clones",
17
+ "gen_sineembed_for_position",
18
+ "get_activation_fn",
19
+ "clean_state_dict",
20
+ "PositionEmbeddingSine",
21
+ "PositionEmbeddingSineHW",
22
+ "PositionEmbeddingLearned",
23
+ ]
detect_tools/upn/models/utils/detr_utils.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from detect_tools.upn import POS_EMBEDDINGS
11
+ from detect_tools.upn.models.module import NestedTensor
12
+
13
+
14
+ @POS_EMBEDDINGS.register_module()
15
+ class PositionEmbeddingSine(nn.Module):
16
+ """This is a more standard version of the position embedding, very similar to the one
17
+ used by the Attention is all you need paper, generalized to work on images.
18
+
19
+ Args:
20
+ num_pos_feats (int): The channel of positional embeddings.
21
+ temperature (float): The temperature used in positional embeddings.
22
+ normalize (bool): Whether to normalize the positional embeddings.
23
+ scale (float): The scale factor of positional embeddings.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ num_pos_feats: int = 64,
29
+ temperature: int = 10000,
30
+ normalize: bool = False,
31
+ scale: float = None,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.num_pos_feats = num_pos_feats
35
+ self.temperature = temperature
36
+ self.normalize = normalize
37
+ if scale is not None and normalize is False:
38
+ raise ValueError("normalize should be True if scale is passed")
39
+ if scale is None:
40
+ scale = 2 * math.pi
41
+ self.scale = scale
42
+
43
+ def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
44
+ """Forward function.
45
+
46
+ Args:
47
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
51
+ """
52
+ x = tensor_list.tensors
53
+ mask = tensor_list.mask
54
+ assert mask is not None
55
+ not_mask = ~mask
56
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
57
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
58
+ if self.normalize:
59
+ eps = 1e-6
60
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
61
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
62
+
63
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
64
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
65
+
66
+ pos_x = x_embed[:, :, :, None] / dim_t
67
+ pos_y = y_embed[:, :, :, None] / dim_t
68
+ pos_x = torch.stack(
69
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
70
+ ).flatten(3)
71
+ pos_y = torch.stack(
72
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
73
+ ).flatten(3)
74
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
75
+ return pos
76
+
77
+
78
+ @POS_EMBEDDINGS.register_module()
79
+ class PositionEmbeddingSineHW(nn.Module):
80
+ """This is a more standard version of the position embedding, very similar to the one
81
+ used by the Attention is all you need paper, generalized to work on images.
82
+
83
+ Args:
84
+ num_pos_feats (int): The channel of positional embeddings.
85
+ temperatureH (float): The temperature used in positional embeddings.
86
+ temperatureW (float): The temperature used in positional embeddings.
87
+ normalize (bool): Whether to normalize the positional embeddings.
88
+ scale (float): The scale factor of positional embeddings.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ num_pos_feats: int = 64,
94
+ temperatureH: int = 10000,
95
+ temperatureW: int = 10000,
96
+ normalize: bool = False,
97
+ scale: float = None,
98
+ ) -> None:
99
+ super().__init__()
100
+ self.num_pos_feats = num_pos_feats
101
+ self.temperatureH = temperatureH
102
+ self.temperatureW = temperatureW
103
+ self.normalize = normalize
104
+ if scale is not None and normalize is False:
105
+ raise ValueError("normalize should be True if scale is passed")
106
+ if scale is None:
107
+ scale = 2 * math.pi
108
+ self.scale = scale
109
+
110
+ def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
111
+ """Forward function.
112
+
113
+ Args:
114
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
115
+
116
+ Returns:
117
+ torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
118
+ """
119
+ x = tensor_list.tensors
120
+ mask = tensor_list.mask
121
+ assert mask is not None
122
+ not_mask = ~mask
123
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
124
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
125
+
126
+ if self.normalize:
127
+ eps = 1e-6
128
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
129
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
130
+
131
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
132
+ dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
133
+ pos_x = x_embed[:, :, :, None] / dim_tx
134
+
135
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
136
+ dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
137
+ pos_y = y_embed[:, :, :, None] / dim_ty
138
+
139
+ pos_x = torch.stack(
140
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
141
+ ).flatten(3)
142
+ pos_y = torch.stack(
143
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
144
+ ).flatten(3)
145
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
146
+
147
+ return pos
148
+
149
+
150
+ @POS_EMBEDDINGS.register_module()
151
+ class PositionEmbeddingLearned(nn.Module):
152
+ """Absolute pos embedding, learned.
153
+
154
+ Args:
155
+ num_pos_feats (int): The channel dimension of positional embeddings.
156
+ num_row (int): The number of rows of the input feature map.
157
+ num_col (int): The number of columns of the input feature map.
158
+ """
159
+
160
+ def __init__(
161
+ self, num_row: int = 50, num_col: int = 50, num_pos_feats: int = 256
162
+ ) -> None:
163
+ super().__init__()
164
+ self.row_embed = nn.Embedding(num_row, num_pos_feats)
165
+ self.col_embed = nn.Embedding(num_col, num_pos_feats)
166
+ self.reset_parameters()
167
+
168
+ def reset_parameters(self):
169
+ nn.init.uniform_(self.row_embed.weight)
170
+ nn.init.uniform_(self.col_embed.weight)
171
+
172
+ def forward(self, tensor_list: NestedTensor) -> torch.Tensor:
173
+ """Forward function.
174
+
175
+ Args:
176
+ tensor_list (NestedTensor): NestedTensor wrapping the input tensor.
177
+
178
+ Returns:
179
+ torch.Tensor: Positional encoding in shape (B, num_pos_feats*2, H, W)
180
+ """
181
+ x = tensor_list.tensors
182
+ h, w = x.shape[-2:]
183
+ i = torch.arange(w, device=x.device)
184
+ j = torch.arange(h, device=x.device)
185
+ x_emb = self.col_embed(i)
186
+ y_emb = self.row_embed(j)
187
+ pos = (
188
+ torch.cat(
189
+ [
190
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
191
+ y_emb.unsqueeze(1).repeat(1, w, 1),
192
+ ],
193
+ dim=-1,
194
+ )
195
+ .permute(2, 0, 1)
196
+ .unsqueeze(0)
197
+ .repeat(x.shape[0], 1, 1, 1)
198
+ )
199
+ return pos
200
+
201
+
202
+ def build_position_encoding(args):
203
+ N_steps = args.hidden_dim // 2
204
+ if args.position_embedding in ("v2", "sine"):
205
+ # TODO find a better way of exposing other arguments
206
+ position_embedding = PositionEmbeddingSineHW(
207
+ N_steps,
208
+ temperatureH=args.pe_temperatureH,
209
+ temperatureW=args.pe_temperatureW,
210
+ normalize=True,
211
+ )
212
+ elif args.position_embedding in ("v3", "learned"):
213
+ position_embedding = PositionEmbeddingLearned(N_steps)
214
+ else:
215
+ raise ValueError(f"not supported {args.position_embedding}")
216
+
217
+ return position_embedding
218
+
219
+
220
+ def clean_state_dict(state_dict):
221
+ new_state_dict = OrderedDict()
222
+ for k, v in state_dict.items():
223
+ if k[:7] == "module.":
224
+ k = k[7:] # remove `module.`
225
+ new_state_dict[k] = v
226
+ return new_state_dict
227
+
228
+
229
+ def get_activation_fn(activation: str, d_model: int = 256, batch_dim: int = 0):
230
+ """Return an activation function given a string
231
+
232
+ Args:
233
+ activation (str): activation function name
234
+ d_model (int, optional): d_model. Defaults to 256.
235
+ batch_dim (int, optional): batch dimension. Defaults to 0.
236
+
237
+ Returns:
238
+ F: activation function
239
+ """
240
+ if activation == "relu":
241
+ return F.relu
242
+ if activation == "gelu":
243
+ return F.gelu
244
+ if activation == "glu":
245
+ return F.glu
246
+ if activation == "prelu":
247
+ return nn.PReLU()
248
+ if activation == "selu":
249
+ return F.selu
250
+
251
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
252
+
253
+
254
+ def get_clones(module: nn.Module, N: int, layer_share: bool = False):
255
+ """Copy module N times
256
+
257
+ Args:
258
+ module (nn.Module): module to copy
259
+ N (int): number of copies
260
+ layer_share (bool, optional): share the same layer. If true, the modules will
261
+ share the same memory. Defaults to False.
262
+ """
263
+ if layer_share:
264
+ return nn.ModuleList([module for _ in range(N)])
265
+ else:
266
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
267
+
268
+
269
+ def inverse_sigmoid(x, eps=1e-3):
270
+ x = x.clamp(min=0, max=1)
271
+ x1 = x.clamp(min=eps)
272
+ x2 = (1 - x).clamp(min=eps)
273
+ return torch.log(x1 / x2)
274
+
275
+
276
+ def gen_sineembed_for_position(pos_tensor):
277
+ # n_query, bs, _ = pos_tensor.size()
278
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
279
+ scale = 2 * math.pi
280
+ dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
281
+ dim_t = 10000 ** (2 * (dim_t // 2) / 128)
282
+ x_embed = pos_tensor[:, :, 0] * scale
283
+ y_embed = pos_tensor[:, :, 1] * scale
284
+ pos_x = x_embed[:, :, None] / dim_t
285
+ pos_y = y_embed[:, :, None] / dim_t
286
+ pos_x = torch.stack(
287
+ (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
288
+ ).flatten(2)
289
+ pos_y = torch.stack(
290
+ (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
291
+ ).flatten(2)
292
+ if pos_tensor.size(-1) == 2:
293
+ pos = torch.cat((pos_y, pos_x), dim=2)
294
+ elif pos_tensor.size(-1) == 4:
295
+ w_embed = pos_tensor[:, :, 2] * scale
296
+ pos_w = w_embed[:, :, None] / dim_t
297
+ pos_w = torch.stack(
298
+ (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
299
+ ).flatten(2)
300
+
301
+ h_embed = pos_tensor[:, :, 3] * scale
302
+ pos_h = h_embed[:, :, None] / dim_t
303
+ pos_h = torch.stack(
304
+ (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
305
+ ).flatten(2)
306
+
307
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
308
+ else:
309
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
310
+ return pos
311
+
312
+
313
+ def get_sine_pos_embed(
314
+ pos_tensor: torch.Tensor,
315
+ num_pos_feats: int = 128,
316
+ temperature: int = 10000,
317
+ exchange_xy: bool = True,
318
+ ):
319
+ """generate sine position embedding from a position tensor
320
+ Args:
321
+ pos_tensor (torch.Tensor): shape: [..., n].
322
+ num_pos_feats (int): projected shape for each float in the tensor.
323
+ temperature (int): temperature in the sine/cosine function.
324
+ exchange_xy (bool, optional): exchange pos x and pos y. \
325
+ For example, input tensor is [x,y], the results will be [pos(y), pos(x)]. Defaults to True.
326
+ Returns:
327
+ pos_embed (torch.Tensor): shape: [..., n*num_pos_feats].
328
+ """
329
+ scale = 2 * math.pi
330
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
331
+ dim_t = temperature ** (
332
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats
333
+ )
334
+
335
+ def sine_func(x: torch.Tensor):
336
+ sin_x = x * scale / dim_t
337
+ sin_x = torch.stack(
338
+ (sin_x[..., 0::2].sin(), sin_x[..., 1::2].cos()), dim=3
339
+ ).flatten(2)
340
+ return sin_x
341
+
342
+ pos_res = [
343
+ sine_func(x) for x in pos_tensor.split([1] * pos_tensor.shape[-1], dim=-1)
344
+ ]
345
+ if exchange_xy:
346
+ pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
347
+ pos_res = torch.cat(pos_res, dim=-1)
348
+ return pos_res
349
+
350
+
351
+ def gen_encoder_output_proposals(
352
+ memory: torch.Tensor,
353
+ memory_padding_mask: torch.Tensor,
354
+ spatial_shapes: torch.Tensor,
355
+ learnedwh=None,
356
+ ):
357
+ """
358
+ Input:
359
+ - memory: bs, \sum{hw}, d_model
360
+ - memory_padding_mask: bs, \sum{hw}
361
+ - spatial_shapes: nlevel, 2
362
+ - learnedwh: 2
363
+ Output:
364
+ - output_memory: bs, \sum{hw}, d_model
365
+ - output_proposals: bs, \sum{hw}, 4
366
+ """
367
+ N_, S_, C_ = memory.shape
368
+ base_scale = 4.0
369
+ proposals = []
370
+ _cur = 0
371
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
372
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
373
+ N_, H_, W_, 1
374
+ )
375
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
376
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
377
+ grid_y, grid_x = torch.meshgrid(
378
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
379
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
380
+ )
381
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
382
+
383
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
384
+ N_, 1, 1, 2
385
+ )
386
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
387
+
388
+ if learnedwh is not None:
389
+ wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl)
390
+ else:
391
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
392
+
393
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
394
+ proposals.append(proposal)
395
+ _cur += H_ * W_
396
+
397
+ output_proposals = torch.cat(proposals, 1)
398
+ output_proposals_valid = (
399
+ (output_proposals > 0.01) & (output_proposals < 0.99)
400
+ ).all(-1, keepdim=True)
401
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
402
+ output_proposals = output_proposals.masked_fill(
403
+ memory_padding_mask.unsqueeze(-1), float("inf")
404
+ )
405
+ output_proposals = output_proposals.masked_fill(
406
+ ~output_proposals_valid, float("inf")
407
+ )
408
+
409
+ output_memory = memory
410
+ output_memory = output_memory.masked_fill(
411
+ memory_padding_mask.unsqueeze(-1), float(0)
412
+ )
413
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
414
+
415
+ return output_memory, output_proposals
detect_tools/upn/ops/functions/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn_func import MSDeformAttnFunction
10
+
detect_tools/upn/ops/functions/ms_deform_attn_func.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.autograd import Function
16
+ from torch.autograd.function import once_differentiable
17
+
18
+ import MultiScaleDeformableAttention as MSDA
19
+
20
+
21
+ class MSDeformAttnFunction(Function):
22
+ @staticmethod
23
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24
+ ctx.im2col_step = im2col_step
25
+ output = MSDA.ms_deform_attn_forward(
26
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28
+ return output
29
+
30
+ @staticmethod
31
+ @once_differentiable
32
+ def backward(ctx, grad_output):
33
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34
+ grad_value, grad_sampling_loc, grad_attn_weight = \
35
+ MSDA.ms_deform_attn_backward(
36
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37
+
38
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39
+
40
+
41
+ def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42
+ # for debug and test only,
43
+ # need to use cuda version instead
44
+ N_, S_, M_, D_ = value.shape
45
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47
+ sampling_grids = 2 * sampling_locations - 1
48
+ sampling_value_list = []
49
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54
+ # N_*M_, D_, Lq_, P_
55
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56
+ mode='bilinear', padding_mode='zeros', align_corners=False)
57
+ sampling_value_list.append(sampling_value_l_)
58
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61
+ return output.transpose(1, 2).contiguous()
detect_tools/upn/ops/modules/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from .ms_deform_attn import MSDeformAttn
detect_tools/upn/ops/modules/ms_deform_attn.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import warnings
14
+ import math, os
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+ from torch.nn.init import xavier_uniform_, constant_
20
+
21
+ try:
22
+ from ..functions import MSDeformAttnFunction
23
+ except:
24
+ warnings.warn("Failed to import MSDeformAttnFunction.")
25
+
26
+
27
+ def _is_power_of_2(n):
28
+ if (not isinstance(n, int)) or (n < 0):
29
+ raise ValueError(
30
+ "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
31
+ )
32
+ return (n & (n - 1) == 0) and n != 0
33
+
34
+
35
+ class MSDeformAttn(nn.Module):
36
+ def __init__(
37
+ self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False
38
+ ):
39
+ """
40
+ Multi-Scale Deformable Attention Module
41
+ :param d_model hidden dimension
42
+ :param n_levels number of feature levels
43
+ :param n_heads number of attention heads
44
+ :param n_points number of sampling points per attention head per feature level
45
+ """
46
+ super().__init__()
47
+ if d_model % n_heads != 0:
48
+ raise ValueError(
49
+ "d_model must be divisible by n_heads, but got {} and {}".format(
50
+ d_model, n_heads
51
+ )
52
+ )
53
+ _d_per_head = d_model // n_heads
54
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
55
+ if not _is_power_of_2(_d_per_head):
56
+ warnings.warn(
57
+ "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
58
+ "which is more efficient in our CUDA implementation."
59
+ )
60
+
61
+ self.im2col_step = 64
62
+
63
+ self.d_model = d_model
64
+ self.n_levels = n_levels
65
+ self.n_heads = n_heads
66
+ self.n_points = n_points
67
+
68
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
69
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
70
+ self.value_proj = nn.Linear(d_model, d_model)
71
+ self.output_proj = nn.Linear(d_model, d_model)
72
+
73
+ self.use_4D_normalizer = use_4D_normalizer
74
+
75
+ self._reset_parameters()
76
+
77
+ def _reset_parameters(self):
78
+ constant_(self.sampling_offsets.weight.data, 0.0)
79
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
80
+ 2.0 * math.pi / self.n_heads
81
+ )
82
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
83
+ grid_init = (
84
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
85
+ .view(self.n_heads, 1, 1, 2)
86
+ .repeat(1, self.n_levels, self.n_points, 1)
87
+ )
88
+ for i in range(self.n_points):
89
+ grid_init[:, :, i, :] *= i + 1
90
+ with torch.no_grad():
91
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
92
+ constant_(self.attention_weights.weight.data, 0.0)
93
+ constant_(self.attention_weights.bias.data, 0.0)
94
+ xavier_uniform_(self.value_proj.weight.data)
95
+ constant_(self.value_proj.bias.data, 0.0)
96
+ xavier_uniform_(self.output_proj.weight.data)
97
+ constant_(self.output_proj.bias.data, 0.0)
98
+
99
+ @torch.cuda.amp.autocast(enabled=False)
100
+ def forward(
101
+ self,
102
+ query,
103
+ reference_points,
104
+ input_flatten,
105
+ input_spatial_shapes,
106
+ input_level_start_index,
107
+ input_padding_mask=None,
108
+ ):
109
+ """
110
+ :param query (N, Length_{query}, C)
111
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
112
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
113
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
114
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
115
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
116
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
117
+
118
+ :return output (N, Length_{query}, C)
119
+ """
120
+ N, Len_q, _ = query.shape
121
+ N, Len_in, _ = input_flatten.shape
122
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
123
+
124
+ value = self.value_proj(input_flatten)
125
+ if input_padding_mask is not None:
126
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
127
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
128
+ sampling_offsets = self.sampling_offsets(query).view(
129
+ N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
130
+ )
131
+ attention_weights = self.attention_weights(query).view(
132
+ N, Len_q, self.n_heads, self.n_levels * self.n_points
133
+ )
134
+ attention_weights = F.softmax(attention_weights, -1).view(
135
+ N, Len_q, self.n_heads, self.n_levels, self.n_points
136
+ )
137
+ # N, Len_q, n_heads, n_levels, n_points, 2
138
+
139
+ # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
140
+ # import ipdb; ipdb.set_trace()
141
+
142
+ if reference_points.shape[-1] == 2:
143
+ offset_normalizer = torch.stack(
144
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
145
+ )
146
+ sampling_locations = (
147
+ reference_points[:, :, None, :, None, :]
148
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
149
+ )
150
+ elif reference_points.shape[-1] == 4:
151
+ if self.use_4D_normalizer:
152
+ offset_normalizer = torch.stack(
153
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
154
+ )
155
+ sampling_locations = (
156
+ reference_points[:, :, None, :, None, :2]
157
+ + sampling_offsets
158
+ / offset_normalizer[None, None, None, :, None, :]
159
+ * reference_points[:, :, None, :, None, 2:]
160
+ * 0.5
161
+ )
162
+ else:
163
+ sampling_locations = (
164
+ reference_points[:, :, None, :, None, :2]
165
+ + sampling_offsets
166
+ / self.n_points
167
+ * reference_points[:, :, None, :, None, 2:]
168
+ * 0.5
169
+ )
170
+ else:
171
+ raise ValueError(
172
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
173
+ reference_points.shape[-1]
174
+ )
175
+ )
176
+
177
+ # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
178
+ # import ipdb; ipdb.set_trace()
179
+
180
+ # for amp
181
+ if value.dtype == torch.float16:
182
+ # for mixed precision
183
+ output = MSDeformAttnFunction.apply(
184
+ value.to(torch.float32),
185
+ input_spatial_shapes,
186
+ input_level_start_index,
187
+ sampling_locations.to(torch.float32),
188
+ attention_weights,
189
+ self.im2col_step,
190
+ )
191
+ output = output.to(torch.float16)
192
+ output = self.output_proj(output)
193
+ return output
194
+
195
+ output = MSDeformAttnFunction.apply(
196
+ value,
197
+ input_spatial_shapes,
198
+ input_level_start_index,
199
+ sampling_locations,
200
+ attention_weights,
201
+ self.im2col_step,
202
+ )
203
+ output = self.output_proj(output)
204
+ return output
detect_tools/upn/ops/modules/ms_deform_attn_key_aware.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import warnings
14
+ import math, os
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+ from torch.nn.init import xavier_uniform_, constant_
20
+
21
+ try:
22
+ from ..functions import MSDeformAttnFunction
23
+ except:
24
+ warnings.warn('Failed to import MSDeformAttnFunction.')
25
+
26
+
27
+ def _is_power_of_2(n):
28
+ if (not isinstance(n, int)) or (n < 0):
29
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
30
+ return (n & (n-1) == 0) and n != 0
31
+
32
+
33
+ class MSDeformAttn(nn.Module):
34
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, use_4D_normalizer=False):
35
+ """
36
+ Multi-Scale Deformable Attention Module
37
+ :param d_model hidden dimension
38
+ :param n_levels number of feature levels
39
+ :param n_heads number of attention heads
40
+ :param n_points number of sampling points per attention head per feature level
41
+ """
42
+ super().__init__()
43
+ if d_model % n_heads != 0:
44
+ raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
45
+ _d_per_head = d_model // n_heads
46
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
47
+ if not _is_power_of_2(_d_per_head):
48
+ warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
49
+ "which is more efficient in our CUDA implementation.")
50
+
51
+ self.im2col_step = 64
52
+
53
+ self.d_model = d_model
54
+ self.n_levels = n_levels
55
+ self.n_heads = n_heads
56
+ self.n_points = n_points
57
+
58
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
59
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
60
+ self.value_proj = nn.Linear(d_model, d_model)
61
+ self.output_proj = nn.Linear(d_model, d_model)
62
+
63
+ self.use_4D_normalizer = use_4D_normalizer
64
+
65
+ self._reset_parameters()
66
+
67
+ def _reset_parameters(self):
68
+ constant_(self.sampling_offsets.weight.data, 0.)
69
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
70
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
71
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
72
+ for i in range(self.n_points):
73
+ grid_init[:, :, i, :] *= i + 1
74
+ with torch.no_grad():
75
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
76
+ constant_(self.attention_weights.weight.data, 0.)
77
+ constant_(self.attention_weights.bias.data, 0.)
78
+ xavier_uniform_(self.value_proj.weight.data)
79
+ constant_(self.value_proj.bias.data, 0.)
80
+ xavier_uniform_(self.output_proj.weight.data)
81
+ constant_(self.output_proj.bias.data, 0.)
82
+
83
+ def forward(self, query, key, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
84
+ """
85
+ :param query (N, Length_{query}, C)
86
+ :param key (N, 1, C)
87
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
88
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
89
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
90
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
91
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
92
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
93
+
94
+ :return output (N, Length_{query}, C)
95
+ """
96
+ N, Len_q, _ = query.shape
97
+ N, Len_in, _ = input_flatten.shape
98
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
99
+
100
+ value = self.value_proj(input_flatten)
101
+ if input_padding_mask is not None:
102
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
103
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
104
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
105
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
106
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
107
+ # N, Len_q, n_heads, n_levels, n_points, 2
108
+
109
+ # if os.environ.get('IPDB_DEBUG_SHILONG', False) == 'INFO':
110
+ # import ipdb; ipdb.set_trace()
111
+
112
+ if reference_points.shape[-1] == 2:
113
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
114
+ sampling_locations = reference_points[:, :, None, :, None, :] \
115
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
116
+ elif reference_points.shape[-1] == 4:
117
+ if self.use_4D_normalizer:
118
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
119
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
120
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] * reference_points[:, :, None, :, None, 2:] * 0.5
121
+ else:
122
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
123
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
124
+ else:
125
+ raise ValueError(
126
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
127
+ output = MSDeformAttnFunction.apply(
128
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
129
+ output = self.output_proj(output)
130
+ return output
detect_tools/upn/ops/setup.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ import os
10
+ import glob
11
+
12
+ import torch
13
+
14
+ from torch.utils.cpp_extension import CUDA_HOME
15
+ from torch.utils.cpp_extension import CppExtension
16
+ from torch.utils.cpp_extension import CUDAExtension
17
+
18
+ from setuptools import find_packages
19
+ from setuptools import setup
20
+
21
+ requirements = ["torch", "torchvision"]
22
+
23
+ def get_extensions():
24
+ this_dir = os.path.dirname(os.path.abspath(__file__))
25
+ extensions_dir = os.path.join(this_dir, "src")
26
+
27
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30
+
31
+ sources = main_file + source_cpu
32
+ extension = CppExtension
33
+ extra_compile_args = {"cxx": []}
34
+ define_macros = []
35
+
36
+ # import ipdb; ipdb.set_trace()
37
+
38
+ if torch.cuda.is_available() and CUDA_HOME is not None:
39
+ extension = CUDAExtension
40
+ sources += source_cuda
41
+ define_macros += [("WITH_CUDA", None)]
42
+ extra_compile_args["nvcc"] = [
43
+ "-DCUDA_HAS_FP16=1",
44
+ "-D__CUDA_NO_HALF_OPERATORS__",
45
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
46
+ "-D__CUDA_NO_HALF2_OPERATORS__",
47
+ ]
48
+ else:
49
+ raise NotImplementedError('Cuda is not availabel')
50
+
51
+ sources = [os.path.join(extensions_dir, s) for s in sources]
52
+ include_dirs = [extensions_dir]
53
+ ext_modules = [
54
+ extension(
55
+ "MultiScaleDeformableAttention",
56
+ sources,
57
+ include_dirs=include_dirs,
58
+ define_macros=define_macros,
59
+ extra_compile_args=extra_compile_args,
60
+ )
61
+ ]
62
+ return ext_modules
63
+
64
+ setup(
65
+ name="MultiScaleDeformableAttention",
66
+ version="1.0",
67
+ author="Weijie Su",
68
+ url="https://github.com/fundamentalvision/Deformable-DETR",
69
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
70
+ packages=find_packages(exclude=("configs", "tests",)),
71
+ ext_modules=get_extensions(),
72
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
73
+ )
detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.cpp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+
13
+ #include <ATen/ATen.h>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+
17
+ at::Tensor
18
+ ms_deform_attn_cpu_forward(
19
+ const at::Tensor &value,
20
+ const at::Tensor &spatial_shapes,
21
+ const at::Tensor &level_start_index,
22
+ const at::Tensor &sampling_loc,
23
+ const at::Tensor &attn_weight,
24
+ const int im2col_step)
25
+ {
26
+ AT_ERROR("Not implement on cpu");
27
+ }
28
+
29
+ std::vector<at::Tensor>
30
+ ms_deform_attn_cpu_backward(
31
+ const at::Tensor &value,
32
+ const at::Tensor &spatial_shapes,
33
+ const at::Tensor &level_start_index,
34
+ const at::Tensor &sampling_loc,
35
+ const at::Tensor &attn_weight,
36
+ const at::Tensor &grad_output,
37
+ const int im2col_step)
38
+ {
39
+ AT_ERROR("Not implement on cpu");
40
+ }
41
+
detect_tools/upn/ops/src/cpu/ms_deform_attn_cpu.h ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor
15
+ ms_deform_attn_cpu_forward(
16
+ const at::Tensor &value,
17
+ const at::Tensor &spatial_shapes,
18
+ const at::Tensor &level_start_index,
19
+ const at::Tensor &sampling_loc,
20
+ const at::Tensor &attn_weight,
21
+ const int im2col_step);
22
+
23
+ std::vector<at::Tensor>
24
+ ms_deform_attn_cpu_backward(
25
+ const at::Tensor &value,
26
+ const at::Tensor &spatial_shapes,
27
+ const at::Tensor &level_start_index,
28
+ const at::Tensor &sampling_loc,
29
+ const at::Tensor &attn_weight,
30
+ const at::Tensor &grad_output,
31
+ const int im2col_step);
32
+
33
+
detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.cu ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include <vector>
12
+ #include "cuda/ms_deform_im2col_cuda.cuh"
13
+
14
+ #include <ATen/ATen.h>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+
20
+ at::Tensor ms_deform_attn_cuda_forward(
21
+ const at::Tensor &value,
22
+ const at::Tensor &spatial_shapes,
23
+ const at::Tensor &level_start_index,
24
+ const at::Tensor &sampling_loc,
25
+ const at::Tensor &attn_weight,
26
+ const int im2col_step)
27
+ {
28
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33
+
34
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
35
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
36
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
37
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
38
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
39
+
40
+ const int batch = value.size(0);
41
+ const int spatial_size = value.size(1);
42
+ const int num_heads = value.size(2);
43
+ const int channels = value.size(3);
44
+
45
+ const int num_levels = spatial_shapes.size(0);
46
+
47
+ const int num_query = sampling_loc.size(1);
48
+ const int num_point = sampling_loc.size(4);
49
+
50
+ const int im2col_step_ = std::min(batch, im2col_step);
51
+
52
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53
+
54
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55
+
56
+ const int batch_n = im2col_step_;
57
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58
+ auto per_value_size = spatial_size * num_heads * channels;
59
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61
+ for (int n = 0; n < batch/im2col_step_; ++n)
62
+ {
63
+ auto columns = output_n.select(0, n);
64
+ AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
65
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
67
+ spatial_shapes.data_ptr<int64_t>(),
68
+ level_start_index.data_ptr<int64_t>(),
69
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
70
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
71
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72
+ columns.data_ptr<scalar_t>());
73
+
74
+ }));
75
+ }
76
+
77
+ output = output.view({batch, num_query, num_heads*channels});
78
+
79
+ return output;
80
+ }
81
+
82
+
83
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
84
+ const at::Tensor &value,
85
+ const at::Tensor &spatial_shapes,
86
+ const at::Tensor &level_start_index,
87
+ const at::Tensor &sampling_loc,
88
+ const at::Tensor &attn_weight,
89
+ const at::Tensor &grad_output,
90
+ const int im2col_step)
91
+ {
92
+
93
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99
+
100
+ AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
101
+ AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
102
+ AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
103
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
104
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
105
+ AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
106
+
107
+ const int batch = value.size(0);
108
+ const int spatial_size = value.size(1);
109
+ const int num_heads = value.size(2);
110
+ const int channels = value.size(3);
111
+
112
+ const int num_levels = spatial_shapes.size(0);
113
+
114
+ const int num_query = sampling_loc.size(1);
115
+ const int num_point = sampling_loc.size(4);
116
+
117
+ const int im2col_step_ = std::min(batch, im2col_step);
118
+
119
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120
+
121
+ auto grad_value = at::zeros_like(value);
122
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
123
+ auto grad_attn_weight = at::zeros_like(attn_weight);
124
+
125
+ const int batch_n = im2col_step_;
126
+ auto per_value_size = spatial_size * num_heads * channels;
127
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130
+
131
+ for (int n = 0; n < batch/im2col_step_; ++n)
132
+ {
133
+ auto grad_output_g = grad_output_n.select(0, n);
134
+ AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
135
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136
+ grad_output_g.data_ptr<scalar_t>(),
137
+ value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
138
+ spatial_shapes.data_ptr<int64_t>(),
139
+ level_start_index.data_ptr<int64_t>(),
140
+ sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
141
+ attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
142
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143
+ grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
144
+ grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
145
+ grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
146
+
147
+ }));
148
+ }
149
+
150
+ return {
151
+ grad_value, grad_sampling_loc, grad_attn_weight
152
+ };
153
+ }
detect_tools/upn/ops/src/cuda/ms_deform_attn_cuda.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+ #include <torch/extension.h>
13
+
14
+ at::Tensor ms_deform_attn_cuda_forward(
15
+ const at::Tensor &value,
16
+ const at::Tensor &spatial_shapes,
17
+ const at::Tensor &level_start_index,
18
+ const at::Tensor &sampling_loc,
19
+ const at::Tensor &attn_weight,
20
+ const int im2col_step);
21
+
22
+ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
23
+ const at::Tensor &value,
24
+ const at::Tensor &spatial_shapes,
25
+ const at::Tensor &level_start_index,
26
+ const at::Tensor &sampling_loc,
27
+ const at::Tensor &attn_weight,
28
+ const at::Tensor &grad_output,
29
+ const int im2col_step);
30
+
detect_tools/upn/ops/src/cuda/ms_deform_im2col_cuda.cuh ADDED
@@ -0,0 +1,1327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************
7
+ * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
+ * Copyright (c) 2018 Microsoft
9
+ **************************************************************************
10
+ */
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+ #include <cstring>
15
+
16
+ #include <ATen/ATen.h>
17
+ #include <ATen/cuda/CUDAContext.h>
18
+
19
+ #include <THC/THCAtomics.cuh>
20
+
21
+ #define CUDA_KERNEL_LOOP(i, n) \
22
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
+ i < (n); \
24
+ i += blockDim.x * gridDim.x)
25
+
26
+ const int CUDA_NUM_THREADS = 1024;
27
+ inline int GET_BLOCKS(const int N, const int num_threads)
28
+ {
29
+ return (N + num_threads - 1) / num_threads;
30
+ }
31
+
32
+
33
+ template <typename scalar_t>
34
+ __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
+ const int &height, const int &width, const int &nheads, const int &channels,
36
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const scalar_t lh = h - h_low;
44
+ const scalar_t lw = w - w_low;
45
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = nheads * channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+ const int base_ptr = m * channels + c;
54
+
55
+ scalar_t v1 = 0;
56
+ if (h_low >= 0 && w_low >= 0)
57
+ {
58
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
+ v1 = bottom_data[ptr1];
60
+ }
61
+ scalar_t v2 = 0;
62
+ if (h_low >= 0 && w_high <= width - 1)
63
+ {
64
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
+ v2 = bottom_data[ptr2];
66
+ }
67
+ scalar_t v3 = 0;
68
+ if (h_high <= height - 1 && w_low >= 0)
69
+ {
70
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
+ v3 = bottom_data[ptr3];
72
+ }
73
+ scalar_t v4 = 0;
74
+ if (h_high <= height - 1 && w_high <= width - 1)
75
+ {
76
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
+ v4 = bottom_data[ptr4];
78
+ }
79
+
80
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
+
82
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
+ return val;
84
+ }
85
+
86
+
87
+ template <typename scalar_t>
88
+ __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
+ const int &height, const int &width, const int &nheads, const int &channels,
90
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
+ const scalar_t &top_grad,
92
+ const scalar_t &attn_weight,
93
+ scalar_t* &grad_value,
94
+ scalar_t* grad_sampling_loc,
95
+ scalar_t* grad_attn_weight)
96
+ {
97
+ const int h_low = floor(h);
98
+ const int w_low = floor(w);
99
+ const int h_high = h_low + 1;
100
+ const int w_high = w_low + 1;
101
+
102
+ const scalar_t lh = h - h_low;
103
+ const scalar_t lw = w - w_low;
104
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
105
+
106
+ const int w_stride = nheads * channels;
107
+ const int h_stride = width * w_stride;
108
+ const int h_low_ptr_offset = h_low * h_stride;
109
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
+ const int w_low_ptr_offset = w_low * w_stride;
111
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
+ const int base_ptr = m * channels + c;
113
+
114
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
+ const scalar_t top_grad_value = top_grad * attn_weight;
116
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
+
118
+ scalar_t v1 = 0;
119
+ if (h_low >= 0 && w_low >= 0)
120
+ {
121
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
+ v1 = bottom_data[ptr1];
123
+ grad_h_weight -= hw * v1;
124
+ grad_w_weight -= hh * v1;
125
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
+ }
127
+ scalar_t v2 = 0;
128
+ if (h_low >= 0 && w_high <= width - 1)
129
+ {
130
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
+ v2 = bottom_data[ptr2];
132
+ grad_h_weight -= lw * v2;
133
+ grad_w_weight += hh * v2;
134
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
+ }
136
+ scalar_t v3 = 0;
137
+ if (h_high <= height - 1 && w_low >= 0)
138
+ {
139
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
+ v3 = bottom_data[ptr3];
141
+ grad_h_weight += hw * v3;
142
+ grad_w_weight -= lh * v3;
143
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
+ }
145
+ scalar_t v4 = 0;
146
+ if (h_high <= height - 1 && w_high <= width - 1)
147
+ {
148
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
+ v4 = bottom_data[ptr4];
150
+ grad_h_weight += lw * v4;
151
+ grad_w_weight += lh * v4;
152
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
+ }
154
+
155
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
+ *grad_attn_weight = top_grad * val;
157
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
+ }
160
+
161
+
162
+ template <typename scalar_t>
163
+ __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
+ const int &height, const int &width, const int &nheads, const int &channels,
165
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
+ const scalar_t &top_grad,
167
+ const scalar_t &attn_weight,
168
+ scalar_t* &grad_value,
169
+ scalar_t* grad_sampling_loc,
170
+ scalar_t* grad_attn_weight)
171
+ {
172
+ const int h_low = floor(h);
173
+ const int w_low = floor(w);
174
+ const int h_high = h_low + 1;
175
+ const int w_high = w_low + 1;
176
+
177
+ const scalar_t lh = h - h_low;
178
+ const scalar_t lw = w - w_low;
179
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
180
+
181
+ const int w_stride = nheads * channels;
182
+ const int h_stride = width * w_stride;
183
+ const int h_low_ptr_offset = h_low * h_stride;
184
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
+ const int w_low_ptr_offset = w_low * w_stride;
186
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
+ const int base_ptr = m * channels + c;
188
+
189
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
+ const scalar_t top_grad_value = top_grad * attn_weight;
191
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
+
193
+ scalar_t v1 = 0;
194
+ if (h_low >= 0 && w_low >= 0)
195
+ {
196
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
+ v1 = bottom_data[ptr1];
198
+ grad_h_weight -= hw * v1;
199
+ grad_w_weight -= hh * v1;
200
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
+ }
202
+ scalar_t v2 = 0;
203
+ if (h_low >= 0 && w_high <= width - 1)
204
+ {
205
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
+ v2 = bottom_data[ptr2];
207
+ grad_h_weight -= lw * v2;
208
+ grad_w_weight += hh * v2;
209
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
+ }
211
+ scalar_t v3 = 0;
212
+ if (h_high <= height - 1 && w_low >= 0)
213
+ {
214
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
+ v3 = bottom_data[ptr3];
216
+ grad_h_weight += hw * v3;
217
+ grad_w_weight -= lh * v3;
218
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
+ }
220
+ scalar_t v4 = 0;
221
+ if (h_high <= height - 1 && w_high <= width - 1)
222
+ {
223
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
+ v4 = bottom_data[ptr4];
225
+ grad_h_weight += lw * v4;
226
+ grad_w_weight += lh * v4;
227
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
+ }
229
+
230
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
+ atomicAdd(grad_attn_weight, top_grad * val);
232
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
+ }
235
+
236
+
237
+ template <typename scalar_t>
238
+ __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
+ const scalar_t *data_value,
240
+ const int64_t *data_spatial_shapes,
241
+ const int64_t *data_level_start_index,
242
+ const scalar_t *data_sampling_loc,
243
+ const scalar_t *data_attn_weight,
244
+ const int batch_size,
245
+ const int spatial_size,
246
+ const int num_heads,
247
+ const int channels,
248
+ const int num_levels,
249
+ const int num_query,
250
+ const int num_point,
251
+ scalar_t *data_col)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, n)
254
+ {
255
+ int _temp = index;
256
+ const int c_col = _temp % channels;
257
+ _temp /= channels;
258
+ const int sampling_index = _temp;
259
+ const int m_col = _temp % num_heads;
260
+ _temp /= num_heads;
261
+ const int q_col = _temp % num_query;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ scalar_t *data_col_ptr = data_col + index;
266
+ int data_weight_ptr = sampling_index * num_levels * num_point;
267
+ int data_loc_w_ptr = data_weight_ptr << 1;
268
+ const int qid_stride = num_heads * channels;
269
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
+ scalar_t col = 0;
271
+
272
+ for (int l_col=0; l_col < num_levels; ++l_col)
273
+ {
274
+ const int level_start_id = data_level_start_index[l_col];
275
+ const int spatial_h_ptr = l_col << 1;
276
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
+ for (int p_col=0; p_col < num_point; ++p_col)
280
+ {
281
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
284
+
285
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
286
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
287
+
288
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
+ {
290
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
+ }
292
+
293
+ data_weight_ptr += 1;
294
+ data_loc_w_ptr += 2;
295
+ }
296
+ }
297
+ *data_col_ptr = col;
298
+ }
299
+ }
300
+
301
+ template <typename scalar_t, unsigned int blockSize>
302
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
+ const scalar_t *grad_col,
304
+ const scalar_t *data_value,
305
+ const int64_t *data_spatial_shapes,
306
+ const int64_t *data_level_start_index,
307
+ const scalar_t *data_sampling_loc,
308
+ const scalar_t *data_attn_weight,
309
+ const int batch_size,
310
+ const int spatial_size,
311
+ const int num_heads,
312
+ const int channels,
313
+ const int num_levels,
314
+ const int num_query,
315
+ const int num_point,
316
+ scalar_t *grad_value,
317
+ scalar_t *grad_sampling_loc,
318
+ scalar_t *grad_attn_weight)
319
+ {
320
+ CUDA_KERNEL_LOOP(index, n)
321
+ {
322
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
+ unsigned int tid = threadIdx.x;
325
+ int _temp = index;
326
+ const int c_col = _temp % channels;
327
+ _temp /= channels;
328
+ const int sampling_index = _temp;
329
+ const int m_col = _temp % num_heads;
330
+ _temp /= num_heads;
331
+ const int q_col = _temp % num_query;
332
+ _temp /= num_query;
333
+ const int b_col = _temp;
334
+
335
+ const scalar_t top_grad = grad_col[index];
336
+
337
+ int data_weight_ptr = sampling_index * num_levels * num_point;
338
+ int data_loc_w_ptr = data_weight_ptr << 1;
339
+ const int grad_sampling_ptr = data_weight_ptr;
340
+ grad_sampling_loc += grad_sampling_ptr << 1;
341
+ grad_attn_weight += grad_sampling_ptr;
342
+ const int grad_weight_stride = 1;
343
+ const int grad_loc_stride = 2;
344
+ const int qid_stride = num_heads * channels;
345
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
+
347
+ for (int l_col=0; l_col < num_levels; ++l_col)
348
+ {
349
+ const int level_start_id = data_level_start_index[l_col];
350
+ const int spatial_h_ptr = l_col << 1;
351
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
+
357
+ for (int p_col=0; p_col < num_point; ++p_col)
358
+ {
359
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
362
+
363
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
364
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
365
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
+ *(cache_grad_attn_weight+threadIdx.x)=0;
368
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
+ {
370
+ ms_deform_attn_col2im_bilinear(
371
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
+ top_grad, weight, grad_value_ptr,
373
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
+ }
375
+
376
+ __syncthreads();
377
+ if (tid == 0)
378
+ {
379
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
+ int sid=2;
381
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
382
+ {
383
+ _grad_w += cache_grad_sampling_loc[sid];
384
+ _grad_h += cache_grad_sampling_loc[sid + 1];
385
+ _grad_a += cache_grad_attn_weight[tid];
386
+ sid += 2;
387
+ }
388
+
389
+
390
+ *grad_sampling_loc = _grad_w;
391
+ *(grad_sampling_loc + 1) = _grad_h;
392
+ *grad_attn_weight = _grad_a;
393
+ }
394
+ __syncthreads();
395
+
396
+ data_weight_ptr += 1;
397
+ data_loc_w_ptr += 2;
398
+ grad_attn_weight += grad_weight_stride;
399
+ grad_sampling_loc += grad_loc_stride;
400
+ }
401
+ }
402
+ }
403
+ }
404
+
405
+
406
+ template <typename scalar_t, unsigned int blockSize>
407
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
+ const scalar_t *grad_col,
409
+ const scalar_t *data_value,
410
+ const int64_t *data_spatial_shapes,
411
+ const int64_t *data_level_start_index,
412
+ const scalar_t *data_sampling_loc,
413
+ const scalar_t *data_attn_weight,
414
+ const int batch_size,
415
+ const int spatial_size,
416
+ const int num_heads,
417
+ const int channels,
418
+ const int num_levels,
419
+ const int num_query,
420
+ const int num_point,
421
+ scalar_t *grad_value,
422
+ scalar_t *grad_sampling_loc,
423
+ scalar_t *grad_attn_weight)
424
+ {
425
+ CUDA_KERNEL_LOOP(index, n)
426
+ {
427
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
+ unsigned int tid = threadIdx.x;
430
+ int _temp = index;
431
+ const int c_col = _temp % channels;
432
+ _temp /= channels;
433
+ const int sampling_index = _temp;
434
+ const int m_col = _temp % num_heads;
435
+ _temp /= num_heads;
436
+ const int q_col = _temp % num_query;
437
+ _temp /= num_query;
438
+ const int b_col = _temp;
439
+
440
+ const scalar_t top_grad = grad_col[index];
441
+
442
+ int data_weight_ptr = sampling_index * num_levels * num_point;
443
+ int data_loc_w_ptr = data_weight_ptr << 1;
444
+ const int grad_sampling_ptr = data_weight_ptr;
445
+ grad_sampling_loc += grad_sampling_ptr << 1;
446
+ grad_attn_weight += grad_sampling_ptr;
447
+ const int grad_weight_stride = 1;
448
+ const int grad_loc_stride = 2;
449
+ const int qid_stride = num_heads * channels;
450
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
+
452
+ for (int l_col=0; l_col < num_levels; ++l_col)
453
+ {
454
+ const int level_start_id = data_level_start_index[l_col];
455
+ const int spatial_h_ptr = l_col << 1;
456
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
+
462
+ for (int p_col=0; p_col < num_point; ++p_col)
463
+ {
464
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
467
+
468
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
469
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
470
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
+ *(cache_grad_attn_weight+threadIdx.x)=0;
473
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
+ {
475
+ ms_deform_attn_col2im_bilinear(
476
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
+ top_grad, weight, grad_value_ptr,
478
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
+ }
480
+
481
+ __syncthreads();
482
+
483
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
484
+ {
485
+ if (tid < s) {
486
+ const unsigned int xid1 = tid << 1;
487
+ const unsigned int xid2 = (tid + s) << 1;
488
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
+ }
492
+ __syncthreads();
493
+ }
494
+
495
+ if (tid == 0)
496
+ {
497
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
498
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
+ *grad_attn_weight = cache_grad_attn_weight[0];
500
+ }
501
+ __syncthreads();
502
+
503
+ data_weight_ptr += 1;
504
+ data_loc_w_ptr += 2;
505
+ grad_attn_weight += grad_weight_stride;
506
+ grad_sampling_loc += grad_loc_stride;
507
+ }
508
+ }
509
+ }
510
+ }
511
+
512
+
513
+ template <typename scalar_t>
514
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
+ const scalar_t *grad_col,
516
+ const scalar_t *data_value,
517
+ const int64_t *data_spatial_shapes,
518
+ const int64_t *data_level_start_index,
519
+ const scalar_t *data_sampling_loc,
520
+ const scalar_t *data_attn_weight,
521
+ const int batch_size,
522
+ const int spatial_size,
523
+ const int num_heads,
524
+ const int channels,
525
+ const int num_levels,
526
+ const int num_query,
527
+ const int num_point,
528
+ scalar_t *grad_value,
529
+ scalar_t *grad_sampling_loc,
530
+ scalar_t *grad_attn_weight)
531
+ {
532
+ CUDA_KERNEL_LOOP(index, n)
533
+ {
534
+ extern __shared__ int _s[];
535
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
+ unsigned int tid = threadIdx.x;
538
+ int _temp = index;
539
+ const int c_col = _temp % channels;
540
+ _temp /= channels;
541
+ const int sampling_index = _temp;
542
+ const int m_col = _temp % num_heads;
543
+ _temp /= num_heads;
544
+ const int q_col = _temp % num_query;
545
+ _temp /= num_query;
546
+ const int b_col = _temp;
547
+
548
+ const scalar_t top_grad = grad_col[index];
549
+
550
+ int data_weight_ptr = sampling_index * num_levels * num_point;
551
+ int data_loc_w_ptr = data_weight_ptr << 1;
552
+ const int grad_sampling_ptr = data_weight_ptr;
553
+ grad_sampling_loc += grad_sampling_ptr << 1;
554
+ grad_attn_weight += grad_sampling_ptr;
555
+ const int grad_weight_stride = 1;
556
+ const int grad_loc_stride = 2;
557
+ const int qid_stride = num_heads * channels;
558
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
+
560
+ for (int l_col=0; l_col < num_levels; ++l_col)
561
+ {
562
+ const int level_start_id = data_level_start_index[l_col];
563
+ const int spatial_h_ptr = l_col << 1;
564
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
+
570
+ for (int p_col=0; p_col < num_point; ++p_col)
571
+ {
572
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
575
+
576
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
577
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
578
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
+ *(cache_grad_attn_weight+threadIdx.x)=0;
581
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
+ {
583
+ ms_deform_attn_col2im_bilinear(
584
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
+ top_grad, weight, grad_value_ptr,
586
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
+ }
588
+
589
+ __syncthreads();
590
+ if (tid == 0)
591
+ {
592
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
+ int sid=2;
594
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
+ {
596
+ _grad_w += cache_grad_sampling_loc[sid];
597
+ _grad_h += cache_grad_sampling_loc[sid + 1];
598
+ _grad_a += cache_grad_attn_weight[tid];
599
+ sid += 2;
600
+ }
601
+
602
+
603
+ *grad_sampling_loc = _grad_w;
604
+ *(grad_sampling_loc + 1) = _grad_h;
605
+ *grad_attn_weight = _grad_a;
606
+ }
607
+ __syncthreads();
608
+
609
+ data_weight_ptr += 1;
610
+ data_loc_w_ptr += 2;
611
+ grad_attn_weight += grad_weight_stride;
612
+ grad_sampling_loc += grad_loc_stride;
613
+ }
614
+ }
615
+ }
616
+ }
617
+
618
+ template <typename scalar_t>
619
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
+ const scalar_t *grad_col,
621
+ const scalar_t *data_value,
622
+ const int64_t *data_spatial_shapes,
623
+ const int64_t *data_level_start_index,
624
+ const scalar_t *data_sampling_loc,
625
+ const scalar_t *data_attn_weight,
626
+ const int batch_size,
627
+ const int spatial_size,
628
+ const int num_heads,
629
+ const int channels,
630
+ const int num_levels,
631
+ const int num_query,
632
+ const int num_point,
633
+ scalar_t *grad_value,
634
+ scalar_t *grad_sampling_loc,
635
+ scalar_t *grad_attn_weight)
636
+ {
637
+ CUDA_KERNEL_LOOP(index, n)
638
+ {
639
+ extern __shared__ int _s[];
640
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
+ unsigned int tid = threadIdx.x;
643
+ int _temp = index;
644
+ const int c_col = _temp % channels;
645
+ _temp /= channels;
646
+ const int sampling_index = _temp;
647
+ const int m_col = _temp % num_heads;
648
+ _temp /= num_heads;
649
+ const int q_col = _temp % num_query;
650
+ _temp /= num_query;
651
+ const int b_col = _temp;
652
+
653
+ const scalar_t top_grad = grad_col[index];
654
+
655
+ int data_weight_ptr = sampling_index * num_levels * num_point;
656
+ int data_loc_w_ptr = data_weight_ptr << 1;
657
+ const int grad_sampling_ptr = data_weight_ptr;
658
+ grad_sampling_loc += grad_sampling_ptr << 1;
659
+ grad_attn_weight += grad_sampling_ptr;
660
+ const int grad_weight_stride = 1;
661
+ const int grad_loc_stride = 2;
662
+ const int qid_stride = num_heads * channels;
663
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
+
665
+ for (int l_col=0; l_col < num_levels; ++l_col)
666
+ {
667
+ const int level_start_id = data_level_start_index[l_col];
668
+ const int spatial_h_ptr = l_col << 1;
669
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
+
675
+ for (int p_col=0; p_col < num_point; ++p_col)
676
+ {
677
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
680
+
681
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
682
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
683
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
+ *(cache_grad_attn_weight+threadIdx.x)=0;
686
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
+ {
688
+ ms_deform_attn_col2im_bilinear(
689
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
+ top_grad, weight, grad_value_ptr,
691
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
+ }
693
+
694
+ __syncthreads();
695
+
696
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
+ {
698
+ if (tid < s) {
699
+ const unsigned int xid1 = tid << 1;
700
+ const unsigned int xid2 = (tid + s) << 1;
701
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
+ if (tid + (s << 1) < spre)
705
+ {
706
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
+ }
710
+ }
711
+ __syncthreads();
712
+ }
713
+
714
+ if (tid == 0)
715
+ {
716
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
717
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
+ *grad_attn_weight = cache_grad_attn_weight[0];
719
+ }
720
+ __syncthreads();
721
+
722
+ data_weight_ptr += 1;
723
+ data_loc_w_ptr += 2;
724
+ grad_attn_weight += grad_weight_stride;
725
+ grad_sampling_loc += grad_loc_stride;
726
+ }
727
+ }
728
+ }
729
+ }
730
+
731
+ template <typename scalar_t>
732
+ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
+ const scalar_t *grad_col,
734
+ const scalar_t *data_value,
735
+ const int64_t *data_spatial_shapes,
736
+ const int64_t *data_level_start_index,
737
+ const scalar_t *data_sampling_loc,
738
+ const scalar_t *data_attn_weight,
739
+ const int batch_size,
740
+ const int spatial_size,
741
+ const int num_heads,
742
+ const int channels,
743
+ const int num_levels,
744
+ const int num_query,
745
+ const int num_point,
746
+ scalar_t *grad_value,
747
+ scalar_t *grad_sampling_loc,
748
+ scalar_t *grad_attn_weight)
749
+ {
750
+ CUDA_KERNEL_LOOP(index, n)
751
+ {
752
+ extern __shared__ int _s[];
753
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
+ unsigned int tid = threadIdx.x;
756
+ int _temp = index;
757
+ const int c_col = _temp % channels;
758
+ _temp /= channels;
759
+ const int sampling_index = _temp;
760
+ const int m_col = _temp % num_heads;
761
+ _temp /= num_heads;
762
+ const int q_col = _temp % num_query;
763
+ _temp /= num_query;
764
+ const int b_col = _temp;
765
+
766
+ const scalar_t top_grad = grad_col[index];
767
+
768
+ int data_weight_ptr = sampling_index * num_levels * num_point;
769
+ int data_loc_w_ptr = data_weight_ptr << 1;
770
+ const int grad_sampling_ptr = data_weight_ptr;
771
+ grad_sampling_loc += grad_sampling_ptr << 1;
772
+ grad_attn_weight += grad_sampling_ptr;
773
+ const int grad_weight_stride = 1;
774
+ const int grad_loc_stride = 2;
775
+ const int qid_stride = num_heads * channels;
776
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
+
778
+ for (int l_col=0; l_col < num_levels; ++l_col)
779
+ {
780
+ const int level_start_id = data_level_start_index[l_col];
781
+ const int spatial_h_ptr = l_col << 1;
782
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
+
788
+ for (int p_col=0; p_col < num_point; ++p_col)
789
+ {
790
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
793
+
794
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
795
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
796
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
+ *(cache_grad_attn_weight+threadIdx.x)=0;
799
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
+ {
801
+ ms_deform_attn_col2im_bilinear(
802
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
+ top_grad, weight, grad_value_ptr,
804
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
+ }
806
+
807
+ __syncthreads();
808
+
809
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
+ {
811
+ if (tid < s) {
812
+ const unsigned int xid1 = tid << 1;
813
+ const unsigned int xid2 = (tid + s) << 1;
814
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
+ if (tid + (s << 1) < spre)
818
+ {
819
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
+ }
823
+ }
824
+ __syncthreads();
825
+ }
826
+
827
+ if (tid == 0)
828
+ {
829
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
+ }
833
+ __syncthreads();
834
+
835
+ data_weight_ptr += 1;
836
+ data_loc_w_ptr += 2;
837
+ grad_attn_weight += grad_weight_stride;
838
+ grad_sampling_loc += grad_loc_stride;
839
+ }
840
+ }
841
+ }
842
+ }
843
+
844
+
845
+ template <typename scalar_t>
846
+ __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
+ const scalar_t *grad_col,
848
+ const scalar_t *data_value,
849
+ const int64_t *data_spatial_shapes,
850
+ const int64_t *data_level_start_index,
851
+ const scalar_t *data_sampling_loc,
852
+ const scalar_t *data_attn_weight,
853
+ const int batch_size,
854
+ const int spatial_size,
855
+ const int num_heads,
856
+ const int channels,
857
+ const int num_levels,
858
+ const int num_query,
859
+ const int num_point,
860
+ scalar_t *grad_value,
861
+ scalar_t *grad_sampling_loc,
862
+ scalar_t *grad_attn_weight)
863
+ {
864
+ CUDA_KERNEL_LOOP(index, n)
865
+ {
866
+ int _temp = index;
867
+ const int c_col = _temp % channels;
868
+ _temp /= channels;
869
+ const int sampling_index = _temp;
870
+ const int m_col = _temp % num_heads;
871
+ _temp /= num_heads;
872
+ const int q_col = _temp % num_query;
873
+ _temp /= num_query;
874
+ const int b_col = _temp;
875
+
876
+ const scalar_t top_grad = grad_col[index];
877
+
878
+ int data_weight_ptr = sampling_index * num_levels * num_point;
879
+ int data_loc_w_ptr = data_weight_ptr << 1;
880
+ const int grad_sampling_ptr = data_weight_ptr;
881
+ grad_sampling_loc += grad_sampling_ptr << 1;
882
+ grad_attn_weight += grad_sampling_ptr;
883
+ const int grad_weight_stride = 1;
884
+ const int grad_loc_stride = 2;
885
+ const int qid_stride = num_heads * channels;
886
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
+
888
+ for (int l_col=0; l_col < num_levels; ++l_col)
889
+ {
890
+ const int level_start_id = data_level_start_index[l_col];
891
+ const int spatial_h_ptr = l_col << 1;
892
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
+
898
+ for (int p_col=0; p_col < num_point; ++p_col)
899
+ {
900
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
903
+
904
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
905
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
906
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
+ {
908
+ ms_deform_attn_col2im_bilinear_gm(
909
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
+ top_grad, weight, grad_value_ptr,
911
+ grad_sampling_loc, grad_attn_weight);
912
+ }
913
+ data_weight_ptr += 1;
914
+ data_loc_w_ptr += 2;
915
+ grad_attn_weight += grad_weight_stride;
916
+ grad_sampling_loc += grad_loc_stride;
917
+ }
918
+ }
919
+ }
920
+ }
921
+
922
+
923
+ template <typename scalar_t>
924
+ void ms_deformable_im2col_cuda(cudaStream_t stream,
925
+ const scalar_t* data_value,
926
+ const int64_t* data_spatial_shapes,
927
+ const int64_t* data_level_start_index,
928
+ const scalar_t* data_sampling_loc,
929
+ const scalar_t* data_attn_weight,
930
+ const int batch_size,
931
+ const int spatial_size,
932
+ const int num_heads,
933
+ const int channels,
934
+ const int num_levels,
935
+ const int num_query,
936
+ const int num_point,
937
+ scalar_t* data_col)
938
+ {
939
+ const int num_kernels = batch_size * num_query * num_heads * channels;
940
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
+ const int num_threads = CUDA_NUM_THREADS;
942
+ ms_deformable_im2col_gpu_kernel<scalar_t>
943
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
+ 0, stream>>>(
945
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
+
948
+ cudaError_t err = cudaGetLastError();
949
+ if (err != cudaSuccess)
950
+ {
951
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
+ }
953
+
954
+ }
955
+
956
+ template <typename scalar_t>
957
+ void ms_deformable_col2im_cuda(cudaStream_t stream,
958
+ const scalar_t* grad_col,
959
+ const scalar_t* data_value,
960
+ const int64_t * data_spatial_shapes,
961
+ const int64_t * data_level_start_index,
962
+ const scalar_t * data_sampling_loc,
963
+ const scalar_t * data_attn_weight,
964
+ const int batch_size,
965
+ const int spatial_size,
966
+ const int num_heads,
967
+ const int channels,
968
+ const int num_levels,
969
+ const int num_query,
970
+ const int num_point,
971
+ scalar_t* grad_value,
972
+ scalar_t* grad_sampling_loc,
973
+ scalar_t* grad_attn_weight)
974
+ {
975
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
+ const int num_kernels = batch_size * num_query * num_heads * channels;
977
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
+ if (channels > 1024)
979
+ {
980
+ if ((channels & 1023) == 0)
981
+ {
982
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
+ num_threads*3*sizeof(scalar_t), stream>>>(
985
+ num_kernels,
986
+ grad_col,
987
+ data_value,
988
+ data_spatial_shapes,
989
+ data_level_start_index,
990
+ data_sampling_loc,
991
+ data_attn_weight,
992
+ batch_size,
993
+ spatial_size,
994
+ num_heads,
995
+ channels,
996
+ num_levels,
997
+ num_query,
998
+ num_point,
999
+ grad_value,
1000
+ grad_sampling_loc,
1001
+ grad_attn_weight);
1002
+ }
1003
+ else
1004
+ {
1005
+ ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
+ 0, stream>>>(
1008
+ num_kernels,
1009
+ grad_col,
1010
+ data_value,
1011
+ data_spatial_shapes,
1012
+ data_level_start_index,
1013
+ data_sampling_loc,
1014
+ data_attn_weight,
1015
+ batch_size,
1016
+ spatial_size,
1017
+ num_heads,
1018
+ channels,
1019
+ num_levels,
1020
+ num_query,
1021
+ num_point,
1022
+ grad_value,
1023
+ grad_sampling_loc,
1024
+ grad_attn_weight);
1025
+ }
1026
+ }
1027
+ else{
1028
+ switch(channels)
1029
+ {
1030
+ case 1:
1031
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
+ 0, stream>>>(
1034
+ num_kernels,
1035
+ grad_col,
1036
+ data_value,
1037
+ data_spatial_shapes,
1038
+ data_level_start_index,
1039
+ data_sampling_loc,
1040
+ data_attn_weight,
1041
+ batch_size,
1042
+ spatial_size,
1043
+ num_heads,
1044
+ channels,
1045
+ num_levels,
1046
+ num_query,
1047
+ num_point,
1048
+ grad_value,
1049
+ grad_sampling_loc,
1050
+ grad_attn_weight);
1051
+ break;
1052
+ case 2:
1053
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
+ 0, stream>>>(
1056
+ num_kernels,
1057
+ grad_col,
1058
+ data_value,
1059
+ data_spatial_shapes,
1060
+ data_level_start_index,
1061
+ data_sampling_loc,
1062
+ data_attn_weight,
1063
+ batch_size,
1064
+ spatial_size,
1065
+ num_heads,
1066
+ channels,
1067
+ num_levels,
1068
+ num_query,
1069
+ num_point,
1070
+ grad_value,
1071
+ grad_sampling_loc,
1072
+ grad_attn_weight);
1073
+ break;
1074
+ case 4:
1075
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
+ 0, stream>>>(
1078
+ num_kernels,
1079
+ grad_col,
1080
+ data_value,
1081
+ data_spatial_shapes,
1082
+ data_level_start_index,
1083
+ data_sampling_loc,
1084
+ data_attn_weight,
1085
+ batch_size,
1086
+ spatial_size,
1087
+ num_heads,
1088
+ channels,
1089
+ num_levels,
1090
+ num_query,
1091
+ num_point,
1092
+ grad_value,
1093
+ grad_sampling_loc,
1094
+ grad_attn_weight);
1095
+ break;
1096
+ case 8:
1097
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
+ 0, stream>>>(
1100
+ num_kernels,
1101
+ grad_col,
1102
+ data_value,
1103
+ data_spatial_shapes,
1104
+ data_level_start_index,
1105
+ data_sampling_loc,
1106
+ data_attn_weight,
1107
+ batch_size,
1108
+ spatial_size,
1109
+ num_heads,
1110
+ channels,
1111
+ num_levels,
1112
+ num_query,
1113
+ num_point,
1114
+ grad_value,
1115
+ grad_sampling_loc,
1116
+ grad_attn_weight);
1117
+ break;
1118
+ case 16:
1119
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
+ 0, stream>>>(
1122
+ num_kernels,
1123
+ grad_col,
1124
+ data_value,
1125
+ data_spatial_shapes,
1126
+ data_level_start_index,
1127
+ data_sampling_loc,
1128
+ data_attn_weight,
1129
+ batch_size,
1130
+ spatial_size,
1131
+ num_heads,
1132
+ channels,
1133
+ num_levels,
1134
+ num_query,
1135
+ num_point,
1136
+ grad_value,
1137
+ grad_sampling_loc,
1138
+ grad_attn_weight);
1139
+ break;
1140
+ case 32:
1141
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
+ 0, stream>>>(
1144
+ num_kernels,
1145
+ grad_col,
1146
+ data_value,
1147
+ data_spatial_shapes,
1148
+ data_level_start_index,
1149
+ data_sampling_loc,
1150
+ data_attn_weight,
1151
+ batch_size,
1152
+ spatial_size,
1153
+ num_heads,
1154
+ channels,
1155
+ num_levels,
1156
+ num_query,
1157
+ num_point,
1158
+ grad_value,
1159
+ grad_sampling_loc,
1160
+ grad_attn_weight);
1161
+ break;
1162
+ case 64:
1163
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
+ 0, stream>>>(
1166
+ num_kernels,
1167
+ grad_col,
1168
+ data_value,
1169
+ data_spatial_shapes,
1170
+ data_level_start_index,
1171
+ data_sampling_loc,
1172
+ data_attn_weight,
1173
+ batch_size,
1174
+ spatial_size,
1175
+ num_heads,
1176
+ channels,
1177
+ num_levels,
1178
+ num_query,
1179
+ num_point,
1180
+ grad_value,
1181
+ grad_sampling_loc,
1182
+ grad_attn_weight);
1183
+ break;
1184
+ case 128:
1185
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
+ 0, stream>>>(
1188
+ num_kernels,
1189
+ grad_col,
1190
+ data_value,
1191
+ data_spatial_shapes,
1192
+ data_level_start_index,
1193
+ data_sampling_loc,
1194
+ data_attn_weight,
1195
+ batch_size,
1196
+ spatial_size,
1197
+ num_heads,
1198
+ channels,
1199
+ num_levels,
1200
+ num_query,
1201
+ num_point,
1202
+ grad_value,
1203
+ grad_sampling_loc,
1204
+ grad_attn_weight);
1205
+ break;
1206
+ case 256:
1207
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
+ 0, stream>>>(
1210
+ num_kernels,
1211
+ grad_col,
1212
+ data_value,
1213
+ data_spatial_shapes,
1214
+ data_level_start_index,
1215
+ data_sampling_loc,
1216
+ data_attn_weight,
1217
+ batch_size,
1218
+ spatial_size,
1219
+ num_heads,
1220
+ channels,
1221
+ num_levels,
1222
+ num_query,
1223
+ num_point,
1224
+ grad_value,
1225
+ grad_sampling_loc,
1226
+ grad_attn_weight);
1227
+ break;
1228
+ case 512:
1229
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
+ 0, stream>>>(
1232
+ num_kernels,
1233
+ grad_col,
1234
+ data_value,
1235
+ data_spatial_shapes,
1236
+ data_level_start_index,
1237
+ data_sampling_loc,
1238
+ data_attn_weight,
1239
+ batch_size,
1240
+ spatial_size,
1241
+ num_heads,
1242
+ channels,
1243
+ num_levels,
1244
+ num_query,
1245
+ num_point,
1246
+ grad_value,
1247
+ grad_sampling_loc,
1248
+ grad_attn_weight);
1249
+ break;
1250
+ case 1024:
1251
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
+ 0, stream>>>(
1254
+ num_kernels,
1255
+ grad_col,
1256
+ data_value,
1257
+ data_spatial_shapes,
1258
+ data_level_start_index,
1259
+ data_sampling_loc,
1260
+ data_attn_weight,
1261
+ batch_size,
1262
+ spatial_size,
1263
+ num_heads,
1264
+ channels,
1265
+ num_levels,
1266
+ num_query,
1267
+ num_point,
1268
+ grad_value,
1269
+ grad_sampling_loc,
1270
+ grad_attn_weight);
1271
+ break;
1272
+ default:
1273
+ if (channels < 64)
1274
+ {
1275
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
+ num_threads*3*sizeof(scalar_t), stream>>>(
1278
+ num_kernels,
1279
+ grad_col,
1280
+ data_value,
1281
+ data_spatial_shapes,
1282
+ data_level_start_index,
1283
+ data_sampling_loc,
1284
+ data_attn_weight,
1285
+ batch_size,
1286
+ spatial_size,
1287
+ num_heads,
1288
+ channels,
1289
+ num_levels,
1290
+ num_query,
1291
+ num_point,
1292
+ grad_value,
1293
+ grad_sampling_loc,
1294
+ grad_attn_weight);
1295
+ }
1296
+ else
1297
+ {
1298
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
+ <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
+ num_threads*3*sizeof(scalar_t), stream>>>(
1301
+ num_kernels,
1302
+ grad_col,
1303
+ data_value,
1304
+ data_spatial_shapes,
1305
+ data_level_start_index,
1306
+ data_sampling_loc,
1307
+ data_attn_weight,
1308
+ batch_size,
1309
+ spatial_size,
1310
+ num_heads,
1311
+ channels,
1312
+ num_levels,
1313
+ num_query,
1314
+ num_point,
1315
+ grad_value,
1316
+ grad_sampling_loc,
1317
+ grad_attn_weight);
1318
+ }
1319
+ }
1320
+ }
1321
+ cudaError_t err = cudaGetLastError();
1322
+ if (err != cudaSuccess)
1323
+ {
1324
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
+ }
1326
+
1327
+ }
detect_tools/upn/ops/src/ms_deform_attn.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #pragma once
12
+
13
+ #include "cpu/ms_deform_attn_cpu.h"
14
+
15
+ #ifdef WITH_CUDA
16
+ #include "cuda/ms_deform_attn_cuda.h"
17
+ #endif
18
+
19
+
20
+ at::Tensor
21
+ ms_deform_attn_forward(
22
+ const at::Tensor &value,
23
+ const at::Tensor &spatial_shapes,
24
+ const at::Tensor &level_start_index,
25
+ const at::Tensor &sampling_loc,
26
+ const at::Tensor &attn_weight,
27
+ const int im2col_step)
28
+ {
29
+ if (value.is_cuda())
30
+ {
31
+ #ifdef WITH_CUDA
32
+ return ms_deform_attn_cuda_forward(
33
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34
+ #else
35
+ AT_ERROR("Not compiled with GPU support");
36
+ #endif
37
+ }
38
+ AT_ERROR("Not implemented on the CPU");
39
+ }
40
+
41
+ std::vector<at::Tensor>
42
+ ms_deform_attn_backward(
43
+ const at::Tensor &value,
44
+ const at::Tensor &spatial_shapes,
45
+ const at::Tensor &level_start_index,
46
+ const at::Tensor &sampling_loc,
47
+ const at::Tensor &attn_weight,
48
+ const at::Tensor &grad_output,
49
+ const int im2col_step)
50
+ {
51
+ if (value.is_cuda())
52
+ {
53
+ #ifdef WITH_CUDA
54
+ return ms_deform_attn_cuda_backward(
55
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56
+ #else
57
+ AT_ERROR("Not compiled with GPU support");
58
+ #endif
59
+ }
60
+ AT_ERROR("Not implemented on the CPU");
61
+ }
62
+
detect_tools/upn/ops/src/vision.cpp ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ **************************************************************************************************
3
+ * Deformable DETR
4
+ * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
+ **************************************************************************************************
7
+ * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
+ **************************************************************************************************
9
+ */
10
+
11
+ #include "ms_deform_attn.h"
12
+
13
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16
+ }
detect_tools/upn/ops/test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------------
2
+ # Deformable DETR
3
+ # Copyright (c) 2020 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------------------
6
+ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7
+ # ------------------------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import print_function
11
+ from __future__ import division
12
+
13
+ import time
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.autograd import gradcheck
17
+
18
+ from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19
+
20
+
21
+ N, M, D = 1, 2, 2
22
+ Lq, L, P = 2, 2, 2
23
+ shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24
+ level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25
+ S = sum([(H*W).item() for H, W in shapes])
26
+
27
+
28
+ torch.manual_seed(3)
29
+
30
+
31
+ @torch.no_grad()
32
+ def check_forward_equal_with_pytorch_double():
33
+ value = torch.rand(N, S, M, D).cuda() * 0.01
34
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37
+ im2col_step = 2
38
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40
+ fwdok = torch.allclose(output_cuda, output_pytorch)
41
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
42
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43
+
44
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45
+
46
+
47
+ @torch.no_grad()
48
+ def check_forward_equal_with_pytorch_float():
49
+ value = torch.rand(N, S, M, D).cuda() * 0.01
50
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53
+ im2col_step = 2
54
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
58
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59
+
60
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61
+
62
+
63
+ def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64
+
65
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
66
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69
+ im2col_step = 2
70
+ func = MSDeformAttnFunction.apply
71
+
72
+ value.requires_grad = grad_value
73
+ sampling_locations.requires_grad = grad_sampling_loc
74
+ attention_weights.requires_grad = grad_attn_weight
75
+
76
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77
+
78
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
79
+
80
+
81
+ if __name__ == '__main__':
82
+ check_forward_equal_with_pytorch_double()
83
+ check_forward_equal_with_pytorch_float()
84
+
85
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86
+ check_gradient_numerical(channels, True, True, True)
87
+
88
+
89
+
detect_tools/upn/requirments.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ mmengine==0.8.2
detect_tools/upn/transforms/transform.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torchvision.transforms.functional as F
4
+
5
+
6
+ def resize(image, target, size, max_size=None):
7
+ # size can be min_size (scalar) or (w, h) tuple
8
+
9
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
10
+ w, h = image_size
11
+ if max_size is not None:
12
+ min_original_size = float(min((w, h)))
13
+ max_original_size = float(max((w, h)))
14
+ if max_original_size / min_original_size * size > max_size:
15
+ size = int(round(max_size * min_original_size / max_original_size))
16
+
17
+ if (w <= h and w == size) or (h <= w and h == size):
18
+ return (h, w)
19
+
20
+ if w < h:
21
+ ow = size
22
+ oh = int(size * h / w)
23
+ else:
24
+ oh = size
25
+ ow = int(size * w / h)
26
+
27
+ return (oh, ow)
28
+
29
+ def get_size(image_size, size, max_size=None):
30
+ if isinstance(size, (list, tuple)):
31
+ return size[::-1]
32
+ else:
33
+ return get_size_with_aspect_ratio(image_size, size, max_size)
34
+
35
+ size = get_size(image.size, size, max_size)
36
+ rescaled_image = F.resize(image, size)
37
+
38
+ if target is None:
39
+ return rescaled_image, None
40
+
41
+ ratios = tuple(
42
+ float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)
43
+ )
44
+ ratio_width, ratio_height = ratios
45
+
46
+ target = target.copy()
47
+ if "exampler_box" in target:
48
+ boxes = target["exampler_box"]
49
+ if isinstance(boxes, torch.Tensor):
50
+ scaled_boxes = boxes * torch.as_tensor(
51
+ [ratio_width, ratio_height, ratio_width, ratio_height]
52
+ )
53
+ target["exampler_box"] = scaled_boxes
54
+ elif isinstance(boxes, dict):
55
+ for k, v in boxes.items():
56
+ scaled_boxes = v * torch.as_tensor(
57
+ [ratio_width, ratio_height, ratio_width, ratio_height]
58
+ )
59
+ target["exampler_box"][k] = scaled_boxes
60
+
61
+ if "demo_pos_exampler_box" in target:
62
+ boxes = target["demo_pos_exampler_box"]
63
+ scaled_boxes = boxes * torch.as_tensor(
64
+ [ratio_width, ratio_height, ratio_width, ratio_height]
65
+ )
66
+ target["demo_pos_exampler_box"] = scaled_boxes
67
+
68
+ if "demo_neg_exampler_box" in target:
69
+ boxes = target["demo_neg_exampler_box"]
70
+ scaled_boxes = boxes * torch.as_tensor(
71
+ [ratio_width, ratio_height, ratio_width, ratio_height]
72
+ )
73
+ target["demo_neg_exampler_box"] = scaled_boxes
74
+
75
+ if "boxes" in target:
76
+ boxes = target["boxes"]
77
+ scaled_boxes = boxes * torch.as_tensor(
78
+ [ratio_width, ratio_height, ratio_width, ratio_height]
79
+ )
80
+ target["boxes"] = scaled_boxes
81
+
82
+ if "area" in target:
83
+ area = target["area"]
84
+ scaled_area = area * (ratio_width * ratio_height)
85
+ target["area"] = scaled_area
86
+
87
+ h, w = size
88
+ target["size"] = torch.tensor([h, w])
89
+
90
+ return rescaled_image, target
91
+
92
+
93
+ class RandomResize(object):
94
+
95
+ def __init__(self, sizes, max_size=None):
96
+ assert isinstance(sizes, (list, tuple))
97
+ self.sizes = sizes
98
+ self.max_size = max_size
99
+
100
+ def __call__(self, img, target=None):
101
+ size = random.choice(self.sizes)
102
+ return resize(img, target, size, self.max_size)
103
+
104
+
105
+ class ToTensor(object):
106
+
107
+ def __call__(self, img, target):
108
+ return F.to_tensor(img), target
109
+
110
+
111
+ class Normalize(object):
112
+
113
+ def __init__(self, mean, std):
114
+ self.mean = mean
115
+ self.std = std
116
+
117
+ def __call__(self, image, target=None):
118
+ image = F.normalize(image, mean=self.mean, std=self.std)
119
+ if target is None:
120
+ return image, None
121
+ target = target.copy()
122
+ h, w = image.shape[-2:]
123
+ return image, target
124
+
125
+
126
+ class Compose(object):
127
+
128
+ def __init__(self, transforms):
129
+ self.transforms = transforms
130
+
131
+ def __call__(self, image, target):
132
+ for t in self.transforms:
133
+ image, target = t(image, target)
134
+ return image, target
135
+
136
+ def __repr__(self):
137
+ format_string = self.__class__.__name__ + "("
138
+ for t in self.transforms:
139
+ format_string += "\n"
140
+ format_string += " {0}".format(t)
141
+ format_string += "\n)"
142
+ return format_string
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ transformers==4.50.1
4
+ https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/sam3-0.1.0-py3-none-any.whl
5
+ timm==1.0.9
6
+ accelerate==1.4.0
7
+ gradio
8
+ mmengine==0.8.2
9
+ einops
10
+ ninja
11
+ scikit-image
12
+ decord
13
+ scikit-learn
14
+ matplotlib
15
+ modelscope
16
+ https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/multiscaledeformableattention-1.0-cp310-cp310-linux_x86_64.whl
17
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
18
+ pycocotools
19
+ opencv-python
resources/__init__.py ADDED
File without changes
vlm_fo1/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
vlm_fo1/constants.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LOGDIR = "."
2
+
3
+ global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
4
+ # Model Constants
5
+ IGNORE_INDEX = -100
6
+ IMAGE_TOKEN_INDEX = -200 #151656 #151655 #-200
7
+ DEFAULT_IMAGE_TOKEN = "<image>"
8
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
9
+ DEFAULT_IM_START_TOKEN = "<im_start>"
10
+ DEFAULT_IM_END_TOKEN = "<im_end>"
11
+
12
+ # For Qwen2_5_VL
13
+ QWEN2_5_VL_IMAGE_TOKEN = "<|image_pad|>"
14
+ QWEN2_5_VL_IMAGE_TOKEN_INDEX = 151655
15
+
16
+ # For regions
17
+ DEFAULT_REGION_TOKEN = "<region<i>>"
18
+ DEFAULT_REGION_FEATURE_TOKEN = "<regionfeat>"
19
+ DEFAULT_REGION_INDEX = -300 #151654 #151654 #-300
20
+
21
+ # For Grounding
22
+ DEFAULT_GROUNDING_START = "<ground>"
23
+ DEFAULT_GROUNDING_END = "</ground>"
24
+ DEFAULT_GROUNDING_OBJECTS_START = "<objects>"
25
+ DEFAULT_GROUNDING_OBJECTS_END = "</objects>"
26
+
27
+ # For Think
28
+ DEFAULT_THINK_START = "<think>"
29
+ DEFAULT_THINK_END = "</think>"
vlm_fo1/mm_utils.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import ImageDraw, ImageOps
3
+ from io import BytesIO
4
+ import base64
5
+ import re
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+ from vlm_fo1.constants import IMAGE_TOKEN_INDEX, DEFAULT_REGION_INDEX
9
+ import requests
10
+ from vlm_fo1.constants import (
11
+ IMAGE_TOKEN_INDEX,
12
+ DEFAULT_IMAGE_TOKEN,
13
+ DEFAULT_IM_START_TOKEN,
14
+ DEFAULT_IM_END_TOKEN,
15
+ IGNORE_INDEX,
16
+ DEFAULT_REGION_TOKEN,
17
+ DEFAULT_REGION_FEATURE_TOKEN
18
+ )
19
+ import torch
20
+ from transformers import TextStreamer
21
+ import random
22
+ import re
23
+ from typing import List, Tuple
24
+ import io
25
+ import base64
26
+
27
+
28
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
29
+ """
30
+ Tokenizes prompts containing <image> or <image_0>... special tokens.
31
+
32
+ If the prompt uses <image_0>, <image_1>, ..., each is replaced with a placeholder index (-200).
33
+ If the prompt uses <image>, it is replaced with image_token_index.
34
+
35
+ Args:
36
+ prompt (str): The prompt potentially containing image tokens.
37
+ tokenizer: The tokenizer object.
38
+ image_token_index (int): Token id to use when encountering <image> token.
39
+ return_tensors (Optional[str]): If 'pt', return a torch tensor.
40
+
41
+ Returns:
42
+ List[int] or torch.Tensor: The tokenized input with image token indices inserted appropriately.
43
+ """
44
+ if "<image_0>" in prompt:
45
+ # Case: prompt contains indexed image tokens like <image_0>, <image_1>, etc.
46
+ image_token_pattern = re.compile(r"<image_(\d+)>")
47
+ prompt_chunks = re.split(r'<image_[0-9]+>', prompt)
48
+ image_tags = image_token_pattern.findall(prompt)
49
+
50
+ input_ids = []
51
+ for i, chunk in enumerate(prompt_chunks):
52
+ input_ids.extend(tokenizer(chunk).input_ids)
53
+ if i < len(image_tags):
54
+ # Insert placeholder where <image_n> token was.
55
+ input_ids.append(-200)
56
+ else:
57
+ # Case: prompt contains plain <image> tokens.
58
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
59
+
60
+ def insert_separator(X, sep):
61
+ # Helper function to insert a separator token between chunks.
62
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
63
+
64
+ input_ids = []
65
+ offset = 0
66
+ # If first chunk starts with <bos> token, make sure to keep it only once.
67
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
68
+ offset = 1
69
+ input_ids.append(prompt_chunks[0][0])
70
+
71
+ # Insert image_token_index between chunks.
72
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
73
+ input_ids.extend(x[offset:])
74
+ # Optionally convert output to PyTorch tensor.
75
+ if return_tensors is not None:
76
+ if return_tensors == 'pt':
77
+ return torch.tensor(input_ids, dtype=torch.long)
78
+ else:
79
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
80
+
81
+ return input_ids
82
+
83
+ def tokenizer_image_region_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, region_token_index=DEFAULT_REGION_INDEX, return_tensors=None):
84
+ """
85
+ Tokenizes prompts containing both <image> and <regionfeat> delimiters, inserting specified token indices.
86
+
87
+ Each <image> chunk is split, and within that chunk, <regionfeat> locations receive region_token_index.
88
+
89
+ Args:
90
+ prompt (str): The prompt with <image> and <regionfeat> delimiters.
91
+ tokenizer: The tokenizer object.
92
+ image_token_index (int): Insert this at <image> splits.
93
+ region_token_index (int): Insert this at <regionfeat> splits.
94
+ return_tensors (Optional[str]): If 'pt', return torch tensor.
95
+
96
+ Returns:
97
+ List[int] or torch.Tensor: The tokenized input with region/image tokens placed.
98
+ """
99
+ # Split by <image> tags first.
100
+ image_chunks = prompt.split('<image>')
101
+
102
+ prompt_chunks = []
103
+ for chunk in image_chunks:
104
+ # Split each image chunk by <regionfeat>.
105
+ obj_chunks = chunk.split('<regionfeat>')
106
+ # Tokenize each subchunk.
107
+ token_chunks = [tokenizer(c).input_ids for c in obj_chunks]
108
+ prompt_chunks.append(token_chunks)
109
+
110
+ input_ids = []
111
+ offset = 0
112
+
113
+ # If first chunk starts with <bos> token, include only once.
114
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and len(prompt_chunks[0][0]) > 0 and prompt_chunks[0][0][0] == tokenizer.bos_token_id:
115
+ offset = 1
116
+ input_ids.append(prompt_chunks[0][0][0])
117
+
118
+ # Stitch together all chunks with region/image tokens at appropriate locations.
119
+ for i, chunk_group in enumerate(prompt_chunks):
120
+ if len(chunk_group) > 0:
121
+ input_ids.extend(chunk_group[0][offset:])
122
+ for chunk in chunk_group[1:]:
123
+ input_ids.append(region_token_index)
124
+ input_ids.extend(chunk)
125
+ # Insert <image> token except after the last image chunk.
126
+ if i < len(prompt_chunks) - 1:
127
+ input_ids.append(image_token_index)
128
+ # Optionally convert to PyTorch tensor.
129
+ if return_tensors is not None:
130
+ if return_tensors == 'pt':
131
+ return torch.tensor(input_ids, dtype=torch.long)
132
+ else:
133
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
134
+
135
+ return input_ids
136
+
137
+ class KeywordsStoppingCriteria(StoppingCriteria):
138
+ """
139
+ Implements custom stopping criteria for generation based on keywords:
140
+ If the generated output contains any of the keywords, generation stops.
141
+ """
142
+ def __init__(self, keywords, tokenizer, input_ids):
143
+ self.keywords = keywords
144
+ self.keyword_ids = []
145
+ self.max_keyword_len = 0
146
+ for keyword in keywords:
147
+ cur_keyword_ids = tokenizer(keyword).input_ids
148
+ # Remove BOS if present except for single token
149
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
150
+ cur_keyword_ids = cur_keyword_ids[1:]
151
+ if len(cur_keyword_ids) > self.max_keyword_len:
152
+ self.max_keyword_len = len(cur_keyword_ids)
153
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
154
+ self.tokenizer = tokenizer
155
+ # Track the generation start length
156
+ self.start_len = input_ids.shape[1]
157
+
158
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
159
+ """
160
+ Checks if a keyword exists in the latest generated output ids for a single batch element.
161
+ """
162
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
163
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
164
+ for keyword_id in self.keyword_ids:
165
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
166
+ if torch.equal(truncated_output_ids, keyword_id):
167
+ return True
168
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
169
+ for keyword in self.keywords:
170
+ if keyword in outputs:
171
+ return True
172
+ return False
173
+
174
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
175
+ """
176
+ Checks for keywords in each batch item; stops when all have satisfied the keyword condition.
177
+ """
178
+ outputs = []
179
+ for i in range(output_ids.shape[0]):
180
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
181
+ return all(outputs)
182
+
183
+ def load_image(image_file):
184
+ """
185
+ Loads an image from a local path, base64 string, URL, or PIL.Image.
186
+
187
+ If the input image is smaller than 28x28, it will be resized to at least that size.
188
+
189
+ Args:
190
+ image_file (str or PIL.Image.Image): Image source.
191
+
192
+ Returns:
193
+ PIL.Image.Image: Loaded image in RGB mode, at least 28x28 in size.
194
+ """
195
+ if isinstance(image_file, Image.Image):
196
+ image = image_file.convert("RGB")
197
+ # Case: load from URL
198
+ elif image_file.startswith("http") or image_file.startswith("https"):
199
+ response = requests.get(image_file)
200
+ image = Image.open(BytesIO(response.content)).convert("RGB")
201
+ # Case: load from base64-encoded string
202
+ elif image_file.startswith("data:image/"):
203
+ image = image_file.replace("data:image/jpeg;base64,", "")
204
+ image_data = base64.b64decode(image)
205
+ image = Image.open(BytesIO(image_data)).convert("RGB")
206
+ elif isinstance(image_file, str):
207
+ # Case: load from local file path
208
+ image = Image.open(image_file).convert("RGB")
209
+ else:
210
+ raise ValueError(f"Unsupported image type: {type(image_file)}")
211
+
212
+ # Ensure minimum size 28x28
213
+ if image.width < 28 or image.height < 28:
214
+ image = image.resize((max(28, image.width), max(28, image.height)))
215
+ return image
216
+
217
+ def image_to_base64(img_pil):
218
+ """
219
+ Encodes a PIL Image as JPEG in base64 format.
220
+
221
+ Args:
222
+ img_pil (PIL.Image.Image): Source image.
223
+
224
+ Returns:
225
+ str: base64-encoded JPEG image string.
226
+ """
227
+ with io.BytesIO() as buffer:
228
+ img_pil.save(buffer, format="JPEG")
229
+ base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
230
+ return base64_image
231
+
232
+ def draw_bboxes_and_save(
233
+ image: Image.Image,
234
+ fo1_bboxes: dict = {},
235
+ detection_bboxes: List[Tuple[int, int, int, int]] = [],
236
+ output_path: str = 'output.jpg',
237
+ color: str = 'red',
238
+ total_color: str = 'green',
239
+ width: int = 2
240
+ ) -> None:
241
+ """
242
+ Draws bounding boxes (both ground-truth/proposed and detection) on a PIL image and saves result.
243
+
244
+ Args:
245
+ image (PIL.Image.Image): Input PIL image object.
246
+ fo1_bboxes (dict): Label -> List[bbox] mapping for annotation bboxes.
247
+ detection_bboxes (List[Tuple]): List of detection bounding boxes; each bbox is (x_min, y_min, x_max, y_max).
248
+ output_path (str): Path to save the output image.
249
+ color (str): Color for fo1_bboxes.
250
+ total_color (str): Color for detection_bboxes.
251
+ width (int): Rectangle outline width.
252
+
253
+ Returns:
254
+ None
255
+ """
256
+ draw = ImageDraw.Draw(image)
257
+
258
+ # Draw detection boxes with `total_color`
259
+ for bbox in detection_bboxes:
260
+ if len(bbox) != 4:
261
+ print(f"Warning: skip the invalid bbox {bbox}")
262
+ continue
263
+ shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
264
+ draw.rectangle(shape, outline=total_color, width=width)
265
+
266
+ # Draw annotated bboxes with labels and `color`
267
+ for bbox_label, bbox_list in fo1_bboxes.items():
268
+ for bbox in bbox_list:
269
+ if len(bbox) != 4:
270
+ print(f"Warning: skip the invalid bbox {bbox}")
271
+ continue
272
+ shape = [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
273
+ draw.rectangle(shape, outline=color, width=width)
274
+ draw.text((bbox[0], bbox[1]), bbox_label, fill=color)
275
+
276
+ # Save output image (catching common IO exceptions).
277
+ try:
278
+ image.save(output_path)
279
+ print(f"The image has been successfully saved to: {output_path}")
280
+ except IOError as e:
281
+ print(f"Error: failed to save the image to {output_path}. Reason: {e}")
282
+
283
+ def adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w):
284
+ """
285
+ Adjusts bounding boxes from original image size to resized image size, compensating for scaling.
286
+
287
+ Args:
288
+ bbox_list (List[List[float]]): List of original boxes [x1, y1, x2, y2].
289
+ original_h (int): Original image height.
290
+ original_w (int): Original image width.
291
+ resize_h (int): Resized image height.
292
+ resize_w (int): Resized image width.
293
+
294
+ Returns:
295
+ List[List[float]]: Bounding boxes transformed to resized image coordinates.
296
+ """
297
+ output_list = []
298
+ def adjust_bbox_range(bbox, width, height):
299
+ # Ensure all coordinates are within the original image border.
300
+ x1, y1, x2, y2 = bbox
301
+ x1 = max(0, min(width, x1))
302
+ y1 = max(0, min(height, y1))
303
+ x2 = max(0, min(width, x2))
304
+ y2 = max(0, min(height, y2))
305
+ return [x1, y1, x2, y2]
306
+
307
+ for bbox in bbox_list:
308
+ bbox = adjust_bbox_range(bbox, original_w, original_h)
309
+ bbox[0] = bbox[0] * resize_w / original_w
310
+ bbox[1] = bbox[1] * resize_h / original_h
311
+ bbox[2] = bbox[2] * resize_w / original_w
312
+ bbox[3] = bbox[3] * resize_h / original_h
313
+ output_list.append(bbox)
314
+ return output_list
315
+
316
+ def extract_predictions_to_bboxes(prediction: str, bbox_list):
317
+ """
318
+ Parse prediction string in the expected format and map each ground label
319
+ to its corresponding bounding boxes using bbox_list.
320
+
321
+ Args:
322
+ prediction (str): Model output string with <ground>...<objects>... markup.
323
+ bbox_list (List[List[float]]): Full list of predicted or reference bounding boxes.
324
+
325
+ Returns:
326
+ dict: label -> list of bboxes
327
+ """
328
+ label_to_indexes = {}
329
+ label_to_bboxes = {}
330
+
331
+ match_pattern = r"<ground>(.*?)<\/ground><objects>(.*?)<\/objects>"
332
+ matches = re.findall(match_pattern, prediction)
333
+
334
+ for label_text, indexes in matches:
335
+ label_text = label_text.strip()
336
+ indexes_tags = re.findall(r"<region\d+>", indexes)
337
+ region_indexes = set([int(index.split("<region")[-1].split(">")[0]) for index in indexes_tags])
338
+ if label_text not in label_to_indexes:
339
+ label_to_indexes[label_text] = region_indexes
340
+ else:
341
+ label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes
342
+
343
+ for label, indexes in label_to_indexes.items():
344
+ label_to_bboxes[label] = [bbox_list[index] for index in indexes]
345
+
346
+ return label_to_bboxes
347
+
348
+ def extract_predictions_to_indexes(prediction: str):
349
+ """
350
+ Parse prediction string, returning label -> set-of-indexes mapping.
351
+
352
+ Args:
353
+ prediction (str): Model prediction output.
354
+
355
+ Returns:
356
+ dict: label -> set(int)
357
+ """
358
+ label_to_indexes = {}
359
+ match_pattern = r"<ground>(.*?)<\/ground><objects>(.*?)<\/objects>"
360
+ matches = re.findall(match_pattern, prediction)
361
+
362
+ for label_text, indexes in matches:
363
+ label_text = label_text.strip()
364
+ indexes_tags = re.findall(r"<region\d+>", indexes)
365
+ region_indexes = set([int(index.split("<region")[-1].split(">")[0]) for index in indexes_tags])
366
+ if label_text not in label_to_indexes:
367
+ label_to_indexes[label_text] = region_indexes
368
+ else:
369
+ label_to_indexes[label_text] = label_to_indexes[label_text] | region_indexes
370
+
371
+ return label_to_indexes
372
+
373
+ def resize_shortest_edge_images_and_bboxes(
374
+ image_list: List[Image.Image],
375
+ bbox_lists: List,
376
+ candidate_sizes: List[int] = [],
377
+ max_size: int = 2048
378
+ ):
379
+ """
380
+ Randomly selects a size for the shortest edge, and proportionally resizes both images and bounding boxes.
381
+
382
+ The function maintains the image aspect ratio and ensures that the resized dimensions do not exceed the specified max_size.
383
+ Bounding boxes are transformed accordingly.
384
+
385
+ Args:
386
+ image_list (List[Image.Image]): A list of PIL Image objects.
387
+ bbox_lists (List[List[List[float]]]): A list of lists of bounding boxes per image.
388
+ candidate_sizes (List[int]): Optional list of sizes to choose the target short edge from.
389
+ max_size (int): Maximum allowed long edge after resizing.
390
+
391
+ Returns:
392
+ Tuple[List[Image.Image], List[List[List[float]]]]:
393
+ ([resized_image1, ...], [bbox_list1, ...]) - Possibly shape will match original (see below)
394
+
395
+ Raises:
396
+ ValueError: on input list length mismatch or emptiness.
397
+ """
398
+ bbox_tensor = torch.tensor(bbox_lists)
399
+ # Normalize input: wrap bbox_lists into list-of-list, if needed.
400
+ if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4:
401
+ bbox_lists = [bbox_lists]
402
+
403
+ if not image_list or not bbox_lists:
404
+ raise ValueError("Input lists cannot be empty.")
405
+ if len(image_list) != len(bbox_lists):
406
+ raise ValueError("The lengths of the image list and the bounding box list must be the same.")
407
+
408
+ # Randomly select short edge size (if given candidate sizes)
409
+ if len(candidate_sizes) > 0:
410
+ target_size = random.choice(candidate_sizes)
411
+ else:
412
+ target_size = None
413
+
414
+ resized_images = []
415
+ transformed_bbox_lists = []
416
+
417
+ # Process each image and its corresponding bbox list
418
+ for img, bboxes in zip(image_list, bbox_lists):
419
+ original_width, original_height = img.size
420
+
421
+ # Determine scaling factor to bring short edge to target_size
422
+ shortest_side = min(original_width, original_height)
423
+ if target_size:
424
+ scale = target_size / shortest_side
425
+ else:
426
+ scale = 1.0
427
+
428
+ # Propose new height and width with this scale
429
+ new_height, new_width = int(original_height * scale), int(original_width * scale)
430
+
431
+ # If resulting long edge exceeds max_size, rescale down so that it fits.
432
+ longest_side = max(new_height, new_width)
433
+ if longest_side > max_size:
434
+ scale = max_size / longest_side
435
+ new_height, new_width = int(new_height * scale), int(new_width * scale)
436
+ # Ensure images are at least 28x28 (model may expect it)
437
+ new_width = max(28, new_width)
438
+ new_height = max(28, new_height)
439
+
440
+ # Resize image, using BICUBIC for quality if shape changes
441
+ if new_width == original_width and new_height == original_height:
442
+ resized_img = img
443
+ else:
444
+ resized_img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
445
+ resized_images.append(resized_img)
446
+
447
+ # Transform bounding boxes
448
+ current_transformed_bboxes = []
449
+ scale_ratio_x = new_width / original_width
450
+ scale_ratio_y = new_height / original_height
451
+ for bbox in bboxes:
452
+ x1, y1, x2, y2 = bbox
453
+ new_x1 = x1 * scale_ratio_x
454
+ new_y1 = y1 * scale_ratio_y
455
+ new_x2 = x2 * scale_ratio_x
456
+ new_y2 = y2 * scale_ratio_y
457
+ current_transformed_bboxes.append([new_x1, new_y1, new_x2, new_y2])
458
+ transformed_bbox_lists.append(current_transformed_bboxes)
459
+
460
+ # If original input was a single image (not list), unpack.
461
+ if len(bbox_tensor.shape) == 2 and bbox_tensor.shape[1] == 4:
462
+ return resized_images, transformed_bbox_lists[0]
463
+ else:
464
+ return resized_images, transformed_bbox_lists
465
+
466
+ def make_message_context(tokenizer, message, chat_format="chatml"):
467
+ """
468
+ Given a message dict, construct the prompt, tokenized context tokens, image URLs, and bbox_list.
469
+
470
+ Handles both standard string 'content' and multi-part (list) content, appropriately placing image/region tokens.
471
+
472
+ Args:
473
+ tokenizer: tokenizer object
474
+ message (dict): Contains role, content, and optionally bbox_list.
475
+ chat_format (str): Optionally select chat format (default 'chatml').
476
+
477
+ Returns:
478
+ tuple: (inp, context_tokens, image_urls, bbox_list)
479
+ """
480
+ image_urls = []
481
+ if chat_format == "chatml":
482
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
483
+ im_start_tokens = [151644]
484
+ im_end_tokens = [151645]
485
+ nl_tokens = tokenizer.encode("\n")
486
+ role = message["role"]
487
+ content = message["content"]
488
+ bbox_list = message.get("bbox_list", None)
489
+
490
+ if role == "system":
491
+ inp = f"{im_start}{role}\n{content}{im_end}\n"
492
+ context_tokens = tokenizer.encode(
493
+ role, allowed_special=set()) + nl_tokens + tokenizer.encode(content, allowed_special=set())
494
+ context_tokens = im_start_tokens + context_tokens + im_end_tokens
495
+
496
+ if role == "user":
497
+ if isinstance(content, str):
498
+ # Plain string message
499
+ inp = f"{im_start}{role}\n{content}{im_end}\n"
500
+ context_tokens = tokenizer.encode(
501
+ role, allowed_special=set()) + nl_tokens + tokenizer.encode(content,
502
+ allowed_special=set())
503
+ context_tokens = im_start_tokens + context_tokens + im_end_tokens
504
+ if isinstance(content, list):
505
+ # Multi-part message (text and image_url parts, maybe region tokens)
506
+ inp = f"{im_start}{role}\n"
507
+ image_count = 1
508
+ for message_part in content:
509
+ if message_part["type"] == "text":
510
+ inp += f"{message_part['text']}"
511
+
512
+ if message_part["type"] == "image_url":
513
+ # Insert special vision/image tokens, possibly region tokens
514
+ inp += DEFAULT_IM_START_TOKEN + '<image>' + DEFAULT_IM_END_TOKEN + '\n'
515
+ # If regions exist, add per-region special token.
516
+ if bbox_list and len(bbox_list) > 0:
517
+ for idx, bbox in enumerate(bbox_list):
518
+ inp += DEFAULT_REGION_TOKEN.replace('<i>', str(idx)) + DEFAULT_REGION_FEATURE_TOKEN
519
+ inp += '\n'
520
+
521
+ image_urls.append(message_part['image_url']['url'])
522
+ image_count += 1
523
+ inp += f"{im_end}\n"
524
+
525
+ # Choose tokenizer logic based on whether bbox (region) list exists
526
+ if bbox_list and len(bbox_list) > 0:
527
+ context_tokens = tokenizer_image_region_token(inp, tokenizer)
528
+ else:
529
+ context_tokens = tokenizer_image_token(inp, tokenizer, image_token_index=IMAGE_TOKEN_INDEX)
530
+ return inp, context_tokens, image_urls, bbox_list
531
+
532
+ def prepare_inputs(model_name, model, image_processors, tokenizer, messages, device="cuda", max_tokens=512, top_p=1.0, temperature=0.0, do_sample=False, image_size=None):
533
+ """
534
+ Fully prepares keyword arguments for model.generate (and compatible API) from messages and model specs.
535
+
536
+ Handles prompt assembly, tokenization, image loading/preprocessing, region support, streaming, etc.
537
+ Supports specific tweak for Qwen2.5-VL style vision tokens.
538
+
539
+ Args:
540
+ model_name (str): Model identifier string.
541
+ model: Model/config object.
542
+ image_processors (tuple): (primary, auxiliary) image processors.
543
+ tokenizer: Tokenizer object.
544
+ messages (list): Multi-message input list (chat history).
545
+ device (str): Target (usually 'cuda' or 'cpu').
546
+ max_tokens, top_p, temperature, do_sample: Standard generation kwargs.
547
+
548
+ Returns:
549
+ dict: ready-to-use argument dict for model.generate().
550
+ """
551
+ # For Qwen2.5-VL, patch vision special tokens globally.
552
+ if 'qwen2.5-vl' in model_name.lower() or 'qwen2_5_vl' in model_name.lower():
553
+ global DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
554
+ DEFAULT_IM_START_TOKEN = "<|vision_start|>"
555
+ DEFAULT_IM_END_TOKEN = "<|vision_end|>"
556
+
557
+ primary_image_processor, auxiliary_image_processor = image_processors
558
+
559
+ prompt = ""
560
+ input_tokens = []
561
+ image_urls = []
562
+ # Compose prompt and accumulate all components from provided messages
563
+ for message in messages:
564
+ inp, context_tokens, image_urls, bbox_list = make_message_context(tokenizer, message)
565
+ prompt += inp
566
+ input_tokens.extend(context_tokens)
567
+
568
+ # Ensure a system prompt at start, if not already present.
569
+ if "system" not in prompt:
570
+ system_content = "system\nYou are a helpful assistant."
571
+ system_prompt = "<|im_start|>" + system_content + "<|im_end|>" + "\n"
572
+ prompt = system_prompt + prompt
573
+ system_tokens = [151644] + tokenizer(system_content).input_ids + [151645] + tokenizer("\n").input_ids
574
+ input_tokens = system_tokens + input_tokens
575
+
576
+ # Ensure prompt ends with assistant's turn.
577
+ if not prompt.endswith("<|im_start|>assistant"):
578
+ last_assistant_prompt = "<|im_start|>" + "assistant" + "\n"
579
+ prompt += last_assistant_prompt
580
+ # last_assistant_tokens = [6] + self.tokenizer("assistant\n").input_ids
581
+ last_assistant_tokens = [151644] + tokenizer("assistant\n").input_ids
582
+ input_tokens.extend(last_assistant_tokens)
583
+
584
+ primary_images_tensor = None
585
+ auxiliary_images_tensor = None
586
+ primary_image_grid_thws = None
587
+ if image_urls:
588
+ # Load images, resize them, and update bbox_list downstream
589
+ images = [load_image(i) for i in image_urls]
590
+ if image_size is not None:
591
+ images, bbox_list = resize_shortest_edge_images_and_bboxes(images, bbox_list, candidate_sizes=[image_size], max_size=2048)
592
+ else:
593
+ images, bbox_list = resize_shortest_edge_images_and_bboxes(images, bbox_list, max_size=2048)
594
+
595
+
596
+ # When region-indexed tokens are enabled
597
+ if getattr(model.config, 'mm_use_region_index_token', False):
598
+ origin_image_size = [image.size for image in images]
599
+ aux_images = images.copy()
600
+ auxiliary_images_tensor = [auxiliary_image_processor.preprocess(i, return_tensors='pt')['pixel_values'][0].to(device) for i in aux_images]
601
+
602
+ if bbox_list and len(bbox_list) > 0:
603
+ # Limit number of bbox (for computational constraints, etc.)
604
+ bbox_list = bbox_list[:100]
605
+ resize_h, resize_w = auxiliary_images_tensor[0].shape[-2:]
606
+ original_w, original_h = origin_image_size[0]
607
+ # Adjust bbox to match resized images (post pre-processing)
608
+ bbox_list = adjust_bbox(bbox_list, original_h, original_w, resize_h, resize_w)
609
+ bbox_list = [torch.tensor(bbox_list)]
610
+ else:
611
+ bbox_list = None
612
+ else:
613
+ auxiliary_images_tensor = None
614
+
615
+ # Preprocess primary images for main vision model branch
616
+ primary_images = []
617
+ primary_image_grid_thws = []
618
+ for im in images:
619
+ processed_data = primary_image_processor.preprocess(im, return_tensors="pt")
620
+ image_i = processed_data['pixel_values']
621
+ image_grid_thw_i = processed_data['image_grid_thw']
622
+ primary_images.append(image_i)
623
+ primary_image_grid_thws.append(image_grid_thw_i)
624
+ primary_images_tensor = [image_i.to(device) for image_i in primary_images]
625
+
626
+ # For Qwen-style, force specific end-token as stopping criterion
627
+ if "qwen" in model_name.lower():
628
+ input_ids = torch.tensor([input_tokens]).to(device)
629
+ keywords = ["<|im_end|>"]
630
+
631
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
632
+ streamer = TextStreamer(
633
+ tokenizer, skip_prompt=True, skip_special_tokens=True
634
+ )
635
+
636
+ # Default: greedy decoding if temperature=0. Else: enable sampling.
637
+ if temperature == 0.0:
638
+ do_sample = False
639
+ else:
640
+ do_sample = True
641
+
642
+ print("question:================\n", prompt, "\n=================")
643
+ # print("input ids:========", input_ids, "========")
644
+ generation_kwargs = dict(
645
+ inputs=input_ids,
646
+ images=primary_images_tensor,
647
+ images_aux=auxiliary_images_tensor,
648
+ image_grid_thws=primary_image_grid_thws,
649
+ bbox_list=bbox_list,
650
+ do_sample=do_sample,
651
+ temperature=temperature,
652
+ max_new_tokens=max_tokens,
653
+ streamer=streamer,
654
+ top_p=top_p,
655
+ use_cache=True,
656
+ stopping_criteria=[stopping_criteria],
657
+ pad_token_id=tokenizer.pad_token_id
658
+ )
659
+ return generation_kwargs
660
+
vlm_fo1/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .language_model.omchat_qwen2_5_vl import OmChatQwen25VLForCausalLM, OmChatQwen25VLConfig
vlm_fo1/model/builder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ import torch
3
+ from vlm_fo1.model import *
4
+ from safetensors.torch import load_file
5
+ import os
6
+
7
+
8
+ def load_pretrained_model(model_path, load_8bit=False, load_4bit=False, device="cuda"):
9
+ """
10
+ Loads a pretrained model along with its vision towers (and associated image processors).
11
+ This function supports loading in 8bit/4bit precision and explicit device placement.
12
+
13
+ Args:
14
+ model_path (str): Path to the pretrained model directory.
15
+ load_8bit (bool): Whether to load the model in 8bit mode.
16
+ load_4bit (bool): Whether to load the model in 4bit mode.
17
+ device (str): Device to load model onto, e.g., "cuda" or "cpu".
18
+
19
+ Returns:
20
+ tuple: (tokenizer, model, image_processor)
21
+ """
22
+ kwargs = {"device_map": device}
23
+
24
+ # Set model loading parameters for quantization or floating point
25
+ if load_8bit:
26
+ kwargs['load_in_8bit'] = True
27
+ elif load_4bit:
28
+ kwargs['load_in_4bit'] = True
29
+ else:
30
+ kwargs['torch_dtype'] = torch.bfloat16
31
+
32
+ # print(model_path)
33
+
34
+ # Only proceed for vlm-fo1 models
35
+ if 'vlm-fo1' in model_path.lower():
36
+ # Load tokenizer (slow tokenizer enforced)
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
38
+ # If this is the Qwen2.5-VL variant, load with additional kwargs
39
+ if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
40
+ model, loading_info = OmChatQwen25VLForCausalLM.from_pretrained(
41
+ model_path,
42
+ low_cpu_mem_usage=True,
43
+ output_loading_info=True,
44
+ attn_implementation="flash_attention_2",
45
+ **kwargs
46
+ )
47
+ # print(f'OmChatQwen25VLForCausalLM loading_info: {loading_info}')
48
+ # (For other variants of vlm-fo1, model loading detail may need additional condition.)
49
+
50
+ if 'vlm-fo1' in model_path.lower():
51
+ # --- Vision Tower Loading ---
52
+ # Load the main vision tower weights from model_path if it is not yet loaded
53
+ primary_vision_tower = model.get_vision_tower()
54
+ if primary_vision_tower and not primary_vision_tower.is_loaded:
55
+ primary_vision_tower.load_model(model_path=model_path, is_train=False)
56
+ primary_vision_tower.to(device=device, dtype=torch.bfloat16) # Move to correct device/dtype
57
+
58
+ # Grab primary image processor from vision tower, if present
59
+ if primary_vision_tower:
60
+ primary_image_processor = primary_vision_tower.image_processor
61
+
62
+ # --- Auxiliary Vision Tower Handling (Qwen2.5-VL case only) ---
63
+ if 'qwen2.5-vl' in model_path.lower() or 'qwen2_5_vl' in model_path.lower():
64
+ try:
65
+ aux_image_size = model.config.aux_image_size
66
+ except Exception:
67
+ # If aux_image_size is missing from config fallback to 768
68
+ aux_image_size = 768
69
+
70
+ aux_image_aspect_ratio = model.config.aux_image_aspect_ratio
71
+ aux_vision_tower = model.get_vision_tower_aux()
72
+ # Only load if not already loaded
73
+ if aux_vision_tower and not aux_vision_tower.is_loaded:
74
+ aux_vision_tower.load_model(image_size=aux_image_size, is_train=False, aspect_ratio=aux_image_aspect_ratio)
75
+ aux_vision_tower.to(device=device, dtype=torch.bfloat16)
76
+
77
+ # Get auxiliary image processor if there is an aux vision tower
78
+ if aux_vision_tower:
79
+ aux_image_processor = aux_vision_tower.image_processor
80
+ else:
81
+ image_processor = None # Set to None if there is no auxiliary vision tower
82
+
83
+ # image_processor returned as a tuple of (primary, aux)
84
+ image_processor = (primary_image_processor, aux_image_processor)
85
+
86
+ # Set model to eval mode and move to correct device before returning
87
+ model.eval()
88
+ model.to(device=device, dtype=torch.bfloat16)
89
+ return tokenizer, model, image_processor
vlm_fo1/model/language_model/omchat_qwen2_5_vl.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import Qwen2_5_VLConfig, AutoConfig, AutoModelForCausalLM
7
+ from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast
8
+ from vlm_fo1.model.multimodal_encoder.qwen2_5_vl_encoder import Qwen2_5_VlVisionTower
9
+ from vlm_fo1.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_REGION_INDEX, QWEN2_5_VL_IMAGE_TOKEN, QWEN2_5_VL_IMAGE_TOKEN_INDEX
10
+
11
+ from ..omchat_arch import OmChatMetaModel, OmChatMetaForCausalLM
12
+
13
+ # Custom config which extends Qwen2_5_VLConfig for OmChat multimodal model
14
+ class OmChatQwen25VLConfig(Qwen2_5_VLConfig):
15
+ model_type = "omchat_qwen2_5_vl"
16
+ rotary_type = "normal_rotary"
17
+ multi_scale_im = None
18
+ vision_tower_aux = None
19
+
20
+ # Core model definition: inherits from OmChat and Qwen multimodal base
21
+ class OmChatQwen25VLModel(OmChatMetaModel, Qwen2_5_VLModel):
22
+ config_class = OmChatQwen25VLConfig
23
+
24
+ def __init__(self, config: Qwen2_5_VLConfig):
25
+ super(OmChatQwen25VLModel, self).__init__(config)
26
+
27
+ # Main class for multimodal CausalLM
28
+ class OmChatQwen25VLForCausalLM(Qwen2_5_VLForConditionalGeneration, OmChatMetaForCausalLM):
29
+ config_class = OmChatQwen25VLConfig
30
+
31
+ def __init__(self, config, delay_load=True):
32
+ # Ensure config has delay_load property
33
+ if not hasattr(config, 'delay_load'):
34
+ config.delay_load = delay_load
35
+ super(Qwen2_5_VLForConditionalGeneration, self).__init__(config)
36
+ self.model = OmChatQwen25VLModel(config)
37
+ self.vocab_size = config.vocab_size
38
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
39
+ self.rope_deltas = None # cache rope_deltas here
40
+
41
+ self.post_init()
42
+
43
+ # Encode input images into feature representations
44
+ def encode_images(self, images, images_grid_thw=None):
45
+ # If vision_tower is Qwen2.5-specific, use its custom forward signature
46
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
47
+ image_features = self.get_model().get_vision_tower()(images, images_grid_thw)
48
+ image_features, image_grid_thws, multi_level_features = image_features
49
+ # If multiple images, handle concatenation
50
+ if type(image_features) is list:
51
+ # List has items of shape (1, seq_len, dim)
52
+ token_length_list = [i.shape[1] for i in image_features]
53
+ image_features = torch.cat(image_features, dim=1) # Concatenate to (1, total_seq_len, dim)
54
+ else:
55
+ image_features = self.get_model().get_vision_tower()(images)
56
+ image_grid_thws = None
57
+ multi_level_features = None
58
+
59
+ image_features = self.get_model().mm_projector(image_features)
60
+
61
+ # Split concatenated image features back by original lengths (for multi-image case)
62
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
63
+ start = 0
64
+ new_image_features = []
65
+ # Split according to token_length_list
66
+ for length in token_length_list:
67
+ end = start + length
68
+ new_image_features.append(image_features[:, start:end, :].squeeze(0))
69
+ start = end
70
+ image_features = new_image_features
71
+
72
+ return image_features, image_grid_thws, multi_level_features
73
+
74
+ # Encode region regions (bounding boxes) into features, optionally using auxiliary vision tower
75
+ def encode_regions(self, images, bbox_list, vt_multi_level_features=None, vt_images_size=None):
76
+ aux_image_features_list = self.get_model().get_vision_tower_aux()(images)
77
+ region_features = []
78
+ if getattr(self.config, "mm_use_vision_tower_region_feature", False):
79
+ image_features_list = vt_multi_level_features
80
+ for batch_idx, (image_features, aux_image_features) in enumerate(zip(image_features_list, aux_image_features_list)):
81
+
82
+ if getattr(self.config, "mm_use_simpleFPN_for_vt", False):
83
+ multilevel_visual_feats = image_features[-1]
84
+ else:
85
+ multilevel_visual_feats = image_features
86
+ multilevel_aux_visual_feats = aux_image_features["image_features"]
87
+ boxes = bbox_list[batch_idx]
88
+
89
+ # If no boxes provided, use dummy box (covers tiny region)
90
+ if boxes is None or len(boxes) == 0:
91
+ boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_aux_visual_feats[0].device, dtype=torch.float32)
92
+
93
+ boxes = boxes.to(torch.float32).to(multilevel_aux_visual_feats[0].device)
94
+ current_image_height, current_image_width = images[batch_idx].shape[-2:]
95
+ original_height, original_width = vt_images_size[batch_idx]
96
+ # Scale bounding boxes from original image size to processed size
97
+ scale_height = original_height / current_image_height
98
+ scale_width = original_width / current_image_width
99
+ vt_boxes = boxes * torch.tensor([scale_width, scale_height, scale_width, scale_height], device=boxes.device)
100
+
101
+ extracted_region_feat = self.get_model().object_vp_extractor(
102
+ aux_multi_level_features=multilevel_aux_visual_feats,
103
+ vt_multi_level_features=multilevel_visual_feats,
104
+ aux_boxes=[boxes],
105
+ vt_boxes=[vt_boxes]
106
+ ).squeeze(0).to(multilevel_aux_visual_feats[0].dtype)
107
+ region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2048]
108
+ region_features.append(region_feat)
109
+ else:
110
+ # Extract region features only from auxiliary vision tower
111
+ for batch_idx, image_features in enumerate(aux_image_features_list):
112
+ multilevel_visual_feats = image_features["image_features"]
113
+ last_feat = image_features["last_feat"]
114
+ boxes = bbox_list[batch_idx]
115
+
116
+ if boxes is None or len(boxes) == 0:
117
+ boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_visual_feats[0].device, dtype=torch.float32)
118
+
119
+ multi_level_aux_features = multilevel_visual_feats
120
+ boxes = boxes.to(torch.float32).to(multi_level_aux_features[0].device)
121
+ extracted_region_feat = self.get_model().object_vp_extractor(
122
+ multi_level_aux_features,
123
+ [boxes],
124
+ ).squeeze(0).to(multi_level_aux_features[0].dtype)
125
+ region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2880]
126
+ region_features.append(region_feat)
127
+
128
+ return region_features
129
+
130
+ def get_model(self):
131
+ # Getter for model. Used to access backbone/model internals.
132
+ return self.model
133
+
134
+ # Convert sequence of input_ids/labels/images/boxes to multimodal embedding and associated masks/ids for transformer input.
135
+ def prepare_inputs_labels_for_qwen2_5_vl_multimodal(
136
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images, images_aux=None, bbox_list=None, image_grid_thws=None
137
+ ):
138
+ # ========================== Above this line, input parsing and batching =============================
139
+ vision_tower = self.get_vision_tower()
140
+ video_tower = self.get_video_tower()
141
+ vision_tower_aux = self.get_vision_tower_aux()
142
+ # Fast-path for non-multimodal case or first step in generation (i.e. only one token in input)
143
+ if (vision_tower is None and video_tower is None) or images is None or input_ids.shape[1] == 1:
144
+ if past_key_values is not None and (vision_tower is not None or video_tower is not None) and images is not None and input_ids.shape[1] == 1:
145
+
146
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
147
+ attention_mask = torch.cat((attention_mask, torch.ones(
148
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
149
+ dtype=attention_mask.dtype,
150
+ device=attention_mask.device
151
+ )), dim=1)
152
+
153
+ position_ids=None
154
+ cache_position = torch.tensor([target_shape - 1],device=attention_mask.device)
155
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, cache_position
156
+
157
+ # Indices for images (3D or 2D tensors) and videos (4D tensors)
158
+ image_idx = [idx for idx, img in enumerate(images) if img.ndim == 3 or img.ndim == 2]
159
+ is_all_image = len(image_idx) == len(images)
160
+ video_idx = [idx for idx, vid in enumerate(images) if vid.ndim == 4]
161
+
162
+ # Stack image and video tensors accordingly for mini-batch processing
163
+ if isinstance(vision_tower, Qwen2_5_VlVisionTower):
164
+ images_minibatch = [images[idx] for idx in image_idx] if len(image_idx) > 0 else [] # list of [c,h,w], can have variable shapes
165
+ else:
166
+ images_minibatch = torch.stack([images[idx] for idx in image_idx]) if len(image_idx) > 0 else [] # tensor [mini_b, c, h, w]
167
+ videos_minibatch = torch.stack([images[idx] for idx in video_idx]) if len(video_idx) > 0 else [] # tensor [mini_b, c, t, h, w]
168
+
169
+ # Auxiliary batch for region encoding, if relevant
170
+ if vision_tower_aux is not None and images_aux is not None:
171
+ images_minibatch_aux = [images_aux[idx].unsqueeze(0) for idx in image_idx] if len(image_idx) > 0 else [] # list of [1, c, h, w]
172
+
173
+ # tmp_image_features will be indexed to scatter extracted image/video features into original batch positions
174
+ tmp_image_features = [None] * (len(image_idx) + len(video_idx))
175
+ if getattr(images_minibatch, 'ndim', 0) == 4 or (type(images_minibatch) is list and len(images_minibatch) > 0): # batch consists of images, [mini_b, c, h, w]
176
+ if vision_tower is not None:
177
+ image_features_minibatch, image_grid_thws_minibatch, vt_multi_level_features_minibatch = self.encode_images(images_minibatch, image_grid_thws) # [mini_b, l, c]
178
+ else:
179
+ image_features_minibatch = torch.randn(1).to(self.device) # dummy feature for video-only training under tuning
180
+
181
+ # Map extracted image features back to their places in the original batch
182
+ for i, pos in enumerate(image_idx):
183
+ tmp_image_features[pos] = image_features_minibatch[i]
184
+
185
+ # Handle auxiliary region features if enabled and boxes provided
186
+ if vision_tower_aux is not None and bbox_list is not None and len(bbox_list) > 0:
187
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
188
+ patch_size = self.get_model().get_vision_tower().config.patch_size
189
+ vt_images_size_minibatch = [im_grid_thw[0][-2:]*patch_size for im_grid_thw in image_grid_thws]
190
+ region_features = self.encode_regions(images_minibatch_aux, bbox_list, vt_multi_level_features_minibatch, vt_images_size_minibatch) # [mini_b, l, c]
191
+ else:
192
+ region_features = None
193
+
194
+ # Same as above, but for video features if any
195
+ if getattr(videos_minibatch, 'ndim', 0) == 5: # batch consists of videos, [mini_b, c, t, h, w]
196
+ video_features_minibatch = self.encode_videos(videos_minibatch) # fake list [mini_b, t, l, c]
197
+ for i, pos in enumerate(video_idx):
198
+ tmp_image_features[pos] = video_features_minibatch[i]
199
+
200
+ # Flatten image feature slot list to proper order for current batch
201
+ new_tmp = []
202
+ for image in tmp_image_features:
203
+ # If multi-image per item, flatten out
204
+ if isinstance(image, list):
205
+ t = len(image)
206
+ for i in range(t):
207
+ new_tmp.append(image[i])
208
+ else:
209
+ new_tmp.append(image)
210
+ image_features = new_tmp
211
+
212
+ # =========================== Now, build multimodal input & target sequences =========================
213
+
214
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
215
+ raise NotImplementedError
216
+
217
+ _labels = labels
218
+ _position_ids = position_ids
219
+ _attention_mask = attention_mask
220
+
221
+ # Default construction of masks etc.
222
+ if attention_mask is None:
223
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
224
+ else:
225
+ attention_mask = attention_mask.bool()
226
+ if position_ids is None:
227
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
228
+ if labels is None:
229
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
230
+
231
+ # For each batch item, strip padded tokens based on attention_mask
232
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
233
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
234
+
235
+ # If neither region auxiliary nor bboxes present: process classic image-text input
236
+ if vision_tower_aux is None and (bbox_list is None or all(x is None for x in bbox_list)):
237
+ new_input_embeds = []
238
+ new_labels = []
239
+ new_input_ids = []
240
+ cur_image_idx = 0
241
+ image_nums_in_batch = []
242
+
243
+ for batch_idx, cur_input_ids in enumerate(input_ids):
244
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
245
+ image_nums_in_batch.append(num_images)
246
+ # If there are no image markers, just get text features
247
+ if num_images == 0:
248
+ cur_image_features = image_features[cur_image_idx]
249
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
250
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
251
+ new_input_embeds.append(cur_input_embeds)
252
+ new_labels.append(labels[batch_idx])
253
+ new_input_ids.append(cur_input_ids)
254
+ cur_image_idx += 1
255
+ continue
256
+
257
+ # Split on image token indices: replace them with image features after conversion
258
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
259
+ cur_input_ids_noim = []
260
+ cur_labels = labels[batch_idx]
261
+ cur_labels_noim = []
262
+ for i in range(len(image_token_indices) - 1):
263
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
264
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
265
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
266
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
267
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
268
+
269
+ cur_new_input_embeds = []
270
+ cur_new_labels = []
271
+ cur_new_input_ids = []
272
+ for i in range(num_images + 1):
273
+ # Interleave text and image features
274
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
275
+ cur_new_labels.append(cur_labels_noim[i])
276
+ cur_new_input_ids.append(cur_input_ids_noim[i])
277
+ if i < num_images:
278
+ cur_image_features = image_features[cur_image_idx].to(self.device)
279
+ cur_image_idx += 1
280
+ cur_new_input_embeds.append(cur_image_features)
281
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
282
+ cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
283
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
284
+ cur_new_labels = torch.cat(cur_new_labels)
285
+ cur_new_input_ids = torch.cat(cur_new_input_ids)
286
+
287
+ new_input_embeds.append(cur_new_input_embeds)
288
+ new_labels.append(cur_new_labels)
289
+ new_input_ids.append(cur_new_input_ids)
290
+ # If region markers or region features enabled in config
291
+ else:
292
+ new_input_embeds = []
293
+ new_labels = []
294
+ new_input_ids = []
295
+ cur_image_idx = 0
296
+ image_nums_in_batch = []
297
+
298
+ for batch_idx, cur_input_ids in enumerate(input_ids):
299
+ cur_region_idx = 0
300
+ # Detect image and region special token counts
301
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
302
+ num_regions = (cur_input_ids == DEFAULT_REGION_INDEX).sum() if DEFAULT_REGION_INDEX in cur_input_ids else 0
303
+ image_nums_in_batch.append(num_images)
304
+
305
+ # If no markers, just do text embedding for this item
306
+ if num_images == 0 and num_regions == 0:
307
+ cur_image_features = image_features[cur_image_idx]
308
+ cur_region_features = region_features[cur_region_idx]
309
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
310
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_region_features[0:0]], dim=0)
311
+ new_input_embeds.append(cur_input_embeds)
312
+ new_labels.append(labels[batch_idx])
313
+ new_input_ids.append(cur_input_ids)
314
+ cur_image_idx += 1
315
+ continue
316
+
317
+ # Get all special marker indices (image/region)
318
+ image_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
319
+ region_indices = torch.where(cur_input_ids == DEFAULT_REGION_INDEX)[0].tolist() if num_regions > 0 else []
320
+ all_special_indices = sorted([-1] + image_indices + region_indices + [cur_input_ids.shape[0]])
321
+
322
+ # Split out plain text chunks between special markers
323
+ cur_input_ids_segments = []
324
+ cur_labels = labels[batch_idx]
325
+ cur_labels_segments = []
326
+
327
+ for i in range(len(all_special_indices) - 1):
328
+ cur_input_ids_segments.append(cur_input_ids[all_special_indices[i]+1:all_special_indices[i+1]])
329
+ cur_labels_segments.append(cur_labels[all_special_indices[i]+1:all_special_indices[i+1]])
330
+
331
+ # Project text ids to word embeddings
332
+ split_sizes = [x.shape[0] for x in cur_labels_segments]
333
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_segments))
334
+ if num_regions == 0 and vision_tower_aux is not None and region_features is not None:
335
+ cur_region_features = region_features[cur_region_idx]
336
+ temp_input_embeds = torch.cat([cur_input_embeds, cur_region_features[0:0]], dim=0)
337
+ cur_input_embeds = temp_input_embeds
338
+
339
+ cur_input_embeds_segments = torch.split(cur_input_embeds, split_sizes, dim=0)
340
+
341
+ # Reassemble text and image/region segments in order
342
+ cur_new_input_embeds = []
343
+ cur_new_labels = []
344
+ cur_new_input_ids = []
345
+
346
+ for i in range(len(all_special_indices) - 1):
347
+ # Insert current text segment
348
+ cur_new_input_embeds.append(cur_input_embeds_segments[i])
349
+ cur_new_labels.append(cur_labels_segments[i])
350
+ cur_new_input_ids.append(cur_input_ids_segments[i])
351
+ # If next is image, insert feature representation
352
+ if all_special_indices[i+1] in image_indices:
353
+ cur_image_features = image_features[cur_image_idx].to(self.device)
354
+ cur_image_idx += 1
355
+ cur_new_input_embeds.append(cur_image_features)
356
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
357
+ cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
358
+
359
+ # If next is region token, insert extracted region features
360
+ elif all_special_indices[i+1] in region_indices:
361
+ cur_region_features = region_features[batch_idx][cur_region_idx].to(self.device).unsqueeze(0)
362
+ cur_region_idx += 1
363
+ cur_new_input_embeds.append(cur_region_features)
364
+
365
+ cur_new_labels.append(torch.full((cur_region_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
366
+ cur_new_input_ids.append(torch.full((cur_region_features.shape[0],), DEFAULT_REGION_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
367
+ # Combine for this batch item
368
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
369
+ cur_new_labels = torch.cat(cur_new_labels)
370
+ cur_new_input_ids = torch.cat(cur_new_input_ids)
371
+ new_input_embeds.append(cur_new_input_embeds)
372
+ new_labels.append(cur_new_labels)
373
+ new_input_ids.append(cur_new_input_ids)
374
+ # Truncate sequences to maximum model length, if image+region tokens caused overflow
375
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
376
+ if tokenizer_model_max_length is not None:
377
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
378
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
379
+
380
+ # Pad sequences in the batch to same length; compute batch masks
381
+ max_len = max(x.shape[0] for x in new_input_embeds)
382
+ batch_size = len(new_input_embeds)
383
+
384
+ new_input_embeds_padded = []
385
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
386
+ new_input_ids_padded = torch.full((batch_size, max_len), self.config.bos_token_id, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device)
387
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
388
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
389
+
390
+ # Left or right padding as per config; fill padded tensors
391
+ for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)):
392
+ cur_len = cur_new_embed.shape[0]
393
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
394
+ # Left pad: add zeros before text tokens/features
395
+ new_input_embeds_padded.append(torch.cat((
396
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
397
+ cur_new_embed
398
+ ), dim=0))
399
+ if cur_len > 0:
400
+ new_labels_padded[i, -cur_len:] = cur_new_labels
401
+ attention_mask[i, -cur_len:] = True
402
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
403
+ else:
404
+ # Right pad: add zeros after text tokens/features
405
+ new_input_embeds_padded.append(torch.cat((
406
+ cur_new_embed,
407
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
408
+ ), dim=0))
409
+ if cur_len > 0:
410
+ new_labels_padded[i, :cur_len] = cur_new_labels
411
+ new_input_ids_padded[i, :cur_len] = cur_new_input_ids
412
+ attention_mask[i, :cur_len] = True
413
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
414
+
415
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
416
+ new_input_ids = new_input_ids_padded
417
+
418
+ # Only set new_labels if original labels were not None
419
+ if _labels is None:
420
+ new_labels = None
421
+ else:
422
+ new_labels = new_labels_padded
423
+
424
+ # Similarly handle provided attention_mask/position_ids overrides
425
+ if _attention_mask is None:
426
+ attention_mask = None
427
+ else:
428
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
429
+
430
+ if _position_ids is None:
431
+ position_ids = None
432
+
433
+ # For Qwen2.5 vision towers, use and concatenate image_grid_thws for positional computations
434
+ if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
435
+ image_grid_thws = []
436
+ cur_image_idx = 0
437
+ for num_images in image_nums_in_batch:
438
+ if num_images == 0:
439
+ cur_image_idx += 1
440
+ continue
441
+ image_grid_thws += image_grid_thws_minibatch[cur_image_idx:cur_image_idx+num_images]
442
+ cur_image_idx += num_images
443
+
444
+ if len(image_grid_thws) > 0:
445
+ image_grid_thws = torch.cat(image_grid_thws, dim=0)
446
+ else:
447
+ image_grid_thws = None
448
+
449
+ rope_index_kwargs = {
450
+ "input_ids": new_input_ids,
451
+ "image_grid_thw": image_grid_thws,
452
+ "video_grid_thw": None,
453
+ "attention_mask": attention_mask,
454
+ }
455
+
456
+ # Compute new position_ids and rope_deltas for transformer (for rotary embeddings)
457
+ position_ids, rope_deltas = self.get_rope_index(**rope_index_kwargs)
458
+ cache_position = torch.arange(new_input_embeds.shape[1], device=new_input_embeds.device)
459
+ else:
460
+ rope_deltas = None
461
+ cache_position = None
462
+ # Final output is a tuple mimicking HuggingFace prepare_inputs_for_generation return
463
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, rope_deltas, cache_position
464
+
465
+ # Patch forward() of HF CausalLM to allow multimodal embedding with images/regions
466
+ def forward(
467
+ self,
468
+ input_ids: torch.LongTensor = None,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ position_ids: Optional[torch.LongTensor] = None,
471
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
472
+ inputs_embeds: Optional[torch.FloatTensor] = None,
473
+ labels: Optional[torch.LongTensor] = None,
474
+ use_cache: Optional[bool] = None,
475
+ output_attentions: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ return_dict: Optional[bool] = None,
478
+ pixel_values: Optional[torch.Tensor] = None,
479
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
480
+ image_grid_thw: Optional[torch.LongTensor] = None,
481
+ video_grid_thw: Optional[torch.LongTensor] = None,
482
+ rope_deltas: Optional[torch.LongTensor] = None,
483
+ cache_position: Optional[torch.LongTensor] = None,
484
+ second_per_grid_ts: Optional[torch.Tensor] = None,
485
+ images: Optional[torch.FloatTensor] = None,
486
+ images_aux: Optional[torch.FloatTensor] = None,
487
+ bbox_list: Optional[torch.FloatTensor] = None,
488
+ image_grid_thws: Optional[torch.FloatTensor] = None,
489
+ ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
490
+
491
+ if inputs_embeds is None:
492
+ (
493
+ input_ids,
494
+ position_ids,
495
+ attention_mask,
496
+ past_key_values,
497
+ inputs_embeds,
498
+ labels,
499
+ rope_deltas,
500
+ cache_position
501
+ ) = self.prepare_inputs_labels_for_qwen2_5_vl_multimodal(
502
+ input_ids,
503
+ position_ids,
504
+ attention_mask,
505
+ past_key_values,
506
+ labels,
507
+ images,
508
+ images_aux,
509
+ bbox_list,
510
+ image_grid_thws
511
+ )
512
+
513
+ if rope_deltas is not None:
514
+ self.rope_deltas = rope_deltas
515
+
516
+ # Call base CausalLM forward, with possibly replaced multimodal embeddings
517
+ out = super().forward(
518
+ input_ids=input_ids,
519
+ attention_mask=attention_mask,
520
+ position_ids=position_ids,
521
+ past_key_values=past_key_values,
522
+ inputs_embeds=inputs_embeds,
523
+ labels=labels,
524
+ use_cache=use_cache,
525
+ output_attentions=output_attentions,
526
+ output_hidden_states=output_hidden_states,
527
+ rope_deltas=rope_deltas,
528
+ cache_position=cache_position,
529
+ second_per_grid_ts=second_per_grid_ts,
530
+ return_dict=return_dict
531
+ )
532
+ return out
533
+
534
+ # Prepare model input dict for autoregressive generation (for use with generation methods like generate())
535
+ def prepare_inputs_for_generation(
536
+ self,
537
+ input_ids,
538
+ past_key_values=None,
539
+ attention_mask=None,
540
+ inputs_embeds=None,
541
+ cache_position=None,
542
+ position_ids=None,
543
+ use_cache=True,
544
+ pixel_values=None,
545
+ pixel_values_videos=None,
546
+ image_grid_thw=None,
547
+ video_grid_thw=None,
548
+ second_per_grid_ts=None,
549
+ images: Optional[torch.FloatTensor] = None,
550
+ images_aux: Optional[torch.FloatTensor] = None,
551
+ bbox_list: Optional[torch.FloatTensor] = None,
552
+ image_grid_thws: Optional[torch.FloatTensor] = None,
553
+ **kwargs,
554
+ ):
555
+ # Wrap parent logic so extra multimodal kwargs are preserved
556
+ model_inputs = super().prepare_inputs_for_generation(
557
+ input_ids,
558
+ past_key_values=past_key_values,
559
+ attention_mask=attention_mask,
560
+ inputs_embeds=inputs_embeds,
561
+ cache_position=cache_position,
562
+ pixel_values=pixel_values,
563
+ pixel_values_videos=pixel_values_videos,
564
+ image_grid_thw=image_grid_thw,
565
+ video_grid_thw=video_grid_thw,
566
+ second_per_grid_ts=second_per_grid_ts,
567
+ images=images,
568
+ images_aux=images_aux,
569
+ bbox_list=bbox_list,
570
+ image_grid_thws=image_grid_thws,
571
+ )
572
+ return model_inputs
573
+
574
+ # Register our config and model with HuggingFace transformers registry
575
+ AutoConfig.register("omchat_qwen2_5_vl", OmChatQwen25VLConfig)
576
+ AutoModelForCausalLM.register(OmChatQwen25VLConfig, OmChatQwen25VLForCausalLM)
vlm_fo1/model/multimodal_encoder/__init__.py ADDED
File without changes
vlm_fo1/model/multimodal_encoder/base_encoder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class AbsVisionTower(nn.Module):
6
+ @torch.no_grad()
7
+ def forward(self, images):
8
+ raise NotImplementedError
9
+
10
+ @property
11
+ def dummy_feature(self):
12
+ raise NotImplementedError
13
+
14
+ @property
15
+ def dtype(self):
16
+ raise NotImplementedError
17
+
18
+ @property
19
+ def device(self):
20
+ raise NotImplementedError
21
+
22
+ @property
23
+ def config(self):
24
+ raise NotImplementedError
25
+
26
+
27
+ @property
28
+ def hidden_size(self):
29
+ raise NotImplementedError
30
+
31
+ @property
32
+ def num_patches(self):
33
+ raise NotImplementedError
vlm_fo1/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Builders for different vision tower backbones (MM encoder visual modules)
2
+ from .qwen2_5_vl_encoder import Qwen2_5_VlVisionTower # Main Qwen2.5 vision tower
3
+ from .davit_aux_encoder import DavitVisionTower as DavitVisionTowerAux # Auxiliary DaViT vision tower
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ """
7
+ Use model config to construct the main vision tower.
8
+
9
+ vision_tower_cfg: should have attribute mm_vision_tower
10
+ Returns: instance of configured vision backbone
11
+ """
12
+ vision_tower_name = getattr(vision_tower_cfg, 'mm_vision_tower', None)
13
+ # print(vision_tower_cfg) # Debug print of the config being used
14
+
15
+ # Check for the Qwen2.5-VL vision model in tower name
16
+ if "qwen2.5-vl" in vision_tower_name.lower():
17
+ return Qwen2_5_VlVisionTower(vision_tower_name, args=vision_tower_cfg, **kwargs)
18
+
19
+ # Raise a clear error for unknown towers
20
+ raise ValueError(f'Unknown vision tower: {vision_tower_name}')
21
+
22
+ def build_vision_tower_aux(vision_tower_cfg, **kwargs):
23
+ """
24
+ Use model config to construct the auxiliary (helper) vision tower.
25
+
26
+ vision_tower_cfg: should have attribute mm_vision_tower_aux
27
+ Returns: instance of configured auxiliary vision backbone
28
+ """
29
+ vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', None)
30
+ # Optionally print config for debugging
31
+ # print(vision_tower_cfg)
32
+
33
+ # Check for the DaViT auxiliary vision model in tower name
34
+ if 'davit' in vision_tower_aux.lower():
35
+ return DavitVisionTowerAux(vision_tower_aux, args=vision_tower_cfg, **kwargs)
36
+
37
+ # Raise a clear error if tower type is unknown
38
+ raise ValueError(f'Unknown aux vision tower: {vision_tower_aux}')